0
0
forked from flt/tools
Files
pdf_unlock/cord/main.py
2025-08-28 22:23:42 +08:00

241 lines
8.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
import io
import streamlit as st
# import scienceplots
# plt.style.use(['nature', 'no-latex',"cjk-sc-font"])
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置中文字体
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
def cubic_bezier_with_zero_derivatives(p0, p1, t_array, influence_factor):
"""
创建三次贝塞尔曲线确保起点和终点的导数为0
参数:
p0: 起点 [x0, y0]
p1: 终点 [x1, y1]
t_array: 参数数组 (0到1)
influence_factor: 影响因子,控制控制点的位置
返回:
x_array, y_array: 贝塞尔曲线上的点
"""
x0, y0 = p0
x1, y1 = p1
# 计算控制点确保起点和终点导数为0
# 控制点位置基于影响因子和两点间距离
dx = x1 - x0
# 第一个控制点在起点右侧y坐标与起点相同确保起点导数为0
p1_control = [x0 + dx * influence_factor[0], y0]
# 第二个控制点在终点左侧y坐标与终点相同确保终点导数为0
p2_control = [x1 - dx * influence_factor[1], y1]
# 计算贝塞尔曲线
x_bezier = ((1-t_array)**3 * x0 +
3*(1-t_array)**2 * t_array * p1_control[0] +
3*(1-t_array) * t_array**2 * p2_control[0] +
t_array**3 * x1)
y_bezier = ((1-t_array)**3 * y0 +
3*(1-t_array)**2 * t_array * p1_control[1] +
3*(1-t_array) * t_array**2 * p2_control[1] +
t_array**3 * y1)
return x_bezier, y_bezier
# @st.cache_resource
def plot_reaction_coordinate(changed=None, _lines=None):
"""
绘制反应坐标图
"""
lines = []
fig,ax1 = plt.subplots(figsize=(9, 6))
last=(-1,-1)
maxy = data["Energy"].max()
miny = data["Energy"].min()
varyy = maxy - miny
for i in range(data.shape[0]):
line:pd.Series = data.loc[i]
if last == (-1,-1):
last = (1, line["Energy"])
if not pd.isna(line["Name"]):
ax1.annotate(str(line["Name"]), (1, line["Energy"]+varyy*K_POS[i]), ha='center')
else:
p1 = last[0]+2,line["Energy"]
x,y = cubic_bezier_with_zero_derivatives(last,p1, np.linspace(0, 1, 300), INFLU_FACTORS[(i*2-2):i*2])
l = ax1.plot(x, y, "-", color="black")[0]
lines.append(l)
if not pd.isna(line["Name"]):
p = p1[0],p1[1]+varyy*K_POS[i]
ax1.annotate(str(line["Name"]), p, ha='center')
last = p1
ax1.set_xlabel("Reaction Coordinate")
ax1.xaxis.set_ticks([])
ax1.set_ylabel("Energy (kcal/mol)")
ax1.set_ylim(miny-varyy*0.1, maxy+varyy*0.1)
if st.session_state.get("xylim", None) is not None:
ax1.set_xlim(st.session_state["xmin"], st.session_state["xmax"])
ax1.set_ylim(st.session_state["ymin"], st.session_state["ymax"])
return fig,lines
# 创建图形和坐标轴
def callback_gen(x,typ=0):
if typ:
def callback():
global K_POS
K_POS[x] = st.session_state.get(f'text_slider_{x}', 0.05)
plot_reaction_coordinate(changed=x, _lines=lines)
else:
def callback():
global INFLU_FACTORS
INFLU_FACTORS[x-1] = st.session_state.get(f'slider_{x}', 0.5)
plot_reaction_coordinate(changed=x, _lines=lines)
return callback
def on_save():
global out_file
# for slider in slides:
# slider.ax.set_visible(False)
plt.draw()
plt.tight_layout()
out_file = io.BytesIO()
fig.savefig(out_file, format=st.session_state.get("export_format", ".tiff")[1:], dpi=300, bbox_inches='tight')
out_file.seek(0)
return out_file.getvalue()
@st.cache_resource
def load_data(file):
# 读取数据文件
if file is not None:
try:
data = pd.read_excel(file) if file.name.endswith((".xlsx", ".xls")) else pd.read_csv(file)
if data.columns.tolist() != ["Name", "Energy"]:
st.warning("Format should be Name, Energy. Modified automatically.")
data.columns = ["Name", "Energy"]
except Exception as e:
st.error(f"Error reading file: {e}")
exit()
else:
exit()
INFLU_FACTORS = [0.5] * data.shape[0] * 2 # 动态创建数组
ene = data["Energy"].to_numpy()
K_POS = np.where(ene[1:]>ene[:1],0.03,-0.05)
K_POS = [-0.05] + K_POS.tolist()
data["Energy"] -= data["Energy"][0]
data["Energy"]*=627.509
return data, INFLU_FACTORS,K_POS
out_file = io.BytesIO()
st.set_page_config(
page_title="反应坐标绘制",
page_icon=":chart_with_upwards_trend:",
initial_sidebar_state="expanded"
)
st.title("反应坐标绘制")
st.write("---")
file = st.file_uploader("上传能量文件", type=["xlsx", "xls", "csv"],key="file")
if not file:
st.set_page_config(layout="centered")
st.write("按照下列格式上传表格。请保证列名和范例一致,或直接下载。")
st.warning("注意Energy单位为Hatree程序将自动转换为kcal/mol的相对能量")
example = pd.DataFrame({"Name":["reactant","TS","result"], "Energy":[-400.310327,-400.210017,-400.341576,]})
st.dataframe(example,hide_index=True)
tmp_file = io.BytesIO()
example.to_excel(tmp_file, index=False)
st.download_button("下载模板",data=tmp_file,file_name="reaction_coordinate_example.xlsx")
st.stop()
else:
st.set_page_config(layout="wide")
col1,col2 = st.columns([0.4,0.6],gap="medium")
with col2:
data, INFLU_FACTORS,K_POS = load_data(file)
fig,lines = plot_reaction_coordinate()
stfig = st.pyplot(fig,False)
with col1:
st.title("作图参数设置")
with st.expander("调整曲线形状(贝塞尔参数)"):
for i in range(data.shape[0]):
if i!=0:
st.slider(
f'{data.loc[i,"Name"]}',
0.0, 1.0, value=INFLU_FACTORS[i*2-1],
key=f'slider_{i*2}',
on_change=callback_gen(i*2)
)
if i!= data.shape[0] - 1:
st.slider(
f'{data.loc[i,"Name"]}',
0.0, 1.0, value=INFLU_FACTORS[i*2],
key=f'slider_{i*2+1}',
on_change=callback_gen(i*2+1)
)
with st.expander("调整文字位置"):
for i in range(data.shape[0]):
st.slider(
f'{data.loc[i,"Name"]}',
-0.1, 0.1, value=K_POS[i],
key=f'text_slider_{i}',
on_change=callback_gen(i,1)
)
with st.expander("调整坐标系极限"):
if st.session_state.get("xylim", None) is None:
xmin,xmax = plt.xlim()
ymin,ymax = plt.ylim()
st.session_state["xylim"] = (xmin,xmax,ymin,ymax)
st.info(st.session_state["xylim"])
xmin,xmax,ymin,ymax = st.session_state["xylim"]
dxmin,dxmax,dymin,dymax = abs(xmin)*0.5,abs(xmax)*0.5,abs(ymin)*0.5,abs(ymax)*0.5
st.slider("x min", xmin-dxmin, xmin+dxmin,value=xmin, key="xmin")
st.slider("x max", xmax-dxmax, xmax+dxmax,value=xmax, key="xmax")
st.slider("y min", ymin-dymin, ymin+dymin,value=ymin, key="ymin")
st.slider("y max", ymax-dymax, ymax+dymax,value=ymax, key="ymax")
with st.expander("导出"):
st.selectbox("导出文件拓展名",[".tiff",".pdf",".png",".pgf"],key="export_format")
btn = st.button("生成文件")
if btn:
st.download_button(
label="下载",
data=on_save(),
file_name="reaction_coordinate"+st.session_state.get("export_format", ".tiff"),
# mime="image/tiff"
)
st.slider("字体大小",8,20, value=12, key="font_size",
on_change=lambda: plt.rcParams.update({'font.size': st.session_state.get("font_size", 12)}))
st.write("---")
st.dataframe(data)