From 502d4ee8448967d8fb60b66371104f0471886b3d Mon Sep 17 00:00:00 2001
From: Reza Behzadan <rbehzadan@gmail.com>
Date: Mon, 11 Dec 2023 03:55:43 +0330
Subject: [PATCH] Feat: ReST API is done using FastAPI

---
 .env.sample |  1 +
 README.md   | 14 +++++++++++
 main.py     | 72 ++++++++++++++++++++++++++++++++---------------------
 3 files changed, 59 insertions(+), 28 deletions(-)
 create mode 100644 README.md

diff --git a/.env.sample b/.env.sample
index 366655b..3db13b5 100644
--- a/.env.sample
+++ b/.env.sample
@@ -1,2 +1,3 @@
 MODEL="m3hrdadfi/wav2vec2-large-xlsr-persian"
 LOG_LEVEL="INFO"
+PORT=8000
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..837f5ea
--- /dev/null
+++ b/README.md
@@ -0,0 +1,14 @@
+# Persian Speech-to-text
+
+
+## Testing
+
+### With `curl`
+```sh
+curl -X POST "http://localhost:8000/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "audio_file=@sample.wav"
+```
+
+### With `HTTPie`
+```sh
+http -f POST http://localhost:8000/transcribe audio_file@sample.wav
+```
diff --git a/main.py b/main.py
index a5fe185..664a170 100644
--- a/main.py
+++ b/main.py
@@ -1,3 +1,4 @@
+import io
 import logging
 import os
 import warnings
@@ -5,43 +6,44 @@ import warnings
 import librosa
 import transformers
 import torch
+import uvicorn
 from dotenv import load_dotenv
+from fastapi import FastAPI, File, UploadFile
 from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
 
 
+app = FastAPI()
+
 warnings.filterwarnings("ignore")
 transformers.logging.set_verbosity_error()
 load_dotenv()
 MODEL = os.getenv("MODEL", "m3hrdadfi/wav2vec2-large-xlsr-persian")
 LOG_LEVEL = os.getenv("LOG_LEVEL", "DEBUG")
+PORT = int(os.getenv("PORT", 8000))
+
+# Initialize logger
+logger = logging.getLogger("speech2text-fa")
+level = getattr(logging, LOG_LEVEL.upper())
+fmt = "%(asctime)s | %(levelname)-8s | %(message)s"
+datefmt = "%Y-%m-%d %H:%M:%S"
+ch = logging.StreamHandler()
+ch.setLevel(logging.DEBUG)
+formatter = logging.Formatter(fmt, datefmt)
+ch.setFormatter(formatter)
+logger.addHandler(ch)
+logger.setLevel(level)
+
+# Load tokenizer and model from Hugging Face
+tokenizer = Wav2Vec2Processor.from_pretrained(MODEL)
+model = Wav2Vec2ForCTC.from_pretrained(MODEL)
+logger.info("Loading model is Done!")
 
 
-def initLogger():
-    logger = logging.getLogger("speech2text-fa")
-    level = getattr(logging, LOG_LEVEL.upper())
-    
-    fmt = "%(asctime)s | %(levelname)-8s | %(message)s"
-    datefmt = "%Y-%m-%d %H:%M:%S"
-    ch = logging.StreamHandler()
-    ch.setLevel(logging.DEBUG)
-    formatter = logging.Formatter(fmt, datefmt)
-    ch.setFormatter(formatter)
-    logger.addHandler(ch)
-
-    logger.setLevel(level)
-    return logger
-
-
-def mp3_to_text(mp3_file_path):
-    # Load the MP3 file and resample to 16kHz
-    audio, sample_rate = librosa.load(mp3_file_path, sr=16000)
+def mp3_to_text(audio_data: io.BytesIO):
+    # Resample to 16kHz
+    audio, sample_rate = librosa.load(audio_data, sr=16000)
     logger.info("Resampling is Done!")
 
-    # Load tokenizer and model from Hugging Face
-    tokenizer = Wav2Vec2Processor.from_pretrained(MODEL)
-    model = Wav2Vec2ForCTC.from_pretrained(MODEL)
-    logger.info("Loading model is Done!")
-
     # Preprocess the audio
     input_values = tokenizer(audio, sampling_rate=16000, return_tensors="pt", padding="longest").input_values
     logits = model(input_values).logits
@@ -55,8 +57,22 @@ def mp3_to_text(mp3_file_path):
     return transcription[0]
 
 
+@app.post("/transcribe")
+async def transcribe_audio(audio_file: UploadFile = File(...)):
+    # Load the audio from the file
+    contents = await audio_file.read()
+    audio_data = io.BytesIO(contents)
+
+    # Convert to text
+    transcription = mp3_to_text(audio_data)
+
+    return {"transcription": transcription}
+
+
+@app.get("/docs")
+async def docs():
+    return {"message": "Welcome to the speech-to-text API!"}
+
+
 if __name__ == "__main__":
-    logger = initLogger()
-    text = mp3_to_text("samples/sample1.wav")
-    print()
-    print(text)
+    uvicorn.run(app, host="0.0.0.0", port=PORT, log_level=LOG_LEVEL.lower())