Add logging
This commit is contained in:
		
							parent
							
								
									d96084d798
								
							
						
					
					
						commit
						5add45c71c
					
				
							
								
								
									
										40
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										40
									
								
								main.py
									
									
									
									
									
								
							@ -1,3 +1,5 @@
 | 
				
			|||||||
 | 
					import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import librosa
 | 
					import librosa
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
 | 
					from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
 | 
				
			||||||
@ -6,34 +8,48 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
 | 
				
			|||||||
MODEL = "/home/reza/data/huggingface-models/04.wav2vec2-large-xlsr-persian"
 | 
					MODEL = "/home/reza/data/huggingface-models/04.wav2vec2-large-xlsr-persian"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def initLogger(name=__name__, level=logging.DEBUG):
 | 
				
			||||||
 | 
					    if name[:2] == '__' and name[-2:] == '__':
 | 
				
			||||||
 | 
					        name = name[2:-2]
 | 
				
			||||||
 | 
					    logger = logging.getLogger(name)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    fmt = '%(asctime)s | %(levelname)-8s | %(name)s | %(message)s'
 | 
				
			||||||
 | 
					    datefmt = '%Y-%m-%d %H:%M:%S'
 | 
				
			||||||
 | 
					    ch = logging.StreamHandler()
 | 
				
			||||||
 | 
					    ch.setLevel(logging.DEBUG)
 | 
				
			||||||
 | 
					    formatter = logging.Formatter(fmt, datefmt)
 | 
				
			||||||
 | 
					    ch.setFormatter(formatter)
 | 
				
			||||||
 | 
					    logger.addHandler(ch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logger.setLevel(level)
 | 
				
			||||||
 | 
					    return logger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def mp3_to_text(mp3_file_path):
 | 
					def mp3_to_text(mp3_file_path):
 | 
				
			||||||
    # Load the MP3 file and resample to 16kHz
 | 
					    # Load the MP3 file and resample to 16kHz
 | 
				
			||||||
    audio, sample_rate = librosa.load(mp3_file_path, sr=16000)
 | 
					    audio, sample_rate = librosa.load(mp3_file_path, sr=16000)
 | 
				
			||||||
    print()
 | 
					    logger.info("Resampling is Done!")
 | 
				
			||||||
    print("Resampling is Done!")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Load tokenizer and model from Hugging Face
 | 
					    # Load tokenizer and model from Hugging Face
 | 
				
			||||||
    tokenizer = Wav2Vec2Processor.from_pretrained(MODEL)
 | 
					    tokenizer = Wav2Vec2Processor.from_pretrained(MODEL)
 | 
				
			||||||
    model = Wav2Vec2ForCTC.from_pretrained(MODEL)
 | 
					    model = Wav2Vec2ForCTC.from_pretrained(MODEL)
 | 
				
			||||||
    print()
 | 
					    logger.info("Loading model is Done!")
 | 
				
			||||||
    print("Loading model is Done!")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Preprocess the audio
 | 
					    # Preprocess the audio
 | 
				
			||||||
    input_values = tokenizer(audio, sampling_rate=16000, return_tensors="pt", padding="longest").input_values
 | 
					    input_values = tokenizer(audio, sampling_rate=16000, return_tensors="pt", padding="longest").input_values
 | 
				
			||||||
    logits = model(input_values).logits
 | 
					    logits = model(input_values).logits
 | 
				
			||||||
    print()
 | 
					    logger.info("Processing the audio is Done!")
 | 
				
			||||||
    print("Processing the audio is Done!")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Decode the predicted IDs
 | 
					    # Decode the predicted IDs
 | 
				
			||||||
    predicted_ids = torch.argmax(logits, dim=-1)
 | 
					    predicted_ids = torch.argmax(logits, dim=-1)
 | 
				
			||||||
    transcription = tokenizer.batch_decode(predicted_ids)
 | 
					    transcription = tokenizer.batch_decode(predicted_ids)
 | 
				
			||||||
    print()
 | 
					    logger.info("Decoding the prodicted IDs is Done!")
 | 
				
			||||||
    print("Decoding the prodicted IDs is Done!")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return transcription[0]
 | 
					    return transcription[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# text = mp3_to_text("samples/captcha.mp3")
 | 
					 | 
				
			||||||
text = mp3_to_text("samples/sample1.wav")
 | 
					 | 
				
			||||||
print()
 | 
					 | 
				
			||||||
print(text)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    logger = initLogger('speech2text_fa', level=logging.INFO)
 | 
				
			||||||
 | 
					    text = mp3_to_text("samples/sample1.wav")
 | 
				
			||||||
 | 
					    print()
 | 
				
			||||||
 | 
					    print(text)
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user