Add logging

This commit is contained in:
Reza Behzadan 2023-12-10 21:54:02 +03:30
parent d96084d798
commit 5add45c71c

40
main.py
View File

@ -1,3 +1,5 @@
import logging
import librosa import librosa
import torch import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
@ -6,34 +8,48 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
MODEL = "/home/reza/data/huggingface-models/04.wav2vec2-large-xlsr-persian" MODEL = "/home/reza/data/huggingface-models/04.wav2vec2-large-xlsr-persian"
def initLogger(name=__name__, level=logging.DEBUG):
if name[:2] == '__' and name[-2:] == '__':
name = name[2:-2]
logger = logging.getLogger(name)
fmt = '%(asctime)s | %(levelname)-8s | %(name)s | %(message)s'
datefmt = '%Y-%m-%d %H:%M:%S'
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter(fmt, datefmt)
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.setLevel(level)
return logger
def mp3_to_text(mp3_file_path): def mp3_to_text(mp3_file_path):
# Load the MP3 file and resample to 16kHz # Load the MP3 file and resample to 16kHz
audio, sample_rate = librosa.load(mp3_file_path, sr=16000) audio, sample_rate = librosa.load(mp3_file_path, sr=16000)
print() logger.info("Resampling is Done!")
print("Resampling is Done!")
# Load tokenizer and model from Hugging Face # Load tokenizer and model from Hugging Face
tokenizer = Wav2Vec2Processor.from_pretrained(MODEL) tokenizer = Wav2Vec2Processor.from_pretrained(MODEL)
model = Wav2Vec2ForCTC.from_pretrained(MODEL) model = Wav2Vec2ForCTC.from_pretrained(MODEL)
print() logger.info("Loading model is Done!")
print("Loading model is Done!")
# Preprocess the audio # Preprocess the audio
input_values = tokenizer(audio, sampling_rate=16000, return_tensors="pt", padding="longest").input_values input_values = tokenizer(audio, sampling_rate=16000, return_tensors="pt", padding="longest").input_values
logits = model(input_values).logits logits = model(input_values).logits
print() logger.info("Processing the audio is Done!")
print("Processing the audio is Done!")
# Decode the predicted IDs # Decode the predicted IDs
predicted_ids = torch.argmax(logits, dim=-1) predicted_ids = torch.argmax(logits, dim=-1)
transcription = tokenizer.batch_decode(predicted_ids) transcription = tokenizer.batch_decode(predicted_ids)
print() logger.info("Decoding the prodicted IDs is Done!")
print("Decoding the prodicted IDs is Done!")
return transcription[0] return transcription[0]
# text = mp3_to_text("samples/captcha.mp3")
text = mp3_to_text("samples/sample1.wav")
print()
print(text)
if __name__ == "__main__":
logger = initLogger('speech2text_fa', level=logging.INFO)
text = mp3_to_text("samples/sample1.wav")
print()
print(text)