feat: 重构配置管理,集中管理环境变量并优化存储后端实现

This commit is contained in:
2025-11-14 23:26:21 +08:00
parent 07a8fafff2
commit 33fe06c59e
10 changed files with 347 additions and 193 deletions

View File

@@ -1,14 +1,11 @@
import os
from typing import Optional
import dotenv
from config import Config
from .base import BaseStorage
from .github import GitHubStorage
from .r2 import R2Storage
dotenv.load_dotenv()
class StorageFactory:
"""存储工厂类,根据配置创建对应的存储实例"""
@@ -29,24 +26,19 @@ class StorageFactory:
if cls._instance is not None:
return cls._instance
storage_type = os.getenv("STORAGE_TYPE")
storage_type = Config.STORAGE_TYPE
if not storage_type:
raise RuntimeError(
"STORAGE_TYPE environment variable is not set. "
"Supported types: r2, github"
)
storage_type = storage_type.lower()
raise RuntimeError("STORAGE_TYPE environment variable is not set. Supported types: r2, github")
if storage_type == "r2":
cls._instance = R2Storage()
elif storage_type == "github":
cls._instance = GitHubStorage()
else:
raise RuntimeError(
f"Unsupported storage type: {storage_type}. Supported types: r2, github"
)
raise RuntimeError(f"Unsupported storage type: {storage_type}. Supported types: r2, github")
return cls._instance
return cls._instance

View File

@@ -1,16 +1,14 @@
import base64
import os
from datetime import datetime
from io import BytesIO
from typing import Any, Dict
import dotenv
import requests
from PIL import Image
from .base import BaseStorage
from config import Config
dotenv.load_dotenv()
from .base import BaseStorage
class StreamWrapper:
@@ -50,29 +48,30 @@ class GitHubStorage(BaseStorage):
"""基于 GitHub 仓库的存储实现"""
def __init__(self):
self.repo_owner = os.getenv("GITHUB_REPO_OWNER")
self.repo_name = os.getenv("GITHUB_REPO_NAME")
self.access_token = os.getenv("GITHUB_ACCESS_TOKEN")
self.branch = os.getenv("GITHUB_BRANCH", "main")
# 反向代理 URL用于加速 GitHub raw 文件访问
# 例如https://raw. githubusercontent.com/ 的反向代理 URL
self.raw_proxy_url = os.getenv("GITHUB_RAW_PROXY_URL", "").rstrip("/")
"""初始化 GitHub 存储客户端"""
self.token = Config.GITHUB_TOKEN
repo_full = Config.GITHUB_REPO # 格式: owner/repo
self.branch = Config.GITHUB_BRANCH
if not all([self.repo_owner, self.repo_name, self.access_token]):
raise RuntimeError("GITHUB_REPO_OWNER, GITHUB_REPO_NAME, GITHUB_ACCESS_TOKEN must be set")
if not self.token or not repo_full:
raise RuntimeError("GITHUB_TOKEN and GITHUB_REPO must be set")
# 解析 owner/repo
repo_parts = repo_full.split("/")
if len(repo_parts) != 2:
raise RuntimeError(f"GITHUB_REPO must be in format 'owner/repo', got: {repo_full}")
self.repo_owner = repo_parts[0]
self.repo_name = repo_parts[1]
self.repo = repo_full
self.api_base_url = f"https://api.github.com/repos/{self.repo_owner}/{self.repo_name}"
# 如果配置了代理 URL则使用代理 URL否则使用官方 raw.githubusercontent.com
if self.raw_proxy_url:
self.raw_content_url = f"{self.raw_proxy_url}/https://raw.githubusercontent.com/{self.repo_owner}/{self.repo_name}/{self.branch}"
else:
self.raw_content_url = f"https://raw.githubusercontent.com/{self.repo_owner}/{self.repo_name}/{self.branch}"
self.raw_content_url = f"https://raw.githubusercontent.com/{self.repo_owner}/{self.repo_name}/{self.branch}"
def _headers(self) -> Dict[str, str]:
"""返回 API 请求的公共头部信息"""
return {
"Authorization": f"token {self.access_token}",
"Authorization": f"token {self.token}",
"Accept": "application/vnd.github.v3+json",
}

View File

@@ -3,29 +3,34 @@ from io import BytesIO
from typing import Any, Dict
import boto3
import dotenv
from botocore.config import Config
from botocore.config import Config as BotocoreConfig
from PIL import Image
from config import Config
from .base import BaseStorage
dotenv.load_dotenv()
class R2Storage(BaseStorage):
def __init__(self):
self.endpoint = os.getenv("R2_ENDPOINT_URL")
if not self.endpoint:
raise RuntimeError("R2_ENDPOINT_URL environment variable is not set")
"""Cloudflare R2 存储后端实现"""
self.access_key = os.getenv("ACCESS_KEY_ID") or os.getenv("ACCESS_KEY_ID")
self.secret_key = os.getenv("SECRET_ACCESS_KEY") or os.getenv("SECRET_ACCESS_KEY")
def __init__(self):
"""初始化 R2 存储客户端"""
# 从统一配置中读取
account_id = Config.R2_ACCOUNT_ID
if not account_id:
raise RuntimeError("R2_ACCOUNT_ID environment variable is not set")
self.endpoint = f"https://{account_id}.r2.cloudflarestorage.com"
self.access_key = Config.R2_ACCESS_KEY_ID
self.secret_key = Config.R2_SECRET_ACCESS_KEY
if not self.access_key or not self.secret_key:
raise RuntimeError("ACCESS_KEY_ID and SECRET_ACCESS_KEY must be set")
raise RuntimeError("R2_ACCESS_KEY_ID and R2_SECRET_ACCESS_KEY must be set")
self.region_name = os.getenv("R2_REGION", "auto")
self.bucket_name = os.getenv("R2_BUCKET_NAME")
self.region_name = "auto"
self.bucket_name = Config.R2_BUCKET_NAME
self.public_domain = Config.R2_PUBLIC_DOMAIN
def get_s3_client(self):
"""
@@ -36,7 +41,7 @@ class R2Storage(BaseStorage):
endpoint_url=self.endpoint,
aws_access_key_id=self.access_key,
aws_secret_access_key=self.secret_key,
config=Config(signature_version="s3v4"),
config=BotocoreConfig(signature_version="s3v4"),
region_name=self.region_name,
)
@@ -93,10 +98,9 @@ class R2Storage(BaseStorage):
"""
生成对象的公共访问 URL
"""
base_url = os.getenv("R2_PUBLIC_URL")
if not base_url:
if not self.public_domain:
return None
return f"{base_url.rstrip('/')}/{key}"
return f"{self.public_domain.rstrip('/')}/{key}"
def generate_thumbnail(self, file_path: str) -> bytes:
"""