A loose federation of distributed, typed datasets
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

feat: add HuggingFace Datasets-style load_dataset() API

Implements a familiar load_dataset() interface inspired by HuggingFace
Datasets, adapted for atdata's typed WebDataset approach:

- load_dataset() function with path resolution, split detection, and
support for brace notation, globs, and explicit data_files mapping
- DatasetDict class for multi-split dataset containers
- Automatic split detection from filenames (train/test/validation)
- streaming parameter for explicit streaming mode indication

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

+2392
+182
.reference/huggingface-datasets/architecture.md
··· 1 + # HuggingFace Datasets - Architecture Overview 2 + 3 + Source: https://huggingface.co/docs/datasets/en/about_dataset_load 4 + 5 + ## How load_dataset Works (ELI5) 6 + 7 + A dataset is a directory that contains: 8 + - Data files in generic formats (JSON, CSV, Parquet, text, etc.) 9 + - A dataset card (`README.md`) with documentation and YAML configuration 10 + 11 + `load_dataset()` fetches the requested dataset locally or from the Hugging Face Hub. 12 + 13 + ### Automatic Format Detection 14 + 15 + If the dataset only contains data files, `load_dataset()` automatically infers how to load them based on file extensions. Under the hood, it uses an appropriate `DatasetBuilder`: 16 + 17 + | Format | Builder Class | 18 + |--------|---------------| 19 + | `.txt` | `datasets.packaged_modules.text.Text` | 20 + | `.csv`, `.tsv` | `datasets.packaged_modules.csv.Csv` | 21 + | `.json`, `.jsonl` | `datasets.packaged_modules.json.Json` | 22 + | `.parquet` | `datasets.packaged_modules.parquet.Parquet` | 23 + | `.arrow` | `datasets.packaged_modules.arrow.Arrow` | 24 + | SQL | `datasets.packaged_modules.sql.Sql` | 25 + | Image folders | `datasets.packaged_modules.imagefolder.ImageFolder` | 26 + | Audio folders | `datasets.packaged_modules.audiofolder.AudioFolder` | 27 + | WebDataset TAR | `datasets.packaged_modules.webdataset.WebDataset` | 28 + 29 + --- 30 + 31 + ## Building a Dataset 32 + 33 + Two main classes are responsible for building a dataset: 34 + 35 + ### BuilderConfig 36 + 37 + Configuration class containing dataset attributes: 38 + 39 + | Attribute | Description | 40 + |-----------|-------------| 41 + | `name` | Short name of the dataset | 42 + | `version` | Dataset version identifier | 43 + | `data_dir` | Path to local folder containing data files | 44 + | `data_files` | Paths to local data files | 45 + | `description` | Description of the dataset | 46 + 47 + Custom attributes (like class labels) can be added by subclassing `BuilderConfig`. 48 + 49 + Configuration can be populated: 50 + 1. Via predefined `BuilderConfig` instances in `DatasetBuilder.BUILDER_CONFIGS` 51 + 2. Via keyword arguments to `load_dataset()` (overrides predefined) 52 + 53 + ### DatasetBuilder 54 + 55 + Accesses `BuilderConfig` attributes to build the actual dataset. 56 + 57 + Three main methods: 58 + 59 + #### 1. `_info()` - Define dataset attributes 60 + 61 + - Defines dataset attributes returned by `dataset.info` 62 + - Specifies `Features` (schema with column names and types) 63 + 64 + #### 2. `_split_generator()` - Organize data files 65 + 66 + - Downloads or retrieves data files 67 + - Uses `DownloadManager` for downloading/extracting 68 + - Organizes files into splits via `SplitGenerator` 69 + - Returns keyword arguments for `_generate_examples` 70 + 71 + #### 3. `_generate_examples()` - Parse and yield examples 72 + 73 + - Reads and parses data files for each split 74 + - Yields examples as Python dicts matching the schema 75 + - Uses Python generator (memory efficient) 76 + - Examples buffered in `ArrowWriter` before writing to disk 77 + 78 + --- 79 + 80 + ## Data Flow 81 + 82 + ``` 83 + load_dataset("name", split="train") 84 + 85 + 86 + ┌───────────────────────────────────────┐ 87 + │ 1. Resolve dataset path │ 88 + │ - Hub repo? Local dir? Builder? │ 89 + └───────────────────────────────────────┘ 90 + 91 + 92 + ┌───────────────────────────────────────┐ 93 + │ 2. Load DatasetBuilder │ 94 + │ - Auto-detect format │ 95 + │ - Apply BuilderConfig │ 96 + └───────────────────────────────────────┘ 97 + 98 + 99 + ┌───────────────────────────────────────┐ 100 + │ 3. Download & prepare (if not cached) │ 101 + │ - _split_generator() downloads │ 102 + │ - _generate_examples() yields │ 103 + │ - Arrow tables cached to disk │ 104 + └───────────────────────────────────────┘ 105 + 106 + 107 + ┌───────────────────────────────────────┐ 108 + │ 4. Load from cache │ 109 + │ - Memory-map Arrow files │ 110 + │ - Return Dataset/DatasetDict │ 111 + └───────────────────────────────────────┘ 112 + ``` 113 + 114 + --- 115 + 116 + ## Caching 117 + 118 + - Datasets are cached as Arrow tables in `~/.cache/huggingface/datasets` 119 + - Subsequent loads use the cache (fast!) 120 + - Cache can be disabled or customized via `cache_dir` parameter 121 + - `download_mode` controls cache behavior: 122 + - `REUSE_DATASET_IF_EXISTS` (default): Use cache if available 123 + - `FORCE_REDOWNLOAD`: Re-download everything 124 + - `REUSE_CACHE_IF_EXISTS`: Reuse cache for downloads but regenerate dataset 125 + 126 + --- 127 + 128 + ## Streaming Mode 129 + 130 + With `streaming=True`: 131 + - No downloading or caching 132 + - Data streamed on-the-fly during iteration 133 + - Returns `IterableDataset` instead of `Dataset` 134 + - Best for large datasets 135 + 136 + ```python 137 + ds = load_dataset("large_dataset", split="train", streaming=True) 138 + for example in ds: 139 + process(example) # Examples fetched as needed 140 + ``` 141 + 142 + --- 143 + 144 + ## Integrity Verification 145 + 146 + `load_dataset()` verifies downloaded data: 147 + - Number of splits in generated `DatasetDict` 148 + - Number of samples in each split 149 + - List of downloaded files 150 + - SHA256 checksums (disabled by default) 151 + 152 + Disable with `verification_mode="no_checks"` if needed. 153 + 154 + --- 155 + 156 + ## Key Design Patterns for atdata Integration 157 + 158 + ### Pattern 1: Path Resolution 159 + HF Datasets supports multiple path types: 160 + - Hub repository: `"username/dataset"` 161 + - Local directory: `"./path/to/data"` 162 + - Builder name: `"parquet"` with `data_files` 163 + 164 + ### Pattern 2: Split Handling 165 + - `split=None` → `DatasetDict` with all splits 166 + - `split="train"` → Single `Dataset` 167 + - Split string algebra: `"train+test"`, `"train[:10%]"` 168 + 169 + ### Pattern 3: Lazy Loading 170 + - Streaming mode for large datasets 171 + - Generator-based iteration 172 + - Buffer-based shuffling 173 + 174 + ### Pattern 4: Format Abstraction 175 + - Single API for multiple formats 176 + - Auto-detection based on file extensions 177 + - Builder-specific configuration via kwargs 178 + 179 + ### Pattern 5: Type System 180 + - `Features` schema defines column types 181 + - Automatic type inference with override capability 182 + - Special types for media (Audio, Image, Video)
+308
.reference/huggingface-datasets/loading-guide.md
··· 1 + # HuggingFace Datasets - Loading Guide 2 + 3 + Source: https://huggingface.co/docs/datasets/en/loading 4 + 5 + ## Overview 6 + 7 + Data can be loaded from multiple sources: 8 + - The Hugging Face Hub 9 + - Local files (CSV, JSON, Parquet, etc.) 10 + - In-memory data (dicts, lists, generators, DataFrames) 11 + - SQL databases 12 + - Remote URLs 13 + 14 + --- 15 + 16 + ## Loading from Hugging Face Hub 17 + 18 + ```python 19 + from datasets import load_dataset 20 + 21 + # Basic usage 22 + dataset = load_dataset("lhoestq/demo1") 23 + 24 + # Specific version (git tag, branch, or commit) 25 + dataset = load_dataset("lhoestq/custom_squad", revision="main") 26 + 27 + # Map data files to splits 28 + data_files = {"train": "train.csv", "test": "test.csv"} 29 + dataset = load_dataset("namespace/your_dataset_name", data_files=data_files) 30 + 31 + # Load subset of files with patterns 32 + c4_subset = load_dataset("allenai/c4", data_files="en/c4-train.0000*-of-01024.json.gz") 33 + 34 + # Load from subdirectory 35 + c4_subset = load_dataset("allenai/c4", data_dir="en") 36 + ``` 37 + 38 + --- 39 + 40 + ## Loading Local Files 41 + 42 + ### CSV 43 + 44 + ```python 45 + from datasets import load_dataset 46 + 47 + # Single file 48 + dataset = load_dataset("csv", data_files="my_file.csv") 49 + 50 + # Multiple files 51 + dataset = load_dataset("csv", data_files=["file1.csv", "file2.csv"]) 52 + 53 + # With split mapping 54 + dataset = load_dataset("csv", data_files={"train": "train.csv", "test": "test.csv"}) 55 + ``` 56 + 57 + ### JSON 58 + 59 + ```python 60 + # Standard JSON lines format (one object per line) 61 + dataset = load_dataset("json", data_files="my_file.json") 62 + 63 + # Nested JSON with field parameter 64 + # File: {"version": "0.1.0", "data": [{"a": 1}, {"a": 2}]} 65 + dataset = load_dataset("json", data_files="my_file.json", field="data") 66 + 67 + # Remote JSON 68 + base_url = "https://example.com/data/" 69 + dataset = load_dataset("json", data_files={ 70 + "train": base_url + "train.json", 71 + "validation": base_url + "dev.json" 72 + }, field="data") 73 + ``` 74 + 75 + ### Parquet 76 + 77 + ```python 78 + # Local 79 + dataset = load_dataset("parquet", data_files={'train': 'train.parquet', 'test': 'test.parquet'}) 80 + 81 + # Remote 82 + base_url = "https://huggingface.co/datasets/wikimedia/wikipedia/resolve/main/20231101.ab/" 83 + data_files = {"train": base_url + "train-00000-of-00001.parquet"} 84 + wiki = load_dataset("parquet", data_files=data_files, split="train") 85 + ``` 86 + 87 + ### Arrow 88 + 89 + ```python 90 + # Via load_dataset 91 + dataset = load_dataset("arrow", data_files={'train': 'train.arrow'}) 92 + 93 + # Direct memory mapping (faster, no cache) 94 + from datasets import Dataset 95 + dataset = Dataset.from_file("data.arrow") 96 + ``` 97 + 98 + ### Text 99 + 100 + ```python 101 + dataset = load_dataset("text", data_files="my_file.txt") 102 + ``` 103 + 104 + ### WebDataset (TAR archives) 105 + 106 + ```python 107 + # Best used with streaming for large datasets 108 + path = "path/to/train/*.tar" 109 + dataset = load_dataset("webdataset", data_files={"train": path}, split="train", streaming=True) 110 + 111 + # Remote WebDataset 112 + base_url = "https://example.com/dataset/" 113 + urls = [base_url + f"shard-{i:06d}.tar" for i in range(4)] 114 + dataset = load_dataset("webdataset", data_files={"train": urls}, split="train", streaming=True) 115 + ``` 116 + 117 + ### HDF5 118 + 119 + ```python 120 + dataset = load_dataset("hdf5", data_files="data.h5") 121 + ``` 122 + 123 + ### SQL Databases 124 + 125 + ```python 126 + from datasets import Dataset 127 + 128 + # Load entire table 129 + dataset = Dataset.from_sql("data_table_name", con="sqlite:///sqlite_file.db") 130 + 131 + # Load from query 132 + dataset = Dataset.from_sql( 133 + "SELECT text FROM table WHERE length(text) > 100 LIMIT 10", 134 + con="sqlite:///sqlite_file.db" 135 + ) 136 + ``` 137 + 138 + --- 139 + 140 + ## Loading In-Memory Data 141 + 142 + ### Python Dictionary 143 + 144 + ```python 145 + from datasets import Dataset 146 + 147 + my_dict = {"a": [1, 2, 3], "b": ["x", "y", "z"]} 148 + dataset = Dataset.from_dict(my_dict) 149 + ``` 150 + 151 + ### Python List of Dictionaries 152 + 153 + ```python 154 + my_list = [{"a": 1, "b": "x"}, {"a": 2, "b": "y"}, {"a": 3, "b": "z"}] 155 + dataset = Dataset.from_list(my_list) 156 + ``` 157 + 158 + ### Python Generator 159 + 160 + ```python 161 + from datasets import Dataset, IterableDataset 162 + 163 + # For data larger than memory 164 + def my_gen(): 165 + for i in range(1, 1000000): 166 + yield {"a": i, "text": f"example {i}"} 167 + 168 + dataset = Dataset.from_generator(my_gen) 169 + 170 + # Sharded generator for distributed processing 171 + def gen(shards): 172 + for shard in shards: 173 + with open(shard) as f: 174 + for line in f: 175 + yield {"line": line} 176 + 177 + shards = [f"data{i}.txt" for i in range(32)] 178 + ds = IterableDataset.from_generator(gen, gen_kwargs={"shards": shards}) 179 + ds = ds.shuffle(seed=42, buffer_size=10_000) 180 + ``` 181 + 182 + ### Pandas DataFrame 183 + 184 + ```python 185 + import pandas as pd 186 + from datasets import Dataset 187 + 188 + df = pd.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"]}) 189 + dataset = Dataset.from_pandas(df) 190 + ``` 191 + 192 + --- 193 + 194 + ## Multiprocessing 195 + 196 + Speed up loading with multiple processes: 197 + 198 + ```python 199 + from datasets import load_dataset 200 + 201 + # Each process handles a subset of shards 202 + imagenet = load_dataset("timm/imagenet-1k-wds", num_proc=8) 203 + ``` 204 + 205 + --- 206 + 207 + ## Slicing Splits 208 + 209 + ### String API 210 + 211 + ```python 212 + import datasets 213 + 214 + # Concatenate splits 215 + train_test_ds = datasets.load_dataset("dataset_name", split="train+test") 216 + 217 + # Select rows by index 218 + train_10_20_ds = datasets.load_dataset("dataset_name", split="train[10:20]") 219 + 220 + # Select by percentage 221 + train_10pct_ds = datasets.load_dataset("dataset_name", split="train[:10%]") 222 + 223 + # Combine percentage slices 224 + train_10_80pct_ds = datasets.load_dataset("dataset_name", split="train[:10%]+train[-80%:]") 225 + 226 + # Cross-validation splits 227 + val_ds = datasets.load_dataset("dataset_name", 228 + split=[f"train[{k}%:{k+10}%]" for k in range(0, 100, 10)]) 229 + train_ds = datasets.load_dataset("dataset_name", 230 + split=[f"train[:{k}%]+train[{k+10}%:]" for k in range(0, 100, 10)]) 231 + ``` 232 + 233 + ### ReadInstruction API 234 + 235 + ```python 236 + import datasets 237 + 238 + # Concatenate 239 + ri = datasets.ReadInstruction("train") + datasets.ReadInstruction("test") 240 + train_test_ds = datasets.load_dataset("dataset_name", split=ri) 241 + 242 + # Percentage with rounding control 243 + ri = datasets.ReadInstruction("train", from_=50, to=52, unit="%", rounding="pct1_dropremainder") 244 + train_50_52_ds = datasets.load_dataset("dataset_name", split=ri) 245 + ``` 246 + 247 + --- 248 + 249 + ## Specifying Features 250 + 251 + Override auto-inferred features: 252 + 253 + ```python 254 + from datasets import load_dataset, Features, Value, ClassLabel 255 + 256 + # Define custom features 257 + class_names = ["sadness", "joy", "love", "anger", "fear", "surprise"] 258 + emotion_features = Features({ 259 + 'text': Value('string'), 260 + 'label': ClassLabel(names=class_names) 261 + }) 262 + 263 + # Apply when loading 264 + dataset = load_dataset('csv', data_files='data.csv', features=emotion_features) 265 + 266 + # Verify 267 + print(dataset['train'].features) 268 + # {'text': Value('string'), 'label': ClassLabel(names=['sadness', 'joy', ...])} 269 + ``` 270 + 271 + --- 272 + 273 + ## Offline Mode 274 + 275 + Use cached datasets without internet: 276 + 277 + ```bash 278 + # Set environment variable 279 + export HF_HUB_OFFLINE=1 280 + ``` 281 + 282 + ```python 283 + # Will use cache only 284 + dataset = load_dataset("dataset_name") 285 + ``` 286 + 287 + --- 288 + 289 + ## Image/Audio/Video Datasets 290 + 291 + ### ImageFolder 292 + 293 + ```python 294 + # Directory structure: images/{class_name}/{image_file} 295 + dataset = load_dataset("imagefolder", data_dir="path/to/images", split="train") 296 + ``` 297 + 298 + ### AudioFolder 299 + 300 + ```python 301 + dataset = load_dataset("audiofolder", data_dir="path/to/audio", split="train") 302 + ``` 303 + 304 + ### VideoFolder 305 + 306 + ```python 307 + dataset = load_dataset("videofolder", data_dir="path/to/videos", split="train") 308 + ```
+233
.reference/huggingface-datasets/loading-methods.md
··· 1 + # HuggingFace Datasets - Loading Methods API Reference 2 + 3 + Source: https://huggingface.co/docs/datasets/en/package_reference/loading_methods 4 + 5 + ## datasets.load_dataset 6 + 7 + Load a dataset from the Hugging Face Hub, or a local dataset. 8 + 9 + ```python 10 + load_dataset( 11 + path: str, 12 + name: Optional[str] = None, 13 + data_dir: Optional[str] = None, 14 + data_files: Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]], None] = None, 15 + split: Union[str, Split, list[str], list[Split], None] = None, 16 + cache_dir: Optional[str] = None, 17 + features: Optional[Features] = None, 18 + download_config: Optional[DownloadConfig] = None, 19 + download_mode: Union[DownloadMode, str, None] = None, 20 + verification_mode: Union[VerificationMode, str, None] = None, 21 + keep_in_memory: Optional[bool] = None, 22 + save_infos: bool = False, 23 + revision: Union[Version, str, None] = None, 24 + token: Union[bool, str, None] = None, 25 + streaming: bool = False, 26 + num_proc: Optional[int] = None, 27 + storage_options: Optional[dict] = None, 28 + **config_kwargs, 29 + ) 30 + ``` 31 + 32 + ### What it does under the hood 33 + 34 + 1. **Load a dataset builder:** 35 + - Find the most common data format in the dataset and pick its associated builder (JSON, CSV, Parquet, Webdataset, ImageFolder, AudioFolder, etc.) 36 + - Find which file goes into which split (e.g. train/test) based on file and directory names or on the YAML configuration 37 + - Can specify `data_files` manually, and which dataset builder to use (e.g. "parquet") 38 + 39 + 2. **Run the dataset builder:** 40 + - Download the data files from the dataset if they are not already available locally or cached 41 + - Process and cache the dataset in typed Arrow tables 42 + - In streaming mode: don't download or cache anything, dataset is lazily loaded 43 + 44 + 3. **Return a dataset built from the requested splits** 45 + 46 + ### Parameters 47 + 48 + | Parameter | Type | Description | 49 + |-----------|------|-------------| 50 + | `path` | `str` | Path or name of the dataset. Can be: Hub repo (`'username/dataset_name'`), local directory (`'./path/to/data'`), or builder name with `data_files`/`data_dir` (`'parquet'`) | 51 + | `name` | `str`, optional | Dataset configuration name | 52 + | `data_dir` | `str`, optional | Directory containing the data files | 53 + | `data_files` | `str` or `Sequence` or `Mapping`, optional | Path(s) to source data file(s) | 54 + | `split` | `str` or `Split`, optional | Which split to load. If `None`, returns `DatasetDict` with all splits | 55 + | `cache_dir` | `str`, optional | Directory to read/write data. Default: `~/.cache/huggingface/datasets` | 56 + | `features` | `Features`, optional | Set the features type to use | 57 + | `download_mode` | `DownloadMode`, optional | Download/generate mode. Default: `REUSE_DATASET_IF_EXISTS` | 58 + | `verification_mode` | `VerificationMode`, optional | Checks to run on downloaded data. Default: `BASIC_CHECKS` | 59 + | `keep_in_memory` | `bool`, optional | Whether to copy the dataset in-memory | 60 + | `revision` | `str`, optional | Version (git tag/commit/branch) of the dataset to load | 61 + | `token` | `str` or `bool`, optional | Bearer token for remote files on the Hub | 62 + | `streaming` | `bool` | If `True`, returns `IterableDataset` without downloading. Default: `False` | 63 + | `num_proc` | `int`, optional | Number of processes for downloading and generating | 64 + | `storage_options` | `dict`, optional | Key/value pairs for file-system backend | 65 + 66 + ### Returns 67 + 68 + - If `split` is not `None`: `Dataset` (or `IterableDataset` if streaming) 69 + - If `split` is `None`: `DatasetDict` (or `IterableDatasetDict` if streaming) 70 + 71 + ### Examples 72 + 73 + ```python 74 + from datasets import load_dataset 75 + 76 + # Load from Hugging Face Hub 77 + ds = load_dataset('cornell-movie-review-data/rotten_tomatoes', split='train') 78 + 79 + # Load a subset/configuration 80 + ds = load_dataset('nyu-mll/glue', 'sst2', split='train') 81 + 82 + # Manual mapping of data files to splits 83 + data_files = {'train': 'train.csv', 'test': 'test.csv'} 84 + ds = load_dataset('namespace/your_dataset_name', data_files=data_files) 85 + 86 + # Load local CSV file 87 + ds = load_dataset('csv', data_files='path/to/local/my_dataset.csv') 88 + 89 + # Load local JSON file 90 + ds = load_dataset('json', data_files='path/to/local/my_dataset.json') 91 + 92 + # Streaming mode (no download) 93 + ds = load_dataset('cornell-movie-review-data/rotten_tomatoes', split='train', streaming=True) 94 + 95 + # ImageFolder 96 + ds = load_dataset('imagefolder', data_dir='/path/to/images', split='train') 97 + 98 + # WebDataset 99 + ds = load_dataset('webdataset', data_files='path/to/train/*.tar', split='train', streaming=True) 100 + ``` 101 + 102 + --- 103 + 104 + ## datasets.load_from_disk 105 + 106 + Loads a dataset that was previously saved using `save_to_disk()`. 107 + 108 + ```python 109 + load_from_disk( 110 + dataset_path: str, 111 + keep_in_memory: Optional[bool] = None, 112 + storage_options: Optional[dict] = None, 113 + ) 114 + ``` 115 + 116 + ### Parameters 117 + 118 + | Parameter | Type | Description | 119 + |-----------|------|-------------| 120 + | `dataset_path` | `path-like` | Path or remote URI (e.g. `"s3://my-bucket/dataset/train"`) | 121 + | `keep_in_memory` | `bool`, optional | Whether to copy the dataset in-memory | 122 + | `storage_options` | `dict`, optional | Key/value pairs for file-system backend | 123 + 124 + ### Returns 125 + 126 + - `Dataset` or `DatasetDict` 127 + 128 + ### Example 129 + 130 + ```python 131 + from datasets import load_from_disk 132 + 133 + ds = load_from_disk('path/to/dataset/directory') 134 + ``` 135 + 136 + --- 137 + 138 + ## datasets.load_dataset_builder 139 + 140 + Load a dataset builder for inspection or streaming without full download. 141 + 142 + ```python 143 + load_dataset_builder( 144 + path: str, 145 + name: Optional[str] = None, 146 + data_dir: Optional[str] = None, 147 + data_files: Optional[...] = None, 148 + cache_dir: Optional[str] = None, 149 + features: Optional[Features] = None, 150 + download_config: Optional[DownloadConfig] = None, 151 + download_mode: Optional[...] = None, 152 + revision: Optional[...] = None, 153 + token: Optional[...] = None, 154 + storage_options: Optional[dict] = None, 155 + **config_kwargs, 156 + ) 157 + ``` 158 + 159 + ### Example 160 + 161 + ```python 162 + from datasets import load_dataset_builder 163 + 164 + ds_builder = load_dataset_builder('cornell-movie-review-data/rotten_tomatoes') 165 + print(ds_builder.info.features) 166 + # {'label': ClassLabel(names=['neg', 'pos']), 'text': Value('string')} 167 + ``` 168 + 169 + --- 170 + 171 + ## datasets.get_dataset_config_names 172 + 173 + Get the list of available config names for a dataset. 174 + 175 + ```python 176 + from datasets import get_dataset_config_names 177 + 178 + get_dataset_config_names("nyu-mll/glue") 179 + # ['cola', 'sst2', 'mrpc', 'qqp', 'stsb', 'mnli', ...] 180 + ``` 181 + 182 + --- 183 + 184 + ## datasets.get_dataset_split_names 185 + 186 + Get the list of available splits for a dataset. 187 + 188 + ```python 189 + from datasets import get_dataset_split_names 190 + 191 + get_dataset_split_names('cornell-movie-review-data/rotten_tomatoes') 192 + # ['train', 'validation', 'test'] 193 + ``` 194 + 195 + --- 196 + 197 + ## Built-in Dataset Builders 198 + 199 + The following builders are available for loading different file formats: 200 + 201 + | Builder | File Types | 202 + |---------|------------| 203 + | `text` | `.txt` | 204 + | `csv` | `.csv`, `.tsv` | 205 + | `json` | `.json`, `.jsonl` | 206 + | `parquet` | `.parquet` | 207 + | `arrow` | `.arrow` | 208 + | `xml` | `.xml` | 209 + | `sql` | SQL databases | 210 + | `webdataset` | `.tar` (WebDataset format) | 211 + | `imagefolder` | Image directories | 212 + | `audiofolder` | Audio directories | 213 + | `videofolder` | Video directories | 214 + | `pdffolder` | PDF directories | 215 + | `hdf5` | `.h5`, `.hdf5` | 216 + 217 + ### Builder-specific options 218 + 219 + Each builder has its own configuration class with specific options: 220 + 221 + ```python 222 + # CSV with custom separator 223 + load_dataset("csv", data_files="data.csv", sep="\t") 224 + 225 + # JSON with nested field 226 + load_dataset("json", data_files="data.json", field="data") 227 + 228 + # Parquet with column selection 229 + load_dataset("parquet", data_files="data.parquet", columns=["col1", "col2"]) 230 + 231 + # Parquet with filters (pushed down to file) 232 + load_dataset("parquet", data_files="data.parquet", filters=[("col", "==", 0)]) 233 + ```
+457
.reference/huggingface-datasets/main-classes.md
··· 1 + # HuggingFace Datasets - Main Classes API Reference 2 + 3 + Source: https://huggingface.co/docs/datasets/en/package_reference/main_classes 4 + 5 + ## Dataset 6 + 7 + A map-style dataset backed by Apache Arrow table, supporting random access. 8 + 9 + ### Creation Methods 10 + 11 + ```python 12 + # From various file formats 13 + Dataset.from_csv('path/to/file.csv') 14 + Dataset.from_json('path/to/file.json') 15 + Dataset.from_parquet('path/to/file.parquet') 16 + Dataset.from_text('path/to/file.txt') 17 + Dataset.from_sql("SELECT * FROM table", "sqlite:///db.sqlite") 18 + 19 + # From in-memory data 20 + Dataset.from_dict({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}) 21 + Dataset.from_pandas(df) 22 + Dataset.from_generator(generator_function) 23 + 24 + # From Arrow 25 + Dataset.from_file('dataset.arrow') 26 + Dataset.from_buffer(arrow_buffer) 27 + ``` 28 + 29 + ### Key Properties 30 + 31 + ```python 32 + ds.num_rows # Number of rows 33 + ds.num_columns # Number of columns 34 + ds.column_names # List of column names 35 + ds.shape # (num_rows, num_columns) 36 + ds.features # Features schema 37 + ds.info # DatasetInfo object 38 + ds.data # Apache Arrow table 39 + ds.cache_files # Cache file locations 40 + ``` 41 + 42 + ### Data Transformation Methods 43 + 44 + #### `map()` - Apply function to examples 45 + 46 + ```python 47 + # Apply to individual examples 48 + ds = ds.map(lambda x: {'text': x['text'].upper()}) 49 + 50 + # Apply to batches 51 + ds = ds.map(lambda batch: {'text': [t.upper() for t in batch['text']]}, 52 + batched=True, batch_size=32) 53 + 54 + # With indices 55 + ds = ds.map(lambda x, idx: {'id': idx, **x}, with_indices=True) 56 + 57 + # Multiprocessing 58 + ds = ds.map(process_fn, num_proc=4) 59 + 60 + # Remove original columns 61 + ds = ds.map(tokenize_fn, remove_columns=['text']) 62 + 63 + # Add new columns 64 + ds = ds.map(lambda x: {'length': len(x['text'])}) 65 + ``` 66 + 67 + #### `filter()` - Keep examples matching condition 68 + 69 + ```python 70 + # Filter by condition 71 + ds_filtered = ds.filter(lambda x: x['label'] == 1) 72 + 73 + # Batched filtering 74 + ds_filtered = ds.filter(lambda batch: [l == 1 for l in batch['label']], 75 + batched=True) 76 + 77 + # With indices 78 + ds_filtered = ds.filter(lambda x, idx: idx % 2 == 0, with_indices=True) 79 + ``` 80 + 81 + #### `select()` - Select specific rows 82 + 83 + ```python 84 + ds_subset = ds.select(range(100)) # First 100 rows 85 + ds_subset = ds.select([0, 10, 20, 30]) # Specific indices 86 + ``` 87 + 88 + #### `shuffle()` - Randomize row order 89 + 90 + ```python 91 + ds_shuffled = ds.shuffle(seed=42) 92 + # For large datasets, flatten indices for better performance 93 + ds_shuffled = ds.shuffle(seed=42).flatten_indices() 94 + ``` 95 + 96 + #### `sort()` - Sort by column(s) 97 + 98 + ```python 99 + ds_sorted = ds.sort('label') 100 + ds_sorted = ds.sort(['label', 'text'], reverse=[True, False]) 101 + ``` 102 + 103 + #### `train_test_split()` - Split into train/test 104 + 105 + ```python 106 + train_test = ds.train_test_split(test_size=0.2, seed=42) 107 + train_ds = train_test['train'] 108 + test_ds = train_test['test'] 109 + 110 + # Stratified split 111 + train_test = ds.train_test_split(test_size=0.2, stratify_by_column='label') 112 + ``` 113 + 114 + ### Column Operations 115 + 116 + ```python 117 + # Add column 118 + ds = ds.add_column('new_col', new_data) 119 + 120 + # Remove columns 121 + ds = ds.remove_columns(['col1', 'col2']) 122 + 123 + # Rename column 124 + ds = ds.rename_column('old_name', 'new_name') 125 + 126 + # Rename multiple columns 127 + ds = ds.rename_columns({'old1': 'new1', 'old2': 'new2'}) 128 + 129 + # Select columns 130 + ds = ds.select_columns(['col1', 'col2']) 131 + 132 + # Cast column to new type 133 + from datasets import ClassLabel 134 + ds = ds.cast_column('label', ClassLabel(names=['neg', 'pos'])) 135 + 136 + # Cast all features 137 + from datasets import Features, Value 138 + new_features = Features({'text': Value('string'), 'label': Value('int32')}) 139 + ds = ds.cast(new_features) 140 + 141 + # Flatten nested features 142 + ds_flat = ds.flatten() 143 + 144 + # Class encode (convert to ClassLabel) 145 + ds = ds.class_encode_column('label') 146 + ``` 147 + 148 + ### Slicing & Iteration 149 + 150 + ```python 151 + # Index access 152 + item = ds[0] # Single item 153 + batch = ds[0:10] # Slice 154 + batch = ds[[0, 5, 9]] # Multiple indices 155 + 156 + # Iteration 157 + for example in ds: 158 + pass 159 + 160 + # Batched iteration 161 + for batch in ds.iter(batch_size=32): 162 + pass 163 + 164 + # Take first n 165 + subset = ds.take(5) 166 + 167 + # Skip first n 168 + subset = ds.skip(10) 169 + 170 + # Shard dataset 171 + shard = ds.shard(num_shards=4, index=0) 172 + ``` 173 + 174 + ### Data Format Control 175 + 176 + ```python 177 + # Set format for ML framework 178 + ds.set_format(type='torch', columns=['input_ids', 'attention_mask']) 179 + ds.set_format(type='numpy') 180 + ds.set_format(type='pandas') 181 + ds.set_format(type='jax') 182 + ds.set_format(type='tensorflow') 183 + 184 + # With context manager 185 + with ds.formatted_as(type='pandas'): 186 + df = ds[:] 187 + 188 + # Reset to default (dict of lists) 189 + ds.reset_format() 190 + 191 + # With on-the-fly transforms 192 + ds = ds.with_transform(tokenize_fn) 193 + ``` 194 + 195 + ### Persistence 196 + 197 + ```python 198 + # Save to disk (Arrow format) 199 + ds.save_to_disk('path/to/dataset') 200 + 201 + # Load from disk 202 + ds = Dataset.load_from_disk('path/to/dataset') 203 + 204 + # Export to formats 205 + ds.to_csv('output.csv') 206 + ds.to_json('output.jsonl') 207 + ds.to_parquet('output.parquet') 208 + df = ds.to_pandas() 209 + d = ds.to_dict() 210 + 211 + # Push to Hub 212 + ds.push_to_hub('username/dataset-name', private=True) 213 + ``` 214 + 215 + --- 216 + 217 + ## DatasetDict 218 + 219 + Dictionary-like container for multiple Dataset splits (train, validation, test, etc.). 220 + 221 + ### Creation 222 + 223 + ```python 224 + from datasets import DatasetDict, load_dataset 225 + 226 + # From dict of datasets 227 + dataset_dict = DatasetDict({ 228 + 'train': train_dataset, 229 + 'validation': val_dataset, 230 + 'test': test_dataset 231 + }) 232 + 233 + # Load from Hub (returns DatasetDict if split not specified) 234 + dataset_dict = load_dataset('dataset_name') 235 + ``` 236 + 237 + ### Access Splits 238 + 239 + ```python 240 + train_ds = dataset_dict['train'] 241 + all_splits = list(dataset_dict.keys()) # ['train', 'validation', 'test'] 242 + 243 + # Iterate 244 + for split_name, dataset in dataset_dict.items(): 245 + print(f"{split_name}: {len(dataset)} examples") 246 + ``` 247 + 248 + ### Properties 249 + 250 + ```python 251 + dataset_dict.num_rows # {'train': N, 'validation': M, ...} 252 + dataset_dict.num_columns # {'train': K, ...} 253 + dataset_dict.column_names # {'train': [...], ...} 254 + dataset_dict.shape # {'train': (N, K), ...} 255 + ``` 256 + 257 + ### Collective Operations 258 + 259 + All Dataset methods can be called on DatasetDict and will be applied to all splits: 260 + 261 + ```python 262 + # Map over all splits 263 + dataset_dict = dataset_dict.map(lambda x: {'text': x['text'].upper()}) 264 + 265 + # Filter all splits 266 + dataset_dict = dataset_dict.filter(lambda x: x['label'] == 1) 267 + 268 + # Remove columns from all splits 269 + dataset_dict = dataset_dict.remove_columns(['col1']) 270 + 271 + # Rename column in all splits 272 + dataset_dict = dataset_dict.rename_column('old', 'new') 273 + 274 + # Sort all splits 275 + dataset_dict = dataset_dict.sort('label') 276 + 277 + # Shuffle all splits 278 + dataset_dict = dataset_dict.shuffle(seed=42) 279 + 280 + # Format all splits 281 + dataset_dict = dataset_dict.set_format(type='torch') 282 + ``` 283 + 284 + ### Persistence 285 + 286 + ```python 287 + # Save all splits 288 + dataset_dict.save_to_disk('path/to/dataset') 289 + 290 + # Load from disk 291 + dataset_dict = DatasetDict.load_from_disk('path/to/dataset') 292 + 293 + # Push to Hub 294 + dataset_dict.push_to_hub('username/dataset-name') 295 + ``` 296 + 297 + --- 298 + 299 + ## IterableDataset 300 + 301 + Iterable dataset for streaming/lazy loading, backed by Python generators. 302 + 303 + ### Creation 304 + 305 + ```python 306 + from datasets import IterableDataset 307 + 308 + # From generator function 309 + def gen(): 310 + for i in range(1000): 311 + yield {'text': f'Example {i}', 'label': i % 2} 312 + 313 + ds = IterableDataset.from_generator(gen) 314 + 315 + # With sharded data 316 + def gen(shards): 317 + for shard in shards: 318 + with open(shard) as f: 319 + for line in f: 320 + yield {'line': line} 321 + 322 + shards = [f'data{i}.txt' for i in range(32)] 323 + ds = IterableDataset.from_generator(gen, gen_kwargs={'shards': shards}) 324 + 325 + # From load_dataset with streaming=True 326 + ds = load_dataset('dataset_name', split='train', streaming=True) 327 + ``` 328 + 329 + ### Iteration 330 + 331 + ```python 332 + # Basic iteration (no random access!) 333 + for example in ds: 334 + process(example) 335 + 336 + # Batched iteration 337 + for batch in ds.iter(batch_size=32): 338 + process_batch(batch) 339 + 340 + # Take n examples 341 + subset = ds.take(100) 342 + 343 + # Skip n examples 344 + subset = ds.skip(10) 345 + ``` 346 + 347 + ### Transformations (applied lazily during iteration) 348 + 349 + ```python 350 + # Map 351 + ds = ds.map(lambda x: {'text': x['text'].upper()}) 352 + 353 + # Filter 354 + ds = ds.filter(lambda x: x['label'] == 1) 355 + 356 + # Shuffle with buffer (approximate) 357 + ds = ds.shuffle(seed=42, buffer_size=1000) 358 + 359 + # Batch into groups 360 + ds = ds.batch(batch_size=32) 361 + 362 + # Remove/select columns 363 + ds = ds.remove_columns(['unwanted_col']) 364 + ds = ds.select_columns(['text', 'label']) 365 + 366 + # Rename column 367 + ds = ds.rename_column('old', 'new') 368 + 369 + # Cast features 370 + ds = ds.cast(new_features) 371 + ``` 372 + 373 + ### Format Control 374 + 375 + ```python 376 + # Set format 377 + ds = ds.with_format('torch') 378 + ds = ds.with_format('numpy') 379 + ``` 380 + 381 + ### Distributed Processing 382 + 383 + ```python 384 + # Shard across workers 385 + ds = ds.shard(num_shards=4, index=0) 386 + 387 + # State management for resumable iteration 388 + state = ds.state_dict() 389 + # ... resume later ... 390 + ds.load_state_dict(state) 391 + ``` 392 + 393 + --- 394 + 395 + ## Features 396 + 397 + Schema definition for dataset structure, specifying column names and types. 398 + 399 + ### Creating Features 400 + 401 + ```python 402 + from datasets import Features, Value, ClassLabel, Sequence, Array2D, Audio, Image 403 + 404 + features = Features({ 405 + 'text': Value('string'), 406 + 'label': ClassLabel(names=['neg', 'pos']), 407 + 'score': Value('float32'), 408 + 'tokens': Sequence(Value('string')), 409 + 'embeddings': Sequence(Value('float32')), 410 + 'audio': Audio(sampling_rate=16000), 411 + 'image': Image(), 412 + }) 413 + ``` 414 + 415 + ### Feature Types 416 + 417 + | Type | Description | 418 + |------|-------------| 419 + | `Value('string')` | String scalar | 420 + | `Value('int32')`, `Value('int64')` | Integer scalars | 421 + | `Value('float32')`, `Value('float64')` | Float scalars | 422 + | `Value('bool')` | Boolean | 423 + | `ClassLabel(names=['a', 'b'])` | Classification labels | 424 + | `Sequence(Value('string'))` | Variable-length sequence | 425 + | `Array2D(shape=(28, 28), dtype='uint8')` | Fixed-shape 2D array | 426 + | `Array3D(shape=(3, 224, 224), dtype='float32')` | Fixed-shape 3D array | 427 + | `Audio(sampling_rate=16000)` | Audio file | 428 + | `Image()` | Image file | 429 + | `Translation(languages=['en', 'fr'])` | Translation pair | 430 + 431 + ### Using Features 432 + 433 + ```python 434 + # Specify features when loading 435 + features = Features({'text': Value('string'), 'label': ClassLabel(names=['neg', 'pos'])}) 436 + ds = load_dataset('csv', data_files='data.csv', features=features) 437 + 438 + # Access features from dataset 439 + print(ds.features) 440 + 441 + # Cast existing dataset 442 + ds = ds.cast(new_features) 443 + ds = ds.cast_column('label', ClassLabel(names=['a', 'b', 'c'])) 444 + ``` 445 + 446 + --- 447 + 448 + ## Key Differences: Dataset vs IterableDataset 449 + 450 + | Feature | Dataset | IterableDataset | 451 + |---------|---------|-----------------| 452 + | **Access** | Random access (`ds[0]`) | Sequential only | 453 + | **Speed** | Fast for batch ops | Better for streaming | 454 + | **Memory** | Arrow memory-mapped | Lazy evaluation | 455 + | **Shuffling** | Full dataset | Approximate (buffer) | 456 + | **Use Case** | Training with epochs | Streaming/large data | 457 + | **len()** | Supported | Not supported |
+7
CHANGELOG.md
··· 11 11 ### Fixed 12 12 13 13 ### Changed 14 + - Add HuggingFace Datasets-style API to atdata (#103) 15 + - Support streaming mode parameter (#108) 16 + - Add split parameter handling (train/test/validation) (#107) 17 + - Implement path/URL resolution and shard discovery (#106) 18 + - Add DatasetDict class for multi-split datasets (#105) 19 + - Implement load_dataset() entry point function (#104) 20 + - Write test suite for _hf_api.py module (#109) 14 21 - Investigate test-bucket directory creation issue (#105) 15 22 - Add remaining Dataset edge case tests (#104) 16 23 - Improve test coverage for edge cases (#103)
issues.db

