I have a google’s visual transformer model which I have trained in Tensorflow 2 and saved as an h5 file.
# Base model pre-trained on ImageNet-21k with the 224x224 image resolution
base_model = TFViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
# Freeze base model
base_model.trainable = False
# Create new model
inputs = keras.Input(shape = (3, 224, 224))
x = data_augmentation_vit(inputs)
vit = base_model.vit(inputs)[0]
vit = keras.layers.GlobalAveragePooling1D()(vit)
vit = tf.keras.layers.Dense(256, activation='relu')(vit)
vit = tf. keras.layers.Dropout(0.15)(vit)
outputs = tf.keras.layers.Dense(1, activation='sigmoid', name='outputs')(vit)
model_vit = tf.keras.Model(inputs, outputs)
print(model_vit.summary())
Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 3, 224, 224)] 0
vit (TFViTMainLayer) TFBaseModelOutputWithPoo 86389248
ling(last_hidden_state=(
None, 197, 768),
pooler_output=(None, 76
8),
hidden_states=None, att
entions=None)
global_average_pooling1d (G (None, 768) 0
lobalAveragePooling1D)
dense_2 (Dense) (None, 256) 196864
dropout_37 (Dropout) (None, 256) 0
outputs (Dense) (None, 1) 257
=================================================================
So when I use the following code and run my app in Streamlit, it gives me this ValueError.
ValueError: Unknown layer: Custom>TFViTMainLayer. Please ensure this object is passed to the `custom_objects` argument. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.
Traceback:
File "C:\Users\maria\anaconda3\envs\tfenv\lib\site-packages\streamlit\scriptrunner\script_runner.py", line 557, in _run_script
exec(code, module.__dict__)
File "app_extended.py", line 248, in <module>
model_loader = tf.keras.models.load_model(path_to_model)
File "C:\Users\maria\anaconda3\envs\tfenv\lib\site-packages\keras\utils\traceback_utils.py", line 67, in error_handler
raise e.with_traceback(filtered_tb) from None
File "C:\Users\maria\anaconda3\envs\tfenv\lib\site-packages\keras\utils\generic_utils.py", line 562, in class_and_config_for_serialized_keras_object
raise ValueError(
my code
import streamlit as st
import numpy as np
from PIL import Image
import tensorflow as tf
st.title("Binary Human Detection Web App")
st.markdown("Is there a human in office space? 🧍")
## Initialize tensorflow model (This can be loaded before anything else)
path_to_model = "C:/Users/myname/Jupiter_Notebooks/Dataset_Thermal_Project/Camera_videos/Saved_models/model_vit.h5"
model_loader = tf.keras.models.load_model(path_to_model)
model_vit = tf.keras.models.Model(model_loader.inputs, model_loader.outputs)
## Preprocess images
def preprocessImage(photo):
resize_photo = photo.resize((224,224))
normalized_photo = np.array(resize_photo)/255 # a normalised 2D array
reshaped_photo = normalized_photo.reshape(-1, 224, 224, 3) # to shape as (1, 224, 224, 3)
return reshaped_photo
uploaded_file = st.sidebar.file_uploader(" ",type=['jpg', 'jpeg'])
if uploaded_file is not None:
## Use a context manager to make sure to close the file!!
with Image.open(uploaded_file) as photo:
tensorflow_image = preprocessImage(photo)
## Show preprocessed image
streamlit_widget_image = st.image(tensorflow_image, 'Uploaded Image', use_column_width=True)
## Do prediction
if st.sidebar.button("Click Here to Predict"):
if uploaded_file is None:
st.sidebar.write("Please upload an Image to Classify")
else:
## Pass the preprocessed image to the vit model (not the streamlit widget)
pred_label = model_vit.predict(tensorflow_image)[0]
## Print prediction
st.sidebar.header("ViT model results:")
if pred_label > 0.5: st.sidebar.info('Human is detected')
else: st.sidebar.info('No human is detected')
Not sure how to register the custom object in my example after looking at this link, “Save and load Keras models | TensorFlow Core”.
Need some help with this please?
here is the downlowdable model_vit.h5, however, it is quite a big file (332MB), https://drive.google.com/file/d/1ASXJ6-QVxV7W-rVUV57pUy5sYK1BokZ4/view?usp=sharing