diff --git a/cord/main.py b/cord/main.py index 2d04204..6b7ea2f 100644 --- a/cord/main.py +++ b/cord/main.py @@ -60,7 +60,6 @@ def plot_reaction_coordinate(changed=None, _lines=None): last=(-1,-1) - maxy = data["Energy"].max() miny = data["Energy"].min() varyy = maxy - miny @@ -87,6 +86,11 @@ def plot_reaction_coordinate(changed=None, _lines=None): ax1.xaxis.set_ticks([]) ax1.set_ylabel("Energy (kcal/mol)") ax1.set_ylim(miny-varyy*0.1, maxy+varyy*0.1) + + if st.session_state.get("xylim", None) is not None: + ax1.set_xlim(st.session_state["xmin"], st.session_state["xmax"]) + ax1.set_ylim(st.session_state["ymin"], st.session_state["ymax"]) + return fig,lines # 创建图形和坐标轴 @@ -105,6 +109,7 @@ def callback_gen(x,typ=0): return callback + def on_save(): global out_file # for slider in slides: @@ -122,6 +127,9 @@ def load_data(file): if file is not None: try: data = pd.read_excel(file) if file.name.endswith((".xlsx", ".xls")) else pd.read_csv(file) + if data.columns.tolist() != ["Name", "Energy"]: + st.warning("Format should be Name, Energy. Modified automatically.") + data.columns = ["Name", "Energy"] except Exception as e: st.error(f"Error reading file: {e}") exit() @@ -133,7 +141,6 @@ def load_data(file): ene = data["Energy"].to_numpy() K_POS = np.where(ene[1:]>ene[:1],0.03,-0.05) K_POS = [-0.05] + K_POS.tolist() - st.info(K_POS) data["Energy"] -= data["Energy"][0] data["Energy"]*=627.509 @@ -145,55 +152,89 @@ out_file = io.BytesIO() st.set_page_config( page_title="反应坐标绘制", page_icon=":chart_with_upwards_trend:", - layout="wide", initial_sidebar_state="expanded" ) st.title("反应坐标绘制") st.write("---") -col1,col2,col3 = st.columns([0.4,0.25,0.25],gap="medium") +file = st.file_uploader("上传能量文件", type=["xlsx", "xls", "csv"],key="file") +if not file: + st.set_page_config(layout="centered") + st.write("按照下列格式上传表格。请保证列名和范例一致,或直接下载。") + st.warning("注意,Energy单位为Hatree,程序将自动转换为kcal/mol的相对能量") + example = pd.DataFrame({"Name":["reactant","TS","result"], "Energy":[-400.310327,-400.210017,-400.341576,]}) + st.dataframe(example,hide_index=True) + tmp_file = io.BytesIO() + example.to_excel(tmp_file, index=False) + st.download_button("下载模板",data=tmp_file,file_name="reaction_coordinate_example.xlsx") + st.stop() +else: + st.set_page_config(layout="wide") +col1,col2 = st.columns([0.4,0.6],gap="medium") -with col1: - file = st.file_uploader("上传能量文件", type=["xlsx", "xls", "csv"]) + +with col2: + data, INFLU_FACTORS,K_POS = load_data(file) fig,lines = plot_reaction_coordinate() stfig = st.pyplot(fig,False) + +with col1: + st.title("作图参数设置") + with st.expander("调整曲线形状(贝塞尔参数)"): + for i in range(data.shape[0]): + if i!=0: + st.slider( + f'{data.loc[i,"Name"]} 左', + 0.0, 1.0, value=INFLU_FACTORS[i*2-1], + key=f'slider_{i*2}', + on_change=callback_gen(i*2) + ) + if i!= data.shape[0] - 1: + st.slider( + f'{data.loc[i,"Name"]} 右', + 0.0, 1.0, value=INFLU_FACTORS[i*2], + key=f'slider_{i*2+1}', + on_change=callback_gen(i*2+1) + ) + with st.expander("调整文字位置"): + for i in range(data.shape[0]): + st.slider( + f'{data.loc[i,"Name"]}', + -0.1, 0.1, value=K_POS[i], + key=f'text_slider_{i}', + on_change=callback_gen(i,1) + ) + + with st.expander("调整坐标系极限"): + if st.session_state.get("xylim", None) is None: + xmin,xmax = plt.xlim() + ymin,ymax = plt.ylim() + st.session_state["xylim"] = (xmin,xmax,ymin,ymax) + st.info(st.session_state["xylim"]) + + xmin,xmax,ymin,ymax = st.session_state["xylim"] + dxmin,dxmax,dymin,dymax = abs(xmin)*0.5,abs(xmax)*0.5,abs(ymin)*0.5,abs(ymax)*0.5 + + st.slider("x min", xmin-dxmin, xmin+dxmin,value=xmin, key="xmin") + st.slider("x max", xmax-dxmax, xmax+dxmax,value=xmax, key="xmax") + st.slider("y min", ymin-dymin, ymin+dymin,value=ymin, key="ymin") + st.slider("y max", ymax-dymax, ymax+dymax,value=ymax, key="ymax") + + with st.expander("导出"): + st.selectbox("导出文件拓展名",[".tiff",".pdf",".png",".pgf"],key="export_format") + btn = st.button("生成文件") + if btn: + st.download_button( + label="下载", + data=on_save(), + file_name="reaction_coordinate"+st.session_state.get("export_format", ".tiff"), + # mime="image/tiff" + ) + st.slider("字体大小",8,20, value=12, key="font_size", on_change=lambda: plt.rcParams.update({'font.size': st.session_state.get("font_size", 12)})) - st.selectbox("导出文件拓展名",[".tiff",".pdf",".png",".pgf"],key="export_format") - st.download_button( - label="Download Plot", - data=on_save(), - file_name="reaction_coordinate"+st.session_state.get("export_format", ".tiff"), - # mime="image/tiff" - ) -with col2: - st.write("调整滑块以改变反应坐标图曲线形状。") - for i in range(data.shape[0]): - if i!=0: - st.slider( - f'{data.loc[i,"Name"]} 左', - 0.0, 1.0, value=INFLU_FACTORS[i*2-1], - key=f'slider_{i*2}', - on_change=callback_gen(i*2) - ) - if i!= data.shape[0] - 1: - st.slider( - f'{data.loc[i,"Name"]} 右', - 0.0, 1.0, value=INFLU_FACTORS[i*2], - key=f'slider_{i*2+1}', - on_change=callback_gen(i*2+1) - ) -with col3: - st.write("调整参数以改变文字位置。") - for i in range(data.shape[0]): - st.slider( - f'{data.loc[i,"Name"]}', - -0.1, 0.1, value=K_POS[i], - key=f'text_slider_{i}', - on_change=callback_gen(i,1) - ) st.write("---") st.dataframe(data)