Source code for psifx.audio.identification.pyannote.tool
"""pyannote speaker identification tool."""
import os
from typing import Union, Optional, Sequence
from itertools import permutations
from pathlib import Path
from tqdm import tqdm
import numpy as np
import pandas as pd
import torch
from torch import Tensor
from pyannote.audio import Audio
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
from pyannote.core import Segment
from psifx.audio.identification.tool import IdentificationTool
from psifx.io import rttm, json, wav
[docs]
def cropped_waveform(
path: Union[str, Path],
start: float,
end: float,
sample_rate: int = 32000,
) -> Tensor:
"""
Crops an audio track and returns its corresponding waveform.
:param path: Path to the audio track.
:param start: Start of segment in seconds.
:param end: End of the segment in seconds.
:param sample_rate: Sample rate of the audio track.
:return: Tensor containing the waveform of the audio segment.
"""
waveform, sample_rate = Audio(
sample_rate=sample_rate,
).crop(
file=path,
segment=Segment(start, end),
mode="pad",
)
assert waveform.shape[0] == 1, f"Audio at path {path} is not mono, got {waveform.shape[0]} channels."
return waveform
[docs]
class PyannoteIdentificationTool(IdentificationTool):
"""
pyannote speaker identification tool.
:param model_names: The names of the models to use.
:param api_token: The HuggingFace API token to use.
:param device: The device where the computation should be executed.
:param overwrite: Whether to overwrite existing files, otherwise raise an error.
:param verbose: Whether to execute the computation verbosely.
"""
def __init__(
self,
model_names: Sequence[str],
api_token: Optional[str] = None,
device: str = "cpu",
overwrite: bool = False,
verbose: Union[bool, int] = True,
):
super().__init__(
device=device,
overwrite=overwrite,
verbose=verbose,
)
self.api_token = api_token or os.environ.get('HF_TOKEN')
self.models = {
name: PretrainedSpeakerEmbedding(
embedding=name,
device=torch.device(device),
use_auth_token=api_token,
)
for name in model_names
}
[docs]
def inference(
self,
mixed_audio_path: Union[str, Path],
diarization_path: Union[str, Path],
mono_audio_paths: Sequence[Union[str, Path]],
identification_path: Union[str, Path],
):
"""
pyannote's backed inference method.
:param mixed_audio_path: Path to the mixed audio track.
:param diarization_path: Path to the diarization file.
:param mono_audio_paths: Path to the mono audio tracks.
:param identification_path: Path to the identification file.
:return:
"""
mixed_audio_path = Path(mixed_audio_path)
diarization_path = Path(diarization_path)
mono_audio_paths = [Path(path) for path in mono_audio_paths]
identification_path = Path(identification_path)
assert mixed_audio_path not in mono_audio_paths
assert sorted(set(mono_audio_paths)) == sorted(mono_audio_paths)
wav.WAVReader.check(path=mixed_audio_path)
rttm.RTTMReader.check(path=diarization_path)
for path in mono_audio_paths:
wav.WAVReader.check(path=path)
json.JSONWriter.check(path=identification_path, overwrite=self.overwrite)
segments = rttm.RTTMReader.read(path=diarization_path, verbose=self.verbose)
dataframe = pd.DataFrame.from_records(segments)
dataframe["end"] = dataframe["start"] + dataframe["duration"]
for name, model in tqdm(
self.models.items(),
desc="Processing",
disable=not self.verbose,
):
distances = []
valids = []
for index in tqdm(
range(dataframe.shape[0]),
desc="Model Embedding",
disable=not self.verbose,
leave=False,
):
row = dataframe.iloc[index]
if row["duration"] < 0.300:
distance = np.nan
else:
mixed_embedding = model(
waveforms=cropped_waveform(
path=mixed_audio_path,
start=row["start"],
end=row["end"],
)[None, ...]
)
mono_embeddings = np.concatenate(
[
model(
waveforms=cropped_waveform(
path=path,
start=row["start"],
end=row["end"],
)[None, ...]
)
for path in mono_audio_paths
]
)
delta = mixed_embedding - mono_embeddings
distance = np.linalg.norm(delta, ord=2, axis=-1)
valid = np.isfinite(distance).all()
distances.append(distance)
valids.append(valid)
dataframe[f"distance_{name}"] = distances
dataframe[f"valid_{name}"] = valids
for name in self.models.keys():
dataframe.drop(
index=dataframe[~dataframe[f"valid_{name}"]].index,
inplace=True,
)
mono_audio_names = [path.name for path in mono_audio_paths]
best_mapping = None
best_agreement = 0.0
for speaker_names in tqdm(
permutations(pd.Categorical(dataframe["speaker_name"]).categories.tolist()),
desc="Voting",
disable=not self.verbose,
):
mapping = dict(zip(speaker_names, mono_audio_names))
speaker_ids = np.stack(
[
mono_audio_names.index(mapping[name])
for name in dataframe["speaker_name"].values
]
)
model_agreements = []
for name in self.models.keys():
distances = np.stack(dataframe[f"distance_{name}"].values)
closest_ids = distances.argmin(axis=-1)
model_agreement = (speaker_ids == closest_ids).mean()
model_agreements.append(model_agreement)
average_agreement = np.stack(model_agreements).mean()
if average_agreement > best_agreement:
best_agreement = average_agreement
best_mapping = mapping
data = {
"mapping": best_mapping,
"agreement": best_agreement,
}
json.JSONWriter.write(
data=data,
path=identification_path,
overwrite=self.overwrite,
verbose=self.verbose,
)