How to combine streamlit_webrtc and asyncio for parrallel audio capture and inference? #1818

How can I combine streamlit_webrtc and asyncio to run inference in parallel with the audio buffering? is there any example of how to do this?

I am currently trying to deploy some Essentia model for music genre recognition with the help of streamlit_webrtc. Basically I want to capture the audio stream and then when a certain amount of frames is reached (e.g. 5 seconds) I run the prediction. Unfortunately, this has a lot of lags and delays so I am trying to run these in parallel with the help of asyncio. This seems to be blocking my pipeline, possibly because streamlit_webrtc uses asyncio too. So what is the proper way to do this. Here is my code:

import json
import time
import json
import pydub
import queue
import socket
import joblib
import librosa
import asyncio
import logging
import numpy as np
import pandas as pd
import logging.handlers
import streamlit as st
import tensorflow as tf
from pathlib import Path
from pprint import pprint
from streamlit_webrtc import WebRtcMode, webrtc_streamer
from essentia.standard import TensorflowPredictEffnetDiscogs, TensorflowPredict2D

DEBUG = False
SAMPLE_RATE = 16000
HERE = Path(__file__).parent
logger = logging.getLogger(__name__)

# Disable eager mode for tf.v1 compatibility with tf.v2
tf.compat.v1.disable_eager_execution()

import warnings
warnings.filterwarnings('ignore')

def get_metadata(fname, section=None):
    with open(fname) as f:
        meta = json.load(f)

    if section:
        return meta[section]
    return meta

MODEL_LABELS = {"ESSENTIA_genre_discogs400-discogs-effnet-1-random-forrest": get_metadata("pretrained-models/genre_discogs400-discogs-effnet-1.json", "classes")}

def load_essentia_model(model, labels, n_frames=None, overlap=None, debug=True):
    embeddings_model_args = {"graphFilename": "pretrained-models/discogs-effnet-bs64-1.pb", "output": "PartitionedCall:1"}
    predictions_model_args = {"graphFilename": "pretrained-models/genre_discogs400-discogs-effnet-1.pb", "input": "serving_default_model_Placeholder", "output": "PartitionedCall:0"}
    embeddings_model = TensorflowPredictEffnetDiscogs(**embeddings_model_args)
    predictions_model = TensorflowPredict2D(**predictions_model_args)
    return embeddings_model, predictions_model

def predict_essentia(audio, embeddings_model, predictions_model):
    # Run the models
    #print("-> Audio to predict: ", audio.shape, audio.dtype, type(audio))
    embeddings = embeddings_model(audio)
    predictions = predictions_model(embeddings)
    return predictions       

async def audio_collector(webrtc_ctx, sound_chunk_queue):
    """Asynchronously collect audio frames from webrtc context"""
    while webrtc_ctx.state.playing:
        try:
            audio_frames = webrtc_ctx.audio_receiver.get_frames(timeout=1)
        except queue.Empty:
            continue
        
        for audio_frame in audio_frames:
            sound = pydub.AudioSegment(
                data=audio_frame.to_ndarray().tobytes(),
                sample_width=audio_frame.format.bytes,
                frame_rate=audio_frame.sample_rate,
                channels=len(audio_frame.layout.channels)
            )
            # Put the sound into the queue
            await sound_chunk_queue.put(sound)

async def run_inference(sound_chunk, model, embeddings_model, predictions_model, random_forrest_model, text_output, json_output):
    start_process = time.time()

    # Audio processing
    print(f"Inference: Received sound_chunk of length {len(sound_chunk)}")
    sound_chunk = sound_chunk.set_channels(1).set_frame_rate(SAMPLE_RATE)
    musicnn_audio_buffer = np.array(sound_chunk.get_array_of_samples(), dtype=np.int16)
    float_musicnn_audio_buffer = musicnn_audio_buffer.astype(np.float32, order='C') / 32768.0

    # Run inference
    likelihoods = predict_essentia(float_musicnn_audio_buffer, embeddings_model, predictions_model)
    results_dict = dict(zip(MODEL_LABELS[model], np.round(np.mean(likelihoods, axis=0), 9)))

    # Aggregate results into merged categories
    probabilities = random_forrest_model.predict_proba(pd.DataFrame([results_dict]))
    genres_dict = {genre: prob for genre, prob in zip(random_forrest_model.classes_, probabilities.flatten())}

    # Display genres_dict using Streamlit
    maximum = max(results_dict, key=results_dict.get)
    text_output.markdown(f"**Top Tag:** {maximum, results_dict[maximum]}")
    json_output.write("Genres Dict:")
    json_output.json(genres_dict)

    # End timing
    end_process = time.time()
    print(f"> Inference done. Total processing time: {end_process - start_process:.4f}s")
    print(f"Genres Dict: {genres_dict}")
    print(f">===============================================================<")

