import io import logging import os import warnings import librosa import transformers import torch import uvicorn from dotenv import load_dotenv from fastapi import FastAPI, File, UploadFile from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor app = FastAPI() warnings.filterwarnings("ignore") transformers.logging.set_verbosity_error() load_dotenv() MODEL = os.getenv("MODEL", "m3hrdadfi/wav2vec2-large-xlsr-persian") LOG_LEVEL = os.getenv("LOG_LEVEL", "DEBUG") PORT = int(os.getenv("PORT", 8000)) # Initialize logger logger = logging.getLogger("speech2text-fa") level = getattr(logging, LOG_LEVEL.upper()) fmt = "%(asctime)s | %(levelname)-8s | %(message)s" datefmt = "%Y-%m-%d %H:%M:%S" ch = logging.StreamHandler() ch.setLevel(logging.DEBUG) formatter = logging.Formatter(fmt, datefmt) ch.setFormatter(formatter) logger.addHandler(ch) logger.setLevel(level) # Load tokenizer and model from Hugging Face tokenizer = Wav2Vec2Processor.from_pretrained(MODEL) model = Wav2Vec2ForCTC.from_pretrained(MODEL) logger.info("Loading model is Done!") def mp3_to_text(audio_data: io.BytesIO): # Resample to 16kHz audio, sample_rate = librosa.load(audio_data, sr=16000) logger.info("Resampling is Done!") # Preprocess the audio input_values = tokenizer(audio, sampling_rate=16000, return_tensors="pt", padding="longest").input_values logits = model(input_values).logits logger.info("Processing the audio is Done!") # Decode the predicted IDs predicted_ids = torch.argmax(logits, dim=-1) transcription = tokenizer.batch_decode(predicted_ids) logger.info("Decoding the prodicted IDs is Done!") return transcription[0] @app.post("/transcribe") async def transcribe_audio(audio_file: UploadFile = File(...)): # Load the audio from the file contents = await audio_file.read() audio_data = io.BytesIO(contents) # Convert to text transcription = mp3_to_text(audio_data) return {"transcription": transcription} @app.get("/docs") async def docs(): return {"message": "Welcome to the speech-to-text API!"} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=PORT, log_level=LOG_LEVEL.lower())