Computing SHAP keeps increasing memory usage after every user input change

So I managed to get my app working on Streamlit Sharing but it will crash after sliding or clicking options a few times. Whenever I slide to a new value, the app refreshes (which I assume it will run the entire script again), and the SHAP values get recomputed again based on the new data. Everytime it does so, memory usage increases by about 500mb!!! Which I assume causes the crash. Without computing the SHAP values, my app runs smoothly without very high memory increase.

I am wondering why this is happening? My model loading is cached:

## LOAD TRAINED RANDOM FOREST MODEL
cloud_model_location = '1PkTZnHK_K4LBTSkAbCfgtDsk-K9S8rLe' # hosted on GD

@st.cache(allow_output_mutation=True)
def load_model():
    save_dest = Path('model')
    save_dest.mkdir(exist_ok=True)  
    f_checkpoint = Path("model/rf_compressed.pkl")
    # download from GD if model not present

    if not f_checkpoint.exists():
        with st.spinner("Downloading model... this may take awhile! \n Don't stop it!"):
            from GD_download import download_file_from_google_drive
            download_file_from_google_drive(cloud_model_location, f_checkpoint)

    return joblib.load(f_checkpoint)

rfr = load_model() # load model

And this line, which computes the SHAP value everytime a slider value is changed, causes the memory increase.
explainer = shap.TreeExplainer(rfr)

Does the model rfr gets ‘cached’ or ‘saved’ whenever the SHAP is computed, thus leading to the memory increase? My suspicion for this is because I get an error CachedObjectMutationWarning: Return value of load_model() was mutated between runs., which goes away when I inserted allow_output_mutation=True. So I assume the load_model() output rfr gets mutated at every rerun. But when I remove the @st.cache() from the model loading function, the same thing happens.

Any ideas on how to approach this?

Ok so what I did to reduce memory usage is to run the explainer = shap.TreeExplainer(rfr) locally, save/joblib explainer and load the explainer together with the model.

By doing this, I don’t have to explicitly run explainer = shap.TreeExplainer(rfr) in the app and that solves the issue with the memory increasing at every reset.

Hi Teyang,

Hope you are doing well. Fantastic application!

I’m trying to develop a streamlit application that incorporates SHAP as well. Similar to you, I run into memory issues with getting the SHAP values calculated within the application.

I tried to follow your code and approach of running the explainer locally, save/joblib explainer, call from Google Drive, then load the explainer. However, I run into a ‘TypeError’ when loading my explainer (see below). I have no issues when using the same approach to save/call/load my trained model.

TypeError: ((4, b'U\r\r\n', b'\xe3\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00C\x00\x00\x00s\x04\x00\x00\x00|\x00S\x00)\x01\xfa\x1c A no-op link function.\n \xa9\x00\xa9\x01\xda\x01xr\x02\x00\x00\x00r\x02\x00\x00\x00\xfa9C:\\Users\\Jethro\\anaconda3\\lib\\site-packages\\shap\\links.py\xda\x08identity\x04\x00\x00\x00s\x02\x00\x00\x00\x00\x04'), {'__name__': 'shap.links'}, 'identity', None) is not a callable object
Traceback:
File "/home/appuser/.local/lib/python3.7/site-packages/streamlit/script_runner.py", line 332, in _run_script
    exec(code, module.__dict__)
File "/app/ssepe/ssEPE_V2.py", line 66, in <module>
    model, explainer = load_model()
File "/home/appuser/.local/lib/python3.7/site-packages/streamlit/caching.py", line 591, in wrapped_func
    return get_or_create_cached_value()
File "/home/appuser/.local/lib/python3.7/site-packages/streamlit/caching.py", line 575, in get_or_create_cached_value
    return_value = func(*args, **kwargs)
File "/app/ssepe/ssEPE_V2.py", line 62, in load_model
    explainer = joblib.load(f_checkpoint2)
File "/home/appuser/.local/lib/python3.7/site-packages/joblib/numpy_pickle.py", line 585, in load
    obj = _unpickle(fobj, filename, mmap_mode)
File "/home/appuser/.local/lib/python3.7/site-packages/joblib/numpy_pickle.py", line 504, in _unpickle
    obj = unpickler.load()
File "/usr/local/lib/python3.7/pickle.py", line 1088, in load
    dispatch[key[0]](self)
File "/usr/local/lib/python3.7/pickle.py", line 1436, in load_reduce
    stack[-1] = func(*args)
File "/home/appuser/.local/lib/python3.7/site-packages/numba/core/serialize.py", line 40, in _rebuild_reduction
    return cls._rebuild(*args)
File "/home/appuser/.local/lib/python3.7/site-packages/numba/core/dispatcher.py", line 825, in _rebuild
    self = cls(py_func, locals, targetoptions, impl_kind)
File "/home/appuser/.local/lib/python3.7/site-packages/numba/core/dispatcher.py", line 748, in __init__
    pysig = utils.pysignature(py_func)
File "/usr/local/lib/python3.7/inspect.py", line 3083, in signature
    return Signature.from_callable(obj, follow_wrapped=follow_wrapped)
File "/usr/local/lib/python3.7/inspect.py", line 2833, in from_callable
    follow_wrapper_chains=follow_wrapped)
File "/usr/local/lib/python3.7/inspect.py", line 2208, in _signature_from_callable
    raise TypeError('{!r} is not a callable object'.format(obj))

Was wondering if I could get your guidance on how you were able to avoid this issue. I wonder if it has anything to do with how we ran our explainer or saved the explainer as a .pkl file.

Thanks in advance!