Source code for psifx.audio.transcription.whisper.tool
"""WhisperX transcription tool."""
from collections import OrderedDict, defaultdict
import os
from typing import Union, Optional
from typing import Any
from pathlib import Path
import torch
from psifx.audio.transcription.tool import TranscriptionTool
from psifx.io import vtt, wav
from psifx.utils.huggingface import patch_hf_hub_download_use_auth_token
import whisperx
[docs]
class WhisperXTool(TranscriptionTool):
"""
WhisperX transcription and translation tool.
:param model_name: The name of the model to use.
:param device: The device where the computation should be executed.
:param overwrite: Whether to overwrite existing files, otherwise raise an error.
:param verbose: Whether to execute the computation verbosely.
"""
def __init__(
self,
model_name: str = "distil-large-v3",
task: str = "transcribe",
device: str = "cpu",
overwrite: bool = False,
verbose: Union[bool, int] = True,
):
super().__init__(
device=device,
overwrite=overwrite,
verbose=verbose,
)
self.model_name = model_name
if task not in ["transcribe", "translate"]:
raise NameError(f"task should be 'transcribe' or 'translate', got '{task}' instead")
self.task = task
compute_type = "float16" if device == 'cuda' else "float32"
os.environ.setdefault("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "1")
self._register_torch_safe_globals()
patch_hf_hub_download_use_auth_token()
self.pipeline = whisperx.load_model(model_name, task=task, device=device, compute_type=compute_type)
[docs]
def inference(
self,
audio_path: Union[str, Path],
transcription_path: Union[str, Path],
batch_size: int = 16,
language: Optional[str] = None,
):
"""
WhisperX's backed transcription method.
:param audio_path: Path to the audio track.
:param transcription_path: Path to the transcription file.
:param batch_size: Batch size, reduce if low on GPU memory.
:param language: Country-code string of the spoken language.
:return:
"""
audio_path = Path(audio_path)
transcription_path = Path(transcription_path)
if self.verbose:
print(f"WhisperX")
print(f"model name = {self.model_name}")
print(f"task = {self.task}")
if language is not None:
print(f"language = {language}")
print(f"audio = {audio_path}")
print(f"transcription = {transcription_path}")
wav.WAVReader.check(path=audio_path)
vtt.VTTWriter.check(path=transcription_path, overwrite=self.overwrite)
audio = whisperx.load_audio(audio_path)
result = self.pipeline.transcribe(audio, batch_size=batch_size)
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=self.device)
result = whisperx.align(result["segments"], model_a, metadata, audio, self.device, return_char_alignments=False)
vtt.VTTWriter.write(
segments=result["segments"], path=transcription_path, overwrite=self.overwrite
)
@staticmethod
def _register_torch_safe_globals() -> None:
add_safe_globals = getattr(torch.serialization, "add_safe_globals", None)
if add_safe_globals is None:
return
safe_globals = []
try:
from omegaconf.listconfig import ListConfig
safe_globals.append(ListConfig)
except Exception:
pass
try:
from omegaconf.dictconfig import DictConfig
safe_globals.append(DictConfig)
except Exception:
pass
try:
from omegaconf.base import ContainerMetadata
safe_globals.append(ContainerMetadata)
except Exception:
pass
try:
import omegaconf.base as omegaconf_base
safe_globals.extend(
value for value in vars(omegaconf_base).values() if isinstance(value, type)
)
except Exception:
pass
try:
import omegaconf.nodes as omegaconf_nodes
safe_globals.extend(
value
for name, value in vars(omegaconf_nodes).items()
if name.endswith("Node") and isinstance(value, type)
)
except Exception:
pass
safe_globals.extend([Any, list, dict, tuple, set, int, float, bool, str, bytes, OrderedDict, defaultdict])
if safe_globals:
add_safe_globals(safe_globals)