diff --git a/.env.sample b/.env.sample index 366655b..3db13b5 100644 --- a/.env.sample +++ b/.env.sample @@ -1,2 +1,3 @@ MODEL="m3hrdadfi/wav2vec2-large-xlsr-persian" LOG_LEVEL="INFO" +PORT=8000 diff --git a/README.md b/README.md new file mode 100644 index 0000000..837f5ea --- /dev/null +++ b/README.md @@ -0,0 +1,14 @@ +# Persian Speech-to-text + + +## Testing + +### With `curl` +```sh +curl -X POST "http://localhost:8000/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "audio_file=@sample.wav" +``` + +### With `HTTPie` +```sh +http -f POST http://localhost:8000/transcribe audio_file@sample.wav +``` diff --git a/main.py b/main.py index a5fe185..664a170 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,4 @@ +import io import logging import os import warnings @@ -5,43 +6,44 @@ import warnings import librosa import transformers import torch +import uvicorn from dotenv import load_dotenv +from fastapi import FastAPI, File, UploadFile from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor +app = FastAPI() + warnings.filterwarnings("ignore") transformers.logging.set_verbosity_error() load_dotenv() MODEL = os.getenv("MODEL", "m3hrdadfi/wav2vec2-large-xlsr-persian") LOG_LEVEL = os.getenv("LOG_LEVEL", "DEBUG") +PORT = int(os.getenv("PORT", 8000)) + +# Initialize logger +logger = logging.getLogger("speech2text-fa") +level = getattr(logging, LOG_LEVEL.upper()) +fmt = "%(asctime)s | %(levelname)-8s | %(message)s" +datefmt = "%Y-%m-%d %H:%M:%S" +ch = logging.StreamHandler() +ch.setLevel(logging.DEBUG) +formatter = logging.Formatter(fmt, datefmt) +ch.setFormatter(formatter) +logger.addHandler(ch) +logger.setLevel(level) + +# Load tokenizer and model from Hugging Face +tokenizer = Wav2Vec2Processor.from_pretrained(MODEL) +model = Wav2Vec2ForCTC.from_pretrained(MODEL) +logger.info("Loading model is Done!") -def initLogger(): - logger = logging.getLogger("speech2text-fa") - level = getattr(logging, LOG_LEVEL.upper()) - - fmt = "%(asctime)s | %(levelname)-8s | %(message)s" - datefmt = "%Y-%m-%d %H:%M:%S" - ch = logging.StreamHandler() - ch.setLevel(logging.DEBUG) - formatter = logging.Formatter(fmt, datefmt) - ch.setFormatter(formatter) - logger.addHandler(ch) - - logger.setLevel(level) - return logger - - -def mp3_to_text(mp3_file_path): - # Load the MP3 file and resample to 16kHz - audio, sample_rate = librosa.load(mp3_file_path, sr=16000) +def mp3_to_text(audio_data: io.BytesIO): + # Resample to 16kHz + audio, sample_rate = librosa.load(audio_data, sr=16000) logger.info("Resampling is Done!") - # Load tokenizer and model from Hugging Face - tokenizer = Wav2Vec2Processor.from_pretrained(MODEL) - model = Wav2Vec2ForCTC.from_pretrained(MODEL) - logger.info("Loading model is Done!") - # Preprocess the audio input_values = tokenizer(audio, sampling_rate=16000, return_tensors="pt", padding="longest").input_values logits = model(input_values).logits @@ -55,8 +57,22 @@ def mp3_to_text(mp3_file_path): return transcription[0] +@app.post("/transcribe") +async def transcribe_audio(audio_file: UploadFile = File(...)): + # Load the audio from the file + contents = await audio_file.read() + audio_data = io.BytesIO(contents) + + # Convert to text + transcription = mp3_to_text(audio_data) + + return {"transcription": transcription} + + +@app.get("/docs") +async def docs(): + return {"message": "Welcome to the speech-to-text API!"} + + if __name__ == "__main__": - logger = initLogger() - text = mp3_to_text("samples/sample1.wav") - print() - print(text) + uvicorn.run(app, host="0.0.0.0", port=PORT, log_level=LOG_LEVEL.lower())