Source code for psifx.audio.transcription.whisper.tool
"""WhisperX transcription tool."""
from typing import Union, Optional
from pathlib import Path
from psifx.audio.transcription.tool import TranscriptionTool
from psifx.io import vtt, wav
import whisperx
[docs]
class WhisperXTool(TranscriptionTool):
"""
WhisperX transcription and translation tool.
:param model_name: The name of the model to use.
:param device: The device where the computation should be executed.
:param overwrite: Whether to overwrite existing files, otherwise raise an error.
:param verbose: Whether to execute the computation verbosely.
"""
def __init__(
self,
model_name: str = "distil-large-v3",
task: str = "transcribe",
device: str = "cpu",
overwrite: bool = False,
verbose: Union[bool, int] = True,
):
super().__init__(
device=device,
overwrite=overwrite,
verbose=verbose,
)
self.model_name = model_name
if task not in ["transcribe", "translate"]:
raise NameError(f"task should be 'transcribe' or 'translate', got '{task}' instead")
self.task = task
compute_type = "float16" if device == 'cuda' else "float32"
self.pipeline = whisperx.load_model(model_name, task=task, device=device, compute_type=compute_type)
[docs]
def inference(
self,
audio_path: Union[str, Path],
transcription_path: Union[str, Path],
batch_size: int = 16,
language: Optional[str] = None,
):
"""
WhisperX's backed transcription method.
:param audio_path: Path to the audio track.
:param transcription_path: Path to the transcription file.
:param batch_size: Batch size, reduce if low on GPU memory.
:param language: Country-code string of the spoken language.
:return:
"""
audio_path = Path(audio_path)
transcription_path = Path(transcription_path)
if self.verbose:
print(f"WhisperX")
print(f"model name = {self.model_name}")
print(f"task = {self.task}")
if language is not None:
print(f"language = {language}")
print(f"audio = {audio_path}")
print(f"transcription = {transcription_path}")
wav.WAVReader.check(path=audio_path)
vtt.VTTWriter.check(path=transcription_path, overwrite=self.overwrite)
audio = whisperx.load_audio(audio_path)
result = self.pipeline.transcribe(audio, batch_size=batch_size)
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=self.device)
result = whisperx.align(result["segments"], model_a, metadata, audio, self.device, return_char_alignments=False)
vtt.VTTWriter.write(
segments=result["segments"], path=transcription_path, overwrite=self.overwrite
)