Feat: ReST API is done using FastAPI
This commit is contained in:
parent
32a9ab07ac
commit
502d4ee844
@ -1,2 +1,3 @@
|
|||||||
MODEL="m3hrdadfi/wav2vec2-large-xlsr-persian"
|
MODEL="m3hrdadfi/wav2vec2-large-xlsr-persian"
|
||||||
LOG_LEVEL="INFO"
|
LOG_LEVEL="INFO"
|
||||||
|
PORT=8000
|
||||||
|
14
README.md
Normal file
14
README.md
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# Persian Speech-to-text
|
||||||
|
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
### With `curl`
|
||||||
|
```sh
|
||||||
|
curl -X POST "http://localhost:8000/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "audio_file=@sample.wav"
|
||||||
|
```
|
||||||
|
|
||||||
|
### With `HTTPie`
|
||||||
|
```sh
|
||||||
|
http -f POST http://localhost:8000/transcribe audio_file@sample.wav
|
||||||
|
```
|
46
main.py
46
main.py
@ -1,3 +1,4 @@
|
|||||||
|
import io
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
@ -5,21 +6,24 @@ import warnings
|
|||||||
import librosa
|
import librosa
|
||||||
import transformers
|
import transformers
|
||||||
import torch
|
import torch
|
||||||
|
import uvicorn
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from fastapi import FastAPI, File, UploadFile
|
||||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
MODEL = os.getenv("MODEL", "m3hrdadfi/wav2vec2-large-xlsr-persian")
|
MODEL = os.getenv("MODEL", "m3hrdadfi/wav2vec2-large-xlsr-persian")
|
||||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "DEBUG")
|
LOG_LEVEL = os.getenv("LOG_LEVEL", "DEBUG")
|
||||||
|
PORT = int(os.getenv("PORT", 8000))
|
||||||
|
|
||||||
|
# Initialize logger
|
||||||
def initLogger():
|
|
||||||
logger = logging.getLogger("speech2text-fa")
|
logger = logging.getLogger("speech2text-fa")
|
||||||
level = getattr(logging, LOG_LEVEL.upper())
|
level = getattr(logging, LOG_LEVEL.upper())
|
||||||
|
|
||||||
fmt = "%(asctime)s | %(levelname)-8s | %(message)s"
|
fmt = "%(asctime)s | %(levelname)-8s | %(message)s"
|
||||||
datefmt = "%Y-%m-%d %H:%M:%S"
|
datefmt = "%Y-%m-%d %H:%M:%S"
|
||||||
ch = logging.StreamHandler()
|
ch = logging.StreamHandler()
|
||||||
@ -27,21 +31,19 @@ def initLogger():
|
|||||||
formatter = logging.Formatter(fmt, datefmt)
|
formatter = logging.Formatter(fmt, datefmt)
|
||||||
ch.setFormatter(formatter)
|
ch.setFormatter(formatter)
|
||||||
logger.addHandler(ch)
|
logger.addHandler(ch)
|
||||||
|
|
||||||
logger.setLevel(level)
|
logger.setLevel(level)
|
||||||
return logger
|
|
||||||
|
|
||||||
|
|
||||||
def mp3_to_text(mp3_file_path):
|
|
||||||
# Load the MP3 file and resample to 16kHz
|
|
||||||
audio, sample_rate = librosa.load(mp3_file_path, sr=16000)
|
|
||||||
logger.info("Resampling is Done!")
|
|
||||||
|
|
||||||
# Load tokenizer and model from Hugging Face
|
# Load tokenizer and model from Hugging Face
|
||||||
tokenizer = Wav2Vec2Processor.from_pretrained(MODEL)
|
tokenizer = Wav2Vec2Processor.from_pretrained(MODEL)
|
||||||
model = Wav2Vec2ForCTC.from_pretrained(MODEL)
|
model = Wav2Vec2ForCTC.from_pretrained(MODEL)
|
||||||
logger.info("Loading model is Done!")
|
logger.info("Loading model is Done!")
|
||||||
|
|
||||||
|
|
||||||
|
def mp3_to_text(audio_data: io.BytesIO):
|
||||||
|
# Resample to 16kHz
|
||||||
|
audio, sample_rate = librosa.load(audio_data, sr=16000)
|
||||||
|
logger.info("Resampling is Done!")
|
||||||
|
|
||||||
# Preprocess the audio
|
# Preprocess the audio
|
||||||
input_values = tokenizer(audio, sampling_rate=16000, return_tensors="pt", padding="longest").input_values
|
input_values = tokenizer(audio, sampling_rate=16000, return_tensors="pt", padding="longest").input_values
|
||||||
logits = model(input_values).logits
|
logits = model(input_values).logits
|
||||||
@ -55,8 +57,22 @@ def mp3_to_text(mp3_file_path):
|
|||||||
return transcription[0]
|
return transcription[0]
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/transcribe")
|
||||||
|
async def transcribe_audio(audio_file: UploadFile = File(...)):
|
||||||
|
# Load the audio from the file
|
||||||
|
contents = await audio_file.read()
|
||||||
|
audio_data = io.BytesIO(contents)
|
||||||
|
|
||||||
|
# Convert to text
|
||||||
|
transcription = mp3_to_text(audio_data)
|
||||||
|
|
||||||
|
return {"transcription": transcription}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/docs")
|
||||||
|
async def docs():
|
||||||
|
return {"message": "Welcome to the speech-to-text API!"}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logger = initLogger()
|
uvicorn.run(app, host="0.0.0.0", port=PORT, log_level=LOG_LEVEL.lower())
|
||||||
text = mp3_to_text("samples/sample1.wav")
|
|
||||||
print()
|
|
||||||
print(text)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user