use pydantic to validate config

This commit is contained in:
2025-10-29 00:24:13 +08:00
parent 4ae07c57cc
commit f56675c486
3 changed files with 103 additions and 18 deletions

View File

@ -8,10 +8,76 @@ from time import time
from rich.logging import RichHandler
from rich.progress import Progress
from pickle import dumps, loads
from typing import Optional, Callable
from typing import Optional, Callable,Literal
import atexit
import re
import get_frame
import pydantic as pyd
from pydantic import BaseModel,Field,field_validator,model_validator
class Config(BaseModel):
save_to:Literal["single","multi"] = Field("single",description="保存到单文件夹或者每个子文件夹创建compress_dir")
crf: Optional[int] = Field(None, ge=0, le=51, description="CRF值范围0-51")
bitrate: Optional[str] = Field(None, description="比特率,格式如: 1000k, 2.5M, 1500B")
codec: str = Field("h264",description="ffmpeg的codec如果使用GPU需要对应设置")
hwaccel:Optional[Literal["amf","qsv","cuda"]] = Field(None,description="使用GPU加速")
extra:Optional[list[str]] = Field(None,description="插入到ffmpeg输出前的自定义参数")
ffmpeg:str = "ffmpeg"
ffprobe:str = "ffprobe"
manual:Optional[list[str]] = Field(None,description=r"手动设置ffmpeg命令ffmpeg -i {input} {manual} {output}")
video_ext:list[str] = Field([".mp4", ".mkv"],description="视频文件后缀,含.")
compress_dir_name:str = Field("compress",description="压缩文件夹名称")
resolution: Optional[str] = Field(None,description="统一到特定尺寸None为不使用缩放")
fps:int = Field(30,description="fps",ge=0)
test_video_resolution:str = "1920x1080"
test_video_fps:int = Field(30,ge=0)
test_video_input:str = "compress_video_test.mp4"
test_video_output:str = "compressed_video_test.mp4"
disable_hwaccel_when_fail:bool = Field(True,description="当运行失败时,禁用硬件加速")
@field_validator('bitrate')
@classmethod
def validate_bitrate(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return v
pattern = r'^[\d\.]+[MkB]*$'
if not re.match(pattern, v):
raise ValueError('bitrate格式不正确应为数字+单位(M/k/B),如: 1000k, 2.5M')
return v
@field_validator('resolution')
@classmethod
def validate_resolution(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return v
pattern = r'^((-1)|\d+):((-1)|\d+)$'
if not re.match(pattern, v):
raise ValueError('resolution格式不正确应为{数字/-1}:{数字/-1}')
return v
@field_validator("compress_dir_name")
# @field_validator("test_video_input")
# @field_validator("test_video_output")
@classmethod
def valid_path(cls, v:str) -> str:
if re.search(r'[\\/:*?"<>|\x00-\x1F]',v):
raise ValueError("某配置不符合目录名语法")
return v
@model_validator(mode='after')
def validate_mutual_exclusive(self):
crf_none = self.crf is None
bitrate_none = self.bitrate is None
# 有且只有一者为None
if crf_none == bitrate_none:
raise ValueError('crf和bitrate必须互斥有且只有一个为None')
return self
root = None
CFG_FILE = Path(sys.path[0]) / "config.json"
@ -85,7 +151,7 @@ def get_cmd(video_path: str | Path, output_file: str | Path) -> list[str]:
"-b:v",
CFG["bitrate"],
"-r",
CFG["fps"],
str(CFG["fps"]),
"-y",
]
)
@ -104,7 +170,7 @@ def get_cmd(video_path: str | Path, output_file: str | Path) -> list[str]:
"-global_quality",
str(CFG["crf"]),
"-r",
CFG["fps"],
str(CFG["fps"]),
"-y",
]
)
@ -390,7 +456,7 @@ def test():
exit(-1)
try:
ret = subprocess.run(
f"ffmpeg -hide_banner -f lavfi -i testsrc=duration=1:size={CFG['test_video_resolution']}:rate={CFG['test_video_fps']} -c:v libx264 -y -pix_fmt yuv420p {CFG['test_video_input']}".split(),
f"{CFG['ffmpeg']} -hide_banner -f lavfi -i testsrc=duration=1:size={CFG['test_video_resolution']}:rate={CFG['test_video_fps']} -c:v libx264 -y -pix_fmt yuv420p {CFG['test_video_input']}".split(),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
@ -413,6 +479,9 @@ def test():
logging.debug(ret.stderr)
logging.error("Error termination via test failed.")
exit(-1)
if get_frame.get_video_frame_count("compress_video_test.mp4") is None:
logging.error("测试读取帧数失败,将无法正确显示进度。")
os.remove("compress_video_test.mp4")
os.remove("compressed_video_test.mp4")
except KeyboardInterrupt as e:
@ -446,7 +515,11 @@ def main(_root=None):
import json
cfg: dict = json.loads(CFG_FILE.read_text())
CFG.update(cfg)
cfg_model = Config(**cfg)
CFG.update(cfg_model.model_dump())
get_frame.ffprobe = CFG["ffprobe"]
logging.debug(cfg_model)
logging.debug(CFG)
except KeyboardInterrupt as e:
raise e
except Exception as e: