A loose federation of distributed, typed datasets
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 878 lines 30 kB view raw
1"""Tests for the HuggingFace Datasets-style API (_hf_api.py).""" 2 3## 4# Imports 5 6import pytest 7 8import numpy as np 9import webdataset as wds 10 11import atdata 12from atdata._hf_api import ( 13 load_dataset, 14 DatasetDict, 15 _is_brace_pattern, 16 _is_glob_pattern, 17 _is_remote_url, 18 _detect_split_from_path, 19 _shards_to_wds_url, 20 _expand_local_glob, 21 _resolve_shards, 22 _resolve_data_files, 23 _group_shards_by_split, 24 _is_indexed_path, 25 _parse_indexed_path, 26) 27from unittest.mock import Mock 28 29from numpy.typing import NDArray 30 31 32## 33# Test sample types 34 35 36@atdata.packable 37class SimpleTestSample: 38 """Simple sample type for testing.""" 39 40 text: str 41 label: int 42 43 44@atdata.packable 45class NumpyTestSample: 46 """Sample type with numpy arrays for testing.""" 47 48 embedding: NDArray 49 label: int 50 51 52## 53# Helper function tests 54 55 56class TestIsBracePattern: 57 """Tests for _is_brace_pattern().""" 58 59 def test_range_pattern(self): 60 assert _is_brace_pattern("data-{000000..000099}.tar") is True 61 62 def test_list_pattern(self): 63 assert _is_brace_pattern("data-{train,test,val}.tar") is True 64 65 def test_no_pattern(self): 66 assert _is_brace_pattern("data-000000.tar") is False 67 68 def test_empty_braces(self): 69 # Empty braces are not valid WebDataset brace notation 70 assert _is_brace_pattern("data-{}.tar") is False 71 72 def test_nested_path_with_pattern(self): 73 assert _is_brace_pattern("path/to/data-{000..099}.tar") is True 74 75 76class TestIsGlobPattern: 77 """Tests for _is_glob_pattern().""" 78 79 def test_asterisk(self): 80 assert _is_glob_pattern("data-*.tar") is True 81 82 def test_question_mark(self): 83 assert _is_glob_pattern("data-00000?.tar") is True 84 85 def test_no_pattern(self): 86 assert _is_glob_pattern("data-000000.tar") is False 87 88 def test_path_with_glob(self): 89 assert _is_glob_pattern("path/to/*.tar") is True 90 91 92class TestIsRemoteUrl: 93 """Tests for _is_remote_url().""" 94 95 def test_s3_url(self): 96 assert _is_remote_url("s3://bucket/path/data.tar") is True 97 98 def test_https_url(self): 99 assert _is_remote_url("https://example.com/data.tar") is True 100 101 def test_http_url(self): 102 assert _is_remote_url("http://example.com/data.tar") is True 103 104 def test_gs_url(self): 105 assert _is_remote_url("gs://bucket/path/data.tar") is True 106 107 def test_az_url(self): 108 assert _is_remote_url("az://container/path/data.tar") is True 109 110 def test_local_absolute_path(self): 111 assert _is_remote_url("/local/path/data.tar") is False 112 113 def test_local_relative_path(self): 114 assert _is_remote_url("./data/data.tar") is False 115 116 def test_windows_path(self): 117 assert _is_remote_url("C:\\data\\data.tar") is False 118 119 120class TestDetectSplitFromPath: 121 """Tests for _detect_split_from_path().""" 122 123 def test_train_in_filename(self): 124 assert _detect_split_from_path("dataset-train-000000.tar") == "train" 125 126 def test_test_in_filename(self): 127 assert _detect_split_from_path("dataset-test-000000.tar") == "test" 128 129 def test_validation_in_filename(self): 130 assert _detect_split_from_path("dataset-validation-000000.tar") == "validation" 131 132 def test_val_in_filename(self): 133 assert _detect_split_from_path("dataset-val-000000.tar") == "validation" 134 135 def test_dev_in_filename(self): 136 assert _detect_split_from_path("dataset-dev-000000.tar") == "validation" 137 138 def test_train_directory(self): 139 assert _detect_split_from_path("train/shard-000000.tar") == "train" 140 141 def test_test_directory(self): 142 assert _detect_split_from_path("test/shard-000000.tar") == "test" 143 144 def test_no_split_detected(self): 145 assert _detect_split_from_path("dataset-000000.tar") is None 146 147 def test_case_insensitive(self): 148 assert _detect_split_from_path("dataset-TRAIN-000000.tar") == "train" 149 assert _detect_split_from_path("dataset-Train-000000.tar") == "train" 150 151 def test_training_variant(self): 152 assert _detect_split_from_path("dataset-training-000000.tar") == "train" 153 154 def test_testing_variant(self): 155 assert _detect_split_from_path("dataset-testing-000000.tar") == "test" 156 157 158class TestShardsToWdsUrl: 159 """Tests for _shards_to_wds_url().""" 160 161 def test_single_shard(self): 162 assert _shards_to_wds_url(["data.tar"]) == "data.tar" 163 164 def test_multiple_shards_common_pattern(self): 165 shards = ["data-000.tar", "data-001.tar", "data-002.tar"] 166 result = _shards_to_wds_url(shards) 167 # Algorithm finds longest common prefix/suffix, resulting in compact notation 168 # Both "data-{000,001,002}.tar" and "data-00{0,1,2}.tar" are valid 169 assert "{" in result and "}" in result 170 assert ".tar" in result 171 assert "data-" in result 172 173 def test_multiple_shards_different_lengths(self): 174 shards = ["data-0.tar", "data-1.tar", "data-10.tar"] 175 result = _shards_to_wds_url(shards) 176 # Should still produce brace notation 177 assert "{" in result and "}" in result 178 179 def test_empty_list_raises(self): 180 with pytest.raises(ValueError, match="empty shard list"): 181 _shards_to_wds_url([]) 182 183 def test_no_common_pattern(self): 184 shards = ["train.tar", "test.tar", "val.tar"] 185 result = _shards_to_wds_url(shards) 186 # Falls back to space-separated or brace notation 187 assert "train" in result 188 189 190class TestExpandLocalGlob: 191 """Tests for _expand_local_glob().""" 192 193 def test_no_matches(self, tmp_path): 194 pattern = str(tmp_path / "*.tar") 195 assert _expand_local_glob(pattern) == [] 196 197 def test_matches_files(self, tmp_path): 198 # Create test files 199 (tmp_path / "data-000.tar").touch() 200 (tmp_path / "data-001.tar").touch() 201 (tmp_path / "data-002.tar").touch() 202 203 pattern = str(tmp_path / "*.tar") 204 result = _expand_local_glob(pattern) 205 206 assert len(result) == 3 207 assert all(".tar" in p for p in result) 208 209 def test_ignores_directories(self, tmp_path): 210 # Create a file and a directory 211 (tmp_path / "data.tar").touch() 212 (tmp_path / "subdir.tar").mkdir() 213 214 pattern = str(tmp_path / "*.tar") 215 result = _expand_local_glob(pattern) 216 217 assert len(result) == 1 218 219 def test_nonexistent_directory(self): 220 result = _expand_local_glob("/nonexistent/path/*.tar") 221 assert result == [] 222 223 224class TestGroupShardsBySplit: 225 """Tests for _group_shards_by_split().""" 226 227 def test_single_split(self): 228 shards = [ 229 "train-000.tar", 230 "train-001.tar", 231 "train-002.tar", 232 ] 233 result = _group_shards_by_split(shards) 234 assert "train" in result 235 assert len(result["train"]) == 3 236 237 def test_multiple_splits(self): 238 shards = [ 239 "data-train-000.tar", 240 "data-train-001.tar", 241 "data-test-000.tar", 242 "data-val-000.tar", 243 ] 244 result = _group_shards_by_split(shards) 245 assert "train" in result 246 assert "test" in result 247 assert "validation" in result 248 assert len(result["train"]) == 2 249 assert len(result["test"]) == 1 250 assert len(result["validation"]) == 1 251 252 def test_no_detected_split_defaults_to_train(self): 253 shards = ["shard-000.tar", "shard-001.tar"] 254 result = _group_shards_by_split(shards) 255 assert "train" in result 256 assert len(result["train"]) == 2 257 258 259class TestResolveDataFiles: 260 """Tests for _resolve_data_files().""" 261 262 def test_string_input(self, tmp_path): 263 result = _resolve_data_files(str(tmp_path), "data.tar") 264 assert "train" in result 265 assert len(result["train"]) == 1 266 267 def test_list_input(self, tmp_path): 268 result = _resolve_data_files(str(tmp_path), ["a.tar", "b.tar"]) 269 assert "train" in result 270 assert len(result["train"]) == 2 271 272 def test_dict_input(self, tmp_path): 273 data_files = { 274 "train": ["train-000.tar", "train-001.tar"], 275 "test": "test-000.tar", 276 } 277 result = _resolve_data_files(str(tmp_path), data_files) 278 assert "train" in result 279 assert "test" in result 280 assert len(result["train"]) == 2 281 assert len(result["test"]) == 1 282 283 def test_resolves_relative_paths(self, tmp_path): 284 result = _resolve_data_files(str(tmp_path), "subdir/data.tar") 285 assert str(tmp_path) in result["train"][0] 286 287 288class TestResolveShards: 289 """Tests for _resolve_shards().""" 290 291 def test_brace_pattern_passthrough(self): 292 path = "data-{000000..000099}.tar" 293 result = _resolve_shards(path) 294 assert "train" in result 295 assert path in result["train"] 296 297 def test_brace_pattern_with_split_name(self): 298 path = "data-train-{000..099}.tar" 299 result = _resolve_shards(path) 300 assert "train" in result 301 302 def test_single_file(self): 303 path = "data.tar" 304 result = _resolve_shards(path) 305 assert "train" in result 306 assert result["train"] == [path] 307 308 def test_with_data_files_override(self, tmp_path): 309 data_files = {"train": "train.tar", "test": "test.tar"} 310 result = _resolve_shards(str(tmp_path), data_files) 311 assert "train" in result 312 assert "test" in result 313 314 def test_local_directory(self, tmp_path): 315 # Create test tar files 316 (tmp_path / "train-000.tar").touch() 317 (tmp_path / "train-001.tar").touch() 318 (tmp_path / "test-000.tar").touch() 319 320 result = _resolve_shards(str(tmp_path)) 321 assert "train" in result 322 assert "test" in result 323 324 def test_glob_pattern(self, tmp_path): 325 # Create test files 326 (tmp_path / "data-000.tar").touch() 327 (tmp_path / "data-001.tar").touch() 328 329 pattern = str(tmp_path / "*.tar") 330 result = _resolve_shards(pattern) 331 assert "train" in result # defaults to train when no split detected 332 333 334## 335# DatasetDict tests 336 337 338class TestDatasetDict: 339 """Tests for DatasetDict class.""" 340 341 def test_empty_init(self): 342 dd = DatasetDict() 343 assert len(dd) == 0 344 345 def test_init_with_splits(self, tmp_path): 346 # Create a minimal tar file for Dataset 347 tar_path = tmp_path / "data.tar" 348 with wds.writer.TarWriter(str(tar_path)) as sink: 349 sample = SimpleTestSample(text="hello", label=1) 350 sink.write(sample.as_wds) 351 352 train_ds = atdata.Dataset[SimpleTestSample](str(tar_path)) 353 test_ds = atdata.Dataset[SimpleTestSample](str(tar_path)) 354 355 dd = DatasetDict({"train": train_ds, "test": test_ds}) 356 357 assert len(dd) == 2 358 assert "train" in dd 359 assert "test" in dd 360 361 def test_getitem(self, tmp_path): 362 tar_path = tmp_path / "data.tar" 363 with wds.writer.TarWriter(str(tar_path)) as sink: 364 sample = SimpleTestSample(text="hello", label=1) 365 sink.write(sample.as_wds) 366 367 train_ds = atdata.Dataset[SimpleTestSample](str(tar_path)) 368 dd = DatasetDict({"train": train_ds}) 369 370 assert dd["train"] is train_ds 371 372 def test_setitem(self, tmp_path): 373 tar_path = tmp_path / "data.tar" 374 with wds.writer.TarWriter(str(tar_path)) as sink: 375 sample = SimpleTestSample(text="hello", label=1) 376 sink.write(sample.as_wds) 377 378 dd = DatasetDict() 379 train_ds = atdata.Dataset[SimpleTestSample](str(tar_path)) 380 dd["train"] = train_ds 381 382 assert "train" in dd 383 assert dd["train"] is train_ds 384 385 def test_keys_values_items(self, tmp_path): 386 tar_path = tmp_path / "data.tar" 387 with wds.writer.TarWriter(str(tar_path)) as sink: 388 sample = SimpleTestSample(text="hello", label=1) 389 sink.write(sample.as_wds) 390 391 train_ds = atdata.Dataset[SimpleTestSample](str(tar_path)) 392 test_ds = atdata.Dataset[SimpleTestSample](str(tar_path)) 393 394 dd = DatasetDict({"train": train_ds, "test": test_ds}) 395 396 assert set(dd.keys()) == {"train", "test"} 397 assert len(list(dd.values())) == 2 398 assert len(list(dd.items())) == 2 399 400 def test_streaming_property(self): 401 dd = DatasetDict(streaming=True) 402 assert dd.streaming is True 403 404 dd2 = DatasetDict(streaming=False) 405 assert dd2.streaming is False 406 407 def test_sample_type_explicit(self): 408 dd = DatasetDict(sample_type=SimpleTestSample) 409 assert dd.sample_type is SimpleTestSample 410 411 def test_num_shards(self, tmp_path): 412 # Create two tar files for train split 413 train_path = tmp_path / "train.tar" 414 with wds.writer.TarWriter(str(train_path)) as sink: 415 sample = SimpleTestSample(text="hello", label=1) 416 sink.write(sample.as_wds) 417 418 train_ds = atdata.Dataset[SimpleTestSample](str(train_path)) 419 dd = DatasetDict({"train": train_ds}) 420 421 num_shards = dd.num_shards 422 assert "train" in num_shards 423 assert num_shards["train"] == 1 424 425 426## 427# load_dataset tests 428 429 430class TestLoadDataset: 431 """Tests for load_dataset() function.""" 432 433 def test_load_single_file_with_split(self, tmp_path): 434 """Load a single tar file specifying a split.""" 435 tar_path = tmp_path / "data.tar" 436 with wds.writer.TarWriter(str(tar_path)) as sink: 437 for i in range(10): 438 sample = SimpleTestSample(text=f"sample_{i}", label=i) 439 sink.write(sample.as_wds) 440 441 ds = load_dataset(str(tar_path), SimpleTestSample, split="train") 442 443 assert isinstance(ds, atdata.Dataset) 444 # Verify we can iterate 445 samples = list(ds.ordered(batch_size=None)) 446 assert len(samples) == 10 447 448 def test_load_returns_dataset_dict_without_split(self, tmp_path): 449 """Without split parameter, returns DatasetDict.""" 450 tar_path = tmp_path / "data.tar" 451 with wds.writer.TarWriter(str(tar_path)) as sink: 452 sample = SimpleTestSample(text="hello", label=1) 453 sink.write(sample.as_wds) 454 455 result = load_dataset(str(tar_path), SimpleTestSample) 456 457 assert isinstance(result, DatasetDict) 458 assert "train" in result 459 460 def test_load_with_data_files_dict(self, tmp_path): 461 """Load with explicit data_files mapping.""" 462 # Create train and test files 463 train_path = tmp_path / "train.tar" 464 test_path = tmp_path / "test.tar" 465 466 with wds.writer.TarWriter(str(train_path)) as sink: 467 for i in range(5): 468 sample = SimpleTestSample(text=f"train_{i}", label=i) 469 sink.write(sample.as_wds) 470 471 with wds.writer.TarWriter(str(test_path)) as sink: 472 for i in range(3): 473 sample = SimpleTestSample(text=f"test_{i}", label=i) 474 sink.write(sample.as_wds) 475 476 result = load_dataset( 477 str(tmp_path), 478 SimpleTestSample, 479 data_files={"train": "train.tar", "test": "test.tar"}, 480 ) 481 482 assert isinstance(result, DatasetDict) 483 assert "train" in result 484 assert "test" in result 485 486 def test_load_nonexistent_split_raises(self, tmp_path): 487 """Requesting a split that doesn't exist raises ValueError.""" 488 tar_path = tmp_path / "train.tar" 489 with wds.writer.TarWriter(str(tar_path)) as sink: 490 sample = SimpleTestSample(text="hello", label=1) 491 sink.write(sample.as_wds) 492 493 with pytest.raises(ValueError, match="Split 'test' not found"): 494 load_dataset(str(tar_path), SimpleTestSample, split="test") 495 496 def test_load_directory_with_split_detection(self, tmp_path): 497 """Load from directory auto-detecting splits from filenames.""" 498 # Create files with split names 499 train_path = tmp_path / "data-train-000.tar" 500 test_path = tmp_path / "data-test-000.tar" 501 502 with wds.writer.TarWriter(str(train_path)) as sink: 503 for i in range(5): 504 sample = SimpleTestSample(text=f"train_{i}", label=i) 505 sink.write(sample.as_wds) 506 507 with wds.writer.TarWriter(str(test_path)) as sink: 508 for i in range(3): 509 sample = SimpleTestSample(text=f"test_{i}", label=i) 510 sink.write(sample.as_wds) 511 512 result = load_dataset(str(tmp_path), SimpleTestSample) 513 514 assert isinstance(result, DatasetDict) 515 assert "train" in result 516 assert "test" in result 517 518 def test_load_with_streaming_flag(self, tmp_path): 519 """streaming=True sets the streaming property.""" 520 tar_path = tmp_path / "data.tar" 521 with wds.writer.TarWriter(str(tar_path)) as sink: 522 sample = SimpleTestSample(text="hello", label=1) 523 sink.write(sample.as_wds) 524 525 result = load_dataset(str(tar_path), SimpleTestSample, streaming=True) 526 527 assert isinstance(result, DatasetDict) 528 assert result.streaming is True 529 530 def test_load_with_numpy_sample_type(self, tmp_path): 531 """Load dataset with numpy arrays in samples.""" 532 tar_path = tmp_path / "data.tar" 533 with wds.writer.TarWriter(str(tar_path)) as sink: 534 for i in range(5): 535 sample = NumpyTestSample( 536 embedding=np.random.randn(128).astype(np.float32), label=i 537 ) 538 sink.write(sample.as_wds) 539 540 ds = load_dataset(str(tar_path), NumpyTestSample, split="train") 541 samples = list(ds.ordered(batch_size=None)) 542 543 assert len(samples) == 5 544 assert isinstance(samples[0].embedding, np.ndarray) 545 assert samples[0].embedding.shape == (128,) 546 547 def test_load_glob_pattern(self, tmp_path): 548 """Load using glob pattern.""" 549 # Create multiple shard files 550 for i in range(3): 551 shard_path = tmp_path / f"data-{i:03d}.tar" 552 with wds.writer.TarWriter(str(shard_path)) as sink: 553 sample = SimpleTestSample(text=f"shard_{i}", label=i) 554 sink.write(sample.as_wds) 555 556 pattern = str(tmp_path / "*.tar") 557 result = load_dataset(pattern, SimpleTestSample) 558 559 assert isinstance(result, DatasetDict) 560 assert "train" in result 561 562 def test_load_brace_notation(self, tmp_path): 563 """Load using WebDataset brace notation.""" 564 # Create sharded files 565 for i in range(3): 566 shard_path = tmp_path / f"data-{i:06d}.tar" 567 with wds.writer.TarWriter(str(shard_path)) as sink: 568 for j in range(2): 569 sample = SimpleTestSample(text=f"shard_{i}_sample_{j}", label=j) 570 sink.write(sample.as_wds) 571 572 # Use brace notation 573 pattern = str(tmp_path / "data-{000000..000002}.tar") 574 ds = load_dataset(pattern, SimpleTestSample, split="train") 575 576 assert isinstance(ds, atdata.Dataset) 577 samples = list(ds.ordered(batch_size=None)) 578 assert len(samples) == 6 # 3 shards * 2 samples each 579 580 def test_load_empty_directory_raises(self, tmp_path): 581 """Loading from empty directory raises FileNotFoundError.""" 582 empty_dir = tmp_path / "empty" 583 empty_dir.mkdir() 584 585 with pytest.raises(FileNotFoundError): 586 load_dataset(str(empty_dir), SimpleTestSample) 587 588 589## 590# Integration tests 591 592 593class TestLoadDatasetIntegration: 594 """Integration tests combining multiple features.""" 595 596 def test_full_workflow_train_test_split(self, tmp_path): 597 """Full workflow: create sharded dataset, load with splits, iterate.""" 598 # Create train shards 599 for i in range(2): 600 shard_path = tmp_path / f"train-{i:03d}.tar" 601 with wds.writer.TarWriter(str(shard_path)) as sink: 602 for j in range(5): 603 sample = SimpleTestSample(text=f"train_{i}_{j}", label=j) 604 sink.write(sample.as_wds) 605 606 # Create test shard 607 test_path = tmp_path / "test-000.tar" 608 with wds.writer.TarWriter(str(test_path)) as sink: 609 for j in range(3): 610 sample = SimpleTestSample(text=f"test_{j}", label=j) 611 sink.write(sample.as_wds) 612 613 # Load dataset 614 ds = load_dataset(str(tmp_path), SimpleTestSample) 615 616 # Verify structure 617 assert "train" in ds 618 assert "test" in ds 619 620 # Iterate train 621 train_samples = list(ds["train"].ordered(batch_size=None)) 622 assert len(train_samples) == 10 # 2 shards * 5 samples 623 624 # Iterate test 625 test_samples = list(ds["test"].ordered(batch_size=None)) 626 assert len(test_samples) == 3 627 628 def test_batched_iteration(self, tmp_path): 629 """Test batched iteration through loaded dataset.""" 630 tar_path = tmp_path / "data.tar" 631 with wds.writer.TarWriter(str(tar_path)) as sink: 632 for i in range(20): 633 sample = SimpleTestSample(text=f"sample_{i}", label=i % 5) 634 sink.write(sample.as_wds) 635 636 ds = load_dataset(str(tar_path), SimpleTestSample, split="train") 637 638 batches = list(ds.ordered(batch_size=4)) 639 assert len(batches) == 5 # 20 samples / 4 per batch 640 641 # Check batch structure 642 first_batch = batches[0] 643 assert len(first_batch.samples) == 4 644 # Aggregated attributes 645 labels = first_batch.label 646 assert len(labels) == 4 647 648 649## 650# Indexed path tests 651 652 653class TestIsIndexedPath: 654 """Tests for _is_indexed_path function.""" 655 656 def test_at_handle_path(self): 657 """@handle/dataset is indexed.""" 658 assert _is_indexed_path("@maxine.science/mnist") is True 659 660 def test_at_did_path(self): 661 """@did:plc:abc/dataset is indexed.""" 662 assert _is_indexed_path("@did:plc:abc123/my-dataset") is True 663 664 def test_local_path(self): 665 """Local paths are not indexed.""" 666 assert _is_indexed_path("/path/to/data.tar") is False 667 668 def test_s3_path(self): 669 """S3 URLs are not indexed.""" 670 assert _is_indexed_path("s3://bucket/data.tar") is False 671 672 def test_relative_path(self): 673 """Relative paths are not indexed.""" 674 assert _is_indexed_path("./data/train.tar") is False 675 676 677class TestParseIndexedPath: 678 """Tests for _parse_indexed_path function.""" 679 680 def test_parse_handle_dataset(self): 681 """Parse @handle/dataset format.""" 682 handle, name = _parse_indexed_path("@maxine.science/mnist") 683 assert handle == "maxine.science" 684 assert name == "mnist" 685 686 def test_parse_did_dataset(self): 687 """Parse @did:plc:xxx/dataset format.""" 688 handle, name = _parse_indexed_path("@did:plc:abc123/my-dataset") 689 assert handle == "did:plc:abc123" 690 assert name == "my-dataset" 691 692 def test_parse_invalid_no_slash(self): 693 """Invalid path without slash raises ValueError.""" 694 with pytest.raises(ValueError, match="Invalid indexed path format"): 695 _parse_indexed_path("@handle-only") 696 697 def test_parse_invalid_no_at(self): 698 """Path without @ raises ValueError.""" 699 with pytest.raises(ValueError, match="Not an indexed path"): 700 _parse_indexed_path("handle/dataset") 701 702 def test_parse_invalid_empty_parts(self): 703 """Empty handle or dataset raises ValueError.""" 704 with pytest.raises(ValueError, match="Invalid indexed path"): 705 _parse_indexed_path("@/dataset") 706 707 708class TestLoadDatasetWithIndex: 709 """Tests for load_dataset with index parameter.""" 710 711 def test_indexed_path_requires_index(self): 712 """@handle/dataset without index raises ValueError.""" 713 with pytest.raises(ValueError, match="Index required"): 714 load_dataset("@handle/dataset", SimpleTestSample) 715 716 def test_none_sample_type_defaults_to_dictsample(self, tmp_path): 717 """sample_type=None returns Dataset[DictSample].""" 718 from atdata import DictSample 719 720 # Create a test tar file 721 tar_path = tmp_path / "data.tar" 722 sample = SimpleTestSample(text="hello", label=42) 723 with wds.writer.TarWriter(str(tar_path)) as writer: 724 writer.write(sample.as_wds) 725 726 # Load without specifying sample_type 727 ds = load_dataset(str(tar_path), split="train") 728 729 # Should return Dataset[DictSample] 730 assert ds.sample_type == DictSample 731 732 # Should be able to iterate and access fields 733 for sample in ds.ordered(): 734 assert sample["text"] == "hello" 735 assert sample.label == 42 736 break 737 738 def test_indexed_path_with_mock_index(self): 739 """load_dataset with indexed path uses index lookup.""" 740 mock_index = Mock() 741 mock_index.data_store = None # No data store, so no URL transformation 742 mock_entry = Mock() 743 mock_entry.data_urls = ["s3://bucket/data.tar"] 744 mock_entry.schema_ref = "local://schemas/test@1.0.0" 745 mock_index.get_dataset.return_value = mock_entry 746 747 # Need to mock decode_schema since sample_type is provided 748 ds = load_dataset( 749 "@local/my-dataset", 750 SimpleTestSample, 751 index=mock_index, 752 split="train", 753 ) 754 755 mock_index.get_dataset.assert_called_once_with("my-dataset") 756 assert ds.url == "s3://bucket/data.tar" 757 758 def test_indexed_path_auto_type_resolution(self): 759 """load_dataset with sample_type=None uses decode_schema.""" 760 mock_index = Mock() 761 mock_index.data_store = None # No data store, so no URL transformation 762 mock_entry = Mock() 763 mock_entry.data_urls = ["s3://bucket/data.tar"] 764 mock_entry.schema_ref = "local://schemas/test@1.0.0" 765 mock_index.get_dataset.return_value = mock_entry 766 mock_index.decode_schema.return_value = SimpleTestSample 767 768 ds = load_dataset( 769 "@local/my-dataset", 770 None, 771 index=mock_index, 772 split="train", 773 ) 774 775 mock_index.decode_schema.assert_called_once_with("local://schemas/test@1.0.0") 776 assert ds.sample_type == SimpleTestSample 777 778 def test_indexed_path_returns_datasetdict_without_split(self): 779 """load_dataset with indexed path returns DatasetDict when split=None.""" 780 mock_index = Mock() 781 mock_index.data_store = None # No data store, so no URL transformation 782 mock_entry = Mock() 783 mock_entry.data_urls = ["s3://bucket/data.tar"] 784 mock_entry.schema_ref = "local://schemas/test@1.0.0" 785 mock_index.get_dataset.return_value = mock_entry 786 787 result = load_dataset( 788 "@local/my-dataset", 789 SimpleTestSample, 790 index=mock_index, 791 ) 792 793 assert isinstance(result, DatasetDict) 794 assert "train" in result 795 796 def test_indexed_path_transforms_urls_via_data_store(self): 797 """load_dataset transforms URLs through data_store.read_url() if available.""" 798 mock_data_store = Mock() 799 mock_data_store.read_url.return_value = "https://r2.example.com/bucket/data.tar" 800 801 mock_index = Mock() 802 mock_index.data_store = mock_data_store 803 mock_entry = Mock() 804 mock_entry.data_urls = ["s3://bucket/data.tar"] 805 mock_entry.schema_ref = "local://schemas/test@1.0.0" 806 mock_index.get_dataset.return_value = mock_entry 807 808 ds = load_dataset( 809 "@local/my-dataset", 810 SimpleTestSample, 811 index=mock_index, 812 split="train", 813 ) 814 815 # Verify read_url was called to transform the URL 816 mock_data_store.read_url.assert_called_once_with("s3://bucket/data.tar") 817 # Verify the transformed URL is used 818 assert ds.url == "https://r2.example.com/bucket/data.tar" 819 820 def test_indexed_path_no_transform_without_data_store(self): 821 """load_dataset uses URLs unchanged when index has no data_store.""" 822 mock_index = Mock() 823 mock_index.data_store = None 824 mock_entry = Mock() 825 mock_entry.data_urls = ["s3://bucket/data.tar"] 826 mock_entry.schema_ref = "local://schemas/test@1.0.0" 827 mock_index.get_dataset.return_value = mock_entry 828 829 ds = load_dataset( 830 "@local/my-dataset", 831 SimpleTestSample, 832 index=mock_index, 833 split="train", 834 ) 835 836 # URL should be unchanged 837 assert ds.url == "s3://bucket/data.tar" 838 839 def test_indexed_path_creates_s3source_with_credentials(self): 840 """load_dataset creates S3Source with credentials when S3DataStore is available.""" 841 from atdata.local import S3DataStore 842 from atdata._sources import S3Source 843 844 # Create a real S3DataStore with mock credentials 845 mock_credentials = { 846 "AWS_ACCESS_KEY_ID": "test-access-key", 847 "AWS_SECRET_ACCESS_KEY": "test-secret-key", 848 "AWS_ENDPOINT": "https://r2.example.com", 849 } 850 851 # Mock the S3DataStore 852 mock_store = Mock(spec=S3DataStore) 853 mock_store.credentials = mock_credentials 854 855 mock_index = Mock() 856 mock_index.data_store = mock_store 857 mock_entry = Mock() 858 mock_entry.data_urls = [ 859 "s3://my-bucket/train-000.tar", 860 "s3://my-bucket/train-001.tar", 861 ] 862 mock_entry.schema_ref = "local://schemas/test@1.0.0" 863 mock_index.get_dataset.return_value = mock_entry 864 865 ds = load_dataset( 866 "@local/my-dataset", 867 SimpleTestSample, 868 index=mock_index, 869 split="train", 870 ) 871 872 # Verify the dataset source is an S3Source with credentials 873 assert isinstance(ds.source, S3Source) 874 assert ds.source.bucket == "my-bucket" 875 assert ds.source.keys == ["train-000.tar", "train-001.tar"] 876 assert ds.source.endpoint == "https://r2.example.com" 877 assert ds.source.access_key == "test-access-key" 878 assert ds.source.secret_key == "test-secret-key"