Hi,
I have a class instance object which I’m importing into my app script via pickle, which stores a model along with dataframes, features and a predict method.
for example:
@st.cache(persist= True)
def load_class_model(model_file):
class_model = joblib.load(open(os.path.join(model_file),"rb"))
return class_model
a = load_class_model(file)
a.model.predict(args)
Usually I’ll use @st.cache(allow_output_mutation = True) with functions where the args can potentially change so it’s not cached.
How can I do the same for imported instance and method objects? In this case for a.model.predict?
Would it be better to pickle the objects individually?
Would really appreciate any guidance.
No, allow_output_mutation
is not used for that purpose. Please check the docs.
You mean how to cache the output of a.model.predict()
?
If yes, why not try something like this?
@st.cache(persist= True)
def load_class_model(model_file):
class_model = joblib.load(open(os.path.join(model_file),"rb"))
return class_model
@st.cache(persist= True)
def predict(file, args):
a = load_class_model(file)
return a.model.predict(args)
output = predict(file, args)
@GokulNC Thanks for the reply and explaining the proper usage of allow_output_mutation.
I made the changes you suggested but it seemed to slow down the app significantly.
The args for predict are based on user inputs that are captured via widgets, every user input change causes all cached functions to re-execute.
@st.cache(persist= True)
def load_class_model(model_file):
class_model = joblib.load(open(os.path.join(model_file),"rb"))
return class_model
@st.cache(persist= True)
def predict(file, args):
a = load_class_model(file)
return a.model.predict(args)
@st.cache(persist= True)
def load_dataframe(file, results):
a = load_class_model(file)
df = a.df1
return df[df["some_column"].isin(results)]
user = st.selectbox("Select user", users, format_func=lambda x: "Select user" if x == "" else x)
if user != "":
num= st.number_input("Pick Number", min_value=None, max_value = 10, value=0)
if num > 0:
output = predict("model.pkl", [user, num])
for p, n in enumerate(output):
st.write(f"Result{p+1}:", n)
if st.sidebar.checkbox("Show Data", False):
metadata = load_dataframe("model.pkl", output)
st.table(metadata)
The pickled class instance stores all relevant objects and attributes that will be accessed in the scripts, would you recommend loading large objects like dataframes from the class instance outside of nested functions with @st.cache?
df1 = load_df1("model.pkl")
df2 = load_df2("model.pkl")
.....