cord example

This commit is contained in:
2025-08-28 22:57:07 +08:00
parent c2fd6857cc
commit e70435e807

View File

@ -124,8 +124,10 @@ def on_save():
@st.cache_resource @st.cache_resource
def load_data(file): def load_data(file):
# 读取数据文件 # 读取数据文件
if file is not None:
try: try:
if st.session_state.get("use_example", False):
data = create_example()[0]
else:
data = pd.read_excel(file) if file.name.endswith((".xlsx", ".xls")) else pd.read_csv(file) data = pd.read_excel(file) if file.name.endswith((".xlsx", ".xls")) else pd.read_csv(file)
if data.columns.tolist() != ["Name", "Energy"]: if data.columns.tolist() != ["Name", "Energy"]:
st.warning("Format should be Name, Energy. Modified automatically.") st.warning("Format should be Name, Energy. Modified automatically.")
@ -133,8 +135,6 @@ def load_data(file):
except Exception as e: except Exception as e:
st.error(f"Error reading file: {e}") st.error(f"Error reading file: {e}")
exit() exit()
else:
exit()
INFLU_FACTORS = [0.5] * data.shape[0] * 2 # 动态创建数组 INFLU_FACTORS = [0.5] * data.shape[0] * 2 # 动态创建数组
@ -147,34 +147,52 @@ def load_data(file):
return data, INFLU_FACTORS,K_POS return data, INFLU_FACTORS,K_POS
@st.cache_data
def create_example():
tmp_file = io.BytesIO()
example = pd.DataFrame({"Name":["reactant","TS","result"], "Energy":[-400.310327,-400.210017,-400.341576,]})
example.to_excel(tmp_file, index=False)
return example,tmp_file
out_file = io.BytesIO() out_file = io.BytesIO()
st.set_page_config( st.set_page_config(
page_title="反应坐标绘制", page_title="反应坐标绘制",
page_icon=":chart_with_upwards_trend:", page_icon=":chart_with_upwards_trend:",
initial_sidebar_state="expanded" initial_sidebar_state="expanded",layout="wide"
) )
st.title("反应坐标绘制") st.title("反应坐标绘制")
st.write("---") st.write("---")
file = st.file_uploader("上传能量文件", type=["xlsx", "xls", "csv"],key="file") file = st.file_uploader("上传能量文件", type=["xlsx", "xls", "csv"],key="file")
if not file:
st.set_page_config(layout="centered") if not file and not st.session_state.get("use_example", False) and "datas" not in st.session_state:
# st.set_page_config(layout="centered")
st.write("按照下列格式上传表格。请保证列名和范例一致,或直接下载。") st.write("按照下列格式上传表格。请保证列名和范例一致,或直接下载。")
st.warning("注意Energy单位为Hatree程序将自动转换为kcal/mol的相对能量") st.warning("注意Energy单位为Hatree程序将自动转换为kcal/mol的相对能量")
example = pd.DataFrame({"Name":["reactant","TS","result"], "Energy":[-400.310327,-400.210017,-400.341576,]}) example,tmp_file = create_example()
st.dataframe(example,hide_index=True) st.dataframe(example,hide_index=True)
tmp_file = io.BytesIO()
example.to_excel(tmp_file, index=False)
st.download_button("下载模板",data=tmp_file,file_name="reaction_coordinate_example.xlsx") st.download_button("下载模板",data=tmp_file,file_name="reaction_coordinate_example.xlsx")
def use_tmp():
global file
st.session_state["use_example"] = True
file = tmp_file
st.button("使用样例使用",on_click=use_tmp)
st.stop() st.stop()
else: else:
st.set_page_config(layout="wide") pass
# st.set_page_config(layout="wide")
col1,col2 = st.columns([0.4,0.6],gap="medium") col1,col2 = st.columns([0.4,0.6],gap="medium")
with col2: with col2:
if "datas" not in st.session_state:
data, INFLU_FACTORS,K_POS = load_data(file) data, INFLU_FACTORS,K_POS = load_data(file)
st.session_state["datas"] = (data, INFLU_FACTORS,K_POS)
else:
data, INFLU_FACTORS, K_POS = st.session_state["datas"]
fig,lines = plot_reaction_coordinate() fig,lines = plot_reaction_coordinate()
stfig = st.pyplot(fig,False) stfig = st.pyplot(fig,False)
@ -211,7 +229,6 @@ with col1:
xmin,xmax = plt.xlim() xmin,xmax = plt.xlim()
ymin,ymax = plt.ylim() ymin,ymax = plt.ylim()
st.session_state["xylim"] = (xmin,xmax,ymin,ymax) st.session_state["xylim"] = (xmin,xmax,ymin,ymax)
st.info(st.session_state["xylim"])
xmin,xmax,ymin,ymax = st.session_state["xylim"] xmin,xmax,ymin,ymax = st.session_state["xylim"]
dxmin,dxmax,dymin,dymax = abs(xmin)*0.5,abs(xmax)*0.5,abs(ymin)*0.5,abs(ymax)*0.5 dxmin,dxmax,dymin,dymax = abs(xmin)*0.5,abs(xmax)*0.5,abs(ymin)*0.5,abs(ymax)*0.5
@ -234,7 +251,9 @@ with col1:
st.slider("字体大小",8,20, value=12, key="font_size", st.slider("字体大小",8,20, value=12, key="font_size",
on_change=lambda: plt.rcParams.update({'font.size': st.session_state.get("font_size", 12)})) on_change=lambda: plt.rcParams.update({'font.size': st.session_state.get("font_size", 12)}))
if st.button("重置",type="primary"):
st.session_state.clear()
st.rerun()
st.write("---") st.write("---")
st.dataframe(data) st.dataframe(data)