Error when I try and cache Keras

Hi, I am trying to cache a Keras model and keep getting an error. Here is the code:

@st.cache(allow_output_mutation=True)
def add_model():
    model = load_model('models/keras/model.h5')
    return model

@st.cache(allow_output_mutation=True)
def add_weights(model):
    model.load_weights('models/keras/weights.h5')
    return model

st.sidebar.text('Loading data...')
model = add_model()
with open('models/keras/architecture.json') as f:
    model = model_from_json(f.read())
model = add_weights(model)
st.sidebar.text('Loading Done!')

The page doesn’t seem to work any faster and I keep getting this error:

Cached function mutated its input arguments

When decorating a function with @st.cache , the arguments should not be mutated inside the function body, as that breaks the caching mechanism. Please update the code of add_weights to bypass the mutation.

See the Streamlit docs for more info.

I don’t see where anything could be changing inside the “add_weights” function. I’d appreciate any insight, thanks!

Hey @jcr592,

Looks like the root of the issue is that you’re generating the output of add_weights by mutating the model input.

Because that model variable is being used as both the input and the output, this means that @st.cache will fail the sanity-check on its return by looking at the initial hash of the input variables, and finding that the value changed somewhere along the way.

That’s why you’re getting that error: the model argument is being mutated at model.load_weights('models/keras/weights.h5') – and that’s why you’re not getting any speedup, because you keep changing the value of model and thus it’s a cache “MISS” every time. :slight_smile:

If all you’re trying to do make sure you do all of this expensive loading once, you may as well combine all of this into one function that doesn’t take any arguments:

@st.cache(allow_output_mutation=True)
def load_model():
    model = load_model('models/keras/model.h5')
    # add anything else you want to do to the model here
    model.load_weights('models/keras/weights.h5')
    return model

I’m assuming the simple use case here, which is that you just want to load a specific model, add weights to it, and use that one model for the duration of your app.

1 Like