async def process_audio(sound_chunk_queue, model, embeddings_model, predictions_model, random_forrest_model, text_output, json_output):
    sound_chunk = pydub.AudioSegment.empty()
    min_buffer_size = int(0.5 * SAMPLE_RATE)  # 0.5 seconds buffer

    while True:
        sound = await sound_chunk_queue.get()  # Get sound from the queue

        # Debugging print: track the queue size after consuming a chunk
        print(f"ProcessAudio: Retrieved frame from queue, current queue size: {sound_chunk_queue.qsize()}")
        sound_chunk += sound

        print(f"Current sound_chunk length: {len(sound_chunk)}, required: {min_buffer_size}")

        # If we have enough audio data (2 seconds), run inference
        if len(sound_chunk) >= min_buffer_size:
            print(f"Running inference with sound chunk of length {len(sound_chunk)}")
            await run_inference(sound_chunk, model, embeddings_model, predictions_model, random_forrest_model, text_output, json_output)

            # Keep only the last 0.5 seconds of audio for overlap in next inference
            last_frames = int(-0.5 * min_buffer_size)
            sound_chunk = sound_chunk[last_frames:]
        else:
            # Force inference every few cycles even if the buffer isn't quite full (if no new data is coming)
            if sound_chunk_queue.qsize() == 0 and len(sound_chunk) > int(1.5 * SAMPLE_RATE):  # Check buffer length, even with slightly smaller chunks
                print(f"Running forced inference with incomplete sound chunk of length {len(sound_chunk)}")
                await run_inference(sound_chunk, model, embeddings_model, predictions_model, random_forrest_model, text_output, json_output)
                sound_chunk = pydub.AudioSegment.empty()  # Clear buffer after inference


def app_sst(model):
    webrtc_ctx = webrtc_streamer(key="music-to-genre",
                                 mode=WebRtcMode.SENDONLY,
                                 audio_receiver_size=1024 * 8,
                                 rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
                                 media_stream_constraints={"video": False, "audio": True})
    
    status_indicator = st.empty()
    text_output = st.empty()
    json_output = st.empty()

    # Load model
    embeddings_model, predictions_model = load_essentia_model(model, MODEL_LABELS[model])
    random_forrest_model = joblib.load('pretrained-classifiers/random_forest_model.pkl')

    sound_chunk_queue = asyncio.Queue()

    async def run_tasks():
        audio_task = asyncio.create_task(audio_collector(webrtc_ctx, sound_chunk_queue))
        process_task = asyncio.create_task(process_audio(sound_chunk_queue, model, embeddings_model, predictions_model, random_forrest_model, text_output, json_output))
        
        # Run both tasks concurrently until `webrtc_ctx.state.playing` is False
        while webrtc_ctx.state.playing:
            await asyncio.sleep(0.1)  # Non-blocking sleep to keep checking the state

        # Gracefully cancel tasks when streaming stops
        audio_task.cancel()
        process_task.cancel()

    # Start event loop for tasks
    asyncio.run(run_tasks())


def main():
    st.set_page_config(layout="wide")
    st.header("Real Time Music Genre Classification")
    st.markdown("""This demo app is using musicnn as a recognition engine.""")
    sound_only_page = "Sound only demo"
    app_mode = st.selectbox("Choose the app mode", [sound_only_page])
    model = st.selectbox("Choose the model", ["ESSENTIA_genre_discogs400-discogs-effnet-1-random-forrest"])
    if app_mode == sound_only_page:
        app_sst(model)

if __name__ == "__main__":
    logging.basicConfig(format="[%(asctime)s] %(levelname)7s from %(name)s in %(pathname)s:%(lineno)d: %(message)s", force=True)
    logger.setLevel(level=logging.DEBUG if DEBUG else logging.INFO)

    st_webrtc_logger = logging.getLogger("streamlit_webrtc")
    st_webrtc_logger.setLevel(logging.DEBUG)

    fsevents_logger = logging.getLogger("fsevents")
    fsevents_logger.setLevel(logging.WARNING)

    main()

I would also welcome any ideas to improve this for speed and performance.