Feat: ReST API is done using FastAPI

This commit is contained in:
Reza Behzadan 2023-12-11 03:55:43 +03:30
parent 32a9ab07ac
commit 502d4ee844
3 changed files with 59 additions and 28 deletions

View File

@ -1,2 +1,3 @@
MODEL="m3hrdadfi/wav2vec2-large-xlsr-persian"
LOG_LEVEL="INFO"
PORT=8000

14
README.md Normal file
View File

@ -0,0 +1,14 @@
# Persian Speech-to-text
## Testing
### With `curl`
```sh
curl -X POST "http://localhost:8000/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "audio_file=@sample.wav"
```
### With `HTTPie`
```sh
http -f POST http://localhost:8000/transcribe audio_file@sample.wav
```

72
main.py
View File

@ -1,3 +1,4 @@
import io
import logging
import os
import warnings
@ -5,43 +6,44 @@ import warnings
import librosa
import transformers
import torch
import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI, File, UploadFile
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
app = FastAPI()
warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
load_dotenv()
MODEL = os.getenv("MODEL", "m3hrdadfi/wav2vec2-large-xlsr-persian")
LOG_LEVEL = os.getenv("LOG_LEVEL", "DEBUG")
PORT = int(os.getenv("PORT", 8000))
# Initialize logger
logger = logging.getLogger("speech2text-fa")
level = getattr(logging, LOG_LEVEL.upper())
fmt = "%(asctime)s | %(levelname)-8s | %(message)s"
datefmt = "%Y-%m-%d %H:%M:%S"
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter(fmt, datefmt)
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.setLevel(level)
# Load tokenizer and model from Hugging Face
tokenizer = Wav2Vec2Processor.from_pretrained(MODEL)
model = Wav2Vec2ForCTC.from_pretrained(MODEL)
logger.info("Loading model is Done!")
def initLogger():
logger = logging.getLogger("speech2text-fa")
level = getattr(logging, LOG_LEVEL.upper())
fmt = "%(asctime)s | %(levelname)-8s | %(message)s"
datefmt = "%Y-%m-%d %H:%M:%S"
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter(fmt, datefmt)
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.setLevel(level)
return logger
def mp3_to_text(mp3_file_path):
# Load the MP3 file and resample to 16kHz
audio, sample_rate = librosa.load(mp3_file_path, sr=16000)
def mp3_to_text(audio_data: io.BytesIO):
# Resample to 16kHz
audio, sample_rate = librosa.load(audio_data, sr=16000)
logger.info("Resampling is Done!")
# Load tokenizer and model from Hugging Face
tokenizer = Wav2Vec2Processor.from_pretrained(MODEL)
model = Wav2Vec2ForCTC.from_pretrained(MODEL)
logger.info("Loading model is Done!")
# Preprocess the audio
input_values = tokenizer(audio, sampling_rate=16000, return_tensors="pt", padding="longest").input_values
logits = model(input_values).logits
@ -55,8 +57,22 @@ def mp3_to_text(mp3_file_path):
return transcription[0]
@app.post("/transcribe")
async def transcribe_audio(audio_file: UploadFile = File(...)):
# Load the audio from the file
contents = await audio_file.read()
audio_data = io.BytesIO(contents)
# Convert to text
transcription = mp3_to_text(audio_data)
return {"transcription": transcription}
@app.get("/docs")
async def docs():
return {"message": "Welcome to the speech-to-text API!"}
if __name__ == "__main__":
logger = initLogger()
text = mp3_to_text("samples/sample1.wav")
print()
print(text)
uvicorn.run(app, host="0.0.0.0", port=PORT, log_level=LOG_LEVEL.lower())