this repo has no description
4
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 143 lines 5.0 kB view raw
1from functools import reduce 2import json 3import glob 4from multiprocessing import Pool 5import os 6import random 7 8import numpy as np 9import torch 10from torch.utils.data import Dataset 11 12 13from scripts.utils import get_logger 14 15logger = get_logger(__name__) 16 17 18def get_frequencies(array: np.ndarray, user_count: int) -> np.ndarray: 19 frequencies = np.zeros((user_count, 2)) 20 for source_id, target_id in array: 21 frequencies[source_id][0] += 1 22 frequencies[target_id][1] += 1 23 return frequencies 24 25 26class FollowDataset(Dataset): 27 def __init__( 28 self, 29 dataset_path: str, 30 split: str, 31 negative_sample_chance: float, 32 ): 33 with open(os.path.join(dataset_path, "metadata.json"), "r") as in_file: 34 metadata = json.load(in_file) 35 36 with open(os.path.join(dataset_path, metadata["did_id_map"]), "r") as in_file: 37 self.did_id_map: dict[str, int] = json.load(in_file) 38 39 if split not in metadata["splits"]: 40 raise ValueError(f"Could not find split in metadata file: {split}") 41 split_files = metadata["splits"][split] 42 43 self.numpy_files: list[tuple[str, str, int, int]] = [] 44 for file in split_files: 45 file_idx = file["filename"].split("_")[2].split(".")[0] 46 self.numpy_files.append((file["filename"], file["dtype"], file_idx, file["shape"][0])) # type: ignore 47 48 self.numpy_files.sort(key=lambda x: x[1]) 49 self.dataframes: list[np.ndarray] = [] 50 for filename, dtype, _, row_count in self.numpy_files: 51 self.dataframes.append( 52 np.memmap( 53 os.path.join(dataset_path, filename), 54 dtype=dtype, 55 mode="r", 56 shape=(row_count, 2), 57 ) 58 ) 59 60 logger.info("Calculating node frequency...") 61 with Pool(7) as p: 62 self.cumulative_freq = reduce( 63 np.add, 64 p.starmap( 65 get_frequencies, 66 [ 67 (dataframe, len(self.did_id_map)) 68 for dataframe in self.dataframes 69 ], 70 ), 71 ) 72 self.cumulative_freq[0] = self.cumulative_freq[0].cumsum() 73 self.cumulative_freq[1] = self.cumulative_freq[1].cumsum() 74 75 self.negative_sample_chance = negative_sample_chance 76 77 def __len__(self) -> int: 78 return sum((row_count for _, _, _, row_count in self.numpy_files)) 79 80 def num_users(self) -> int: 81 return len(self.did_id_map) 82 83 def _idx_to_row(self, idx: int) -> tuple[int, int]: 84 if idx < 0 or idx >= len(self): 85 raise IndexError(f"Invalid index: {idx}") 86 87 # Find which file contains index 88 row_index_total = 0 89 effective_idx = idx 90 i = 0 91 for i, (_, _, _, row_count) in enumerate(self.numpy_files): 92 row_index_total += row_count 93 if idx < row_index_total: 94 break 95 effective_idx -= row_count 96 97 row = self.dataframes[i][effective_idx] 98 return (row[0].item(), row[1].item()) 99 100 def __getitem__(self, idx: int) -> tuple[tuple[int, int], int]: 101 """ 102 Grab follow connection and corrupt it at defined frequency to another id 103 weighted by that id's prevalence in the dataset 104 """ 105 sample = self._idx_to_row(idx) 106 return (sample, 1) 107 108 def collate_rows( 109 self, 110 batch: list[tuple[tuple[int, int], int]], 111 ) -> tuple[torch.Tensor, torch.Tensor]: 112 # Corrupt some rows into negative edges 113 corrupted_sources = [] 114 corrupted_targets = [] 115 for i in range(len(batch)): 116 if random.random() < self.negative_sample_chance: 117 if random.random() < 0.5: 118 corrupted_sources.append(i) 119 else: 120 corrupted_targets.append(i) 121 122 new_sources = self.cumulative_freq[:, 0].searchsorted( 123 np.random.sample(len(corrupted_sources)) * self.cumulative_freq.shape[0] 124 ) 125 new_targets = self.cumulative_freq[:, 1].searchsorted( 126 np.random.sample(len(corrupted_targets)) * self.cumulative_freq.shape[0] 127 ) 128 new_sources[new_sources >= self.cumulative_freq.shape[0]] = ( 129 self.cumulative_freq.shape[0] - 1 130 ) 131 new_targets[new_targets >= self.cumulative_freq.shape[0]] = ( 132 self.cumulative_freq.shape[0] - 1 133 ) 134 135 for i, idx in enumerate(corrupted_sources): 136 batch[idx] = ((new_sources[i], batch[idx][0][1]), -1) 137 138 for i, idx in enumerate(corrupted_targets): 139 batch[idx] = ((batch[idx][0][0], new_targets[i]), -1) 140 141 follows = torch.concat(tuple(torch.IntTensor([follow]) for follow, _ in batch)) 142 labels = torch.concat(tuple(torch.IntTensor([label]) for _, label in batch)) 143 return (follows, labels)