Source code for psifx.audio.identification.pyannote.tool
"""pyannote speaker identification tool."""
from collections import OrderedDict, defaultdict
import os
from typing import Union, Optional, Sequence
from typing import Any
from itertools import permutations
from pathlib import Path
from tqdm import tqdm
import numpy as np
import pandas as pd
import torch
from torch import Tensor
from pyannote.audio import Audio
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
from pyannote.core import Segment
from psifx.audio.identification.tool import IdentificationTool
from psifx.io import rttm, json, wav
from psifx.utils.huggingface import patch_hf_hub_download_use_auth_token
[docs]
def cropped_waveform(
path: Union[str, Path],
start: float,
end: float,
sample_rate: int = 32000,
) -> Tensor:
"""
Crops an audio track and returns its corresponding waveform.
:param path: Path to the audio track.
:param start: Start of segment in seconds.
:param end: End of the segment in seconds.
:param sample_rate: Sample rate of the audio track.
:return: Tensor containing the waveform of the audio segment.
"""
waveform, sample_rate = Audio(
sample_rate=sample_rate,
).crop(
file=path,
segment=Segment(start, end),
mode="pad",
)
assert waveform.shape[0] == 1, f"Audio at path {path} is not mono, got {waveform.shape[0]} channels."
return waveform
[docs]
class PyannoteIdentificationTool(IdentificationTool):
"""
pyannote speaker identification tool.
:param model_names: The names of the models to use.
:param api_token: The HuggingFace API token to use.
:param device: The device where the computation should be executed.
:param overwrite: Whether to overwrite existing files, otherwise raise an error.
:param verbose: Whether to execute the computation verbosely.
"""
def __init__(
self,
model_names: Sequence[str],
api_token: Optional[str] = None,
device: str = "cpu",
overwrite: bool = False,
verbose: Union[bool, int] = True,
):
super().__init__(
device=device,
overwrite=overwrite,
verbose=verbose,
)
self.api_token = api_token or os.environ.get('HF_TOKEN')
os.environ.setdefault("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "1")
self._register_torch_safe_globals()
patch_hf_hub_download_use_auth_token()
try:
self.models = {
name: PretrainedSpeakerEmbedding(
embedding=name,
device=torch.device(device),
use_auth_token=self.api_token,
)
for name in model_names
}
except AttributeError as exc:
message = str(exc)
if "'NoneType' object has no attribute 'eval'" in message:
raise PermissionError(
"Could not load pyannote embedding model(s). Ensure HF_TOKEN is valid and "
"the token account accepted pyannote model terms on Hugging Face."
) from exc
raise
[docs]
def inference(
self,
mixed_audio_path: Union[str, Path],
diarization_path: Union[str, Path],
mono_audio_paths: Sequence[Union[str, Path]],
identification_path: Union[str, Path],
):
"""
pyannote's backed inference method.
:param mixed_audio_path: Path to the mixed audio track.
:param diarization_path: Path to the diarization file.
:param mono_audio_paths: Path to the mono audio tracks.
:param identification_path: Path to the identification file.
:return:
"""
mixed_audio_path = Path(mixed_audio_path)
diarization_path = Path(diarization_path)
mono_audio_paths = [Path(path) for path in mono_audio_paths]
identification_path = Path(identification_path)
assert mixed_audio_path not in mono_audio_paths
assert sorted(set(mono_audio_paths)) == sorted(mono_audio_paths)
wav.WAVReader.check(path=mixed_audio_path)
rttm.RTTMReader.check(path=diarization_path)
for path in mono_audio_paths:
wav.WAVReader.check(path=path)
json.JSONWriter.check(path=identification_path, overwrite=self.overwrite)
segments = rttm.RTTMReader.read(path=diarization_path, verbose=self.verbose)
dataframe = pd.DataFrame.from_records(segments)
dataframe["end"] = dataframe["start"] + dataframe["duration"]
for name, model in tqdm(
self.models.items(),
desc="Processing",
disable=not self.verbose,
):
distances = []
valids = []
for index in tqdm(
range(dataframe.shape[0]),
desc="Model Embedding",
disable=not self.verbose,
leave=False,
):
row = dataframe.iloc[index]
if row["duration"] < 0.300:
distance = np.nan
else:
mixed_embedding = model(
waveforms=cropped_waveform(
path=mixed_audio_path,
start=row["start"],
end=row["end"],
)[None, ...]
)
mono_embeddings = np.concatenate(
[
model(
waveforms=cropped_waveform(
path=path,
start=row["start"],
end=row["end"],
)[None, ...]
)
for path in mono_audio_paths
]
)
delta = mixed_embedding - mono_embeddings
distance = np.linalg.norm(delta, ord=2, axis=-1)
valid = np.isfinite(distance).all()
distances.append(distance)
valids.append(valid)
dataframe[f"distance_{name}"] = distances
dataframe[f"valid_{name}"] = valids
for name in self.models.keys():
dataframe.drop(
index=dataframe[~dataframe[f"valid_{name}"]].index,
inplace=True,
)
mono_audio_names = [path.name for path in mono_audio_paths]
best_mapping = None
best_agreement = 0.0
for speaker_names in tqdm(
permutations(pd.Categorical(dataframe["speaker_name"]).categories.tolist()),
desc="Voting",
disable=not self.verbose,
):
mapping = dict(zip(speaker_names, mono_audio_names))
speaker_ids = np.stack(
[
mono_audio_names.index(mapping[name])
for name in dataframe["speaker_name"].values
]
)
model_agreements = []
for name in self.models.keys():
distances = np.stack(dataframe[f"distance_{name}"].values)
closest_ids = distances.argmin(axis=-1)
model_agreement = (speaker_ids == closest_ids).mean()
model_agreements.append(model_agreement)
average_agreement = np.stack(model_agreements).mean()
if average_agreement > best_agreement:
best_agreement = average_agreement
best_mapping = mapping
data = {
"mapping": best_mapping,
"agreement": best_agreement,
}
json.JSONWriter.write(
data=data,
path=identification_path,
overwrite=self.overwrite,
verbose=self.verbose,
)
@staticmethod
def _register_torch_safe_globals() -> None:
add_safe_globals = getattr(torch.serialization, "add_safe_globals", None)
if add_safe_globals is None:
return
safe_globals = []
try:
from omegaconf.listconfig import ListConfig
safe_globals.append(ListConfig)
except Exception:
pass
try:
from omegaconf.dictconfig import DictConfig
safe_globals.append(DictConfig)
except Exception:
pass
try:
from omegaconf.base import ContainerMetadata
safe_globals.append(ContainerMetadata)
except Exception:
pass
try:
import omegaconf.base as omegaconf_base
safe_globals.extend(
value for value in vars(omegaconf_base).values() if isinstance(value, type)
)
except Exception:
pass
try:
import omegaconf.nodes as omegaconf_nodes
safe_globals.extend(
value
for name, value in vars(omegaconf_nodes).items()
if name.endswith("Node") and isinstance(value, type)
)
except Exception:
pass
safe_globals.extend([Any, list, dict, tuple, set, int, float, bool, str, bytes, OrderedDict, defaultdict])
if safe_globals:
add_safe_globals(safe_globals)