feat: 添加文件下载响应生成方法,支持不同存储后端的下载处理

This commit is contained in:
2025-11-14 22:58:14 +08:00
parent ed845a3d9a
commit e404764ea9
3 changed files with 97 additions and 86 deletions

View File

@@ -3,15 +3,7 @@ import os
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List from typing import Any, Dict, List
from flask import ( from flask import Blueprint, Response, abort, jsonify, redirect, render_template, request
Blueprint,
Response,
abort,
jsonify,
redirect,
render_template,
request,
)
from storages.factory import StorageFactory from storages.factory import StorageFactory
@@ -77,18 +69,12 @@ def build_file_entry(obj: Dict[str, Any], prefix: str) -> Dict[str, Any] | None:
return entry return entry
def build_directory_entry( def build_directory_entry(prefix_value: str | None, current_prefix: str) -> Dict[str, Any] | None:
prefix_value: str | None, current_prefix: str
) -> Dict[str, Any] | None:
"""根据前缀构建目录条目。""" """根据前缀构建目录条目。"""
if not prefix_value: if not prefix_value:
return None return None
rel = ( rel = prefix_value[len(current_prefix) :].rstrip("/") if current_prefix else prefix_value.rstrip("/")
prefix_value[len(current_prefix) :].rstrip("/")
if current_prefix
else prefix_value.rstrip("/")
)
return {"name": rel, "key": prefix_value, "is_dir": True} return {"name": rel, "key": prefix_value, "is_dir": True}
@@ -198,7 +184,7 @@ def serve_file(file_path):
@main_route.route("/download/<path:file_path>") @main_route.route("/download/<path:file_path>")
def download_file(file_path): def download_file(file_path):
"""为 GitHub 存储提供下载支持,添加 Content-Disposition 头以强制下载""" """下载文件,支持所有存储类型"""
try: try:
# 验证文件存在 # 验证文件存在
try: try:
@@ -206,57 +192,24 @@ def download_file(file_path):
except Exception: except Exception:
abort(404) abort(404)
# 获取存储类型 # 使用存储后端的统一接口生成下载响应
storage_type = type(storage).__name__ download_response = storage.generate_download_response(file_path)
# GitHub 存储:通过服务器中继以添加 Content-Disposition 头
if storage_type == "GitHubStorage":
try:
file_obj = storage.get_object(file_path)
file_name = file_path.split("/")[-1] if "/" in file_path else file_path
# 获取完整内容用于返回
body = file_obj.get("Body")
if hasattr(body, "read"):
content = body.read()
elif hasattr(body, "data"):
content = body.data
else:
content = body
# 使用 RFC 5987 编码处理文件名中的特殊字符
from urllib.parse import quote
encoded_filename = quote(file_name.encode("utf-8"), safe="")
headers = {
"Content-Type": file_obj.get(
"ContentType", "application/octet-stream"
),
"Content-Disposition": f"attachment; filename=\"{file_name}\"; filename*=UTF-8''{encoded_filename}",
"Cache-Control": "public, max-age=86400",
}
return Response(
content,
headers=headers,
mimetype=file_obj.get("ContentType", "application/octet-stream"),
)
except Exception as e:
print(f"GitHub download error: {e}")
abort(404)
# R2 和其他存储:直接重定向
presigned = storage.generate_presigned_url(file_path)
if presigned:
return redirect(presigned)
public_url = storage.get_public_url(file_path)
if public_url:
return redirect(public_url)
if not download_response:
abort(403) abort(403)
# 根据响应类型处理
if download_response["type"] == "redirect":
return redirect(download_response["url"])
elif download_response["type"] == "content":
return Response(
download_response["content"],
headers=download_response["headers"],
mimetype=download_response["mimetype"],
)
else:
abort(500)
except Exception as e: except Exception as e:
print(f"Download error: {e}") print(f"Download error: {e}")
abort(500) abort(500)
@@ -459,9 +412,7 @@ def copy_item():
if not source or not destination: if not source or not destination:
return ( return (
jsonify( jsonify({"success": False, "error": "Source or destination not provided"}),
{"success": False, "error": "Source or destination not provided"}
),
400, 400,
) )
@@ -493,9 +444,7 @@ def move_item():
if not source or not destination: if not source or not destination:
return ( return (
jsonify( jsonify({"success": False, "error": "Source or destination not provided"}),
{"success": False, "error": "Source or destination not provided"}
),
400, 400,
) )

View File

