Add logging
This commit is contained in:
parent
d96084d798
commit
5add45c71c
40
main.py
40
main.py
@ -1,3 +1,5 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import torch
|
import torch
|
||||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||||
@ -6,34 +8,48 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
|||||||
MODEL = "/home/reza/data/huggingface-models/04.wav2vec2-large-xlsr-persian"
|
MODEL = "/home/reza/data/huggingface-models/04.wav2vec2-large-xlsr-persian"
|
||||||
|
|
||||||
|
|
||||||
|
def initLogger(name=__name__, level=logging.DEBUG):
|
||||||
|
if name[:2] == '__' and name[-2:] == '__':
|
||||||
|
name = name[2:-2]
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
|
||||||
|
fmt = '%(asctime)s | %(levelname)-8s | %(name)s | %(message)s'
|
||||||
|
datefmt = '%Y-%m-%d %H:%M:%S'
|
||||||
|
ch = logging.StreamHandler()
|
||||||
|
ch.setLevel(logging.DEBUG)
|
||||||
|
formatter = logging.Formatter(fmt, datefmt)
|
||||||
|
ch.setFormatter(formatter)
|
||||||
|
logger.addHandler(ch)
|
||||||
|
|
||||||
|
logger.setLevel(level)
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
def mp3_to_text(mp3_file_path):
|
def mp3_to_text(mp3_file_path):
|
||||||
# Load the MP3 file and resample to 16kHz
|
# Load the MP3 file and resample to 16kHz
|
||||||
audio, sample_rate = librosa.load(mp3_file_path, sr=16000)
|
audio, sample_rate = librosa.load(mp3_file_path, sr=16000)
|
||||||
print()
|
logger.info("Resampling is Done!")
|
||||||
print("Resampling is Done!")
|
|
||||||
|
|
||||||
# Load tokenizer and model from Hugging Face
|
# Load tokenizer and model from Hugging Face
|
||||||
tokenizer = Wav2Vec2Processor.from_pretrained(MODEL)
|
tokenizer = Wav2Vec2Processor.from_pretrained(MODEL)
|
||||||
model = Wav2Vec2ForCTC.from_pretrained(MODEL)
|
model = Wav2Vec2ForCTC.from_pretrained(MODEL)
|
||||||
print()
|
logger.info("Loading model is Done!")
|
||||||
print("Loading model is Done!")
|
|
||||||
|
|
||||||
# Preprocess the audio
|
# Preprocess the audio
|
||||||
input_values = tokenizer(audio, sampling_rate=16000, return_tensors="pt", padding="longest").input_values
|
input_values = tokenizer(audio, sampling_rate=16000, return_tensors="pt", padding="longest").input_values
|
||||||
logits = model(input_values).logits
|
logits = model(input_values).logits
|
||||||
print()
|
logger.info("Processing the audio is Done!")
|
||||||
print("Processing the audio is Done!")
|
|
||||||
|
|
||||||
# Decode the predicted IDs
|
# Decode the predicted IDs
|
||||||
predicted_ids = torch.argmax(logits, dim=-1)
|
predicted_ids = torch.argmax(logits, dim=-1)
|
||||||
transcription = tokenizer.batch_decode(predicted_ids)
|
transcription = tokenizer.batch_decode(predicted_ids)
|
||||||
print()
|
logger.info("Decoding the prodicted IDs is Done!")
|
||||||
print("Decoding the prodicted IDs is Done!")
|
|
||||||
|
|
||||||
return transcription[0]
|
return transcription[0]
|
||||||
|
|
||||||
# text = mp3_to_text("samples/captcha.mp3")
|
|
||||||
text = mp3_to_text("samples/sample1.wav")
|
|
||||||
print()
|
|
||||||
print(text)
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logger = initLogger('speech2text_fa', level=logging.INFO)
|
||||||
|
text = mp3_to_text("samples/sample1.wav")
|
||||||
|
print()
|
||||||
|
print(text)
|
||||||
|
Loading…
Reference in New Issue
Block a user