Compare commits
5 Commits
5fca3520f6
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| db56f1da62 | |||
| 9ba34f8d2e | |||
| cae41d9bb0 | |||
| 6f304a634c | |||
| 5e94b202b5 |
@ -1,22 +1,22 @@
|
||||
{
|
||||
"save_to": "single",
|
||||
"bitrate": null,
|
||||
"crf": 26,
|
||||
"crf": null,
|
||||
"bitrate": "15M",
|
||||
"codec": "h264_qsv",
|
||||
"hwaccel": "qsv",
|
||||
"extra": [],
|
||||
"ffmpeg": "C:/tools/ffmpeg/bin/ffmpeg.exe",
|
||||
"ffprobe": "C:/tools/ffmpeg/bin/ffprobe",
|
||||
"ffmpeg": "ffmpeg",
|
||||
"ffprobe": "ffprobe",
|
||||
"manual": null,
|
||||
"video_ext": [
|
||||
".mp4",
|
||||
".mkv"
|
||||
],
|
||||
"compress_dir_name": "compress_qsv",
|
||||
"compress_dir_name": "compress",
|
||||
"resolution": null,
|
||||
"fps": "30",
|
||||
"fps": 30,
|
||||
"test_video_resolution": "1920x1080",
|
||||
"test_video_fps": "30",
|
||||
"test_video_fps": 30,
|
||||
"test_video_input": "compress_video_test.mp4",
|
||||
"test_video_output": "compressed_video_test.mp4",
|
||||
"disable_hwaccel_when_fail": true
|
||||
|
||||
@ -8,5 +8,14 @@
|
||||
".mp4",
|
||||
".mkv"
|
||||
],
|
||||
"resolution": null
|
||||
"resolution": "1920x1080",
|
||||
"extra": [],
|
||||
"manual": null,
|
||||
"compress_dir_name": "compress",
|
||||
"fps": 30,
|
||||
"test_video_resolution": "1920x1080",
|
||||
"test_video_fps": 30,
|
||||
"test_video_input": "compress_video_test.mp4",
|
||||
"test_video_output": "compressed_video_test.mp4",
|
||||
"disable_hwaccel_when_fail": true
|
||||
}
|
||||
@ -123,10 +123,14 @@ def get_video_frame_count(
|
||||
|
||||
for key in fallback_order:
|
||||
try:
|
||||
func = methods.get(key)
|
||||
if not func:
|
||||
try:
|
||||
func = methods.get(key)
|
||||
if not func:
|
||||
continue
|
||||
n = func(path, stream_index)
|
||||
except Exception:
|
||||
logging.debug(f"Errored to get frame with {key}.",exc_info=True)
|
||||
continue
|
||||
n = func(path, stream_index)
|
||||
if isinstance(n, int) and n >= 0:
|
||||
return n
|
||||
else:
|
||||
|
||||
@ -8,101 +8,170 @@ from time import time
|
||||
from rich.logging import RichHandler
|
||||
from rich.progress import Progress
|
||||
from pickle import dumps, loads
|
||||
from typing import Optional, Callable,Literal
|
||||
from typing import Optional, Callable, Literal, List, Any, TYPE_CHECKING
|
||||
import atexit
|
||||
import re
|
||||
import get_frame
|
||||
from pydantic import BaseModel,Field,field_validator,model_validator
|
||||
import json
|
||||
import argparse
|
||||
import shutil
|
||||
|
||||
class Config(BaseModel):
|
||||
save_to:Literal["single","multi"] = Field("single",description="保存到单文件夹,或者每个子文件夹创建compress_dir")
|
||||
crf: Optional[int] = Field(None, ge=0, le=51, description="CRF值,范围0-51")
|
||||
bitrate: Optional[str] = Field(None, description="比特率,格式如: 1000k, 2.5M, 1500B")
|
||||
codec: str = Field("h264",description="ffmpeg的codec,如果使用GPU需要对应设置")
|
||||
hwaccel:Optional[Literal["amf","qsv","cuda"]] = Field(None,description="使用GPU加速")
|
||||
extra:Optional[list[str]] = Field(None,description="插入到ffmpeg输出前的自定义参数")
|
||||
ffmpeg:str = "ffmpeg"
|
||||
ffprobe:str = "ffprobe"
|
||||
manual:Optional[list[str]] = Field(None,description=r"手动设置ffmpeg,命令ffmpeg -i {input} {manual} {output}")
|
||||
video_ext:list[str] = Field([".mp4", ".mkv"],description="视频文件后缀,含.")
|
||||
compress_dir_name:str = Field("compress",description="压缩文件夹名称")
|
||||
resolution: Optional[str] = Field(None,description="统一到特定尺寸,None为不使用缩放")
|
||||
fps:int = Field(30,description="fps",ge=0)
|
||||
test_video_resolution:str = "1920x1080"
|
||||
test_video_fps:int = Field(30,ge=0)
|
||||
test_video_input:str = "compress_video_test.mp4"
|
||||
test_video_output:str = "compressed_video_test.mp4"
|
||||
disable_hwaccel_when_fail:bool = Field(True,description="当运行失败时,禁用硬件加速")
|
||||
try:
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
HAS_PYDANTIC = True
|
||||
|
||||
class Config(BaseModel):
|
||||
save_to: Literal["single", "multi"] = Field("single", description="保存到单文件夹,或者每个子文件夹创建compress_dir")
|
||||
crf: Optional[int] = Field(18, ge=0, le=51, description="CRF值,范围0-51")
|
||||
bitrate: Optional[str] = Field(None, description="比特率,格式如: 1000k, 2.5M, 1500B")
|
||||
codec: str = Field("h264", description="ffmpeg的codec,如果使用GPU需要对应设置")
|
||||
hwaccel: Optional[Literal["amf", "qsv", "cuda"]] = Field(None, description="使用GPU加速")
|
||||
extra: List[str] = Field([], description="插入到ffmpeg输出前的自定义参数")
|
||||
ffmpeg: str = "ffmpeg"
|
||||
ffprobe: str = "ffprobe"
|
||||
manual: Optional[List[str]] = Field(None, description=r"手动设置ffmpeg,命令ffmpeg -i {input} {manual} {output}")
|
||||
video_ext: List[str] = Field([".mp4", ".mkv"], description="视频文件后缀,含.")
|
||||
compress_dir_name: str = Field("compress", description="压缩文件夹名称")
|
||||
resolution: Optional[str] = Field(None, description="统一到特定尺寸,None为不使用缩放")
|
||||
fps: int = Field(30, description="fps", ge=0)
|
||||
test_video_resolution: str = "1920x1080"
|
||||
test_video_fps: int = Field(30, ge=0)
|
||||
test_video_input: str = "compress_video_test.mp4"
|
||||
test_video_output: str = "compressed_video_test.mp4"
|
||||
disable_hwaccel_when_fail: bool = Field(True, description="当运行失败时,禁用硬件加速")
|
||||
|
||||
|
||||
@field_validator('bitrate')
|
||||
@classmethod
|
||||
def validate_bitrate(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None:
|
||||
@field_validator('bitrate')
|
||||
@classmethod
|
||||
def validate_bitrate(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None:
|
||||
return v
|
||||
pattern = r'^[\d\.]+[MkB]*$'
|
||||
if not re.match(pattern, v):
|
||||
raise ValueError('bitrate格式不正确,应为数字+单位(M/k/B),如: 1000k, 2.5M')
|
||||
return v
|
||||
pattern = r'^[\d\.]+[MkB]*$'
|
||||
if not re.match(pattern, v):
|
||||
raise ValueError('bitrate格式不正确,应为数字+单位(M/k/B),如: 1000k, 2.5M')
|
||||
return v
|
||||
|
||||
@field_validator('resolution')
|
||||
@classmethod
|
||||
def validate_resolution(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None:
|
||||
@field_validator('resolution')
|
||||
@classmethod
|
||||
def validate_resolution(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None:
|
||||
return v
|
||||
pattern = r'^((-1)|\d+):((-1)|\d+)$'
|
||||
if not re.match(pattern, v):
|
||||
raise ValueError('resolution格式不正确,应为{数字/-1}:{数字/-1}')
|
||||
return v
|
||||
pattern = r'^((-1)|\d+):((-1)|\d+)$'
|
||||
if not re.match(pattern, v):
|
||||
raise ValueError('resolution格式不正确,应为{数字/-1}:{数字/-1}')
|
||||
return v
|
||||
|
||||
@field_validator("compress_dir_name")
|
||||
# @field_validator("test_video_input")
|
||||
# @field_validator("test_video_output")
|
||||
@classmethod
|
||||
def valid_path(cls, v:str) -> str:
|
||||
if re.search(r'[\\/:*?"<>|\x00-\x1F]',v):
|
||||
raise ValueError("某配置不符合目录名语法")
|
||||
return v
|
||||
@field_validator("compress_dir_name")
|
||||
@classmethod
|
||||
def valid_path(cls, v: str) -> str:
|
||||
if re.search(r'[\\/:*?"<>|\x00-\x1F]', v):
|
||||
raise ValueError("某配置不符合目录名语法")
|
||||
return v
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_mutual_exclusive(self):
|
||||
crf_none = self.crf is None
|
||||
bitrate_none = self.bitrate is None
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_mutual_exclusive(self):
|
||||
crf_none = self.crf is None
|
||||
bitrate_none = self.bitrate is None
|
||||
# 有且只有一者为None
|
||||
if crf_none == bitrate_none:
|
||||
raise ValueError('crf和bitrate必须互斥:有且只有一个为None')
|
||||
|
||||
# 有且只有一者为None
|
||||
if crf_none == bitrate_none:
|
||||
raise ValueError('crf和bitrate必须互斥:有且只有一个为None')
|
||||
return self
|
||||
|
||||
return self
|
||||
def dump(self):
|
||||
return self.model_dump()
|
||||
|
||||
except ImportError:
|
||||
HAS_PYDANTIC = False
|
||||
from dataclasses import dataclass, asdict
|
||||
import copy
|
||||
@dataclass
|
||||
class Config:
|
||||
save_to: str = "single"
|
||||
crf: Optional[int] = 18
|
||||
bitrate: Optional[str] = None
|
||||
codec: str = "h264"
|
||||
hwaccel: Optional[str] = None
|
||||
extra: List[str] = []
|
||||
ffmpeg: str = "ffmpeg"
|
||||
ffprobe: str = "ffprobe"
|
||||
manual: Optional[List[str]] = None
|
||||
video_ext: List[str] = [".mp4", ".mkv"]
|
||||
compress_dir_name: str = "compress"
|
||||
resolution: Optional[str] = None
|
||||
fps: int = 30
|
||||
test_video_resolution: str = "1920x1080"
|
||||
test_video_fps: int = 30
|
||||
test_video_input: str = "compress_video_test.mp4"
|
||||
test_video_output: str = "compressed_video_test.mp4"
|
||||
disable_hwaccel_when_fail: bool = True
|
||||
|
||||
def update(self, other):
|
||||
if isinstance(other, dict):
|
||||
d = other
|
||||
elif isinstance(other, Config):
|
||||
d = asdict(other)
|
||||
else:
|
||||
return
|
||||
|
||||
for k, v in d.items():
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
def copy(self):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
def dump(self):
|
||||
return asdict(self)
|
||||
|
||||
|
||||
root = None
|
||||
if os.environ.get("INSTALL", "0") == "1":
|
||||
CFG_FILE= Path(os.getenv("APPDATA", "C:/")) / "VideoCompress" / "config.json"
|
||||
CFG_FILE = Path(os.getenv("APPDATA", "C:/")) / "VideoCompress" / "config.json"
|
||||
else:
|
||||
CFG_FILE= Path(sys.path[0]) / "config.json"
|
||||
CFG = {
|
||||
"save_to": "single",
|
||||
"crf": "18",
|
||||
"bitrate": None,
|
||||
"codec": "h264",
|
||||
"hwaccel": None,
|
||||
"extra": [],
|
||||
"ffmpeg": "ffmpeg",
|
||||
"manual": None,
|
||||
"video_ext": [".mp4", ".mkv"],
|
||||
"compress_dir_name": "compress",
|
||||
"resolution": None,
|
||||
"fps": "30",
|
||||
"test_video_resolution": "1920x1080",
|
||||
"test_video_fps": "30",
|
||||
"test_video_input": "compress_video_test.mp4",
|
||||
"test_video_output": "compressed_video_test.mp4",
|
||||
"disable_hwaccel_when_fail": True,
|
||||
}
|
||||
CFG_FILE = Path(sys.path[0]) / "config.json"
|
||||
|
||||
if CFG_FILE.exists():
|
||||
try:
|
||||
import json
|
||||
|
||||
if HAS_PYDANTIC:
|
||||
assert BaseModel # type: ignore
|
||||
assert issubclass(Config, BaseModel)
|
||||
CFG = Config.model_validate_json(CFG_FILE.read_text())
|
||||
else:
|
||||
assert Config
|
||||
cfg:dict[str, Any] = json.loads(CFG_FILE.read_text())
|
||||
CFG = Config(**cfg)
|
||||
|
||||
get_frame.ffprobe = CFG.ffprobe
|
||||
logging.debug(CFG)
|
||||
except KeyboardInterrupt as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.warning("Invalid config file, ignored.")
|
||||
logging.debug(e)
|
||||
else:
|
||||
try:
|
||||
if HAS_PYDANTIC:
|
||||
if TYPE_CHECKING:
|
||||
assert BaseModel # type: ignore
|
||||
assert issubclass(Config, BaseModel)
|
||||
CFG = Config() # type: ignore
|
||||
CFG_FILE.write_text(CFG.model_dump_json(indent=4))
|
||||
else:
|
||||
import json
|
||||
if TYPE_CHECKING:
|
||||
assert Config
|
||||
assert asdict # type: ignore
|
||||
CFG = Config() # type: ignore
|
||||
CFG_FILE.write_text(json.dumps(asdict(CFG), indent=4))
|
||||
|
||||
logging.info("Config file created.")
|
||||
except Exception as e:
|
||||
logging.warning("Failed to create config file.", exc_info=e)
|
||||
|
||||
current_running_file:Optional[Path] = None
|
||||
|
||||
def get_cmd(video_path: str | Path, output_file: str | Path) -> list[str]:
|
||||
if isinstance(video_path, Path):
|
||||
@ -110,23 +179,23 @@ def get_cmd(video_path: str | Path, output_file: str | Path) -> list[str]:
|
||||
if isinstance(output_file, Path):
|
||||
output_file = str(output_file.resolve())
|
||||
|
||||
if CFG["manual"] is not None:
|
||||
command = [CFG["ffmpeg"], "-hide_banner", "-i", video_path]
|
||||
command.extend(CFG["manual"])
|
||||
if CFG.manual is not None:
|
||||
command = [CFG.ffmpeg, "-hide_banner", "-i", video_path]
|
||||
command.extend(CFG.manual)
|
||||
command.append(output_file)
|
||||
return command
|
||||
|
||||
command = [
|
||||
CFG["ffmpeg"],
|
||||
CFG.ffmpeg,
|
||||
"-hide_banner",
|
||||
]
|
||||
if CFG["hwaccel"] is not None:
|
||||
if CFG.hwaccel is not None:
|
||||
command.extend(
|
||||
[
|
||||
"-hwaccel",
|
||||
CFG["hwaccel"],
|
||||
CFG.hwaccel,
|
||||
"-hwaccel_output_format",
|
||||
CFG["hwaccel"],
|
||||
CFG.hwaccel,
|
||||
|
||||
]
|
||||
)
|
||||
@ -137,69 +206,74 @@ def get_cmd(video_path: str | Path, output_file: str | Path) -> list[str]:
|
||||
]
|
||||
)
|
||||
|
||||
if CFG["bitrate"] is not None:
|
||||
if CFG.bitrate is not None:
|
||||
|
||||
if CFG["resolution"] is not None:
|
||||
if CFG.resolution is not None:
|
||||
command.extend(
|
||||
[
|
||||
"-vf",
|
||||
f"scale={CFG['resolution']}",
|
||||
f"scale={CFG.resolution}",
|
||||
]
|
||||
)
|
||||
command.extend(
|
||||
[
|
||||
"-c:v",
|
||||
CFG["codec"],
|
||||
CFG.codec,
|
||||
"-b:v",
|
||||
CFG["bitrate"],
|
||||
CFG.bitrate,
|
||||
"-r",
|
||||
str(CFG["fps"]),
|
||||
str(CFG.fps),
|
||||
"-y",
|
||||
]
|
||||
)
|
||||
else:
|
||||
if CFG["resolution"] is not None:
|
||||
if CFG.resolution is not None:
|
||||
command.extend(
|
||||
[
|
||||
"-vf",
|
||||
f"scale={CFG['resolution']}",
|
||||
f"scale={CFG.resolution}",
|
||||
]
|
||||
)
|
||||
command.extend(
|
||||
[
|
||||
"-c:v",
|
||||
CFG["codec"],
|
||||
CFG.codec,
|
||||
"-global_quality",
|
||||
str(CFG["crf"]),
|
||||
str(CFG.crf),
|
||||
"-r",
|
||||
str(CFG["fps"]),
|
||||
str(CFG.fps),
|
||||
"-y",
|
||||
]
|
||||
)
|
||||
|
||||
command.extend(CFG["extra"])
|
||||
command.extend(CFG.extra)
|
||||
command.append(output_file)
|
||||
logging.debug(f"Create CMD: {command}")
|
||||
return command
|
||||
|
||||
|
||||
# 配置logging
|
||||
def setup_logging():
|
||||
def setup_logging(verbose: bool = False):
|
||||
log_dir = Path("logs")
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
log_file = log_dir / f"video_compress_{datetime.now().strftime('%Y%m%d')}.log"
|
||||
stream = RichHandler(rich_tracebacks=True, tracebacks_show_locals=True)
|
||||
stream.setLevel(logging.INFO)
|
||||
stream = RichHandler(level=logging.DEBUG if verbose else logging.INFO,
|
||||
rich_tracebacks=True, tracebacks_show_locals=True)
|
||||
# stream.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||
stream.setFormatter(logging.Formatter("%(message)s"))
|
||||
|
||||
file = logging.FileHandler(log_file, encoding="utf-8")
|
||||
file.setLevel(logging.DEBUG)
|
||||
|
||||
# 清除现有的handlers,避免多次调用basicConfig无效
|
||||
logging.getLogger().handlers.clear()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s - %(levelname) 7s - %(message)s",
|
||||
handlers=[file, stream],
|
||||
handlers=[stream, file],
|
||||
)
|
||||
logging.debug("Logging is set up.")
|
||||
|
||||
|
||||
def fmt_time(t: float | int) -> str:
|
||||
@ -216,10 +290,11 @@ def process_video(
|
||||
compress_dir: Optional[Path] = None,
|
||||
update_func: Optional[Callable[[Optional[int], Optional[str]], None]] = None,
|
||||
):
|
||||
global current_running_file
|
||||
|
||||
if compress_dir is None:
|
||||
# 在视频文件所在目录下创建 compress 子目录(如果不存在)
|
||||
compress_dir = video_path.parent / CFG["compress_dir_name"]
|
||||
compress_dir = video_path.parent / CFG.compress_dir_name
|
||||
else:
|
||||
assert root
|
||||
compress_dir /= video_path.parent.relative_to(root)
|
||||
@ -235,6 +310,7 @@ def process_video(
|
||||
|
||||
video_path_str = str(video_path.absolute())
|
||||
command = get_cmd(video_path_str, output_file)
|
||||
current_running_file = output_file
|
||||
|
||||
try:
|
||||
result = subprocess.Popen(
|
||||
@ -268,6 +344,8 @@ def process_video(
|
||||
rate = match.group(0) if match else None
|
||||
update_func(frame_number, rate)
|
||||
|
||||
current_running_file = None
|
||||
|
||||
if result.returncode != 0:
|
||||
logging.error(
|
||||
f"处理文件 {video_path_str} 失败"
|
||||
@ -277,7 +355,7 @@ def process_video(
|
||||
assert result.stdout is not None
|
||||
logging.debug(result.stdout.read())
|
||||
logging.debug(total)
|
||||
if CFG["hwaccel"] == "mediacodec" and CFG["codec"] in [
|
||||
if CFG.hwaccel == "mediacodec" and CFG.codec in [
|
||||
"h264_mediacodec",
|
||||
"hevc_mediacodec",
|
||||
]:
|
||||
@ -285,37 +363,49 @@ def process_video(
|
||||
"mediacodec硬件加速器已知在较短片段上存在异常,将禁用加速重试。"
|
||||
)
|
||||
output_file.unlink(missing_ok=True)
|
||||
bak = CFG.copy()
|
||||
CFG["hwaccel"] = None
|
||||
CFG["codec"] = "h264" if CFG["codec"] == "h264_mediacodec" else "hevc"
|
||||
bak = CFG.codec, CFG.hwaccel
|
||||
CFG.hwaccel = None
|
||||
CFG.codec = "h264" if CFG.codec == "h264_mediacodec" else "hevc"
|
||||
assert not output_file.exists()
|
||||
ret = process_video(video_path, compress_dir, update_func)
|
||||
CFG.update(bak)
|
||||
CFG.codec, CFG.hwaccel = bak
|
||||
if not ret:
|
||||
logging.error("重试仍然失败。")
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
elif CFG["disable_hwaccel_when_fail"] and CFG["hwaccel"] is not None:
|
||||
elif CFG.disable_hwaccel_when_fail and CFG.hwaccel is not None:
|
||||
logging.info("正在禁用硬件加速器重试,进度条可能发生混乱。")
|
||||
output_file.unlink(missing_ok=True)
|
||||
bak = CFG.copy()
|
||||
CFG["hwaccel"] = None
|
||||
if TYPE_CHECKING:
|
||||
assert BaseModel # type: ignore
|
||||
assert isinstance(CFG, BaseModel)
|
||||
|
||||
bak = CFG.codec, CFG.hwaccel
|
||||
CFG.hwaccel = None
|
||||
if (
|
||||
CFG["codec"].endswith("_mediacodec")
|
||||
or CFG["codec"].endswith("_qsv")
|
||||
or CFG["codec"].endswith("_nvenc")
|
||||
or CFG["codec"].endswith("_amf")
|
||||
CFG.codec.endswith("_mediacodec")
|
||||
or CFG.codec.endswith("_qsv")
|
||||
or CFG.codec.endswith("_nvenc")
|
||||
or CFG.codec.endswith("_amf")
|
||||
):
|
||||
CFG["codec"] = CFG["codec"].split("_")[0]
|
||||
CFG.codec = CFG.codec.split("_")[0]
|
||||
assert not output_file.exists()
|
||||
ret = process_video(video_path, compress_dir, update_func)
|
||||
CFG.update(bak)
|
||||
CFG.codec, CFG.hwaccel = bak
|
||||
if not ret:
|
||||
logging.error("重试仍然失败。")
|
||||
return False
|
||||
else:
|
||||
logging.debug(f"文件处理成功: {video_path_str} -> {output_file}")
|
||||
if video_path.stat().st_size <= output_file.stat().st_size:
|
||||
logging.info(
|
||||
f"压缩后文件比原文件大,直接复制原文件: {video_path_str}"
|
||||
)
|
||||
output_file.unlink(missing_ok=True)
|
||||
shutil.copy2(video_path, output_file)
|
||||
return True
|
||||
else:
|
||||
logging.debug(f"文件处理成功: {video_path_str} -> {output_file}")
|
||||
|
||||
except KeyboardInterrupt as e:
|
||||
raise e
|
||||
@ -324,12 +414,15 @@ def process_video(
|
||||
f"执行 ffmpeg 命令时发生异常, 文件:{str(video_path_str)},cmd={' '.join(map(str,command))}",
|
||||
exc_info=e,
|
||||
)
|
||||
if current_running_file is not None:
|
||||
current_running_file.unlink(missing_ok=True)
|
||||
current_running_file = None
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def traverse_directory(root_dir: Path):
|
||||
video_extensions = set(CFG["video_ext"])
|
||||
video_extensions = set(CFG.video_ext)
|
||||
sm = None
|
||||
# 获取视频文件列表和帧数信息
|
||||
video_files:list[Path] = []
|
||||
@ -338,8 +431,8 @@ def traverse_directory(root_dir: Path):
|
||||
d = que.pop()
|
||||
for file in d.glob("*") if d.is_dir() else [d]:
|
||||
if (
|
||||
file.parent.name == CFG["compress_dir_name"]
|
||||
or file.name == CFG["compress_dir_name"]
|
||||
file.parent.name == CFG.compress_dir_name
|
||||
or file.name == CFG.compress_dir_name
|
||||
):
|
||||
continue
|
||||
if file.is_file() and file.suffix.lower() in video_extensions:
|
||||
@ -426,9 +519,9 @@ def traverse_directory(root_dir: Path):
|
||||
)
|
||||
prog.update(main_task, completed=completed_start + x)
|
||||
|
||||
if CFG["save_to"] == "single":
|
||||
if CFG.save_to == "single":
|
||||
process_video(
|
||||
file, root_dir / CFG["compress_dir_name"], update_progress
|
||||
file, root_dir / CFG.compress_dir_name, update_progress
|
||||
)
|
||||
else:
|
||||
process_video(file, None, update_progress)
|
||||
@ -450,7 +543,7 @@ def test():
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
[CFG["ffmpeg"], "-version"], stdout=-3, stderr=-3
|
||||
[CFG.ffmpeg, "-version"], stdout=-3, stderr=-3
|
||||
).check_returncode()
|
||||
except Exception as e:
|
||||
print(__file__)
|
||||
@ -458,7 +551,7 @@ def test():
|
||||
exit(-1)
|
||||
try:
|
||||
ret = subprocess.run(
|
||||
f"{CFG['ffmpeg']} -hide_banner -f lavfi -i testsrc=duration=1:size={CFG['test_video_resolution']}:rate={CFG['test_video_fps']} -c:v libx264 -y -pix_fmt yuv420p {CFG['test_video_input']}".split(),
|
||||
f"{CFG.ffmpeg} -hide_banner -f lavfi -i testsrc=duration=1:size={CFG.test_video_resolution}:rate={CFG.test_video_fps} -c:v libx264 -y -pix_fmt yuv420p {CFG.test_video_input}".split(),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
@ -469,8 +562,8 @@ def test():
|
||||
logging.debug(ret.stderr)
|
||||
ret.check_returncode()
|
||||
cmd = get_cmd(
|
||||
CFG["test_video_input"],
|
||||
CFG["test_video_output"],
|
||||
CFG.test_video_input,
|
||||
CFG.test_video_output,
|
||||
)
|
||||
ret = subprocess.run(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
@ -501,52 +594,47 @@ def exit_pause():
|
||||
elif os.name == "posix":
|
||||
os.system("read -p 'Press Enter to continue...'")
|
||||
|
||||
def finalize():
|
||||
global current_running_file
|
||||
if current_running_file is not None:
|
||||
try:
|
||||
current_running_file.unlink(missing_ok=True)
|
||||
except Exception as e:
|
||||
try:
|
||||
logging.error(
|
||||
"Failed to delete incomplete output file after keyboard interrupt. CHECK IF LAST PROCSSING VIDEO IS COMPLETED",
|
||||
exc_info=e,
|
||||
)
|
||||
except Exception:
|
||||
print("Failed to delete incomplete output file after keyboard interrupt. CHECK IF LAST PROCSSING VIDEO IS COMPLETED")
|
||||
current_running_file = None
|
||||
|
||||
def main(_root=None):
|
||||
|
||||
atexit.register(exit_pause)
|
||||
atexit.register(finalize)
|
||||
|
||||
global root, current_running_file
|
||||
|
||||
if _root is not None:
|
||||
setup_logging()
|
||||
root = Path(_root)
|
||||
else:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("directory", nargs="?", help="目标目录路径")
|
||||
parser.add_argument("--verbose", "-v", action="store_true", help="启用详细日志记录")
|
||||
args = parser.parse_args()
|
||||
if not args.directory:
|
||||
print("Error termination via invalid input.")
|
||||
sys.exit(1)
|
||||
root = Path(args.directory)
|
||||
setup_logging(args.verbose)
|
||||
|
||||
global root
|
||||
setup_logging()
|
||||
tot_bgn = time()
|
||||
logging.info("-------------------------------")
|
||||
logging.info(datetime.now().strftime("Video Compress started at %Y/%m/%d %H:%M"))
|
||||
|
||||
if CFG_FILE.exists():
|
||||
try:
|
||||
import json
|
||||
|
||||
cfg: dict = json.loads(CFG_FILE.read_text())
|
||||
cfg_model = Config(**cfg)
|
||||
CFG.update(cfg_model.model_dump())
|
||||
get_frame.ffprobe = CFG["ffprobe"]
|
||||
logging.debug(cfg_model)
|
||||
logging.debug(CFG)
|
||||
except KeyboardInterrupt as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.warning("Invalid config file, ignored.")
|
||||
logging.debug(e)
|
||||
else:
|
||||
try:
|
||||
import json
|
||||
|
||||
CFG_FILE.write_text(json.dumps(CFG, indent=4))
|
||||
logging.info("Config file created.")
|
||||
except Exception as e:
|
||||
logging.warning("Failed to create config file.", exc_info=e)
|
||||
|
||||
if _root is not None:
|
||||
root = Path(_root)
|
||||
else:
|
||||
# 通过命令行参数传入需要遍历的目录
|
||||
if len(sys.argv) < 2:
|
||||
print(f"用法:python {__file__} <目标目录>")
|
||||
logging.warning("Error termination via invalid input.")
|
||||
sys.exit(1)
|
||||
root = Path(sys.argv[1])
|
||||
|
||||
if root.name.lower() == CFG["compress_dir_name"].lower():
|
||||
if root.name.lower() == CFG.compress_dir_name.lower():
|
||||
logging.critical("请修改目标目录名为非compress。")
|
||||
logging.error("Error termination via invalid input.")
|
||||
sys.exit(1)
|
||||
@ -555,7 +643,7 @@ def main(_root=None):
|
||||
test()
|
||||
|
||||
if not root.is_dir():
|
||||
print("提供的路径不是一个有效目录。")
|
||||
logging.critical("提供的路径不是一个有效目录。")
|
||||
logging.warning("Error termination via invalid input.")
|
||||
sys.exit(1)
|
||||
|
||||
@ -566,15 +654,15 @@ def main(_root=None):
|
||||
logging.info("Normal termination of Video Compress.")
|
||||
except KeyboardInterrupt:
|
||||
logging.warning(
|
||||
"Error termination via keyboard interrupt, CHECK IF LAST PROCSSING VIDEO IS COMPLETED."
|
||||
"Error termination via keyboard interrupt."
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
"Error termination via unhandled error, CHECK IF LAST PROCSSING VIDEO IS COMPLETED.",
|
||||
"Error termination via unhandled error",
|
||||
exc_info=e,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.argv.append(r'C:\Users\flt\Documents\WeChat Files\wxid_m8h0igh8p52p22\FileStorage\Video')
|
||||
# sys.argv.append(r'C:\Users\flt\Documents\WeChat Files\wxid_m8h0igh8p52p22\FileStorage\Video')
|
||||
main()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
4
calc_utils/.gitignore
vendored
Normal file
4
calc_utils/.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
dist
|
||||
__pycache__
|
||||
*.egg-info
|
||||
.venv
|
||||
73
calc_utils/README.md
Normal file
73
calc_utils/README.md
Normal file
@ -0,0 +1,73 @@
|
||||
# calc-utils
|
||||
|
||||
个人使用的计算化学工具集,主要基于 [ASE (Atomic Simulation Environment)](https://wiki.fysik.dtu.dk/ase/) 和 [RDKit](https://www.rdkit.org/)。
|
||||
|
||||
包含了一些方便的转换工具,以及针对特定服务器环境(PBS/Slurm/Custom)定制的 `ase.calculators.gaussian` 补丁。
|
||||
|
||||
专用软件,`futils.gaussian`在不同服务器环境中无法直接运行,必须予以修改。
|
||||
|
||||
## 安装
|
||||
|
||||
需要 Python 3.12+。
|
||||
|
||||
```bash
|
||||
git clone https://github.com/your-repo/calc-utils.git
|
||||
cd calc-utils
|
||||
pip install .
|
||||
```
|
||||
|
||||
## 功能模块
|
||||
|
||||
### 1. `futils.gaussian` (**Breaking Change**)
|
||||
|
||||
这是一个对 `ase.calculators.gaussian` 的深度定制和 Monkey Patch。
|
||||
|
||||
**注意:导入此模块会直接修改 `ase.calculators.gaussian` 中的类定义。**
|
||||
|
||||
主要修改内容:
|
||||
- **强制任务提交脚本**:计算器的 `command` 默认被设置为调用 `gsub_wait` 脚本。
|
||||
- 默认路径硬编码为 `/home/fanhj/calcs/lele/tools/gsub_wait`(需要在 `futils/gaussian.py` 中按需修改 `GSUB` 变量)。
|
||||
- **文件后缀变更**:输入文件使用 `.gin` 而非 `.gjf`,输出文件默认读取 `.out`。
|
||||
- **参数增强**:`__init__` 方法提供了更详细的 Type Hinting 和默认参数(如 `mem="30GB"`, `proc=32`)。
|
||||
- **辅助方法**:增加了 `mod()` 方法用于快速复制并修改计算器参数。
|
||||
|
||||
```python
|
||||
from futils.gaussian import Gaussian
|
||||
from ase import Atoms
|
||||
|
||||
# 使用定制后的 Gaussian 计算器
|
||||
# 注意:这会尝试调用 gsub_wait 提交任务
|
||||
calc = Gaussian(label='test_calc', method='B3LYP', basis='6-31G(d)')
|
||||
```
|
||||
|
||||
### 2. `futils.rdkit2ase`
|
||||
|
||||
提供 RDKit 分子对象 (`rdkit.Chem.Mol`) 与 ASE 原子对象 (`ase.Atoms`) 之间的无缝转换,**保留 3D 坐标**。
|
||||
|
||||
```python
|
||||
from futils.rdkit2ase import MolToAtoms, AtomsToMol
|
||||
from rdkit import Chem
|
||||
|
||||
# RDKit -> ASE
|
||||
mol = Chem.MolFromMolFile("molecule.mol")
|
||||
atoms = MolToAtoms(mol)
|
||||
|
||||
# ASE -> RDKit
|
||||
new_mol = AtomsToMol(atoms)
|
||||
```
|
||||
|
||||
### 3. `futils.rdkit_utils`
|
||||
|
||||
一些 RDKit 绘图辅助函数。
|
||||
- `draw2D(mol)`: 生成 SVG 格式的 2D 分子图。
|
||||
- `draw3D(mol)`: 使用 IPythonConsole 绘制 3D 分子图。
|
||||
|
||||
## 脚本工具 (`bin/`)
|
||||
|
||||
本项目包含了一些用于任务提交管理的 Shell 脚本,适用于特定的集群环境。
|
||||
|
||||
- **`gsub`**: 任务提交脚本。支持本地或通过 SSH 远程提交到名为 `cluster` 的主机。
|
||||
- **`gsub_wait`**: 提交任务并阻塞等待完成,用于 ASE Calculator 的 `command` 调用,以便实现 Python 脚本的同步执行。
|
||||
|
||||
**配置说明**:
|
||||
使用前请检查 `bin/` 下的脚本以及 `futils/gaussian.py` 中的 `GSUB` 路径,根据您的服务器环境进行调整。
|
||||
176
calc_utils/bin/gsub
Normal file
176
calc_utils/bin/gsub
Normal file
@ -0,0 +1,176 @@
|
||||
#!/bin/bash
|
||||
set -u
|
||||
|
||||
# Usage: gsub <jobname>
|
||||
|
||||
job=${1:-}
|
||||
if [[ -z "$job" ]]; then
|
||||
echo "Usage: $0 <jobname-without-extension>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# ==========================================
|
||||
# 0. 安全检测函数 (Safety Check)
|
||||
# ==========================================
|
||||
check_dangerous_path() {
|
||||
local path="${1:-}"
|
||||
|
||||
# 1. Empty check
|
||||
if [[ -z "$path" ]]; then
|
||||
echo "Error: Empty path is dangerous for deletion." >&2
|
||||
return 1
|
||||
fi
|
||||
|
||||
# 2. Root check
|
||||
if [[ "$path" == "/" ]]; then
|
||||
echo "Error: Root path '/' is dangerous for deletion." >&2
|
||||
return 1
|
||||
fi
|
||||
|
||||
# 3. Space check (optional, but good for safety)
|
||||
if [[ "$path" =~ ^[[:space:]]+$ ]]; then
|
||||
echo "Error: Whitespace path is dangerous." >&2
|
||||
return 1
|
||||
fi
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
# ==========================================
|
||||
# 1. 检查运行环境 (Check Host)
|
||||
# ==========================================
|
||||
# 如果不是 cluster,尝试通过 SSH 远程调用
|
||||
host_short=$(hostname -s 2>/dev/null || hostname)
|
||||
if [[ "$host_short" != "cluster" ]]; then
|
||||
# 假设本地挂载路径 /mnt/home 对应远程 /home (根据原脚本逻辑调整)
|
||||
cur_dir=$(pwd)
|
||||
remote_dir="${cur_dir//\/mnt\/home/\/home}"
|
||||
|
||||
# 定位当前脚本并转换为远程路径
|
||||
# 获取脚本所在目录的绝对路径
|
||||
script_dir=$(cd "$(dirname "$0")" && pwd)
|
||||
script_name=$(basename "$0")
|
||||
local_script="$script_dir/$script_name"
|
||||
|
||||
# 同样对脚本路径进行替换
|
||||
remote_script="${local_script//\/mnt\/home/\/home}"
|
||||
|
||||
# 尝试在远程执行自己
|
||||
echo "Running remotely on cluster: $remote_script" >&2
|
||||
ssh cluster "cd '$remote_dir' && '$remote_script' '$job'"
|
||||
exit $?
|
||||
fi
|
||||
|
||||
# ==========================================
|
||||
# 2. 准备作业 (Prepare Job)
|
||||
# ==========================================
|
||||
|
||||
gin_file="$job.gin"
|
||||
if [[ ! -f "$gin_file" ]]; then
|
||||
echo "Error: $gin_file not found in $(pwd)"
|
||||
exit 2
|
||||
fi
|
||||
|
||||
# 解析配置确定资源 (Parse Proc)
|
||||
# 查找 %NProcShared=XX
|
||||
proc=$(sed -n 's/^%NProcShared=\([0-9]\+\).*$/\1/pI' "$gin_file" | head -n 1)
|
||||
|
||||
queue=""
|
||||
ppn=""
|
||||
|
||||
if [[ "$proc" == "32" ]]; then
|
||||
queue="n32"
|
||||
ppn="32"
|
||||
elif [[ "$proc" == "20" ]]; then
|
||||
queue="n20"
|
||||
ppn="20"
|
||||
else
|
||||
echo "Error: Unsupported NProcShared=$proc in $gin_file. Only 20 or 32 allowed."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 清理旧文件 (Clean up old output)
|
||||
if [[ -f "$job.out" ]]; then
|
||||
# 原脚本逻辑:休眠并删除
|
||||
# echo "Warning: $job.out exists. Deleting..." >&2
|
||||
# 使用安全检查
|
||||
if check_dangerous_path "$job.out"; then
|
||||
rm "$job.out"
|
||||
else
|
||||
echo "Skipping deletion of unsafe path: $job.out" >&2
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# ==========================================
|
||||
# 3. 生成作业脚本 (.job)
|
||||
# ==========================================
|
||||
job_file="$job.job"
|
||||
|
||||
# 使用 heredoc 动态生成 PBS 脚本
|
||||
# 整合了原 g16_32.pbs 的内容和 gsub32 的追加内容
|
||||
cat > "$job_file" <<EOF
|
||||
#!/bin/sh
|
||||
#PBS -l nodes=1:ppn=$ppn
|
||||
#PBS -q $queue
|
||||
#PBS -j oe
|
||||
#PBS -N $job
|
||||
|
||||
cd \$PBS_O_WORKDIR
|
||||
|
||||
# Define Safety Check Function in Job Script
|
||||
check_rm_path() {
|
||||
p="\$1"
|
||||
# Empty check
|
||||
if [ -z "\$p" ]; then
|
||||
echo "Refusing to delete empty path"
|
||||
return 1
|
||||
fi
|
||||
# Root check
|
||||
if [ "\$p" = "/" ]; then
|
||||
echo "Refusing to delete root path"
|
||||
return 1
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
export g16root=/share/apps/soft
|
||||
source \$g16root/g16/bsd/g16.profile
|
||||
|
||||
# Create Scratch Directory
|
||||
if [ -n "\$USER" ] && [ -n "\$PBS_JOBID" ]; then
|
||||
mkdir -p /scr/\$USER/\$PBS_JOBID
|
||||
export GAUSS_SCRDIR=/scr/\$USER/\$PBS_JOBID
|
||||
else
|
||||
echo "Error: USER or PBS_JOBID not set. Cannot setup scratch."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
NODES=\`cat \$PBS_NODEFILE | uniq\`
|
||||
echo "--------------------------------------------------------"
|
||||
echo " JOBID: \$PBS_JOBID"
|
||||
echo " The job was started at \`date\`"
|
||||
echo " The job was running at \$NODES."
|
||||
echo "--------------------------------------------------------"
|
||||
|
||||
# Run G16 Job
|
||||
echo "Executing: g16 < $gin_file > $job.out"
|
||||
g16 < $gin_file > $job.out
|
||||
|
||||
echo "--------------------------------------------------------"
|
||||
echo " The job was finished at \`date\`"
|
||||
echo "--------------------------------------------------------"
|
||||
|
||||
# Delete the tmp File (Cleanup Scratch)
|
||||
echo "Cleaning up \$GAUSS_SCRDIR"
|
||||
if check_rm_path "\$GAUSS_SCRDIR"; then
|
||||
rm -rf "\$GAUSS_SCRDIR"
|
||||
fi
|
||||
|
||||
EOF
|
||||
|
||||
# ==========================================
|
||||
# 4. 提交作业 (Submit)
|
||||
# ==========================================
|
||||
# qsub 会输出 Job ID,例如 12345.cluster
|
||||
qsub "$job_file"
|
||||
116
calc_utils/bin/gsub_wait
Normal file
116
calc_utils/bin/gsub_wait
Normal file
@ -0,0 +1,116 @@
|
||||
#!/bin/bash
|
||||
set -u
|
||||
|
||||
# Usage: gsub_wait <jobname>
|
||||
|
||||
job=${1:-}
|
||||
if [[ -z "$job" ]]; then
|
||||
echo "Usage: $0 <jobname-without-extension>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# ==========================================
|
||||
# 1. 提交任务 (Submit Job)
|
||||
# ==========================================
|
||||
|
||||
# 确定 gsub 命令位置
|
||||
# 优先查找当前目录下的 gsub,否则查找 PATH
|
||||
if [[ -x "./gsub" ]]; then
|
||||
GSUB_CMD="./gsub"
|
||||
else
|
||||
GSUB_CMD="gsub"
|
||||
fi
|
||||
|
||||
# 调用 gsub 并捕获输出
|
||||
# 注意:gsub 内部可能通过 SSH 在远程执行,最终返回 qsub 的输出
|
||||
output=$($GSUB_CMD "$job")
|
||||
echo "$output"
|
||||
|
||||
# ==========================================
|
||||
# 2. 检查是否需要等待 (Check Silent Mode)
|
||||
# ==========================================
|
||||
# 如果 GSUB_SILENT 为 1,则不进行监控,直接退出
|
||||
if [[ "${GSUB_SILENT:-0}" == "1" ]]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# ==========================================
|
||||
# 3. 监控任务进度 (Monitor Progress)
|
||||
# ==========================================
|
||||
|
||||
# 尝试提取 Job ID (例如: 67147.cluster -> 67147)
|
||||
jobid_full=$(echo "$output" | grep -oE '[0-9]+\.cluster|[0-9]+' | head -n 1 || true)
|
||||
|
||||
if [[ -n "$jobid_full" ]]; then
|
||||
jobid=${jobid_full%%.*}
|
||||
|
||||
# 准备参数
|
||||
out_file="$job.out"
|
||||
gin_file="$job.gin"
|
||||
end_file="$job.job.o$jobid"
|
||||
|
||||
if [[ ! -f "$gin_file" ]]; then
|
||||
# 如果 gin 文件找不到(可能是远程路径问题?),跳过监控
|
||||
echo "Warning: $gin_file not found nearby. Skipping monitor."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# 计算 Total Steps: (--link1-- 数量) + 1
|
||||
link_count=$(grep -c -- "--link1--" "$gin_file" || true)
|
||||
total=$((link_count + 1))
|
||||
cntDone=0
|
||||
cntSCF=0
|
||||
|
||||
last_lines=0
|
||||
|
||||
echo "Monitoring Job $jobid..."
|
||||
|
||||
while true; do
|
||||
# A. 检查 PBS 结束文件 (Job 完成标志)
|
||||
if [[ -f "$end_file" ]]; then
|
||||
echo "Job finished (found $end_file)."
|
||||
break
|
||||
fi
|
||||
|
||||
# B. 检查并读取 .out 输出文件
|
||||
if [[ -f "$out_file" ]]; then
|
||||
curr_lines=$(wc -l < "$out_file" 2>/dev/null || echo 0)
|
||||
|
||||
# 如果文件变小(被截断或重新生成),重置读取位置
|
||||
if (( curr_lines < last_lines )); then last_lines=0; fi
|
||||
|
||||
if (( curr_lines > last_lines )); then
|
||||
# 逐行处理新增内容
|
||||
# 使用进程替换 < <(...) 避免管道导致的子shell变量丢失问题
|
||||
while IFS= read -r line; do
|
||||
|
||||
# 检查 SCF Done
|
||||
# 正则匹配: SCF Done: ... E ... = (数值) A.U.
|
||||
if [[ "$line" =~ SCF[[:space:]]Done:.*E.*=[[:space:]]*([-0-9.]+)[[:space:]]*A\.U\. ]]; then
|
||||
energy="${BASH_REMATCH[1]}"
|
||||
cntSCF=$((cntSCF + 1))
|
||||
echo "$job: SCF Done: $energy [$cntSCF] ($cntDone/$total)"
|
||||
fi
|
||||
|
||||
# 检查 Termination
|
||||
if [[ "$line" == *"termination of Gaussian"* ]]; then
|
||||
cntDone=$((cntDone + 1))
|
||||
echo "$job: task done ($cntDone/$total)"
|
||||
fi
|
||||
|
||||
done < <(tail -n "+$((last_lines + 1))" "$out_file")
|
||||
|
||||
last_lines=$curr_lines
|
||||
fi
|
||||
fi
|
||||
|
||||
sleep 2
|
||||
done
|
||||
|
||||
# C. 最终校验
|
||||
if (( cntDone != total )); then
|
||||
echo "Warning: cntDone ($cntDone) != total ($total)"
|
||||
fi
|
||||
else
|
||||
echo "Could not parse Job ID from output. Monitor skipped."
|
||||
fi
|
||||
19
calc_utils/futils/__init__.py
Normal file
19
calc_utils/futils/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
# from .patch_gaussian import GSUB
|
||||
from .rdkit2ase import MolToAtoms,AtomsToMol
|
||||
from .rdkit_utils import draw2D,draw3D
|
||||
from rdkit.Chem import AllChem
|
||||
from rdkit import Chem
|
||||
from ase import atoms
|
||||
# import patch_gaussian as gaussian
|
||||
|
||||
__all__ = [
|
||||
# 'GSUB',
|
||||
'MolToAtoms',
|
||||
'AtomsToMol',
|
||||
'draw2D',
|
||||
'draw3D',
|
||||
'Chem',
|
||||
'AllChem',
|
||||
'atoms',
|
||||
'gaussian'
|
||||
]
|
||||
367
calc_utils/futils/gaussian.py
Normal file
367
calc_utils/futils/gaussian.py
Normal file
@ -0,0 +1,367 @@
|
||||
from ase.calculators import gaussian
|
||||
from ase.calculators.calculator import FileIOCalculator
|
||||
from ase.io import read, write
|
||||
from typing import TYPE_CHECKING, Optional, Literal
|
||||
from copy import deepcopy
|
||||
|
||||
GSUB = "/home/fanhj/calcs/lele/tools/gsub_wait"
|
||||
gau_src= gaussian.Gaussian
|
||||
gau_dyn_src = gaussian.GaussianDynamics
|
||||
|
||||
methodType = Optional[
|
||||
str
|
||||
| Literal[
|
||||
"HF",
|
||||
"MP2",
|
||||
"MP3",
|
||||
"MP4",
|
||||
"MP4(DQ)",
|
||||
"MP4(SDQ)",
|
||||
"MP5",
|
||||
"CCSD",
|
||||
"CCSDT",
|
||||
"QCISD",
|
||||
"CID",
|
||||
"CISD",
|
||||
"CIS",
|
||||
"B3LYP",
|
||||
"B3PW91",
|
||||
"BLYP",
|
||||
"PBE",
|
||||
"PBE0",
|
||||
"M06",
|
||||
"M062X",
|
||||
"M06L",
|
||||
"M06HF",
|
||||
"CAM-B3LYP",
|
||||
"wb97xd",
|
||||
"wb97xd3",
|
||||
"LC-wPBE",
|
||||
"HSE06",
|
||||
"LSDA",
|
||||
"SVWN",
|
||||
"PW91",
|
||||
"mPW1PW91",
|
||||
"HCTH",
|
||||
"HCTH147",
|
||||
"HCTH407",
|
||||
"TPSSh",
|
||||
"TPSS",
|
||||
"revPBE",
|
||||
"PBEPBE",
|
||||
"B2PLYP",
|
||||
"mPW2PLYP",
|
||||
"B2PLYPD3",
|
||||
"PBE0DH",
|
||||
"PBEQIDH",
|
||||
]
|
||||
]
|
||||
basisType = Optional[
|
||||
str
|
||||
| Literal[
|
||||
"STO-3G",
|
||||
"3-21G",
|
||||
"6-31G",
|
||||
"6-31G(d)",
|
||||
"6-31G(d,p)",
|
||||
"6-31+G(d)",
|
||||
"6-31+G(d,p)",
|
||||
"6-31++G(d,p)",
|
||||
"6-311G",
|
||||
"6-311G(d)",
|
||||
"6-311G(d,p)",
|
||||
"6-311+G(d)",
|
||||
"6-311+G(d,p)",
|
||||
"6-311++G(d,p)",
|
||||
"cc-pVDZ",
|
||||
"cc-pVTZ",
|
||||
"cc-pVQZ",
|
||||
"cc-pV5Z",
|
||||
"cc-pV6Z",
|
||||
"aug-cc-pVDZ",
|
||||
"aug-cc-pVTZ",
|
||||
"aug-cc-pVQZ",
|
||||
"aug-cc-pV5Z",
|
||||
"def2SVP",
|
||||
"def2SVPD",
|
||||
"def2TZVP",
|
||||
"def2TZVPD",
|
||||
"def2QZVP",
|
||||
"def2QZVPP",
|
||||
"LANL2DZ",
|
||||
"LANL2MB",
|
||||
"SDD",
|
||||
"CEP-4G",
|
||||
"CEP-31G",
|
||||
"CEP-121G",
|
||||
"DGDZVP",
|
||||
"DGDZVP2",
|
||||
"Gen",
|
||||
"GenECP",
|
||||
]
|
||||
]
|
||||
|
||||
scrfSolventType = Optional[
|
||||
str
|
||||
| Literal['Water', 'Acetone', 'Acetonitrile', 'Aniline', 'Benzene', 'Bromoform', 'Butanol',
|
||||
'CarbonDisulfide', 'CarbonTetrachloride', 'Chlorobenzene', 'Chloroform', 'Cyclohexane',
|
||||
'Dichloroethane', 'Dichloromethane', 'Diethylether', 'Dimethylformamide', 'Dimethylsulfoxide',
|
||||
'Ethanol', 'Ethylacetate', 'Heptane', 'Hexane', 'Methanol', 'Nitromethane', 'Octanol',
|
||||
'Pyridine', 'Tetrahydrofuran', 'Toluene', 'Xylene'
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
|
||||
class Gaussian(gau_src):
|
||||
mem:int
|
||||
nprocshared:Literal[20,32]
|
||||
charge: int
|
||||
mult: Optional[int]
|
||||
|
||||
def __init__(self,
|
||||
proc:Literal[20,32]=32,
|
||||
charge:int=0,
|
||||
mult:Optional[int]=None,
|
||||
mem="30GB",
|
||||
label='Gaussian',
|
||||
method:methodType=None,
|
||||
basis:basisType=None,
|
||||
fitting_basis:Optional[str]=None,
|
||||
output_type:Literal['N','P']='P',
|
||||
basisfile:Optional[str]=None,
|
||||
basis_set:Optional[str]=None,
|
||||
xc:Optional[str]=None,
|
||||
extra:Optional[str]=None,
|
||||
ioplist:Optional[list[str]]=None,
|
||||
addsec=None,
|
||||
spinlist=None,
|
||||
zefflist=None,
|
||||
qmomlist=None,
|
||||
nmagmlist=None,
|
||||
znuclist=None,
|
||||
radnuclearlist=None,
|
||||
chk:Optional[str]=None,
|
||||
oldchk:Optional[str]=None,
|
||||
nprocshared:Optional[int]=None,
|
||||
scrfSolvent:scrfSolventType=None,
|
||||
scrf:Optional[str]=None,
|
||||
em:Optional[Literal["GD2","GD3","GD4","GD3BJ"]|str]=None,
|
||||
**kwargs):
|
||||
'''
|
||||
Parameters
|
||||
-----------
|
||||
proc: int
|
||||
A short name for nprocshared
|
||||
method: str
|
||||
Level of theory to use, e.g. ``hf``, ``ccsd``, ``mp2``, or ``b3lyp``.
|
||||
Overrides ``xc`` (see below).
|
||||
xc: str
|
||||
Level of theory to use. Translates several XC functionals from
|
||||
their common name (e.g. ``PBE``) to their internal Gaussian name
|
||||
(e.g. ``PBEPBE``).
|
||||
basis: str
|
||||
The basis set to use. If not provided, no basis set will be requested,
|
||||
which usually results in ``STO-3G``. Maybe omitted if basisfile is set
|
||||
(see below).
|
||||
fitting_basis: str
|
||||
The name of the fitting basis set to use.
|
||||
output_type: str
|
||||
Level of output to record in the Gaussian
|
||||
output file - this may be ``N``- normal or ``P`` -
|
||||
additional.
|
||||
basisfile: str
|
||||
The name of the basis file to use. If a value is provided, basis may
|
||||
be omitted (it will be automatically set to 'gen')
|
||||
basis_set: str
|
||||
The basis set definition to use. This is an alternative
|
||||
to basisfile, and would be the same as the contents
|
||||
of such a file.
|
||||
charge: int
|
||||
The system charge. If not provided, it will be automatically
|
||||
determined from the ``Atoms`` object’s initial_charges.
|
||||
mult: int
|
||||
The system multiplicity (``spin + 1``). If not provided, it will be
|
||||
automatically determined from the ``Atoms`` object’s
|
||||
``initial_magnetic_moments``.
|
||||
extra: str
|
||||
Extra lines to be included in the route section verbatim.
|
||||
It should not be necessary to use this, but it is included for
|
||||
backwards compatibility.
|
||||
ioplist: list
|
||||
A collection of IOPs definitions to be included in the route line.
|
||||
addsec: str
|
||||
Text to be added after the molecular geometry specification, e.g. for
|
||||
defining masses with ``freq=ReadIso``.
|
||||
spinlist: list
|
||||
A list of nuclear spins to be added into the nuclear
|
||||
propeties section of the molecule specification.
|
||||
zefflist: list
|
||||
A list of effective charges to be added into the nuclear
|
||||
propeties section of the molecule specification.
|
||||
qmomlist: list
|
||||
A list of nuclear quadropole moments to be added into
|
||||
the nuclear propeties section of the molecule
|
||||
specification.
|
||||
nmagmlist: list
|
||||
A list of nuclear magnetic moments to be added into
|
||||
the nuclear propeties section of the molecule
|
||||
specification.
|
||||
znuclist: list
|
||||
A list of nuclear charges to be added into the nuclear
|
||||
propeties section of the molecule specification.
|
||||
radnuclearlist: list
|
||||
A list of nuclear radii to be added into the nuclear
|
||||
propeties section of the molecule specification.
|
||||
params: dict
|
||||
Contains any extra keywords and values that will be included in either
|
||||
the link0 section or route section of the gaussian input file.
|
||||
To be included in the link0 section, the keyword must be one of the
|
||||
following: ``mem``, ``chk``, ``oldchk``, ``schk``, ``rwf``,
|
||||
``oldmatrix``, ``oldrawmatrix``, ``int``, ``d2e``, ``save``,
|
||||
``nosave``, ``errorsave``, ``cpu``, ``nprocshared``, ``gpucpu``,
|
||||
``lindaworkers``, ``usessh``, ``ssh``, ``debuglinda``.
|
||||
Any other keywords will be placed (along with their values) in the
|
||||
route section.
|
||||
'''
|
||||
if nprocshared is not None and proc is not None:
|
||||
if nprocshared == proc:print("Providing both nprocshared and proc is not recomanded")
|
||||
else:
|
||||
raise ValueError("both nprocshared and proc provided, and inequal.")
|
||||
|
||||
if scrfSolvent is not None and scrf is not None:
|
||||
raise ValueError("scrfSolvent and scrf both not None")
|
||||
if scrfSolvent is not None:
|
||||
scrf = "Solvent="+scrfSolvent
|
||||
|
||||
|
||||
optional_keys = ['chk','oldchk','scrf', 'geom', 'integral', 'density', 'nosymm', 'symmetry', 'units',
|
||||
'temperature', 'pressure', 'counterpoise', 'gfinput', 'gfprint', 'test',
|
||||
'output', 'punch', 'prop', 'pseudo', 'restart', 'scale', 'sparse', 'window', 'em']
|
||||
for key in optional_keys:
|
||||
val = locals().get(key,None)
|
||||
if val is not None:
|
||||
kwargs[key] = val
|
||||
|
||||
super().__init__(
|
||||
nprocshared = proc,
|
||||
mem=mem,
|
||||
label=label,
|
||||
command=GSUB+" "+label,
|
||||
|
||||
charge = charge,
|
||||
mult = mult,
|
||||
|
||||
method = method,
|
||||
basis = basis,
|
||||
fitting_basis = fitting_basis,
|
||||
output_type = output_type,
|
||||
|
||||
basisfile = basisfile,
|
||||
basis_set = basis_set,
|
||||
xc = xc,
|
||||
|
||||
extra = extra,
|
||||
ioplist = ioplist,
|
||||
addsec = addsec,
|
||||
|
||||
spinlist = spinlist,
|
||||
zefflist = zefflist,
|
||||
qmomlist = qmomlist,
|
||||
nmagmlist = nmagmlist,
|
||||
znuclist = znuclist,
|
||||
radnuclearlist = radnuclearlist,
|
||||
|
||||
|
||||
**kwargs
|
||||
)
|
||||
assert self.fileio_rules
|
||||
self.fileio_rules.stdin_name = '{prefix}.gin'
|
||||
|
||||
def mod(self,charge:int=0, mult:int=1) -> "Gaussian":
|
||||
new = deepcopy(self)
|
||||
new.charge = charge
|
||||
new.mult = mult
|
||||
return new
|
||||
|
||||
|
||||
def write_input(self, atoms, properties=None, system_changes=None):
|
||||
FileIOCalculator.write_input(self, atoms, properties, system_changes)
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(self.label, str)
|
||||
assert self.parameters
|
||||
write(self.label + '.gin', atoms, properties=properties,
|
||||
format='gaussian-in', parallel=False, **self.parameters)
|
||||
|
||||
def read_results(self):
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(self.label, str)
|
||||
output = read(self.label + '.out', format='gaussian-out')
|
||||
assert output
|
||||
self.calc = output.calc
|
||||
self.results = output.calc.results
|
||||
|
||||
|
||||
class GaussianDynamics(gau_dyn_src):
|
||||
def __init__(self, atoms, calc=None):
|
||||
super().__init__(atoms, calc=calc)
|
||||
|
||||
def run(self, **kwargs):
|
||||
calc_old = self.atoms.calc
|
||||
params_old = deepcopy(self.calc.parameters)
|
||||
|
||||
self.delete_keywords(kwargs)
|
||||
self.delete_keywords(self.calc.parameters)
|
||||
self.set_keywords(kwargs)
|
||||
|
||||
self.calc.set(**kwargs)
|
||||
self.atoms.calc = self.calc
|
||||
|
||||
try:
|
||||
self.atoms.get_potential_energy()
|
||||
except OSError:
|
||||
converged = False
|
||||
else:
|
||||
converged = True
|
||||
|
||||
atoms = read(self.calc.label + '.out')
|
||||
self.atoms.cell = atoms.cell
|
||||
self.atoms.positions = atoms.positions
|
||||
self.atoms.calc = atoms.calc
|
||||
|
||||
self.calc.parameters = params_old
|
||||
self.calc.reset()
|
||||
if calc_old is not None:
|
||||
self.atoms.calc = calc_old
|
||||
|
||||
return converged
|
||||
|
||||
|
||||
class GaussianOptimizer(GaussianDynamics):
|
||||
keyword = 'opt'
|
||||
special_keywords = {
|
||||
'fmax': '{}',
|
||||
'steps': 'maxcycle={}',
|
||||
}
|
||||
|
||||
|
||||
class GaussianIRC(GaussianDynamics):
|
||||
keyword = 'irc'
|
||||
special_keywords = {
|
||||
'direction': '{}',
|
||||
'steps': 'maxpoints={}',
|
||||
}
|
||||
|
||||
|
||||
gaussian.Gaussian = Gaussian
|
||||
gaussian.GaussianDynamics = GaussianDynamics
|
||||
gaussian.GaussianOptimizer = GaussianOptimizer
|
||||
gaussian.GaussianIRC = GaussianIRC
|
||||
|
||||
__all__ = [
|
||||
'Gaussian',
|
||||
'GaussianDynamics',
|
||||
'GaussianOptimizer',
|
||||
'GaussianIRC',
|
||||
'GSUB'
|
||||
]
|
||||
145
calc_utils/futils/rdkit2ase.py
Normal file
145
calc_utils/futils/rdkit2ase.py
Normal file
@ -0,0 +1,145 @@
|
||||
from rdkit import Chem
|
||||
from ase import atoms
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
from rdkit.Geometry import Point3D
|
||||
|
||||
|
||||
def MolToAtoms(mol: Chem.Mol, confID=-1) -> atoms.Atoms:
|
||||
conf = mol.GetConformer(confID)
|
||||
symbols = [atom.GetSymbol() for atom in mol.GetAtoms()]
|
||||
positions = [
|
||||
(conf.GetAtomPosition(i).x,
|
||||
conf.GetAtomPosition(i).y,
|
||||
conf.GetAtomPosition(i).z)
|
||||
for i in range(mol.GetNumAtoms())
|
||||
]
|
||||
aseAtoms = atoms.Atoms(symbols=symbols, positions=positions)
|
||||
charges = [atom.GetFormalCharge() for atom in mol.GetAtoms()]
|
||||
aseAtoms.set_initial_charges(charges)
|
||||
|
||||
return aseAtoms
|
||||
|
||||
|
||||
def AtomsToMol(
|
||||
atoms: atoms.Atoms,
|
||||
mol: Optional[Chem.Mol] = None,
|
||||
conf_id: int = 0,
|
||||
charge: Optional[int] = None,
|
||||
allow_reorder: bool = False,
|
||||
inplace = False
|
||||
) -> Chem.Mol:
|
||||
"""
|
||||
Convert ASE Atoms -> RDKit Mol.
|
||||
|
||||
If mol is provided:
|
||||
- verify natoms and element symbols match (unless allow_reorder=True)
|
||||
- update (or add) conformer coordinates so mol matches atoms
|
||||
|
||||
If mol is None:
|
||||
- create a new Mol with atoms only, set 3D coords
|
||||
- Determine bonds from geometry using rdDetermineBonds.DetermineBonds()
|
||||
|
||||
Parameters
|
||||
----------
|
||||
atoms : ase.Atoms
|
||||
Must have positions (Å).
|
||||
mol : rdkit.Chem.Mol | None
|
||||
Optional template mol.
|
||||
conf_id : int
|
||||
Conformer id to update/use. If mol has no conformer, one will be added.
|
||||
charge : int | None
|
||||
Total molecular charge used by DetermineBonds (recommended if known).
|
||||
If None, will try to infer from formal charges when mol is provided;
|
||||
otherwise defaults to 0 for new mol.
|
||||
allow_reorder : bool
|
||||
If True, will not enforce symbol-by-index equality (only checks counts).
|
||||
Most of the time you want False to guarantee consistency.
|
||||
inplace : bool
|
||||
If True, will modify input mol instead of make a copy.
|
||||
|
||||
Returns
|
||||
-------
|
||||
rdkit.Chem.Mol
|
||||
"""
|
||||
positions = np.asarray(atoms.get_positions(), dtype=float)
|
||||
symbols = list(atoms.get_chemical_symbols())
|
||||
n = len(symbols)
|
||||
|
||||
# -------- case 1: update existing mol --------
|
||||
if mol is not None:
|
||||
if mol.GetNumAtoms() != n:
|
||||
raise ValueError(f"mol has {mol.GetNumAtoms()} atoms but ASE atoms has {n}.")
|
||||
|
||||
mol_symbols = [a.GetSymbol() for a in mol.GetAtoms()]
|
||||
if not allow_reorder:
|
||||
if mol_symbols != symbols:
|
||||
raise ValueError(
|
||||
"Element symbols mismatch by index between mol and atoms.\n"
|
||||
f"mol: {mol_symbols}\n"
|
||||
f"atoms: {symbols}\n"
|
||||
"If you REALLY know what you're doing, set allow_reorder=True "
|
||||
"and handle mapping yourself."
|
||||
)
|
||||
else:
|
||||
# only check multiset counts
|
||||
if sorted(mol_symbols) != sorted(symbols):
|
||||
raise ValueError("Element symbol counts differ between mol and atoms.")
|
||||
|
||||
if not inplace:
|
||||
mol = Chem.Mol(mol) # copy
|
||||
if mol.GetNumConformers() == 0:
|
||||
conf = Chem.Conformer(n)
|
||||
conf.Set3D(True)
|
||||
mol.AddConformer(conf, assignId=True)
|
||||
|
||||
# pick conformer
|
||||
try:
|
||||
conf = mol.GetConformer(conf_id)
|
||||
except ValueError:
|
||||
# create new conformer if requested id doesn't exist
|
||||
conf = Chem.Conformer(n)
|
||||
conf.Set3D(True)
|
||||
mol.AddConformer(conf, assignId=True)
|
||||
conf = mol.GetConformer(mol.GetNumConformers() - 1)
|
||||
if conf_id!=0:
|
||||
print("Warning: Failed to pick conformer.")
|
||||
|
||||
for i in range(n):
|
||||
x, y, z = map(float, positions[i])
|
||||
conf.SetAtomPosition(i, Point3D(x, y, z))
|
||||
|
||||
# charge inference if not given
|
||||
if charge is None:
|
||||
charge = int(sum(a.GetFormalCharge() for a in mol.GetAtoms()))
|
||||
|
||||
return mol
|
||||
|
||||
# -------- case 2: build mol + determine bonds --------
|
||||
rw = Chem.RWMol()
|
||||
for sym in symbols:
|
||||
rw.AddAtom(Chem.Atom(sym))
|
||||
|
||||
new_mol = rw.GetMol()
|
||||
conf = Chem.Conformer(n)
|
||||
conf.Set3D(True)
|
||||
for i in range(n):
|
||||
x, y, z = map(float, positions[i])
|
||||
conf.SetAtomPosition(i, Point3D(x, y, z))
|
||||
new_mol.AddConformer(conf, assignId=True)
|
||||
|
||||
# Determine bonds from geometry
|
||||
if charge is None:
|
||||
charge = 0
|
||||
|
||||
try:
|
||||
from rdkit.Chem import rdDetermineBonds
|
||||
rdDetermineBonds.DetermineBonds(new_mol, charge=charge)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"DetermineBonds failed. This can happen for metals/ions/odd geometries.\n"
|
||||
"Consider providing a template mol, or implement custom distance-based bonding.\n"
|
||||
f"Original error: {e}"
|
||||
)
|
||||
|
||||
return new_mol
|
||||
14
calc_utils/futils/rdkit_utils.py
Normal file
14
calc_utils/futils/rdkit_utils.py
Normal file
@ -0,0 +1,14 @@
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem import Draw
|
||||
from IPython.display import SVG
|
||||
from rdkit.Chem.Draw import IPythonConsole
|
||||
|
||||
|
||||
def draw2D(mol:Chem.Mol,confId:int=-1):
|
||||
d = Draw.MolDraw2DSVG(250, 200)
|
||||
d.drawOptions().addAtomIndices = True
|
||||
d.DrawMolecule(mol,confId=confId)
|
||||
d.FinishDrawing()
|
||||
return SVG(d.GetDrawingText())
|
||||
|
||||
draw3D = lambda m3d,confId=-1: IPythonConsole.drawMol3D(m3d,confId=confId)
|
||||
10
calc_utils/pyproject.toml
Normal file
10
calc_utils/pyproject.toml
Normal file
@ -0,0 +1,10 @@
|
||||
[project]
|
||||
name = "calc-utils"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"ase>=3.27.0",
|
||||
"rdkit>=2025.9.3",
|
||||
]
|
||||
Reference in New Issue
Block a user