Inferencing a pytorch model

I have deployed the app to a huggingface space, When i upload the file, and reaches the part where input tensor is fed into the model, always the streamlit server shows CONNECTING… and then the app restarts but the uploaded file is there. sometimes its giving error 500 also…coudn’t figure out the issue here, can someone please point out if there is any bugs in the code

import streamlit as st
from st_audiorec import st_audiorec
import matplotlib.pyplot as plt
import sounddevice as sd
import numpy as np
import pandas as pd
import torch
# import torchaudio
import wave
import io
from import wavfile
import pydub
import time
import os
import atexit
import librosa


def load_model():
    model = torch.jit.load("*****l.ptl")
    return model
model = load_model()

# Audio parameters
def process_data(waveform_chunks):
    snore = 0
    other = 0
    st.write("Reached stage 4")
    for chunk in waveform_chunks:
        st.write("Reached stage 5")
        input_tensor = torch.tensor(chunk).unsqueeze(0)
        st.write("Reached stage 6")
        result = model(input_tensor)
        if np.abs(result[0][0]) > np.abs(result[0][1]):
            other += 1
            snore += 1
    return snore, other



uploaded_file = st.file_uploader("Upload Sample", type=["wav"])
if uploaded_file is not None:
    # time.sleep(2.5)
    audio, sample_rate = librosa.load(uploaded_file, sr=None)
    waveform = audio
    # Set the chunk size
    chunk_size = 16000
    st.write("Reached stage 2")
    # Calculate the number of chunks
    num_chunks = len(waveform) // chunk_size

    # Reshape the waveform into chunks
    waveform_chunks = np.array_split(waveform[:num_chunks * chunk_size], num_chunks)
    st.write("Reached stage 3")
    snore, other = process_data(waveform_chunks)

    total = snore + other
    snore_percentage = (snore / total) * 100
    other_percentage = (other / total) * 100

    categories = ["Snore", "Other"]
    percentages = [snore_percentage, other_percentage]

    st.write(f'Snore Percentage: {snore_percentage}')
    # plt.figure(figsize=(8, 4))
    # plt.barh(categories, percentages, color=['#ff0033', '#00ffee'])
    # plt.xlabel('Percentage')
    # plt.title('Percentage of Snoring')
    # plt.xlim(0, 100)

    # for i, percentage in enumerate(percentages):
    #     plt.text(percentage, i, f' {percentage:.2f}%', va='center')
    # st.write("DONE")
    # st.pyplot(plt)

This topic was automatically closed 180 days after the last reply. New replies are no longer allowed.