use pydantic to validate config

This commit is contained in:
2025-10-29 00:24:13 +08:00
parent 4ae07c57cc
commit f56675c486
3 changed files with 103 additions and 18 deletions

View File

@ -5,7 +5,8 @@
"codec": "h264_qsv",
"hwaccel": "qsv",
"extra": [],
"ffmpeg": "ffmpeg",
"ffmpeg": "C:/tools/ffmpeg/bin/ffmpeg.exe",
"ffprobe": "C:/tools/ffmpeg/bin/ffprobe",
"manual": null,
"video_ext": [
".mp4",

View File

@ -1,19 +1,18 @@
import json
import shutil
import logging
import subprocess
from fractions import Fraction
from decimal import Decimal
from typing import Optional, Tuple
ffprobe:str = "ffprobe"
class FFProbeError(RuntimeError):
pass
def _run_ffprobe(args: list[str]) -> dict:
"""运行 ffprobe 并以 JSON 返回,若失败抛异常。"""
if not shutil.which("ffprobe"):
raise FileNotFoundError("未找到 ffprobe请先安装 FFmpeg 并确保 ffprobe 在 PATH 中。")
# 始终要求 JSON 输出,便于稳健解析
base = ["ffprobe", "-v", "error", "-print_format", "json"]
base = [ffprobe, "-v", "error", "-print_format", "json"]
proc = subprocess.run(base + args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if proc.returncode != 0:
raise FFProbeError(proc.stderr.strip() or "ffprobe 调用失败")
@ -30,6 +29,7 @@ def _try_nb_frames(path: str, stream_index: int) -> Optional[int]:
])
streams = data.get("streams") or []
if not streams:
logging.debug("_try_nb_frames: failed no stream")
return None
nb = streams[0].get("nb_frames")
if nb and nb != "N/A":
@ -37,7 +37,9 @@ def _try_nb_frames(path: str, stream_index: int) -> Optional[int]:
n = int(nb)
return n if n >= 0 else None
except ValueError:
logging.debug(f"_try_nb_frames: failed nb not positive int: {nb}")
return None
logging.debug(f"_try_nb_frames: failed nb NA: {nb}")
return None
def _try_avgfps_times_duration(path: str, stream_index: int) -> Optional[int]:
@ -58,6 +60,7 @@ def _try_avgfps_times_duration(path: str, stream_index: int) -> Optional[int]:
f = _run_ffprobe(["-show_entries", "format=duration", path])
dur_str = (f.get("format") or {}).get("duration")
if not dur_str:
logging.debug(f"_try_avgfps_times_duration: failed no dur_str, {f}")
return None
try:
@ -66,7 +69,8 @@ def _try_avgfps_times_duration(path: str, stream_index: int) -> Optional[int]:
# 四舍五入到最近整数,避免系统性低估
est = int(dur * Decimal(fps.numerator) / Decimal(fps.denominator) + Decimal("0.5"))
return est if est >= 0 else None
except Exception:
except Exception as e:
logging.debug("_try_avgfps_times_duration: failed",exc_info=e)
return None
def _try_count_packets(path: str, stream_index: int) -> Optional[int]:
@ -79,12 +83,14 @@ def _try_count_packets(path: str, stream_index: int) -> Optional[int]:
])
streams = data.get("streams") or []
if not streams:
logging.debug("_try_count_packets: failed no stream")
return None
nbp = streams[0].get("nb_read_packets")
try:
n = int(nbp)
return n if n >= 0 else None
except Exception:
except Exception as e:
logging.debug("_try_count_packets: failed",exc_info=e)
return None
def get_video_frame_count(
@ -116,12 +122,17 @@ def get_video_frame_count(
}
for key in fallback_order:
func = methods.get(key)
if not func:
continue
n = func(path, stream_index)
if isinstance(n, int) and n >= 0:
return n
try:
func = methods.get(key)
if not func:
continue
n = func(path, stream_index)
if isinstance(n, int) and n >= 0:
return n
else:
logging.debug(f"Failed to get frame with {key}")
except Exception as e:
logging.debug(f"Errored to get frame with {key}.",exc_info=e)
return None
raise RuntimeError("无法获取或估计帧数:所有回退方法均失败。")

View File

@ -8,10 +8,76 @@ from time import time
from rich.logging import RichHandler
from rich.progress import Progress
from pickle import dumps, loads
from typing import Optional, Callable
from typing import Optional, Callable,Literal
import atexit
import re
import get_frame
import pydantic as pyd
from pydantic import BaseModel,Field,field_validator,model_validator
class Config(BaseModel):
save_to:Literal["single","multi"] = Field("single",description="保存到单文件夹或者每个子文件夹创建compress_dir")
crf: Optional[int] = Field(None, ge=0, le=51, description="CRF值范围0-51")
bitrate: Optional[str] = Field(None, description="比特率,格式如: 1000k, 2.5M, 1500B")
codec: str = Field("h264",description="ffmpeg的codec如果使用GPU需要对应设置")
hwaccel:Optional[Literal["amf","qsv","cuda"]] = Field(None,description="使用GPU加速")
extra:Optional[list[str]] = Field(None,description="插入到ffmpeg输出前的自定义参数")
ffmpeg:str = "ffmpeg"
ffprobe:str = "ffprobe"
manual:Optional[list[str]] = Field(None,description=r"手动设置ffmpeg命令ffmpeg -i {input} {manual} {output}")
video_ext:list[str] = Field([".mp4", ".mkv"],description="视频文件后缀,含.")
compress_dir_name:str = Field("compress",description="压缩文件夹名称")
resolution: Optional[str] = Field(None,description="统一到特定尺寸None为不使用缩放")
fps:int = Field(30,description="fps",ge=0)
test_video_resolution:str = "1920x1080"
test_video_fps:int = Field(30,ge=0)
test_video_input:str = "compress_video_test.mp4"
test_video_output:str = "compressed_video_test.mp4"
disable_hwaccel_when_fail:bool = Field(True,description="当运行失败时,禁用硬件加速")
@field_validator('bitrate')
@classmethod
def validate_bitrate(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return v
pattern = r'^[\d\.]+[MkB]*$'
if not re.match(pattern, v):
raise ValueError('bitrate格式不正确应为数字+单位(M/k/B),如: 1000k, 2.5M')
return v
@field_validator('resolution')
@classmethod
def validate_resolution(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return v
pattern = r'^((-1)|\d+):((-1)|\d+)$'
if not re.match(pattern, v):
raise ValueError('resolution格式不正确应为{数字/-1}:{数字/-1}')
return v
@field_validator("compress_dir_name")
# @field_validator("test_video_input")
# @field_validator("test_video_output")
@classmethod
def valid_path(cls, v:str) -> str:
if re.search(r'[\\/:*?"<>|\x00-\x1F]',v):
raise ValueError("某配置不符合目录名语法")
return v
@model_validator(mode='after')
def validate_mutual_exclusive(self):
crf_none = self.crf is None
bitrate_none = self.bitrate is None
# 有且只有一者为None
if crf_none == bitrate_none:
raise ValueError('crf和bitrate必须互斥有且只有一个为None')
return self
root = None
CFG_FILE = Path(sys.path[0]) / "config.json"
@ -85,7 +151,7 @@ def get_cmd(video_path: str | Path, output_file: str | Path) -> list[str]:
"-b:v",
CFG["bitrate"],
"-r",
CFG["fps"],
str(CFG["fps"]),
"-y",
]
)
@ -104,7 +170,7 @@ def get_cmd(video_path: str | Path, output_file: str | Path) -> list[str]:
"-global_quality",
str(CFG["crf"]),
"-r",
CFG["fps"],
str(CFG["fps"]),
"-y",
]
)
@ -390,7 +456,7 @@ def test():
exit(-1)
try:
ret = subprocess.run(
f"ffmpeg -hide_banner -f lavfi -i testsrc=duration=1:size={CFG['test_video_resolution']}:rate={CFG['test_video_fps']} -c:v libx264 -y -pix_fmt yuv420p {CFG['test_video_input']}".split(),
f"{CFG['ffmpeg']} -hide_banner -f lavfi -i testsrc=duration=1:size={CFG['test_video_resolution']}:rate={CFG['test_video_fps']} -c:v libx264 -y -pix_fmt yuv420p {CFG['test_video_input']}".split(),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
@ -413,6 +479,9 @@ def test():
logging.debug(ret.stderr)
logging.error("Error termination via test failed.")
exit(-1)
if get_frame.get_video_frame_count("compress_video_test.mp4") is None:
logging.error("测试读取帧数失败,将无法正确显示进度。")
os.remove("compress_video_test.mp4")
os.remove("compressed_video_test.mp4")
except KeyboardInterrupt as e:
@ -446,7 +515,11 @@ def main(_root=None):
import json
cfg: dict = json.loads(CFG_FILE.read_text())
CFG.update(cfg)
cfg_model = Config(**cfg)
CFG.update(cfg_model.model_dump())
get_frame.ffprobe = CFG["ffprobe"]
logging.debug(cfg_model)
logging.debug(CFG)
except KeyboardInterrupt as e:
raise e
except Exception as e: