ValueError: Unknown layer: Custom>TFViTMainLayer when using a Google transformer model in Streamlit

I have a google’s visual transformer model which I have trained in Tensorflow 2 and saved as an h5 file.

# Base model pre-trained on ImageNet-21k with the 224x224 image resolution
base_model = TFViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
# Freeze base model
base_model.trainable = False
# Create new model
inputs = keras.Input(shape = (3, 224, 224))
x = data_augmentation_vit(inputs)

vit = base_model.vit(inputs)[0]
vit = keras.layers.GlobalAveragePooling1D()(vit)
vit = tf.keras.layers.Dense(256, activation='relu')(vit)
vit = tf. keras.layers.Dropout(0.15)(vit)
outputs = tf.keras.layers.Dense(1, activation='sigmoid', name='outputs')(vit)

model_vit = tf.keras.Model(inputs, outputs)

Model: "model_1"
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 3, 224, 224)]     0         
 vit (TFViTMainLayer)        TFBaseModelOutputWithPoo  86389248  
                             None, 197, 768),                    
                              pooler_output=(None, 76            
                              hidden_states=None, att            
 global_average_pooling1d (G  (None, 768)              0         
 dense_2 (Dense)             (None, 256)               196864    
 dropout_37 (Dropout)        (None, 256)               0         
 outputs (Dense)             (None, 1)                 257       

So when I use the following code and run my app in Streamlit, it gives me this ValueError.

ValueError: Unknown layer: Custom>TFViTMainLayer. Please ensure this object is passed to the `custom_objects` argument. See for details.


File "C:\Users\maria\anaconda3\envs\tfenv\lib\site-packages\streamlit\scriptrunner\", line 557, in _run_script
    exec(code, module.__dict__)
File "", line 248, in <module>
    model_loader = tf.keras.models.load_model(path_to_model)
File "C:\Users\maria\anaconda3\envs\tfenv\lib\site-packages\keras\utils\", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
File "C:\Users\maria\anaconda3\envs\tfenv\lib\site-packages\keras\utils\", line 562, in class_and_config_for_serialized_keras_object
    raise ValueError(

my code

import streamlit as st
import numpy as np
from PIL import Image 
import tensorflow as tf

st.title("Binary Human Detection Web App")
st.markdown("Is there a human in office space? 🧍")

## Initialize tensorflow model (This can be loaded before anything else)
path_to_model = "C:/Users/myname/Jupiter_Notebooks/Dataset_Thermal_Project/Camera_videos/Saved_models/model_vit.h5"
model_loader = tf.keras.models.load_model(path_to_model)
model_vit = tf.keras.models.Model(model_loader.inputs, model_loader.outputs)

## Preprocess images
def preprocessImage(photo):
    resize_photo = photo.resize((224,224))
    normalized_photo = np.array(resize_photo)/255 # a normalised 2D array                
    reshaped_photo = normalized_photo.reshape(-1, 224, 224, 3)   # to shape as (1, 224, 224, 3)
    return reshaped_photo

uploaded_file = st.sidebar.file_uploader(" ",type=['jpg', 'jpeg'])    

if uploaded_file is not None:
    ## Use a context manager to make sure to close the file!! 
    with as photo:
        tensorflow_image = preprocessImage(photo)
    ## Show preprocessed image
    streamlit_widget_image = st.image(tensorflow_image, 'Uploaded Image', use_column_width=True)

## Do prediction
if st.sidebar.button("Click Here to Predict"):
    if uploaded_file is None:
        st.sidebar.write("Please upload an Image to Classify")
        ## Pass the preprocessed image to the vit model (not the streamlit widget)
        pred_label = model_vit.predict(tensorflow_image)[0]

        ## Print prediction
        st.sidebar.header("ViT model results:") 
        if pred_label > 0.5:'Human is detected')
        else:'No human is detected')

Not sure how to register the custom object in my example after looking at this link, “Save and load Keras models  |  TensorFlow Core”.

Need some help with this please?

here is the downlowdable model_vit.h5, however, it is quite a big file (332MB),