Source code for psifx.audio.transcription.tool

"""transcription tool."""

from typing import Union

from pathlib import Path

import pandas as pd

from psifx.tool import Tool
from psifx.io import vtt, rttm, json


[docs] class TranscriptionTool(Tool): """ Base class for transcription tools. """
[docs] def inference( self, audio_path: Union[str, Path], transcription_path: Union[str, Path], ): """ Template of the inference method. :param audio_path: Path to the audio track. :param transcription_path: Path to the transcription file. :return: """ audio_path = Path(audio_path) transcription_path = Path(transcription_path) # audio = load(audio_path) # audio = pre_process_func(audio) # transcription = model(audio) # transcription = post_process_func(transcription) # write(transcription, transcription_path) raise NotImplementedError
[docs] def enhance( self, transcription_path: Union[str, Path], diarization_path: Union[str, Path], identification_path: Union[str, Path], enhanced_transcription_path: Union[str, Path], ): """ Enhances an audio transcription by fusing the transcribed audio with inferred speaker diarization and identification. :param transcription_path: Path to the transcription file. :param diarization_path: Path to the diarization file. :param identification_path: Path to the identification file. :param enhanced_transcription_path: Path to the enhanced diarization file. :return: """ transcription_path = Path(transcription_path) diarization_path = Path(diarization_path) identification_path = Path(identification_path) enhanced_transcription_path = Path(enhanced_transcription_path) assert transcription_path != enhanced_transcription_path if self.verbose: print(f"transcription = {transcription_path}") print(f"diarization = {diarization_path}") print(f"identification = {identification_path}") print(f"enhanced_transcription = {enhanced_transcription_path}") vtt.VTTReader.check(path=transcription_path) rttm.RTTMReader.check(path=diarization_path) json.JSONReader.check(path=identification_path) vtt.VTTWriter.check(path=enhanced_transcription_path, overwrite=self.overwrite) transcription = vtt.VTTReader.read(transcription_path) transcription = pd.DataFrame.from_records(transcription) diarization = rttm.RTTMReader.read(diarization_path) diarization = pd.DataFrame.from_records(diarization) diarization["end"] = diarization["start"] + diarization["duration"] identification = json.JSONReader.read(identification_path) mapping = identification["mapping"] segments = [] for transcription_index in range(len(transcription)): transcription_row = transcription.iloc[transcription_index] highest_iou_index, highest_iou = None, 0.0 for diarization_index in range(len(diarization)): diarization_row = diarization.iloc[diarization_index] intersection_start = max( transcription_row["start"], diarization_row["start"] ) intersection_end = min(transcription_row["end"], diarization_row["end"]) union_start = min(transcription_row["start"], diarization_row["start"]) union_end = max(transcription_row["end"], diarization_row["end"]) intersection_duration = max(0.0, intersection_end - intersection_start) union_duration = max(0.0, union_end - union_start) iou = intersection_duration / union_duration if iou > highest_iou: highest_iou_index, highest_iou = diarization_index, iou matching_diarization_index = highest_iou_index if matching_diarization_index is not None: speaker_name = mapping[ diarization.iloc[matching_diarization_index]["speaker_name"] ] else: speaker_name = "NA" transcription.loc[transcription_index, "speaker"] = speaker_name segment = { "start": transcription.loc[transcription_index, "start"], "end": transcription.loc[transcription_index, "end"], "speaker": transcription.loc[transcription_index, "speaker"], "text": transcription.loc[transcription_index, "text"], } segments.append(segment) vtt.VTTWriter.write( segments=segments, path=enhanced_transcription_path, overwrite=self.overwrite, verbose=self.verbose, )