aboutsummaryrefslogtreecommitdiff
path: root/mediaapi/routing/download.go
diff options
context:
space:
mode:
Diffstat (limited to 'mediaapi/routing/download.go')
-rw-r--r--mediaapi/routing/download.go164
1 files changed, 146 insertions, 18 deletions
diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go
index fa1c417a..c812b9d6 100644
--- a/mediaapi/routing/download.go
+++ b/mediaapi/routing/download.go
@@ -21,7 +21,9 @@ import (
"io"
"io/fs"
"mime"
+ "mime/multipart"
"net/http"
+ "net/textproto"
"net/url"
"os"
"path/filepath"
@@ -31,6 +33,7 @@ import (
"sync"
"unicode"
+ "github.com/google/uuid"
"github.com/matrix-org/dendrite/mediaapi/fileutils"
"github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/mediaapi/thumbnailer"
@@ -61,6 +64,9 @@ type downloadRequest struct {
ThumbnailSize types.ThumbnailSize
Logger *log.Entry
DownloadFilename string
+ multipartResponse bool // whether we need to return a multipart/mixed response (for requests coming in over federation)
+ fedClient fclient.FederationClient
+ origin spec.ServerName
}
// Taken from: https://github.com/matrix-org/synapse/blob/c3627d0f99ed5a23479305dc2bd0e71ca25ce2b1/synapse/media/_base.py#L53C1-L84
@@ -111,11 +117,17 @@ func Download(
cfg *config.MediaAPI,
db storage.Database,
client *fclient.Client,
+ fedClient fclient.FederationClient,
activeRemoteRequests *types.ActiveRemoteRequests,
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
isThumbnailRequest bool,
customFilename string,
+ federationRequest bool,
) {
+ // This happens if we call Download for a federation request
+ if federationRequest && origin == "" {
+ origin = cfg.Matrix.ServerName
+ }
dReq := &downloadRequest{
MediaMetadata: &types.MediaMetadata{
MediaID: mediaID,
@@ -126,7 +138,10 @@ func Download(
"Origin": origin,
"MediaID": mediaID,
}),
- DownloadFilename: customFilename,
+ DownloadFilename: customFilename,
+ multipartResponse: federationRequest,
+ origin: cfg.Matrix.ServerName,
+ fedClient: fedClient,
}
if dReq.IsThumbnailRequest {
@@ -355,7 +370,7 @@ func (r *downloadRequest) respondFromLocalFile(
}).Trace("Responding with file")
responseFile = file
responseMetadata = r.MediaMetadata
- if err := r.addDownloadFilenameToHeaders(w, responseMetadata); err != nil {
+ if err = r.addDownloadFilenameToHeaders(w, responseMetadata); err != nil {
return nil, err
}
}
@@ -367,14 +382,61 @@ func (r *downloadRequest) respondFromLocalFile(
" plugin-types application/pdf;" +
" style-src 'unsafe-inline';" +
" object-src 'self';"
- w.Header().Set("Content-Security-Policy", contentSecurityPolicy)
- if _, err := io.Copy(w, responseFile); err != nil {
- return nil, fmt.Errorf("io.Copy: %w", err)
+ if !r.multipartResponse {
+ w.Header().Set("Content-Security-Policy", contentSecurityPolicy)
+ if _, err = io.Copy(w, responseFile); err != nil {
+ return nil, fmt.Errorf("io.Copy: %w", err)
+ }
+ } else {
+ var written int64
+ written, err = multipartResponse(w, r, string(responseMetadata.ContentType), responseFile)
+ if err != nil {
+ return nil, err
+ }
+ responseMetadata.FileSizeBytes = types.FileSizeBytes(written)
}
return responseMetadata, nil
}
+func multipartResponse(w http.ResponseWriter, r *downloadRequest, contentType string, responseFile io.Reader) (int64, error) {
+ // Update the header to be multipart/mixed; boundary=$randomBoundary
+ boundary := uuid.NewString()
+ w.Header().Set("Content-Type", "multipart/mixed; boundary="+boundary)
+
+ w.Header().Del("Content-Length") // let Go handle the content length
+ mw := multipart.NewWriter(w)
+ defer func() {
+ if err := mw.Close(); err != nil {
+ r.Logger.WithError(err).Error("Failed to close multipart writer")
+ }
+ }()
+
+ if err := mw.SetBoundary(boundary); err != nil {
+ return 0, fmt.Errorf("failed to set multipart boundary: %w", err)
+ }
+
+ // JSON object part
+ jsonWriter, err := mw.CreatePart(textproto.MIMEHeader{
+ "Content-Type": {"application/json"},
+ })
+ if err != nil {
+ return 0, fmt.Errorf("failed to create json writer: %w", err)
+ }
+ if _, err = jsonWriter.Write([]byte("{}")); err != nil {
+ return 0, fmt.Errorf("failed to write to json writer: %w", err)
+ }
+
+ // media part
+ mediaWriter, err := mw.CreatePart(textproto.MIMEHeader{
+ "Content-Type": {contentType},
+ })
+ if err != nil {
+ return 0, fmt.Errorf("failed to create media writer: %w", err)
+ }
+ return io.Copy(mediaWriter, responseFile)
+}
+
func (r *downloadRequest) addDownloadFilenameToHeaders(
w http.ResponseWriter,
responseMetadata *types.MediaMetadata,
@@ -722,8 +784,7 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
return nil
}
-func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, body *io.ReadCloser, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) {
- reader := *body
+func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, reader io.ReadCloser, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) {
var contentLength int64
if contentLengthHeader != "" {
@@ -742,7 +803,7 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string,
// We successfully parsed the Content-Length, so we'll return a limited
// reader that restricts us to reading only up to this size.
- reader = io.NopCloser(io.LimitReader(*body, parsedLength))
+ reader = io.NopCloser(io.LimitReader(reader, parsedLength))
contentLength = parsedLength
} else {
// Content-Length header is missing. If we have a maximum file size
@@ -751,7 +812,7 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string,
// ultimately it will get rewritten later when the temp file is written
// to disk.
if maxFileSizeBytes > 0 {
- reader = io.NopCloser(io.LimitReader(*body, int64(maxFileSizeBytes)))
+ reader = io.NopCloser(io.LimitReader(reader, int64(maxFileSizeBytes)))
}
contentLength = 0
}
@@ -759,6 +820,11 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string,
return contentLength, reader, nil
}
+// mediaMeta contains information about a multipart media response.
+// TODO: extend once something is defined.
+type mediaMeta struct{}
+
+// nolint: gocyclo
func (r *downloadRequest) fetchRemoteFile(
ctx context.Context,
client *fclient.Client,
@@ -767,19 +833,38 @@ func (r *downloadRequest) fetchRemoteFile(
) (types.Path, bool, error) {
r.Logger.Debug("Fetching remote file")
- // create request for remote file
- resp, err := client.CreateMediaDownloadRequest(ctx, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
+ // Attempt to download via authenticated media endpoint
+ isAuthed := true
+ resp, err := r.fedClient.DownloadMedia(ctx, r.origin, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
if err != nil || (resp != nil && resp.StatusCode != http.StatusOK) {
- if resp != nil && resp.StatusCode == http.StatusNotFound {
- return "", false, fmt.Errorf("File with media ID %q does not exist on %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
+ isAuthed = false
+ // try again on the unauthed endpoint
+ // create request for remote file
+ resp, err = client.CreateMediaDownloadRequest(ctx, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
+ if err != nil || (resp != nil && resp.StatusCode != http.StatusOK) {
+ if resp != nil && resp.StatusCode == http.StatusNotFound {
+ return "", false, fmt.Errorf("File with media ID %q does not exist on %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
+ }
+ return "", false, fmt.Errorf("file with media ID %q could not be downloaded from %s: %w", r.MediaMetadata.MediaID, r.MediaMetadata.Origin, err)
}
- return "", false, fmt.Errorf("file with media ID %q could not be downloaded from %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
}
defer resp.Body.Close() // nolint: errcheck
- // The reader returned here will be limited either by the Content-Length
- // and/or the configured maximum media size.
- contentLength, reader, parseErr := r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), &resp.Body, maxFileSizeBytes)
+ // If this wasn't a multipart response, set the Content-Type now. Will be overwritten
+ // by the multipart Content-Type below.
+ r.MediaMetadata.ContentType = types.ContentType(resp.Header.Get("Content-Type"))
+
+ var contentLength int64
+ var reader io.Reader
+ var parseErr error
+ if isAuthed {
+ parseErr, contentLength, reader = parseMultipartResponse(r, resp, maxFileSizeBytes)
+ } else {
+ // The reader returned here will be limited either by the Content-Length
+ // and/or the configured maximum media size.
+ contentLength, reader, parseErr = r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), resp.Body, maxFileSizeBytes)
+ }
+
if parseErr != nil {
return "", false, parseErr
}
@@ -790,7 +875,6 @@ func (r *downloadRequest) fetchRemoteFile(
}
r.MediaMetadata.FileSizeBytes = types.FileSizeBytes(contentLength)
- r.MediaMetadata.ContentType = types.ContentType(resp.Header.Get("Content-Type"))
dispositionHeader := resp.Header.Get("Content-Disposition")
if _, params, e := mime.ParseMediaType(dispositionHeader); e == nil {
@@ -844,6 +928,50 @@ func (r *downloadRequest) fetchRemoteFile(
return types.Path(finalPath), duplicate, nil
}
+func parseMultipartResponse(r *downloadRequest, resp *http.Response, maxFileSizeBytes config.FileSizeBytes) (error, int64, io.Reader) {
+ _, params, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
+ if err != nil {
+ return err, 0, nil
+ }
+ if params["boundary"] == "" {
+ return fmt.Errorf("no boundary header found on media %s from %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin), 0, nil
+ }
+ mr := multipart.NewReader(resp.Body, params["boundary"])
+
+ // Get the first, JSON, part
+ p, err := mr.NextPart()
+ if err != nil {
+ return err, 0, nil
+ }
+ defer p.Close() // nolint: errcheck
+
+ if p.Header.Get("Content-Type") != "application/json" {
+ return fmt.Errorf("first part of the response must be application/json"), 0, nil
+ }
+ // Try to parse media meta information
+ meta := mediaMeta{}
+ if err = json.NewDecoder(p).Decode(&meta); err != nil {
+ return err, 0, nil
+ }
+ defer p.Close() // nolint: errcheck
+
+ // Get the actual media content
+ p, err = mr.NextPart()
+ if err != nil {
+ return err, 0, nil
+ }
+
+ redirect := p.Header.Get("Location")
+ if redirect != "" {
+ return fmt.Errorf("Location header is not yet supported"), 0, nil
+ }
+
+ contentLength, reader, err := r.GetContentLengthAndReader(p.Header.Get("Content-Length"), p, maxFileSizeBytes)
+ // For multipart requests, we need to get the Content-Type of the second part, which is the actual media
+ r.MediaMetadata.ContentType = types.ContentType(p.Header.Get("Content-Type"))
+ return err, contentLength, reader
+}
+
// contentDispositionFor returns the Content-Disposition for a given
// content type.
func contentDispositionFor(contentType types.ContentType) string {