I have a big simulation with 100+ variables and I do not want to change all my variables to session.state.variable
.
A easy solution which I developed is to loop through pythons globals()
dict. Here are always all variables stored in (also in your script right now).
When the user clicks on Pause, I save variables in session.state
with a simple callback function in the button, which is executed before the streamlits rerun.
When streamlit reruns my simulation I simply load all variables from session.state
to globals()
, to have all variables back.
import streamlit as st
from time import sleep
def save_session_state(var_to_save):
for var in var_to_save:
st.session_state[var] = globals()[var]
def delete_session_state():
st.session_state.clear()
def load_session_state():
for var in st.session_state:
globals()[var] = st.session_state[var]
def pause_simulation(variables_to_save):
global pause
pause = not pause
save_session_state(variables_to_save)
print(f'saved {count} Pause: {pause}')
# Variables to save in session state
count = 0
other_var = 'blabla'
if 'pause' not in st.session_state:
# initialize pause to False
pause = False
else:
# load state after e.g. user clicked a button
load_session_state()
# specify which variables to save in session state
variables_to_save = ['count', 'other_var', 'pause']
btn_text = "Pause" if not pause else "Continue"
pause_bt = st.button(btn_text, on_click=pause_simulation, args=(variables_to_save,))
rerun_bt = st.button("Rerun", on_click=delete_session_state)
# simulation loop
while count < 100:
# display everything
count += 1
st.write(f"count: {count} Pause: {pause}")
sleep(1)
# pause simulation until next reload
if pause:
print('clicked')
st.write(f"Simulation paused at iteration: {count}")
break
If you really want to store all variables, you donβt need variables_to_save
, but you can simply iterate over globals() to store all, and exclude some unnecessary stuff like:
def save_session_state(skip_names=[], skip_types=[]):
print('----------------------------------')
print('save_session_state')
for var, value in globals().items():
if var in skip_names:
continue
if var.startswith('__'):
continue
if isinstance(value, (tuple(skip_types))):
print('skipped', var, type(value))
continue
print('saved', var, type(value))
st.session_state[var] = value
def load_session_state(skip_names=[], skip_types=[]):
print('----------------------------------')
print('load_session_state')
for var in st.session_state:
if var in skip_names:
continue
if var.startswith('__'):
continue
if isinstance(st.session_state[var], (tuple(skip_types))):
print('skipped', var, type(st.session_state[var]))
continue
print('loaded', var, type(st.session_state[var]))
globals()[var] = st.session_state[var]
skip_types = [types.FunctionType, matplotlib.figure.Figure, Axes, types.ModuleType, types.BuiltinFunctionType, st.delta_generator.DeltaGenerator]
skip_variables = ["pause_bt", "_", "rerun_bt", "uploaded_file", "download-csv", "csv", "Normalize", "fileUploaderField", "AxesSubplot"]