diff --git a/cord/main.py b/cord/main.py index 6b7ea2f..2b2ef56 100644 --- a/cord/main.py +++ b/cord/main.py @@ -124,16 +124,16 @@ def on_save(): @st.cache_resource def load_data(file): # 读取数据文件 - if file is not None: - try: + try: + if st.session_state.get("use_example", False): + data = create_example()[0] + else: data = pd.read_excel(file) if file.name.endswith((".xlsx", ".xls")) else pd.read_csv(file) - if data.columns.tolist() != ["Name", "Energy"]: - st.warning("Format should be Name, Energy. Modified automatically.") - data.columns = ["Name", "Energy"] - except Exception as e: - st.error(f"Error reading file: {e}") - exit() - else: + if data.columns.tolist() != ["Name", "Energy"]: + st.warning("Format should be Name, Energy. Modified automatically.") + data.columns = ["Name", "Energy"] + except Exception as e: + st.error(f"Error reading file: {e}") exit() INFLU_FACTORS = [0.5] * data.shape[0] * 2 # 动态创建数组 @@ -147,34 +147,52 @@ def load_data(file): return data, INFLU_FACTORS,K_POS +@st.cache_data +def create_example(): + tmp_file = io.BytesIO() + example = pd.DataFrame({"Name":["reactant","TS","result"], "Energy":[-400.310327,-400.210017,-400.341576,]}) + example.to_excel(tmp_file, index=False) + return example,tmp_file + out_file = io.BytesIO() + st.set_page_config( page_title="反应坐标绘制", page_icon=":chart_with_upwards_trend:", - initial_sidebar_state="expanded" + initial_sidebar_state="expanded",layout="wide" ) st.title("反应坐标绘制") st.write("---") + file = st.file_uploader("上传能量文件", type=["xlsx", "xls", "csv"],key="file") -if not file: - st.set_page_config(layout="centered") + +if not file and not st.session_state.get("use_example", False) and "datas" not in st.session_state: + # st.set_page_config(layout="centered") st.write("按照下列格式上传表格。请保证列名和范例一致,或直接下载。") st.warning("注意,Energy单位为Hatree,程序将自动转换为kcal/mol的相对能量") - example = pd.DataFrame({"Name":["reactant","TS","result"], "Energy":[-400.310327,-400.210017,-400.341576,]}) + example,tmp_file = create_example() st.dataframe(example,hide_index=True) - tmp_file = io.BytesIO() - example.to_excel(tmp_file, index=False) st.download_button("下载模板",data=tmp_file,file_name="reaction_coordinate_example.xlsx") + def use_tmp(): + global file + st.session_state["use_example"] = True + file = tmp_file + st.button("使用样例使用",on_click=use_tmp) st.stop() else: - st.set_page_config(layout="wide") + pass + # st.set_page_config(layout="wide") col1,col2 = st.columns([0.4,0.6],gap="medium") with col2: - data, INFLU_FACTORS,K_POS = load_data(file) + if "datas" not in st.session_state: + data, INFLU_FACTORS,K_POS = load_data(file) + st.session_state["datas"] = (data, INFLU_FACTORS,K_POS) + else: + data, INFLU_FACTORS, K_POS = st.session_state["datas"] fig,lines = plot_reaction_coordinate() stfig = st.pyplot(fig,False) @@ -211,7 +229,6 @@ with col1: xmin,xmax = plt.xlim() ymin,ymax = plt.ylim() st.session_state["xylim"] = (xmin,xmax,ymin,ymax) - st.info(st.session_state["xylim"]) xmin,xmax,ymin,ymax = st.session_state["xylim"] dxmin,dxmax,dymin,dymax = abs(xmin)*0.5,abs(xmax)*0.5,abs(ymin)*0.5,abs(ymax)*0.5 @@ -234,7 +251,9 @@ with col1: st.slider("字体大小",8,20, value=12, key="font_size", on_change=lambda: plt.rcParams.update({'font.size': st.session_state.get("font_size", 12)})) - + if st.button("重置",type="primary"): + st.session_state.clear() + st.rerun() st.write("---") st.dataframe(data)