Feat: ReST API is done using FastAPI
This commit is contained in:
		
							parent
							
								
									32a9ab07ac
								
							
						
					
					
						commit
						502d4ee844
					
				@ -1,2 +1,3 @@
 | 
				
			|||||||
MODEL="m3hrdadfi/wav2vec2-large-xlsr-persian"
 | 
					MODEL="m3hrdadfi/wav2vec2-large-xlsr-persian"
 | 
				
			||||||
LOG_LEVEL="INFO"
 | 
					LOG_LEVEL="INFO"
 | 
				
			||||||
 | 
					PORT=8000
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										14
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,14 @@
 | 
				
			|||||||
 | 
					# Persian Speech-to-text
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Testing
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### With `curl`
 | 
				
			||||||
 | 
					```sh
 | 
				
			||||||
 | 
					curl -X POST "http://localhost:8000/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "audio_file=@sample.wav"
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### With `HTTPie`
 | 
				
			||||||
 | 
					```sh
 | 
				
			||||||
 | 
					http -f POST http://localhost:8000/transcribe audio_file@sample.wav
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
							
								
								
									
										46
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										46
									
								
								main.py
									
									
									
									
									
								
							@ -1,3 +1,4 @@
 | 
				
			|||||||
 | 
					import io
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
import warnings
 | 
					import warnings
 | 
				
			||||||
@ -5,21 +6,24 @@ import warnings
 | 
				
			|||||||
import librosa
 | 
					import librosa
 | 
				
			||||||
import transformers
 | 
					import transformers
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					import uvicorn
 | 
				
			||||||
from dotenv import load_dotenv
 | 
					from dotenv import load_dotenv
 | 
				
			||||||
 | 
					from fastapi import FastAPI, File, UploadFile
 | 
				
			||||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
 | 
					from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					app = FastAPI()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
warnings.filterwarnings("ignore")
 | 
					warnings.filterwarnings("ignore")
 | 
				
			||||||
transformers.logging.set_verbosity_error()
 | 
					transformers.logging.set_verbosity_error()
 | 
				
			||||||
load_dotenv()
 | 
					load_dotenv()
 | 
				
			||||||
MODEL = os.getenv("MODEL", "m3hrdadfi/wav2vec2-large-xlsr-persian")
 | 
					MODEL = os.getenv("MODEL", "m3hrdadfi/wav2vec2-large-xlsr-persian")
 | 
				
			||||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "DEBUG")
 | 
					LOG_LEVEL = os.getenv("LOG_LEVEL", "DEBUG")
 | 
				
			||||||
 | 
					PORT = int(os.getenv("PORT", 8000))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Initialize logger
 | 
				
			||||||
def initLogger():
 | 
					 | 
				
			||||||
logger = logging.getLogger("speech2text-fa")
 | 
					logger = logging.getLogger("speech2text-fa")
 | 
				
			||||||
level = getattr(logging, LOG_LEVEL.upper())
 | 
					level = getattr(logging, LOG_LEVEL.upper())
 | 
				
			||||||
    
 | 
					 | 
				
			||||||
fmt = "%(asctime)s | %(levelname)-8s | %(message)s"
 | 
					fmt = "%(asctime)s | %(levelname)-8s | %(message)s"
 | 
				
			||||||
datefmt = "%Y-%m-%d %H:%M:%S"
 | 
					datefmt = "%Y-%m-%d %H:%M:%S"
 | 
				
			||||||
ch = logging.StreamHandler()
 | 
					ch = logging.StreamHandler()
 | 
				
			||||||
@ -27,21 +31,19 @@ def initLogger():
 | 
				
			|||||||
formatter = logging.Formatter(fmt, datefmt)
 | 
					formatter = logging.Formatter(fmt, datefmt)
 | 
				
			||||||
ch.setFormatter(formatter)
 | 
					ch.setFormatter(formatter)
 | 
				
			||||||
logger.addHandler(ch)
 | 
					logger.addHandler(ch)
 | 
				
			||||||
 | 
					 | 
				
			||||||
logger.setLevel(level)
 | 
					logger.setLevel(level)
 | 
				
			||||||
    return logger
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def mp3_to_text(mp3_file_path):
 | 
					 | 
				
			||||||
    # Load the MP3 file and resample to 16kHz
 | 
					 | 
				
			||||||
    audio, sample_rate = librosa.load(mp3_file_path, sr=16000)
 | 
					 | 
				
			||||||
    logger.info("Resampling is Done!")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Load tokenizer and model from Hugging Face
 | 
					# Load tokenizer and model from Hugging Face
 | 
				
			||||||
tokenizer = Wav2Vec2Processor.from_pretrained(MODEL)
 | 
					tokenizer = Wav2Vec2Processor.from_pretrained(MODEL)
 | 
				
			||||||
model = Wav2Vec2ForCTC.from_pretrained(MODEL)
 | 
					model = Wav2Vec2ForCTC.from_pretrained(MODEL)
 | 
				
			||||||
logger.info("Loading model is Done!")
 | 
					logger.info("Loading model is Done!")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def mp3_to_text(audio_data: io.BytesIO):
 | 
				
			||||||
 | 
					    # Resample to 16kHz
 | 
				
			||||||
 | 
					    audio, sample_rate = librosa.load(audio_data, sr=16000)
 | 
				
			||||||
 | 
					    logger.info("Resampling is Done!")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Preprocess the audio
 | 
					    # Preprocess the audio
 | 
				
			||||||
    input_values = tokenizer(audio, sampling_rate=16000, return_tensors="pt", padding="longest").input_values
 | 
					    input_values = tokenizer(audio, sampling_rate=16000, return_tensors="pt", padding="longest").input_values
 | 
				
			||||||
    logits = model(input_values).logits
 | 
					    logits = model(input_values).logits
 | 
				
			||||||
@ -55,8 +57,22 @@ def mp3_to_text(mp3_file_path):
 | 
				
			|||||||
    return transcription[0]
 | 
					    return transcription[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.post("/transcribe")
 | 
				
			||||||
 | 
					async def transcribe_audio(audio_file: UploadFile = File(...)):
 | 
				
			||||||
 | 
					    # Load the audio from the file
 | 
				
			||||||
 | 
					    contents = await audio_file.read()
 | 
				
			||||||
 | 
					    audio_data = io.BytesIO(contents)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Convert to text
 | 
				
			||||||
 | 
					    transcription = mp3_to_text(audio_data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return {"transcription": transcription}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.get("/docs")
 | 
				
			||||||
 | 
					async def docs():
 | 
				
			||||||
 | 
					    return {"message": "Welcome to the speech-to-text API!"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    logger = initLogger()
 | 
					    uvicorn.run(app, host="0.0.0.0", port=PORT, log_level=LOG_LEVEL.lower())
 | 
				
			||||||
    text = mp3_to_text("samples/sample1.wav")
 | 
					 | 
				
			||||||
    print()
 | 
					 | 
				
			||||||
    print(text)
 | 
					 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user