This is a binary file and will not be displayed.

+5
src/atdata/__init__.py
··· 51 51 lens, 52 52 ) 53 53 54 + from ._hf_api import ( 55 + load_dataset, 56 + DatasetDict, 57 + ) 58 + 54 59 # ATProto integration (lazy import to avoid requiring atproto package) 55 60 from . import atmosphere 56 61
+555
src/atdata/_hf_api.py
··· 1 + """HuggingFace Datasets-style API for atdata. 2 + 3 + This module provides a familiar `load_dataset()` interface inspired by the 4 + HuggingFace Datasets library, adapted for atdata's typed WebDataset approach. 5 + 6 + Key differences from HuggingFace Datasets: 7 + - Requires explicit `sample_type` parameter (typed dataclass) 8 + - Returns atdata.Dataset[ST] instead of HF Dataset 9 + - Built on WebDataset for efficient streaming of large datasets 10 + - No Arrow caching layer (WebDataset handles remote/local transparently) 11 + 12 + Example: 13 + >>> import atdata 14 + >>> from atdata import load_dataset 15 + >>> 16 + >>> @atdata.packable 17 + ... class MyData: 18 + ... text: str 19 + ... label: int 20 + >>> 21 + >>> # Load a single split 22 + >>> ds = load_dataset("path/to/train-{000000..000099}.tar", MyData, split="train") 23 + >>> 24 + >>> # Load all splits (returns DatasetDict) 25 + >>> ds_dict = load_dataset("path/to/{train,test}-*.tar", MyData) 26 + >>> train_ds = ds_dict["train"] 27 + """ 28 + 29 + from __future__ import annotations 30 + 31 + import re 32 + from pathlib import Path 33 + from typing import ( 34 + Any, 35 + Generic, 36 + Iterator, 37 + Mapping, 38 + Type, 39 + TypeVar, 40 + Union, 41 + overload, 42 + ) 43 + 44 + from .dataset import Dataset, PackableSample 45 + 46 + ## 47 + # Type variables 48 + 49 + ST = TypeVar("ST", bound=PackableSample) 50 + 51 + 52 + ## 53 + # DatasetDict - container for multiple splits 54 + 55 + 56 + class DatasetDict(Generic[ST], dict): 57 + """A dictionary of split names to Dataset instances. 58 + 59 + Similar to HuggingFace's DatasetDict, this provides a container for 60 + multiple dataset splits (train, test, validation, etc.) with convenience 61 + methods that operate across all splits. 62 + 63 + Type Parameters: 64 + ST: The sample type for all datasets in this dict. 65 + 66 + Example: 67 + >>> ds_dict = load_dataset("path/to/data", MyData) 68 + >>> train = ds_dict["train"] 69 + >>> test = ds_dict["test"] 70 + >>> 71 + >>> # Iterate over all splits 72 + >>> for split_name, dataset in ds_dict.items(): 73 + ... print(f"{split_name}: {len(dataset.shard_list)} shards") 74 + """ 75 + 76 + def __init__( 77 + self, 78 + splits: Mapping[str, Dataset[ST]] | None = None, 79 + sample_type: Type[ST] | None = None, 80 + streaming: bool = False, 81 + ) -> None: 82 + """Create a DatasetDict from a mapping of split names to datasets. 83 + 84 + Args: 85 + splits: Mapping of split names to Dataset instances. 86 + sample_type: The sample type for datasets in this dict. If not 87 + provided, inferred from the first dataset in splits. 88 + streaming: Whether this DatasetDict was loaded in streaming mode. 89 + """ 90 + super().__init__(splits or {}) 91 + self._sample_type = sample_type 92 + self._streaming = streaming 93 + 94 + @property 95 + def sample_type(self) -> Type[ST] | None: 96 + """The sample type for datasets in this dict.""" 97 + if self._sample_type is not None: 98 + return self._sample_type 99 + # Infer from first dataset 100 + if self: 101 + first_ds = next(iter(self.values())) 102 + return first_ds.sample_type 103 + return None 104 + 105 + def __getitem__(self, key: str) -> Dataset[ST]: 106 + """Get a dataset by split name.""" 107 + return super().__getitem__(key) 108 + 109 + def __setitem__(self, key: str, value: Dataset[ST]) -> None: 110 + """Set a dataset for a split name.""" 111 + super().__setitem__(key, value) 112 + 113 + @property 114 + def streaming(self) -> bool: 115 + """Whether this DatasetDict was loaded in streaming mode.""" 116 + return self._streaming 117 + 118 + @property 119 + def num_shards(self) -> dict[str, int]: 120 + """Number of shards in each split. 121 + 122 + Returns: 123 + Dict mapping split names to shard counts. 124 + 125 + Note: 126 + This property accesses the shard list, which may trigger 127 + shard enumeration for remote datasets. 128 + """ 129 + return {name: len(ds.shard_list) for name, ds in self.items()} 130 + 131 + 132 + ## 133 + # Path resolution utilities 134 + 135 + 136 + def _is_brace_pattern(path: str) -> bool: 137 + """Check if path contains WebDataset brace expansion notation. 138 + 139 + Examples: 140 + >>> _is_brace_pattern("data-{000000..000099}.tar") 141 + True 142 + >>> _is_brace_pattern("data-{train,test}.tar") 143 + True 144 + >>> _is_brace_pattern("data-000000.tar") 145 + False 146 + """ 147 + return bool(re.search(r"\{[^}]+\}", path)) 148 + 149 + 150 + def _is_glob_pattern(path: str) -> bool: 151 + """Check if path contains glob wildcards. 152 + 153 + Examples: 154 + >>> _is_glob_pattern("data-*.tar") 155 + True 156 + >>> _is_glob_pattern("data-000000.tar") 157 + False 158 + """ 159 + return "*" in path or "?" in path 160 + 161 + 162 + def _is_remote_url(path: str) -> bool: 163 + """Check if path is a remote URL (s3, http, etc.). 164 + 165 + Examples: 166 + >>> _is_remote_url("s3://bucket/path") 167 + True 168 + >>> _is_remote_url("https://example.com/data.tar") 169 + True 170 + >>> _is_remote_url("/local/path/data.tar") 171 + False 172 + """ 173 + return path.startswith(("s3://", "gs://", "http://", "https://", "az://")) 174 + 175 + 176 + def _expand_local_glob(pattern: str) -> list[str]: 177 + """Expand a local glob pattern to list of paths. 178 + 179 + Args: 180 + pattern: Glob pattern like "path/to/*.tar" 181 + 182 + Returns: 183 + Sorted list of matching file paths. 184 + """ 185 + base_path = Path(pattern).parent 186 + glob_part = Path(pattern).name 187 + 188 + if not base_path.exists(): 189 + return [] 190 + 191 + matches = sorted(base_path.glob(glob_part)) 192 + return [str(p) for p in matches if p.is_file()] 193 + 194 + 195 + # Common split name patterns in filenames 196 + _SPLIT_PATTERNS = [ 197 + # Patterns like "dataset-train-000000.tar" (split in middle with delimiters) 198 + (r"[_-](train|training)[_-]", "train"), 199 + (r"[_-](test|testing)[_-]", "test"), 200 + (r"[_-](val|valid|validation)[_-]", "validation"), 201 + (r"[_-](dev|development)[_-]", "validation"), 202 + # Patterns at start of filename like "train-000.tar" or "test_data.tar" 203 + (r"^(train|training)[_-]", "train"), 204 + (r"^(test|testing)[_-]", "test"), 205 + (r"^(val|valid|validation)[_-]", "validation"), 206 + (r"^(dev|development)[_-]", "validation"), 207 + # Patterns in directory path like "/path/train/shard-000.tar" 208 + (r"[/\\](train|training)[/\\]", "train"), 209 + (r"[/\\](test|testing)[/\\]", "test"), 210 + (r"[/\\](val|valid|validation)[/\\]", "validation"), 211 + (r"[/\\](dev|development)[/\\]", "validation"), 212 + # Patterns at start of path like "train/shard-000.tar" 213 + (r"^(train|training)[/\\]", "train"), 214 + (r"^(test|testing)[/\\]", "test"), 215 + (r"^(val|valid|validation)[/\\]", "validation"), 216 + (r"^(dev|development)[/\\]", "validation"), 217 + ] 218 + 219 + 220 + def _detect_split_from_path(path: str) -> str | None: 221 + """Attempt to detect split name from a file path. 222 + 223 + Args: 224 + path: File path to analyze. 225 + 226 + Returns: 227 + Detected split name ("train", "test", "validation") or None. 228 + """ 229 + # Extract just the filename for pattern matching on full paths 230 + filename = Path(path).name 231 + path_lower = path.lower() 232 + filename_lower = filename.lower() 233 + 234 + # Check filename first (more specific) 235 + for pattern, split_name in _SPLIT_PATTERNS: 236 + if re.search(pattern, filename_lower): 237 + return split_name 238 + 239 + # Fall back to full path (catches directory patterns like "train/...") 240 + for pattern, split_name in _SPLIT_PATTERNS: 241 + if re.search(pattern, path_lower): 242 + return split_name 243 + 244 + return None 245 + 246 + 247 + def _resolve_shards( 248 + path: str, 249 + data_files: str | list[str] | dict[str, str | list[str]] | None = None, 250 + ) -> dict[str, list[str]]: 251 + """Resolve path specification to dict of split -> shard URLs. 252 + 253 + Handles: 254 + - WebDataset brace notation: "path/{train,test}-{000..099}.tar" 255 + - Glob patterns: "path/*.tar" 256 + - Explicit data_files mapping 257 + 258 + Args: 259 + path: Base path or pattern. 260 + data_files: Optional explicit mapping of splits to files. 261 + 262 + Returns: 263 + Dict mapping split names to lists of shard URLs. 264 + """ 265 + # If explicit data_files provided, use those 266 + if data_files is not None: 267 + return _resolve_data_files(path, data_files) 268 + 269 + # WebDataset brace notation - pass through as-is 270 + # WebDataset handles expansion internally 271 + if _is_brace_pattern(path): 272 + # Try to detect split from the pattern itself 273 + split = _detect_split_from_path(path) 274 + split_name = split or "train" 275 + return {split_name: [path]} 276 + 277 + # Local glob pattern 278 + if not _is_remote_url(path) and _is_glob_pattern(path): 279 + shards = _expand_local_glob(path) 280 + return _group_shards_by_split(shards) 281 + 282 + # Local directory - scan for .tar files 283 + if not _is_remote_url(path) and Path(path).is_dir(): 284 + shards = _expand_local_glob(str(Path(path) / "*.tar")) 285 + return _group_shards_by_split(shards) 286 + 287 + # Single file or remote URL - treat as single shard 288 + split = _detect_split_from_path(path) 289 + split_name = split or "train" 290 + return {split_name: [path]} 291 + 292 + 293 + def _resolve_data_files( 294 + base_path: str, 295 + data_files: str | list[str] | dict[str, str | list[str]], 296 + ) -> dict[str, list[str]]: 297 + """Resolve explicit data_files specification. 298 + 299 + Args: 300 + base_path: Base path for relative file references. 301 + data_files: File specification - can be: 302 + - str: Single file pattern 303 + - list[str]: List of file patterns 304 + - dict[str, ...]: Mapping of split names to patterns 305 + 306 + Returns: 307 + Dict mapping split names to lists of resolved file paths. 308 + """ 309 + base = Path(base_path) if not _is_remote_url(base_path) else None 310 + 311 + if isinstance(data_files, str): 312 + # Single pattern -> "train" split 313 + if base and not Path(data_files).is_absolute(): 314 + data_files = str(base / data_files) 315 + return {"train": [data_files]} 316 + 317 + if isinstance(data_files, list): 318 + # List of patterns -> "train" split 319 + resolved = [] 320 + for f in data_files: 321 + if base and not Path(f).is_absolute(): 322 + f = str(base / f) 323 + resolved.append(f) 324 + return {"train": resolved} 325 + 326 + # Dict mapping splits to patterns 327 + result: dict[str, list[str]] = {} 328 + for split_name, files in data_files.items(): 329 + if isinstance(files, str): 330 + files = [files] 331 + resolved = [] 332 + for f in files: 333 + if base and not Path(f).is_absolute(): 334 + f = str(base / f) 335 + resolved.append(f) 336 + result[split_name] = resolved 337 + 338 + return result 339 + 340 + 341 + def _shards_to_wds_url(shards: list[str]) -> str: 342 + """Convert a list of shard paths to a WebDataset URL. 343 + 344 + WebDataset supports brace expansion, so we convert multiple shards 345 + into brace notation when they share a common prefix/suffix. 346 + 347 + Args: 348 + shards: List of shard file paths. 349 + 350 + Returns: 351 + WebDataset-compatible URL string. 352 + 353 + Examples: 354 + >>> _shards_to_wds_url(["data-000.tar", "data-001.tar", "data-002.tar"]) 355 + "data-{000,001,002}.tar" 356 + >>> _shards_to_wds_url(["train.tar"]) 357 + "train.tar" 358 + """ 359 + if len(shards) == 0: 360 + raise ValueError("Cannot create URL from empty shard list") 361 + 362 + if len(shards) == 1: 363 + return shards[0] 364 + 365 + # Find common prefix across ALL shards 366 + prefix = shards[0] 367 + for s in shards[1:]: 368 + # Shorten prefix until it matches 369 + while not s.startswith(prefix) and prefix: 370 + prefix = prefix[:-1] 371 + 372 + # Find common suffix across ALL shards 373 + suffix = shards[0] 374 + for s in shards[1:]: 375 + # Shorten suffix until it matches 376 + while not s.endswith(suffix) and suffix: 377 + suffix = suffix[1:] 378 + 379 + prefix_len = len(prefix) 380 + suffix_len = len(suffix) 381 + 382 + # Ensure prefix and suffix don't overlap 383 + min_shard_len = min(len(s) for s in shards) 384 + if prefix_len + suffix_len > min_shard_len: 385 + # Overlapping - prefer prefix, reduce suffix 386 + suffix_len = max(0, min_shard_len - prefix_len) 387 + suffix = shards[0][-suffix_len:] if suffix_len > 0 else "" 388 + 389 + if prefix_len > 0 or suffix_len > 0: 390 + # Extract the varying middle parts 391 + middles = [] 392 + for s in shards: 393 + if suffix_len > 0: 394 + middle = s[prefix_len:-suffix_len] 395 + else: 396 + middle = s[prefix_len:] 397 + middles.append(middle) 398 + 399 + # Only use brace notation if we have meaningful variation 400 + if all(middles): 401 + return f"{prefix}{{{','.join(middles)}}}{suffix}" 402 + 403 + # Fallback: space-separated URLs for WebDataset 404 + return " ".join(shards) 405 + 406 + 407 + def _group_shards_by_split(shards: list[str]) -> dict[str, list[str]]: 408 + """Group a list of shard paths by detected split. 409 + 410 + Args: 411 + shards: List of shard file paths. 412 + 413 + Returns: 414 + Dict mapping split names to lists of shards. Files with no 415 + detected split are placed in "train". 416 + """ 417 + result: dict[str, list[str]] = {} 418 + 419 + for shard in shards: 420 + split = _detect_split_from_path(shard) 421 + split_name = split or "train" 422 + if split_name not in result: 423 + result[split_name] = [] 424 + result[split_name].append(shard) 425 + 426 + return result 427 + 428 + 429 + ## 430 + # Main load_dataset function 431 + 432 + 433 + @overload 434 + def load_dataset( 435 + path: str, 436 + sample_type: Type[ST], 437 + *, 438 + split: str, 439 + data_files: str | list[str] | dict[str, str | list[str]] | None = None, 440 + streaming: bool = False, 441 + ) -> Dataset[ST]: ... 442 + 443 + 444 + @overload 445 + def load_dataset( 446 + path: str, 447 + sample_type: Type[ST], 448 + *, 449 + split: None = None, 450 + data_files: str | list[str] | dict[str, str | list[str]] | None = None, 451 + streaming: bool = False, 452 + ) -> DatasetDict[ST]: ... 453 + 454 + 455 + def load_dataset( 456 + path: str, 457 + sample_type: Type[ST], 458 + *, 459 + split: str | None = None, 460 + data_files: str | list[str] | dict[str, str | list[str]] | None = None, 461 + streaming: bool = False, 462 + ) -> Dataset[ST] | DatasetDict[ST]: 463 + """Load a dataset from local files or remote URLs. 464 + 465 + This function provides a HuggingFace Datasets-style interface for loading 466 + atdata typed datasets. It handles path resolution, split detection, and 467 + returns either a single Dataset or a DatasetDict depending on the split 468 + parameter. 469 + 470 + Args: 471 + path: Path to dataset. Can be: 472 + - WebDataset brace notation: "path/to/{train,test}-{000..099}.tar" 473 + - Local directory: "./data/" (scans for .tar files) 474 + - Glob pattern: "path/to/*.tar" 475 + - Remote URL: "s3://bucket/path/data-*.tar" 476 + - Single file: "path/to/data.tar" 477 + 478 + sample_type: The PackableSample subclass defining the schema for 479 + samples in this dataset. This is required (unlike HF Datasets) 480 + because atdata uses typed dataclasses. 481 + 482 + split: Which split to load. If None, returns a DatasetDict with all 483 + detected splits. If specified (e.g., "train", "test"), returns 484 + a single Dataset for that split. 485 + 486 + data_files: Optional explicit mapping of data files. Can be: 487 + - str: Single file pattern 488 + - list[str]: List of file patterns (assigned to "train") 489 + - dict[str, str | list[str]]: Explicit split -> files mapping 490 + 491 + streaming: If True, explicitly marks the dataset for streaming mode. 492 + Note: atdata Datasets are already lazy/streaming via WebDataset 493 + pipelines, so this parameter primarily signals intent. When True, 494 + shard list precomputation is skipped. Default False. 495 + 496 + Returns: 497 + If split is None: DatasetDict[ST] with all detected splits. 498 + If split is specified: Dataset[ST] for that split. 499 + 500 + Raises: 501 + ValueError: If the specified split is not found. 502 + FileNotFoundError: If no data files are found at the path. 503 + 504 + Example: 505 + >>> @atdata.packable 506 + ... class TextData: 507 + ... text: str 508 + ... label: int 509 + >>> 510 + >>> # Load single split 511 + >>> train_ds = load_dataset("./data/train-*.tar", TextData, split="train") 512 + >>> 513 + >>> # Load all splits 514 + >>> ds_dict = load_dataset("./data/", TextData) 515 + >>> train_ds = ds_dict["train"] 516 + >>> test_ds = ds_dict["test"] 517 + >>> 518 + >>> # Explicit data files 519 + >>> ds_dict = load_dataset("./data/", TextData, data_files={ 520 + ... "train": "train-*.tar", 521 + ... "test": "test-*.tar", 522 + ... }) 523 + """ 524 + # Resolve path to split -> shard URL mapping 525 + splits_shards = _resolve_shards(path, data_files) 526 + 527 + if not splits_shards: 528 + raise FileNotFoundError(f"No data files found at path: {path}") 529 + 530 + # Build Dataset for each split 531 + datasets: dict[str, Dataset[ST]] = {} 532 + for split_name, shards in splits_shards.items(): 533 + url = _shards_to_wds_url(shards) 534 + ds = Dataset[sample_type](url) 535 + datasets[split_name] = ds 536 + 537 + # Return single Dataset or DatasetDict 538 + if split is not None: 539 + if split not in datasets: 540 + available = list(datasets.keys()) 541 + raise ValueError( 542 + f"Split '{split}' not found. Available splits: {available}" 543 + ) 544 + return datasets[split] 545 + 546 + return DatasetDict(datasets, sample_type=sample_type, streaming=streaming) 547 + 548 + 549 + ## 550 + # Convenience re-exports (will be exposed in __init__.py) 551 + 552 + __all__ = [ 553 + "load_dataset", 554 + "DatasetDict", 555 + ]
+645
tests/test_hf_api.py
··· 1 + """Tests for the HuggingFace Datasets-style API (_hf_api.py).""" 2 + 3 + ## 4 + # Imports 5 + 6 + import pytest 7 + from dataclasses import dataclass 8 + from pathlib import Path 9 + 10 + import numpy as np 11 + import webdataset as wds 12 + 13 + import atdata 14 + from atdata._hf_api import ( 15 + load_dataset, 16 + DatasetDict, 17 + _is_brace_pattern, 18 + _is_glob_pattern, 19 + _is_remote_url, 20 + _detect_split_from_path, 21 + _shards_to_wds_url, 22 + _expand_local_glob, 23 + _resolve_shards, 24 + _resolve_data_files, 25 + _group_shards_by_split, 26 + ) 27 + 28 + from numpy.typing import NDArray 29 + 30 + 31 + ## 32 + # Test sample types 33 + 34 + 35 + @atdata.packable 36 + class SimpleTestSample: 37 + """Simple sample type for testing.""" 38 + 39 + text: str 40 + label: int 41 + 42 + 43 + @atdata.packable 44 + class NumpyTestSample: 45 + """Sample type with numpy arrays for testing.""" 46 + 47 + embedding: NDArray 48 + label: int 49 + 50 + 51 + ## 52 + # Helper function tests 53 + 54 + 55 + class TestIsBracePattern: 56 + """Tests for _is_brace_pattern().""" 57 + 58 + def test_range_pattern(self): 59 + assert _is_brace_pattern("data-{000000..000099}.tar") is True 60 + 61 + def test_list_pattern(self): 62 + assert _is_brace_pattern("data-{train,test,val}.tar") is True 63 + 64 + def test_no_pattern(self): 65 + assert _is_brace_pattern("data-000000.tar") is False 66 + 67 + def test_empty_braces(self): 68 + # Empty braces are not valid WebDataset brace notation 69 + assert _is_brace_pattern("data-{}.tar") is False 70 + 71 + def test_nested_path_with_pattern(self): 72 + assert _is_brace_pattern("path/to/data-{000..099}.tar") is True 73 + 74 + 75 + class TestIsGlobPattern: 76 + """Tests for _is_glob_pattern().""" 77 + 78 + def test_asterisk(self): 79 + assert _is_glob_pattern("data-*.tar") is True 80 + 81 + def test_question_mark(self): 82 + assert _is_glob_pattern("data-00000?.tar") is True 83 + 84 + def test_no_pattern(self): 85 + assert _is_glob_pattern("data-000000.tar") is False 86 + 87 + def test_path_with_glob(self): 88 + assert _is_glob_pattern("path/to/*.tar") is True 89 + 90 + 91 + class TestIsRemoteUrl: 92 + """Tests for _is_remote_url().""" 93 + 94 + def test_s3_url(self): 95 + assert _is_remote_url("s3://bucket/path/data.tar") is True 96 + 97 + def test_https_url(self): 98 + assert _is_remote_url("https://example.com/data.tar") is True 99 + 100 + def test_http_url(self): 101 + assert _is_remote_url("http://example.com/data.tar") is True 102 + 103 + def test_gs_url(self): 104 + assert _is_remote_url("gs://bucket/path/data.tar") is True 105 + 106 + def test_az_url(self): 107 + assert _is_remote_url("az://container/path/data.tar") is True 108 + 109 + def test_local_absolute_path(self): 110 + assert _is_remote_url("/local/path/data.tar") is False 111 + 112 + def test_local_relative_path(self): 113 + assert _is_remote_url("./data/data.tar") is False 114 + 115 + def test_windows_path(self): 116 + assert _is_remote_url("C:\\data\\data.tar") is False 117 + 118 + 119 + class TestDetectSplitFromPath: 120 + """Tests for _detect_split_from_path().""" 121 + 122 + def test_train_in_filename(self): 123 + assert _detect_split_from_path("dataset-train-000000.tar") == "train" 124 + 125 + def test_test_in_filename(self): 126 + assert _detect_split_from_path("dataset-test-000000.tar") == "test" 127 + 128 + def test_validation_in_filename(self): 129 + assert _detect_split_from_path("dataset-validation-000000.tar") == "validation" 130 + 131 + def test_val_in_filename(self): 132 + assert _detect_split_from_path("dataset-val-000000.tar") == "validation" 133 + 134 + def test_dev_in_filename(self): 135 + assert _detect_split_from_path("dataset-dev-000000.tar") == "validation" 136 + 137 + def test_train_directory(self): 138 + assert _detect_split_from_path("train/shard-000000.tar") == "train" 139 + 140 + def test_test_directory(self): 141 + assert _detect_split_from_path("test/shard-000000.tar") == "test" 142 + 143 + def test_no_split_detected(self): 144 + assert _detect_split_from_path("dataset-000000.tar") is None 145 + 146 + def test_case_insensitive(self): 147 + assert _detect_split_from_path("dataset-TRAIN-000000.tar") == "train" 148 + assert _detect_split_from_path("dataset-Train-000000.tar") == "train" 149 + 150 + def test_training_variant(self): 151 + assert _detect_split_from_path("dataset-training-000000.tar") == "train" 152 + 153 + def test_testing_variant(self): 154 + assert _detect_split_from_path("dataset-testing-000000.tar") == "test" 155 + 156 + 157 + class TestShardsToWdsUrl: 158 + """Tests for _shards_to_wds_url().""" 159 + 160 + def test_single_shard(self): 161 + assert _shards_to_wds_url(["data.tar"]) == "data.tar" 162 + 163 + def test_multiple_shards_common_pattern(self): 164 + shards = ["data-000.tar", "data-001.tar", "data-002.tar"] 165 + result = _shards_to_wds_url(shards) 166 + # Algorithm finds longest common prefix/suffix, resulting in compact notation 167 + # Both "data-{000,001,002}.tar" and "data-00{0,1,2}.tar" are valid 168 + assert "{" in result and "}" in result 169 + assert ".tar" in result 170 + assert "data-" in result 171 + 172 + def test_multiple_shards_different_lengths(self): 173 + shards = ["data-0.tar", "data-1.tar", "data-10.tar"] 174 + result = _shards_to_wds_url(shards) 175 + # Should still produce brace notation 176 + assert "{" in result and "}" in result 177 + 178 + def test_empty_list_raises(self): 179 + with pytest.raises(ValueError, match="empty shard list"): 180 + _shards_to_wds_url([]) 181 + 182 + def test_no_common_pattern(self): 183 + shards = ["train.tar", "test.tar", "val.tar"] 184 + result = _shards_to_wds_url(shards) 185 + # Falls back to space-separated or brace notation 186 + assert "train" in result 187 + 188 + 189 + class TestExpandLocalGlob: 190 + """Tests for _expand_local_glob().""" 191 + 192 + def test_no_matches(self, tmp_path): 193 + pattern = str(tmp_path / "*.tar") 194 + assert _expand_local_glob(pattern) == [] 195 + 196 + def test_matches_files(self, tmp_path): 197 + # Create test files 198 + (tmp_path / "data-000.tar").touch() 199 + (tmp_path / "data-001.tar").touch() 200 + (tmp_path / "data-002.tar").touch() 201 + 202 + pattern = str(tmp_path / "*.tar") 203 + result = _expand_local_glob(pattern) 204 + 205 + assert len(result) == 3 206 + assert all(".tar" in p for p in result) 207 + 208 + def test_ignores_directories(self, tmp_path): 209 + # Create a file and a directory 210 + (tmp_path / "data.tar").touch() 211 + (tmp_path / "subdir.tar").mkdir() 212 + 213 + pattern = str(tmp_path / "*.tar") 214 + result = _expand_local_glob(pattern) 215 + 216 + assert len(result) == 1 217 + 218 + def test_nonexistent_directory(self): 219 + result = _expand_local_glob("/nonexistent/path/*.tar") 220 + assert result == [] 221 + 222 + 223 + class TestGroupShardsBySplit: 224 + """Tests for _group_shards_by_split().""" 225 + 226 + def test_single_split(self): 227 + shards = [ 228 + "train-000.tar", 229 + "train-001.tar", 230 + "train-002.tar", 231 + ] 232 + result = _group_shards_by_split(shards) 233 + assert "train" in result 234 + assert len(result["train"]) == 3 235 + 236 + def test_multiple_splits(self): 237 + shards = [ 238 + "data-train-000.tar", 239 + "data-train-001.tar", 240 + "data-test-000.tar", 241 + "data-val-000.tar", 242 + ] 243 + result = _group_shards_by_split(shards) 244 + assert "train" in result 245 + assert "test" in result 246 + assert "validation" in result 247 + assert len(result["train"]) == 2 248 + assert len(result["test"]) == 1 249 + assert len(result["validation"]) == 1 250 + 251 + def test_no_detected_split_defaults_to_train(self): 252 + shards = ["shard-000.tar", "shard-001.tar"] 253 + result = _group_shards_by_split(shards) 254 + assert "train" in result 255 + assert len(result["train"]) == 2 256 + 257 + 258 + class TestResolveDataFiles: 259 + """Tests for _resolve_data_files().""" 260 + 261 + def test_string_input(self, tmp_path): 262 + result = _resolve_data_files(str(tmp_path), "data.tar") 263 + assert "train" in result 264 + assert len(result["train"]) == 1 265 + 266 + def test_list_input(self, tmp_path): 267 + result = _resolve_data_files(str(tmp_path), ["a.tar", "b.tar"]) 268 + assert "train" in result 269 + assert len(result["train"]) == 2 270 + 271 + def test_dict_input(self, tmp_path): 272 + data_files = { 273 + "train": ["train-000.tar", "train-001.tar"], 274 + "test": "test-000.tar", 275 + } 276 + result = _resolve_data_files(str(tmp_path), data_files) 277 + assert "train" in result 278 + assert "test" in result 279 + assert len(result["train"]) == 2 280 + assert len(result["test"]) == 1 281 + 282 + def test_resolves_relative_paths(self, tmp_path): 283 + result = _resolve_data_files(str(tmp_path), "subdir/data.tar") 284 + assert str(tmp_path) in result["train"][0] 285 + 286 + 287 + class TestResolveShards: 288 + """Tests for _resolve_shards().""" 289 + 290 + def test_brace_pattern_passthrough(self): 291 + path = "data-{000000..000099}.tar" 292 + result = _resolve_shards(path) 293 + assert "train" in result 294 + assert path in result["train"] 295 + 296 + def test_brace_pattern_with_split_name(self): 297 + path = "data-train-{000..099}.tar" 298 + result = _resolve_shards(path) 299 + assert "train" in result 300 + 301 + def test_single_file(self): 302 + path = "data.tar" 303 + result = _resolve_shards(path) 304 + assert "train" in result 305 + assert result["train"] == [path] 306 + 307 + def test_with_data_files_override(self, tmp_path): 308 + data_files = {"train": "train.tar", "test": "test.tar"} 309 + result = _resolve_shards(str(tmp_path), data_files) 310 + assert "train" in result 311 + assert "test" in result 312 + 313 + def test_local_directory(self, tmp_path): 314 + # Create test tar files 315 + (tmp_path / "train-000.tar").touch() 316 + (tmp_path / "train-001.tar").touch() 317 + (tmp_path / "test-000.tar").touch() 318 + 319 + result = _resolve_shards(str(tmp_path)) 320 + assert "train" in result 321 + assert "test" in result 322 + 323 + def test_glob_pattern(self, tmp_path): 324 + # Create test files 325 + (tmp_path / "data-000.tar").touch() 326 + (tmp_path / "data-001.tar").touch() 327 + 328 + pattern = str(tmp_path / "*.tar") 329 + result = _resolve_shards(pattern) 330 + assert "train" in result # defaults to train when no split detected 331 + 332 + 333 + ## 334 + # DatasetDict tests 335 + 336 + 337 + class TestDatasetDict: 338 + """Tests for DatasetDict class.""" 339 + 340 + def test_empty_init(self): 341 + dd = DatasetDict() 342 + assert len(dd) == 0 343 + 344 + def test_init_with_splits(self, tmp_path): 345 + # Create a minimal tar file for Dataset 346 + tar_path = tmp_path / "data.tar" 347 + with wds.writer.TarWriter(str(tar_path)) as sink: 348 + sample = SimpleTestSample(text="hello", label=1) 349 + sink.write(sample.as_wds) 350 + 351 + train_ds = atdata.Dataset[SimpleTestSample](str(tar_path)) 352 + test_ds = atdata.Dataset[SimpleTestSample](str(tar_path)) 353 + 354 + dd = DatasetDict({"train": train_ds, "test": test_ds}) 355 + 356 + assert len(dd) == 2 357 + assert "train" in dd 358 + assert "test" in dd 359 + 360 + def test_getitem(self, tmp_path): 361 + tar_path = tmp_path / "data.tar" 362 + with wds.writer.TarWriter(str(tar_path)) as sink: 363 + sample = SimpleTestSample(text="hello", label=1) 364 + sink.write(sample.as_wds) 365 + 366 + train_ds = atdata.Dataset[SimpleTestSample](str(tar_path)) 367 + dd = DatasetDict({"train": train_ds}) 368 + 369 + assert dd["train"] is train_ds 370 + 371 + def test_setitem(self, tmp_path): 372 + tar_path = tmp_path / "data.tar" 373 + with wds.writer.TarWriter(str(tar_path)) as sink: 374 + sample = SimpleTestSample(text="hello", label=1) 375 + sink.write(sample.as_wds) 376 + 377 + dd = DatasetDict() 378 + train_ds = atdata.Dataset[SimpleTestSample](str(tar_path)) 379 + dd["train"] = train_ds 380 + 381 + assert "train" in dd 382 + assert dd["train"] is train_ds 383 + 384 + def test_keys_values_items(self, tmp_path): 385 + tar_path = tmp_path / "data.tar" 386 + with wds.writer.TarWriter(str(tar_path)) as sink: 387 + sample = SimpleTestSample(text="hello", label=1) 388 + sink.write(sample.as_wds) 389 + 390 + train_ds = atdata.Dataset[SimpleTestSample](str(tar_path)) 391 + test_ds = atdata.Dataset[SimpleTestSample](str(tar_path)) 392 + 393 + dd = DatasetDict({"train": train_ds, "test": test_ds}) 394 + 395 + assert set(dd.keys()) == {"train", "test"} 396 + assert len(list(dd.values())) == 2 397 + assert len(list(dd.items())) == 2 398 + 399 + def test_streaming_property(self): 400 + dd = DatasetDict(streaming=True) 401 + assert dd.streaming is True 402 + 403 + dd2 = DatasetDict(streaming=False) 404 + assert dd2.streaming is False 405 + 406 + def test_sample_type_explicit(self): 407 + dd = DatasetDict(sample_type=SimpleTestSample) 408 + assert dd.sample_type is SimpleTestSample 409 + 410 + def test_num_shards(self, tmp_path): 411 + # Create two tar files for train split 412 + train_path = tmp_path / "train.tar" 413 + with wds.writer.TarWriter(str(train_path)) as sink: 414 + sample = SimpleTestSample(text="hello", label=1) 415 + sink.write(sample.as_wds) 416 + 417 + train_ds = atdata.Dataset[SimpleTestSample](str(train_path)) 418 + dd = DatasetDict({"train": train_ds}) 419 + 420 + num_shards = dd.num_shards 421 + assert "train" in num_shards 422 + assert num_shards["train"] == 1 423 + 424 + 425 + ## 426 + # load_dataset tests 427 + 428 + 429 + class TestLoadDataset: 430 + """Tests for load_dataset() function.""" 431 + 432 + def test_load_single_file_with_split(self, tmp_path): 433 + """Load a single tar file specifying a split.""" 434 + tar_path = tmp_path / "data.tar" 435 + with wds.writer.TarWriter(str(tar_path)) as sink: 436 + for i in range(10): 437 + sample = SimpleTestSample(text=f"sample_{i}", label=i) 438 + sink.write(sample.as_wds) 439 + 440 + ds = load_dataset(str(tar_path), SimpleTestSample, split="train") 441 + 442 + assert isinstance(ds, atdata.Dataset) 443 + # Verify we can iterate 444 + samples = list(ds.ordered(batch_size=None)) 445 + assert len(samples) == 10 446 + 447 + def test_load_returns_dataset_dict_without_split(self, tmp_path): 448 + """Without split parameter, returns DatasetDict.""" 449 + tar_path = tmp_path / "data.tar" 450 + with wds.writer.TarWriter(str(tar_path)) as sink: 451 + sample = SimpleTestSample(text="hello", label=1) 452 + sink.write(sample.as_wds) 453 + 454 + result = load_dataset(str(tar_path), SimpleTestSample) 455 + 456 + assert isinstance(result, DatasetDict) 457 + assert "train" in result 458 + 459 + def test_load_with_data_files_dict(self, tmp_path): 460 + """Load with explicit data_files mapping.""" 461 + # Create train and test files 462 + train_path = tmp_path / "train.tar" 463 + test_path = tmp_path / "test.tar" 464 + 465 + with wds.writer.TarWriter(str(train_path)) as sink: 466 + for i in range(5): 467 + sample = SimpleTestSample(text=f"train_{i}", label=i) 468 + sink.write(sample.as_wds) 469 + 470 + with wds.writer.TarWriter(str(test_path)) as sink: 471 + for i in range(3): 472 + sample = SimpleTestSample(text=f"test_{i}", label=i) 473 + sink.write(sample.as_wds) 474 + 475 + result = load_dataset( 476 + str(tmp_path), 477 + SimpleTestSample, 478 + data_files={"train": "train.tar", "test": "test.tar"}, 479 + ) 480 + 481 + assert isinstance(result, DatasetDict) 482 + assert "train" in result 483 + assert "test" in result 484 + 485 + def test_load_nonexistent_split_raises(self, tmp_path): 486 + """Requesting a split that doesn't exist raises ValueError.""" 487 + tar_path = tmp_path / "train.tar" 488 + with wds.writer.TarWriter(str(tar_path)) as sink: 489 + sample = SimpleTestSample(text="hello", label=1) 490 + sink.write(sample.as_wds) 491 + 492 + with pytest.raises(ValueError, match="Split 'test' not found"): 493 + load_dataset(str(tar_path), SimpleTestSample, split="test") 494 + 495 + def test_load_directory_with_split_detection(self, tmp_path): 496 + """Load from directory auto-detecting splits from filenames.""" 497 + # Create files with split names 498 + train_path = tmp_path / "data-train-000.tar" 499 + test_path = tmp_path / "data-test-000.tar" 500 + 501 + with wds.writer.TarWriter(str(train_path)) as sink: 502 + for i in range(5): 503 + sample = SimpleTestSample(text=f"train_{i}", label=i) 504 + sink.write(sample.as_wds) 505 + 506 + with wds.writer.TarWriter(str(test_path)) as sink: 507 + for i in range(3): 508 + sample = SimpleTestSample(text=f"test_{i}", label=i) 509 + sink.write(sample.as_wds) 510 + 511 + result = load_dataset(str(tmp_path), SimpleTestSample) 512 + 513 + assert isinstance(result, DatasetDict) 514 + assert "train" in result 515 + assert "test" in result 516 + 517 + def test_load_with_streaming_flag(self, tmp_path): 518 + """streaming=True sets the streaming property.""" 519 + tar_path = tmp_path / "data.tar" 520 + with wds.writer.TarWriter(str(tar_path)) as sink: 521 + sample = SimpleTestSample(text="hello", label=1) 522 + sink.write(sample.as_wds) 523 + 524 + result = load_dataset(str(tar_path), SimpleTestSample, streaming=True) 525 + 526 + assert isinstance(result, DatasetDict) 527 + assert result.streaming is True 528 + 529 + def test_load_with_numpy_sample_type(self, tmp_path): 530 + """Load dataset with numpy arrays in samples.""" 531 + tar_path = tmp_path / "data.tar" 532 + with wds.writer.TarWriter(str(tar_path)) as sink: 533 + for i in range(5): 534 + sample = NumpyTestSample( 535 + embedding=np.random.randn(128).astype(np.float32), label=i 536 + ) 537 + sink.write(sample.as_wds) 538 + 539 + ds = load_dataset(str(tar_path), NumpyTestSample, split="train") 540 + samples = list(ds.ordered(batch_size=None)) 541 + 542 + assert len(samples) == 5 543 + assert isinstance(samples[0].embedding, np.ndarray) 544 + assert samples[0].embedding.shape == (128,) 545 + 546 + def test_load_glob_pattern(self, tmp_path): 547 + """Load using glob pattern.""" 548 + # Create multiple shard files 549 + for i in range(3): 550 + shard_path = tmp_path / f"data-{i:03d}.tar" 551 + with wds.writer.TarWriter(str(shard_path)) as sink: 552 + sample = SimpleTestSample(text=f"shard_{i}", label=i) 553 + sink.write(sample.as_wds) 554 + 555 + pattern = str(tmp_path / "*.tar") 556 + result = load_dataset(pattern, SimpleTestSample) 557 + 558 + assert isinstance(result, DatasetDict) 559 + assert "train" in result 560 + 561 + def test_load_brace_notation(self, tmp_path): 562 + """Load using WebDataset brace notation.""" 563 + # Create sharded files 564 + for i in range(3): 565 + shard_path = tmp_path / f"data-{i:06d}.tar" 566 + with wds.writer.TarWriter(str(shard_path)) as sink: 567 + for j in range(2): 568 + sample = SimpleTestSample(text=f"shard_{i}_sample_{j}", label=j) 569 + sink.write(sample.as_wds) 570 + 571 + # Use brace notation 572 + pattern = str(tmp_path / "data-{000000..000002}.tar") 573 + ds = load_dataset(pattern, SimpleTestSample, split="train") 574 + 575 + assert isinstance(ds, atdata.Dataset) 576 + samples = list(ds.ordered(batch_size=None)) 577 + assert len(samples) == 6 # 3 shards * 2 samples each 578 + 579 + def test_load_empty_directory_raises(self, tmp_path): 580 + """Loading from empty directory raises FileNotFoundError.""" 581 + empty_dir = tmp_path / "empty" 582 + empty_dir.mkdir() 583 + 584 + with pytest.raises(FileNotFoundError): 585 + load_dataset(str(empty_dir), SimpleTestSample) 586 + 587 + 588 + ## 589 + # Integration tests 590 + 591 + 592 + class TestLoadDatasetIntegration: 593 + """Integration tests combining multiple features.""" 594 + 595 + def test_full_workflow_train_test_split(self, tmp_path): 596 + """Full workflow: create sharded dataset, load with splits, iterate.""" 597 + # Create train shards 598 + for i in range(2): 599 + shard_path = tmp_path / f"train-{i:03d}.tar" 600 + with wds.writer.TarWriter(str(shard_path)) as sink: 601 + for j in range(5): 602 + sample = SimpleTestSample(text=f"train_{i}_{j}", label=j) 603 + sink.write(sample.as_wds) 604 + 605 + # Create test shard 606 + test_path = tmp_path / "test-000.tar" 607 + with wds.writer.TarWriter(str(test_path)) as sink: 608 + for j in range(3): 609 + sample = SimpleTestSample(text=f"test_{j}", label=j) 610 + sink.write(sample.as_wds) 611 + 612 + # Load dataset 613 + ds = load_dataset(str(tmp_path), SimpleTestSample) 614 + 615 + # Verify structure 616 + assert "train" in ds 617 + assert "test" in ds 618 + 619 + # Iterate train 620 + train_samples = list(ds["train"].ordered(batch_size=None)) 621 + assert len(train_samples) == 10 # 2 shards * 5 samples 622 + 623 + # Iterate test 624 + test_samples = list(ds["test"].ordered(batch_size=None)) 625 + assert len(test_samples) == 3 626 + 627 + def test_batched_iteration(self, tmp_path): 628 + """Test batched iteration through loaded dataset.""" 629 + tar_path = tmp_path / "data.tar" 630 + with wds.writer.TarWriter(str(tar_path)) as sink: 631 + for i in range(20): 632 + sample = SimpleTestSample(text=f"sample_{i}", label=i % 5) 633 + sink.write(sample.as_wds) 634 + 635 + ds = load_dataset(str(tar_path), SimpleTestSample, split="train") 636 + 637 + batches = list(ds.ordered(batch_size=4)) 638 + assert len(batches) == 5 # 20 samples / 4 per batch 639 + 640 + # Check batch structure 641 + first_batch = batches[0] 642 + assert len(first_batch.samples) == 4 643 + # Aggregated attributes 644 + labels = first_batch.label 645 + assert len(labels) == 4