personal memory agent
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

Update diarize.py for pyannote.audio 4.0 API compatibility

- Replace deprecated `use_auth_token` with `token` parameter
- Use `exclusive_speaker_diarization` for cleaner turn segments
- Use built-in `speaker_embeddings` from DiarizeOutput instead of
manual extraction with separate embedding model
- Simplify `save_speaker_embeddings` to accept dict directly
- Remove unused EMB_MODEL_ID constant and imports

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

+52 -83
+44 -75
observe/diarize.py
··· 1 1 """Speaker diarization using pyannote pipeline with embedding extraction. 2 2 3 3 This module provides speaker diarization (who spoke when) using the pyannote 4 - community pipeline, along with per-turn speaker embeddings for future 5 - speaker identification. 4 + speaker-diarization-3.1 pipeline, along with per-turn speaker embeddings for 5 + future speaker identification. 6 6 7 7 Requires HUGGINGFACE_API_KEY environment variable for HuggingFace authentication. 8 8 """ ··· 18 18 import soundfile as sf 19 19 20 20 PIPELINE_ID = "pyannote/speaker-diarization-3.1" 21 - EMB_MODEL_ID = "pyannote/wespeaker-voxceleb-resnet34-LM" 22 21 23 22 # Built-in parameters 24 23 SEGMENTATION_STEP = 0.2 # 0.1 = 90% overlap (default), 0.2 = 80% overlap (2x faster) ··· 49 48 50 49 def diarize( 51 50 audio_path: Path, 52 - ) -> tuple[list[dict], np.ndarray, dict, list[dict]]: 53 - """Run speaker diarization and extract per-turn embeddings. 51 + ) -> tuple[list[dict], dict[str, np.ndarray], dict, list[dict]]: 52 + """Run speaker diarization and extract per-speaker embeddings. 53 + 54 + Uses exclusive diarization (no overlapping speech in turns) for cleaner 55 + transcription segments. Overlapping speech regions are reported separately. 54 56 55 57 Args: 56 58 audio_path: Path to audio file (FLAC, WAV, etc.) 57 59 58 60 Returns: 59 - Tuple of (turns, embeddings, timings, overlaps) where: 61 + Tuple of (turns, speaker_embeddings, timings, overlaps) where: 60 62 - turns: List of dicts with "start", "end", "speaker" keys 61 63 Speaker labels are human-readable: "Speaker 1", "Speaker 2", etc. 62 - - embeddings: numpy array of shape (num_turns, 256) 64 + Uses exclusive diarization (no overlapping segments). 65 + - speaker_embeddings: Dict mapping speaker labels to embedding arrays 63 66 - timings: Dict with timing information for each stage 64 67 - overlaps: List of dicts with "start", "end" keys for overlapping speech 65 68 66 69 Raises: 67 70 DiarizationError: If diarization fails (missing token, access denied, etc.) 68 71 """ 69 - from pyannote.audio import Inference, Model, Pipeline 70 - from pyannote.core import Segment 72 + from pyannote.audio import Inference, Pipeline 71 73 72 74 hf_token = get_hf_token() 73 75 timings: dict[str, float] = {} ··· 75 77 76 78 # Get audio info 77 79 info = sf.info(input_path) 78 - audio_duration = info.duration 79 - logging.info(f"Diarizing: {audio_path.name} ({audio_duration:.1f}s)") 80 + logging.info(f"Diarizing: {audio_path.name} ({info.duration:.1f}s)") 80 81 81 82 # Load pipeline 82 83 logging.info("Loading diarization pipeline...") 83 84 t0 = time.perf_counter() 84 85 try: 85 - pipeline = Pipeline.from_pretrained(PIPELINE_ID, use_auth_token=hf_token) 86 + pipeline = Pipeline.from_pretrained(PIPELINE_ID, token=hf_token) 86 87 except Exception as e: 87 88 if "403" in str(e) or "gated" in str(e).lower(): 88 89 raise DiarizationError( ··· 130 131 timings["diarization"] = time.perf_counter() - t0 131 132 logging.info(f" Diarization: {timings['diarization']:.2f}s") 132 133 133 - # Extract turns and map speaker labels to human-readable names 134 + # Extract turns from exclusive diarization (no overlapping segments) 135 + # and map speaker labels to human-readable names 134 136 raw_turns: list[dict] = [] 135 137 speaker_map: dict[str, str] = {} # SPEAKER_00 -> Speaker 1 136 138 137 - for turn, _, speaker in diarization.itertracks(yield_label=True): 139 + for turn, _, speaker in diarization.exclusive_speaker_diarization.itertracks( 140 + yield_label=True 141 + ): 138 142 if speaker not in speaker_map: 139 143 speaker_map[speaker] = f"Speaker {len(speaker_map) + 1}" 140 144 raw_turns.append( ··· 145 149 } 146 150 ) 147 151 148 - # Extract overlapping speech regions 149 - overlap_timeline = diarization.get_overlap() 152 + # Extract overlapping speech regions from regular diarization 153 + overlap_timeline = diarization.speaker_diarization.get_overlap() 150 154 overlaps: list[dict] = [ 151 155 {"start": float(seg.start), "end": float(seg.end)} for seg in overlap_timeline 152 156 ] ··· 157 161 if not raw_turns: 158 162 logging.info("No speech detected.") 159 163 timings["total"] = timings["pipeline_load"] + timings["diarization"] 160 - return [], np.array([]), timings, overlaps 164 + return [], {}, timings, overlaps 161 165 162 166 n_speakers = len(speaker_map) 163 167 logging.info(f" Found {len(raw_turns)} turns, {n_speakers} speakers") ··· 173 177 if not turns: 174 178 logging.info("No turns remaining after filtering.") 175 179 timings["total"] = timings["pipeline_load"] + timings["diarization"] 176 - return [], np.array([]), timings, overlaps 180 + return [], {}, timings, overlaps 177 181 178 - # Load embedding model 179 - logging.info("Loading embedding model...") 180 - t0 = time.perf_counter() 181 - emb_model = Model.from_pretrained(EMB_MODEL_ID, use_auth_token=hf_token) 182 - if torch.cuda.is_available(): 183 - emb_model.to(torch.device("cuda")) 184 - emb_infer = Inference(emb_model, window="whole") 185 - timings["emb_model_load"] = time.perf_counter() - t0 182 + # Extract per-speaker embeddings from pipeline output 183 + # Embeddings are indexed by speaker order matching exclusive_speaker_diarization.labels() 184 + speaker_embeddings: dict[str, np.ndarray] = {} 185 + if diarization.speaker_embeddings is not None: 186 + labels = diarization.exclusive_speaker_diarization.labels() 187 + for idx, raw_label in enumerate(labels): 188 + if raw_label in speaker_map and idx < len(diarization.speaker_embeddings): 189 + emb = diarization.speaker_embeddings[idx] 190 + # Normalize for cosine similarity 191 + emb = emb / (np.linalg.norm(emb) + 1e-8) 192 + speaker_embeddings[speaker_map[raw_label]] = emb.astype(np.float32) 193 + logging.info(f" Extracted embeddings for {len(speaker_embeddings)} speakers") 186 194 187 - # Extract embeddings for each turn 188 - logging.info(f"Extracting embeddings for {len(turns)} turns...") 189 - t0 = time.perf_counter() 190 - emb_list: list[np.ndarray] = [] 191 - 192 - for i, turn in enumerate(turns): 193 - try: 194 - emb_vec = emb_infer.crop(input_path, Segment(turn["start"], turn["end"])) 195 - if emb_vec.ndim > 1: 196 - emb_vec = emb_vec.mean(axis=0) 197 - emb_list.append(emb_vec.astype(np.float32)) 198 - except Exception as e: 199 - logging.warning(f"Failed to extract embedding for turn {i}: {e}") 200 - emb_list.append(np.zeros(256, dtype=np.float32)) 201 - 202 - if (i + 1) % 20 == 0: 203 - logging.info(f" Embedded {i + 1}/{len(turns)} turns") 195 + timings["total"] = timings["pipeline_load"] + timings["diarization"] 204 196 205 - embeddings = np.stack(emb_list, axis=0) 206 - timings["embedding"] = time.perf_counter() - t0 207 - logging.info(f" Embeddings: {timings['embedding']:.2f}s ({embeddings.shape})") 208 - 209 - timings["total"] = ( 210 - timings["pipeline_load"] 211 - + timings["diarization"] 212 - + timings["emb_model_load"] 213 - + timings["embedding"] 214 - ) 215 - 216 - return turns, embeddings, timings, overlaps 197 + return turns, speaker_embeddings, timings, overlaps 217 198 218 199 219 200 def save_speaker_embeddings( 220 201 output_dir: Path, 221 - turns: list[dict], 222 - embeddings: np.ndarray, 202 + speaker_embeddings: dict[str, np.ndarray], 223 203 ) -> list[Path]: 224 - """Save per-speaker mean embeddings to NPZ files. 204 + """Save per-speaker embeddings to NPZ files. 225 205 226 206 Args: 227 207 output_dir: Directory to save embeddings (e.g., segment_dir/audio_stem/) 228 - turns: List of turn dicts with "speaker" key 229 - embeddings: Array of shape (num_turns, 256) 208 + speaker_embeddings: Dict mapping speaker labels to embedding arrays 230 209 231 210 Returns: 232 211 List of paths to saved NPZ files 233 212 """ 213 + if not speaker_embeddings: 214 + return [] 215 + 234 216 output_dir.mkdir(parents=True, exist_ok=True) 235 217 236 - # Group embeddings by speaker 237 - speaker_embeddings: dict[str, list[np.ndarray]] = {} 238 - for turn, emb in zip(turns, embeddings): 239 - speaker = turn["speaker"] 240 - if speaker not in speaker_embeddings: 241 - speaker_embeddings[speaker] = [] 242 - speaker_embeddings[speaker].append(emb) 243 - 244 - # Save mean embedding per speaker 245 218 saved_paths: list[Path] = [] 246 - for speaker, embs in speaker_embeddings.items(): 247 - mean_emb = np.mean(embs, axis=0).astype(np.float32) 248 - # Normalize for cosine similarity 249 - mean_emb = mean_emb / (np.linalg.norm(mean_emb) + 1e-8) 250 - 219 + for speaker, embedding in speaker_embeddings.items(): 251 220 emb_path = output_dir / f"{speaker}.npz" 252 - np.savez_compressed(emb_path, embedding=mean_emb) 221 + np.savez_compressed(emb_path, embedding=embedding) 253 222 saved_paths.append(emb_path) 254 223 logging.info(f" Saved embedding: {emb_path}") 255 224
+8 -8
observe/transcribe.py
··· 250 250 251 251 try: 252 252 # Run diarization 253 - diarization_turns, embeddings, timings, overlaps = diarize(audio_path) 253 + diarization_turns, speaker_embeddings, timings, overlaps = diarize( 254 + audio_path 255 + ) 254 256 255 257 if not diarization_turns: 256 258 logging.info(f"No speech detected in {raw_path}") 257 259 return { 258 260 "turns": [], 259 - "embeddings": np.array([]), 261 + "speaker_embeddings": {}, 260 262 "speakers": [], 261 263 "diarization": { 262 264 "turns": [], ··· 335 337 336 338 return { 337 339 "turns": processed, 338 - "embeddings": embeddings, 340 + "speaker_embeddings": speaker_embeddings, 339 341 "speakers": speakers, 340 342 "diarization": { 341 343 "turns": diarization_turns, ··· 499 501 raise SystemExit(1) 500 502 501 503 turns = result["turns"] 502 - embeddings = result["embeddings"] 504 + speaker_embeddings = result["speaker_embeddings"] 503 505 speakers = result["speakers"] 504 506 diarization_data = result["diarization"] 505 507 ··· 519 521 final_path = self._move_to_segment(raw_path) 520 522 521 523 # Save speaker embeddings 522 - if embeddings.size > 0: 524 + if speaker_embeddings: 523 525 embeddings_dir = self._get_embeddings_dir(raw_path) 524 - # Need to reconstruct turn list with speaker info for embedding save 525 - turn_info = [{"speaker": t["speaker"]} for t in turns] 526 - save_speaker_embeddings(embeddings_dir, turn_info, embeddings) 526 + save_speaker_embeddings(embeddings_dir, speaker_embeddings) 527 527 528 528 # Emit completion event 529 529 journal_path = Path(os.getenv("JOURNAL_PATH", ""))