Source code for psifx.audio.transcription.whisper.tool

"""WhisperX transcription tool."""
from collections import OrderedDict, defaultdict
import os
from typing import Union, Optional
from typing import Any
from pathlib import Path
import torch
from psifx.audio.transcription.tool import TranscriptionTool
from psifx.io import vtt, wav
from psifx.utils.huggingface import patch_hf_hub_download_use_auth_token
import whisperx


[docs] class WhisperXTool(TranscriptionTool): """ WhisperX transcription and translation tool. :param model_name: The name of the model to use. :param device: The device where the computation should be executed. :param overwrite: Whether to overwrite existing files, otherwise raise an error. :param verbose: Whether to execute the computation verbosely. """ def __init__( self, model_name: str = "distil-large-v3", task: str = "transcribe", device: str = "cpu", overwrite: bool = False, verbose: Union[bool, int] = True, ): super().__init__( device=device, overwrite=overwrite, verbose=verbose, ) self.model_name = model_name if task not in ["transcribe", "translate"]: raise NameError(f"task should be 'transcribe' or 'translate', got '{task}' instead") self.task = task compute_type = "float16" if device == 'cuda' else "float32" os.environ.setdefault("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "1") self._register_torch_safe_globals() patch_hf_hub_download_use_auth_token() self.pipeline = whisperx.load_model(model_name, task=task, device=device, compute_type=compute_type)
[docs] def inference( self, audio_path: Union[str, Path], transcription_path: Union[str, Path], batch_size: int = 16, language: Optional[str] = None, ): """ WhisperX's backed transcription method. :param audio_path: Path to the audio track. :param transcription_path: Path to the transcription file. :param batch_size: Batch size, reduce if low on GPU memory. :param language: Country-code string of the spoken language. :return: """ audio_path = Path(audio_path) transcription_path = Path(transcription_path) if self.verbose: print(f"WhisperX") print(f"model name = {self.model_name}") print(f"task = {self.task}") if language is not None: print(f"language = {language}") print(f"audio = {audio_path}") print(f"transcription = {transcription_path}") wav.WAVReader.check(path=audio_path) vtt.VTTWriter.check(path=transcription_path, overwrite=self.overwrite) audio = whisperx.load_audio(audio_path) result = self.pipeline.transcribe(audio, batch_size=batch_size) model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=self.device) result = whisperx.align(result["segments"], model_a, metadata, audio, self.device, return_char_alignments=False) vtt.VTTWriter.write( segments=result["segments"], path=transcription_path, overwrite=self.overwrite )
@staticmethod def _register_torch_safe_globals() -> None: add_safe_globals = getattr(torch.serialization, "add_safe_globals", None) if add_safe_globals is None: return safe_globals = [] try: from omegaconf.listconfig import ListConfig safe_globals.append(ListConfig) except Exception: pass try: from omegaconf.dictconfig import DictConfig safe_globals.append(DictConfig) except Exception: pass try: from omegaconf.base import ContainerMetadata safe_globals.append(ContainerMetadata) except Exception: pass try: import omegaconf.base as omegaconf_base safe_globals.extend( value for value in vars(omegaconf_base).values() if isinstance(value, type) ) except Exception: pass try: import omegaconf.nodes as omegaconf_nodes safe_globals.extend( value for name, value in vars(omegaconf_nodes).items() if name.endswith("Node") and isinstance(value, type) ) except Exception: pass safe_globals.extend([Any, list, dict, tuple, set, int, float, bool, str, bytes, OrderedDict, defaultdict]) if safe_globals: add_safe_globals(safe_globals)