So I managed to get my app working on Streamlit Sharing but it will crash after sliding or clicking options a few times. Whenever I slide to a new value, the app refreshes (which I assume it will run the entire script again), and the SHAP values get recomputed again based on the new data. Everytime it does so, memory usage increases by about 500mb!!! Which I assume causes the crash. Without computing the SHAP values, my app runs smoothly without very high memory increase.
I am wondering why this is happening? My model loading is cached:
## LOAD TRAINED RANDOM FOREST MODEL
cloud_model_location = '1PkTZnHK_K4LBTSkAbCfgtDsk-K9S8rLe' # hosted on GD
@st.cache(allow_output_mutation=True)
def load_model():
save_dest = Path('model')
save_dest.mkdir(exist_ok=True)
f_checkpoint = Path("model/rf_compressed.pkl")
# download from GD if model not present
if not f_checkpoint.exists():
with st.spinner("Downloading model... this may take awhile! \n Don't stop it!"):
from GD_download import download_file_from_google_drive
download_file_from_google_drive(cloud_model_location, f_checkpoint)
return joblib.load(f_checkpoint)
rfr = load_model() # load model
And this line, which computes the SHAP value everytime a slider value is changed, causes the memory increase.
explainer = shap.TreeExplainer(rfr)
Does the model rfr
gets ‘cached’ or ‘saved’ whenever the SHAP is computed, thus leading to the memory increase? My suspicion for this is because I get an error CachedObjectMutationWarning: Return value of load_model() was mutated between runs.
, which goes away when I inserted allow_output_mutation=True
. So I assume the load_model()
output rfr
gets mutated at every rerun. But when I remove the @st.cache()
from the model loading function, the same thing happens.
Any ideas on how to approach this?