Cache Keras trained model

Hi folks,
I have trained a model (via Keras framework), exported it with model.save('model.hdf5') and now I want to integrate it with the awesome Streamlit.
Obviously, I do not want to load the model every time the end-user insert a new input, but to load it once and for all.
so my code looks something like that:

@st.cache
def load_my_model():
    model = load_model('model.hdf5')
    model.summary()

    return model

if __name__ == '__main__':
    st.title('My first app')
    sentence = st.text_input('Input your sentence here:')
    model = load_my_model()
    if sentence:
        y_hat = model.predict(sentence)

In that way I got:

“streamlit.errors.UnhashableType: <exception str() failed>”

exception.
I tried to use @st.cache(allow_output_mutation=True) and when I run a query at the streamlit page. I got:

“TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor(“input_1:0”, shape=(?, 80), dtype=int32) is not an element of this graph.”

(Of-course without any cache decorators the model is loaded and works fine)

HOW should I properly load and cache a Keras trained model?

Python ver: 2.7 (unfortunately)
Keras ver: 2.1.3
Tensorflow ver: 1.3.0
Streamlit ver: 0.55.2

Many thanks!

Hey @Daniel_Juravski and welcome to the forum :wave:,

Would it possible to upgrade your python version to 3.6 and Streamlit version to 0.57.3? [We have an updated error message on releases after 0.57.0 to help with debugging st.cache issues.]

Hi @tc1,
Unfortunately the src that I use is not trevial to be migrated to python 3.6 (and probably therefore cannot upgrade my streamlit version).
Any suggestions with my current versions?

Thank you.

Hey @Daniel_Juravski,

Thanks for the additional details. Asking the team internally to see if they have any good suggestions. @Jonathan_Rhone is going to ping the thread as soon as we have a bit more info.

Thank you guys @tc1, @Jonathan_Rhone , if any additional details are requested, please let me know.

1 Like

Solved, the solution was:

  1. adding _make_predict_function() call
  2. return the session
from keras import backend as K

@st.cache(allow_output_mutation=True)
def load_model():
    model = load_model(MODEL_PATH)
    model._make_predict_function()
    model.summary()  # included to make it visible when model is reloaded
    session = K.get_session()
    return model, session

if __name__ == '__main__':
    st.title('My first app')
    sentence = st.text_input('Input your sentence here:')
    model, session = load_model()
    if sentence:
        K.set_session(session)
        y_hat = model.predict(sentence)
2 Likes

Thanks for the solution. We should use this method instead, in Tensorflow 2 (due to the removal of session on tf2) :

import tensorflow.keras.backend as K

@st.cache(allow_output_mutation=True)
def load_model():
    model = load_model(MODEL_PATH)
    model._make_predict_function()
    model.summary()  # included to make it visible when model is reloaded
    return model

if __name__ == '__main__':
    st.title('My first app')
    sentence = st.text_input('Input your sentence here:')
    model= load_model()
    if sentence:
        y_hat = model.predict(sentence)
2 Likes

Thanks for posting that @Amin_Taheri…with the rapid change of ML libraries, it feels like code snippets go out of date so quickly :+1:

1 Like