aboutsummaryrefslogtreecommitdiff
path: root/mediaapi/fileutils/fileutils.go
diff options
context:
space:
mode:
Diffstat (limited to 'mediaapi/fileutils/fileutils.go')
-rw-r--r--mediaapi/fileutils/fileutils.go191
1 files changed, 191 insertions, 0 deletions
diff --git a/mediaapi/fileutils/fileutils.go b/mediaapi/fileutils/fileutils.go
new file mode 100644
index 00000000..36b2c5b8
--- /dev/null
+++ b/mediaapi/fileutils/fileutils.go
@@ -0,0 +1,191 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fileutils
+
+import (
+ "bufio"
+ "crypto/sha256"
+ "encoding/base64"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/matrix-org/dendrite/common/config"
+ "github.com/matrix-org/dendrite/mediaapi/types"
+ log "github.com/sirupsen/logrus"
+)
+
+// GetPathFromBase64Hash evaluates the path to a media file from its Base64Hash
+// 3 subdirectories are created for more manageable browsing and use the remainder as the file name.
+// For example, if Base64Hash is 'qwerty', the path will be 'q/w/erty/file'.
+func GetPathFromBase64Hash(base64Hash types.Base64Hash, absBasePath config.Path) (string, error) {
+ if len(base64Hash) < 3 {
+ return "", fmt.Errorf("Invalid filePath (Base64Hash too short - min 3 characters): %q", base64Hash)
+ }
+ if len(base64Hash) > 255 {
+ return "", fmt.Errorf("Invalid filePath (Base64Hash too long - max 255 characters): %q", base64Hash)
+ }
+
+ filePath, err := filepath.Abs(filepath.Join(
+ string(absBasePath),
+ string(base64Hash[0:1]),
+ string(base64Hash[1:2]),
+ string(base64Hash[2:]),
+ "file",
+ ))
+ if err != nil {
+ return "", fmt.Errorf("Unable to construct filePath: %q", err)
+ }
+
+ // check if the absolute absBasePath is a prefix of the absolute filePath
+ // if so, no directory escape has occurred and the filePath is valid
+ // Note: absBasePath is already absolute
+ if !strings.HasPrefix(filePath, string(absBasePath)) {
+ return "", fmt.Errorf("Invalid filePath (not within absBasePath %v): %v", absBasePath, filePath)
+ }
+
+ return filePath, nil
+}
+
+// MoveFileWithHashCheck checks for hash collisions when moving a temporary file to its final path based on metadata
+// The final path is based on the hash of the file.
+// If the final path exists and the file size matches, the file does not need to be moved.
+// In error cases where the file is not a duplicate, the caller may decide to remove the final path.
+// Returns the final path of the file, whether it is a duplicate and an error.
+func MoveFileWithHashCheck(tmpDir types.Path, mediaMetadata *types.MediaMetadata, absBasePath config.Path, logger *log.Entry) (types.Path, bool, error) {
+ // Note: in all error and success cases, we need to remove the temporary directory
+ defer RemoveDir(tmpDir, logger)
+ duplicate := false
+ finalPath, err := GetPathFromBase64Hash(mediaMetadata.Base64Hash, absBasePath)
+ if err != nil {
+ return "", duplicate, fmt.Errorf("failed to get file path from metadata: %q", err)
+ }
+
+ var stat os.FileInfo
+ // Note: The double-negative is intentional as os.IsExist(err) != !os.IsNotExist(err).
+ // The functions are error checkers to be used in different cases.
+ if stat, err = os.Stat(finalPath); !os.IsNotExist(err) {
+ duplicate = true
+ if stat.Size() == int64(mediaMetadata.FileSizeBytes) {
+ return types.Path(finalPath), duplicate, nil
+ }
+ return "", duplicate, fmt.Errorf("downloaded file with hash collision but different file size (%v)", finalPath)
+ }
+ err = moveFile(
+ types.Path(filepath.Join(string(tmpDir), "content")),
+ types.Path(finalPath),
+ )
+ if err != nil {
+ return "", duplicate, fmt.Errorf("failed to move file to final destination (%v): %q", finalPath, err)
+ }
+ return types.Path(finalPath), duplicate, nil
+}
+
+// RemoveDir removes a directory and logs a warning in case of errors
+func RemoveDir(dir types.Path, logger *log.Entry) {
+ dirErr := os.RemoveAll(string(dir))
+ if dirErr != nil {
+ logger.WithError(dirErr).WithField("dir", dir).Warn("Failed to remove directory")
+ }
+}
+
+// WriteTempFile writes to a new temporary file
+func WriteTempFile(reqReader io.Reader, maxFileSizeBytes config.FileSizeBytes, absBasePath config.Path) (hash types.Base64Hash, size types.FileSizeBytes, path types.Path, err error) {
+ size = -1
+
+ tmpFileWriter, tmpFile, tmpDir, err := createTempFileWriter(absBasePath)
+ if err != nil {
+ return
+ }
+ defer (func() { err = tmpFile.Close() })()
+
+ // The amount of data read is limited to maxFileSizeBytes. At this point, if there is more data it will be truncated.
+ limitedReader := io.LimitReader(reqReader, int64(maxFileSizeBytes))
+ // Hash the file data. The hash will be returned. The hash is useful as a
+ // method of deduplicating files to save storage, as well as a way to conduct
+ // integrity checks on the file data in the repository.
+ hasher := sha256.New()
+ teeReader := io.TeeReader(limitedReader, hasher)
+ bytesWritten, err := io.Copy(tmpFileWriter, teeReader)
+ if err != nil && err != io.EOF {
+ return
+ }
+
+ err = tmpFileWriter.Flush()
+ if err != nil {
+ return
+ }
+
+ hash = types.Base64Hash(base64.RawURLEncoding.EncodeToString(hasher.Sum(nil)[:]))
+ size = types.FileSizeBytes(bytesWritten)
+ path = tmpDir
+ return
+}
+
+// moveFile attempts to move the file src to dst
+func moveFile(src types.Path, dst types.Path) error {
+ dstDir := filepath.Dir(string(dst))
+
+ err := os.MkdirAll(dstDir, 0770)
+ if err != nil {
+ return fmt.Errorf("Failed to make directory: %q", err)
+ }
+ err = os.Rename(string(src), string(dst))
+ if err != nil {
+ return fmt.Errorf("Failed to move directory: %q", err)
+ }
+ return nil
+}
+
+func createTempFileWriter(absBasePath config.Path) (*bufio.Writer, *os.File, types.Path, error) {
+ tmpDir, err := createTempDir(absBasePath)
+ if err != nil {
+ return nil, nil, "", fmt.Errorf("Failed to create temp dir: %q", err)
+ }
+ writer, tmpFile, err := createFileWriter(tmpDir)
+ if err != nil {
+ return nil, nil, "", fmt.Errorf("Failed to create file writer: %q", err)
+ }
+ return writer, tmpFile, tmpDir, nil
+}
+
+// createTempDir creates a tmp/<random string> directory within baseDirectory and returns its path
+func createTempDir(baseDirectory config.Path) (types.Path, error) {
+ baseTmpDir := filepath.Join(string(baseDirectory), "tmp")
+ if err := os.MkdirAll(baseTmpDir, 0770); err != nil {
+ return "", fmt.Errorf("Failed to create base temp dir: %v", err)
+ }
+ tmpDir, err := ioutil.TempDir(baseTmpDir, "")
+ if err != nil {
+ return "", fmt.Errorf("Failed to create temp dir: %v", err)
+ }
+ return types.Path(tmpDir), nil
+}
+
+// createFileWriter creates a buffered file writer with a new file
+// The caller should flush the writer before closing the file.
+// Returns the file handle as it needs to be closed when writing is complete
+func createFileWriter(directory types.Path) (*bufio.Writer, *os.File, error) {
+ filePath := filepath.Join(string(directory), "content")
+ file, err := os.Create(filePath)
+ if err != nil {
+ return nil, nil, fmt.Errorf("Failed to create file: %v", err)
+ }
+
+ return bufio.NewWriter(file), file, nil
+}