@@ -208,3 +208,29 @@ class BaseStorage(ABC):
创建成功返回 True失败返回 False 创建成功返回 True失败返回 False
""" """
pass pass
def generate_download_response(self, key: str) -> Dict[str, Any]:
"""
生成文件下载响应
Args:
key: 对象键名(文件路径)
Returns:
包含下载信息的字典,包括:
- type: "redirect""content"
- url: 重定向URL当type为redirect时
- content: 文件内容当type为content时
- headers: HTTP响应头
- mimetype: MIME类型
"""
# 默认实现返回重定向URL
presigned = self.generate_presigned_url(key)
if presigned:
return {"type": "redirect", "url": presigned}
public_url = self.get_public_url(key)
if public_url:
return {"type": "redirect", "url": public_url}
return None

View File

@@ -59,13 +59,9 @@ class GitHubStorage(BaseStorage):
self.raw_proxy_url = os.getenv("GITHUB_RAW_PROXY_URL", "").rstrip("/") self.raw_proxy_url = os.getenv("GITHUB_RAW_PROXY_URL", "").rstrip("/")
if not all([self.repo_owner, self.repo_name, self.access_token]): if not all([self.repo_owner, self.repo_name, self.access_token]):
raise RuntimeError( raise RuntimeError("GITHUB_REPO_OWNER, GITHUB_REPO_NAME, GITHUB_ACCESS_TOKEN must be set")
"GITHUB_REPO_OWNER, GITHUB_REPO_NAME, GITHUB_ACCESS_TOKEN must be set"
)
self.api_base_url = ( self.api_base_url = f"https://api.github.com/repos/{self.repo_owner}/{self.repo_name}"
f"https://api.github.com/repos/{self.repo_owner}/{self.repo_name}"
)
# 如果配置了代理 URL则使用代理 URL否则使用官方 raw.githubusercontent.com # 如果配置了代理 URL则使用代理 URL否则使用官方 raw.githubusercontent.com
if self.raw_proxy_url: if self.raw_proxy_url:
@@ -121,11 +117,7 @@ class GitHubStorage(BaseStorage):
try: try:
# 移除末尾的 / 以保持 GitHub API 的一致性 # 移除末尾的 / 以保持 GitHub API 的一致性
prefix = prefix.rstrip("/") if prefix else "" prefix = prefix.rstrip("/") if prefix else ""
url = ( url = f"{self.api_base_url}/contents/{prefix}" if prefix else f"{self.api_base_url}/contents"
f"{self.api_base_url}/contents/{prefix}"
if prefix
else f"{self.api_base_url}/contents"
)
response = requests.get(url, headers=self._headers()) response = requests.get(url, headers=self._headers())
response.raise_for_status() response.raise_for_status()
@@ -212,9 +204,7 @@ class GitHubStorage(BaseStorage):
return { return {
"Body": StreamWrapper(content), "Body": StreamWrapper(content),
"ContentLength": len(content), "ContentLength": len(content),
"ContentType": response.headers.get( "ContentType": response.headers.get("Content-Type", "application/octet-stream"),
"Content-Type", "application/octet-stream"
),
} }
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to get object: {str(e)}") from e raise RuntimeError(f"Failed to get object: {str(e)}") from e
@@ -498,3 +488,49 @@ class GitHubStorage(BaseStorage):
except Exception as e: except Exception as e:
print(f"Create folder failed: {str(e)}") print(f"Create folder failed: {str(e)}")
return False return False
def generate_download_response(self, key: str) -> Dict[str, Any]:
"""
生成文件下载响应GitHub 特有实现)
GitHub 存储需要通过服务器中继以添加 Content-Disposition 头
Args:
key: 对象键名(文件路径)
Returns:
包含下载信息的字典
"""
try:
file_obj = self.get_object(key)
file_name = key.split("/")[-1] if "/" in key else key
# 获取完整内容
body = file_obj.get("Body")
if hasattr(body, "read"):
content = body.read()
elif hasattr(body, "data"):
content = body.data
else:
content = body
# 使用 RFC 5987 编码处理文件名中的特殊字符
from urllib.parse import quote
encoded_filename = quote(file_name.encode("utf-8"), safe="")
headers = {
"Content-Type": file_obj.get("ContentType", "application/octet-stream"),
"Content-Disposition": f"attachment; filename=\"{file_name}\"; filename*=UTF-8''{encoded_filename}",
"Cache-Control": "public, max-age=86400",
}
return {
"type": "content",
"content": content,
"headers": headers,
"mimetype": file_obj.get("ContentType", "application/octet-stream"),
}
except Exception as e:
print(f"GitHub download response generation failed: {str(e)}")
return None