A loose federation of distributed, typed datasets
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

refactor(tests): improve test_local.py with fixtures and reduced duplication

- Add clean_redis fixture for automatic BasicIndexEntry cleanup
- Add make_simple_dataset() and make_array_dataset() helper functions
- Parametrize s3_env missing field tests (3 tests → 1)
- Convert 15+ tests to use clean_redis fixture
- Strengthen assertions in write_to_redis test
- Fix empty dataset test to expect success (matches WebDataset behavior)
- Remove ~200 lines of duplicate cleanup code

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

+140 -425
.chainlink/issues.db

This is a binary file and will not be displayed.

+140 -425
tests/test_local.py
··· 41 41 42 42 43 43 @pytest.fixture 44 + def clean_redis(redis_connection): 45 + """Provide a Redis connection with automatic BasicIndexEntry cleanup. 46 + 47 + Clears all BasicIndexEntry keys before and after each test to ensure 48 + test isolation. 49 + """ 50 + def _clear_entries(): 51 + for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 52 + redis_connection.delete(key) 53 + 54 + _clear_entries() 55 + yield redis_connection 56 + _clear_entries() 57 + 58 + 59 + @pytest.fixture 44 60 def mock_s3(): 45 61 """Provide a mock S3 environment using moto.""" 46 62 with mock_aws(): ··· 99 115 data: NDArray 100 116 101 117 118 + def make_simple_dataset(tmp_path: Path, num_samples: int = 10, name: str = "test") -> atdata.Dataset: 119 + """Create a SimpleTestSample dataset for testing.""" 120 + dataset_path = tmp_path / f"{name}-dataset-000000.tar" 121 + with wds.TarWriter(str(dataset_path)) as sink: 122 + for i in range(num_samples): 123 + sample = SimpleTestSample(name=f"sample_{i}", value=i * 10) 124 + sink.write(sample.as_wds) 125 + return atdata.Dataset[SimpleTestSample](url=str(dataset_path)) 126 + 127 + 128 + def make_array_dataset(tmp_path: Path, num_samples: int = 3, array_shape: tuple = (10, 10)) -> atdata.Dataset: 129 + """Create an ArrayTestSample dataset for testing.""" 130 + dataset_path = tmp_path / "array-dataset-000000.tar" 131 + with wds.TarWriter(str(dataset_path)) as sink: 132 + for i in range(num_samples): 133 + arr = np.random.randn(*array_shape) 134 + sample = ArrayTestSample(label=f"array_{i}", data=arr) 135 + sink.write(sample.as_wds) 136 + return atdata.Dataset[ArrayTestSample](url=str(dataset_path)) 137 + 138 + 102 139 ## 103 140 # Helper function tests 104 141 ··· 160 197 } 161 198 162 199 163 - def test_s3_env_missing_endpoint(tmp_path): 164 - """Test that loading S3 credentials fails when AWS_ENDPOINT is missing. 165 - 166 - Should raise AssertionError when .env file lacks AWS_ENDPOINT. 167 - """ 168 - env_file = tmp_path / ".env" 169 - env_file.write_text( 170 - "AWS_ACCESS_KEY_ID=minioadmin\n" 171 - "AWS_SECRET_ACCESS_KEY=minioadmin\n" 172 - ) 173 - 174 - with pytest.raises(AssertionError): 175 - atlocal._s3_env(env_file) 176 - 177 - 178 - def test_s3_env_missing_access_key(tmp_path): 179 - """Test that loading S3 credentials fails when AWS_ACCESS_KEY_ID is missing. 180 - 181 - Should raise AssertionError when .env file lacks AWS_ACCESS_KEY_ID. 182 - """ 183 - env_file = tmp_path / ".env" 184 - env_file.write_text( 185 - "AWS_ENDPOINT=http://localhost:9000\n" 186 - "AWS_SECRET_ACCESS_KEY=minioadmin\n" 187 - ) 188 - 189 - with pytest.raises(AssertionError): 190 - atlocal._s3_env(env_file) 191 - 192 - 193 - def test_s3_env_missing_secret_key(tmp_path): 194 - """Test that loading S3 credentials fails when AWS_SECRET_ACCESS_KEY is missing. 200 + @pytest.mark.parametrize("missing_field,env_content", [ 201 + ("AWS_ENDPOINT", "AWS_ACCESS_KEY_ID=minioadmin\nAWS_SECRET_ACCESS_KEY=minioadmin\n"), 202 + ("AWS_ACCESS_KEY_ID", "AWS_ENDPOINT=http://localhost:9000\nAWS_SECRET_ACCESS_KEY=minioadmin\n"), 203 + ("AWS_SECRET_ACCESS_KEY", "AWS_ENDPOINT=http://localhost:9000\nAWS_ACCESS_KEY_ID=minioadmin\n"), 204 + ]) 205 + def test_s3_env_missing_required_field(tmp_path, missing_field, env_content): 206 + """Test that loading S3 credentials fails when a required field is missing. 195 207 196 - Should raise AssertionError when .env file lacks AWS_SECRET_ACCESS_KEY. 208 + Should raise AssertionError when .env file lacks any of the required fields: 209 + AWS_ENDPOINT, AWS_ACCESS_KEY_ID, or AWS_SECRET_ACCESS_KEY. 197 210 """ 198 211 env_file = tmp_path / ".env" 199 - env_file.write_text( 200 - "AWS_ENDPOINT=http://localhost:9000\n" 201 - "AWS_ACCESS_KEY_ID=minioadmin\n" 202 - ) 212 + env_file.write_text(env_content) 203 213 204 214 with pytest.raises(AssertionError): 205 215 atlocal._s3_env(env_file) ··· 283 293 assert str(parsed_uuid) == entry.uuid 284 294 285 295 286 - def test_basic_index_entry_write_to_redis(redis_connection): 296 + def test_basic_index_entry_write_to_redis(clean_redis): 287 297 """Test persisting a BasicIndexEntry to Redis. 288 298 289 299 Should write the entry to Redis as a hash with key 'BasicIndexEntry:{uuid}' 290 - and all fields should be retrievable. 300 + and all fields should be retrievable with correct values. 291 301 """ 292 - redis = redis_connection 293 302 test_uuid = "12345678-1234-1234-1234-123456789abc" 294 303 295 304 entry = atlocal.BasicIndexEntry( ··· 299 308 uuid=test_uuid 300 309 ) 301 310 302 - entry.write_to(redis) 311 + entry.write_to(clean_redis) 303 312 304 - # Retrieve from Redis and verify 305 - stored_data = redis.hgetall(f"BasicIndexEntry:{test_uuid}") 306 - assert stored_data is not None 307 - assert len(stored_data) > 0 308 - 309 - # Clean up 310 - redis.delete(f"BasicIndexEntry:{test_uuid}") 313 + # Retrieve and verify actual stored values 314 + stored_data = atlocal._decode_bytes_dict(clean_redis.hgetall(f"BasicIndexEntry:{test_uuid}")) 315 + assert stored_data['wds_url'] == "s3://bucket/dataset.tar" 316 + assert stored_data['sample_kind'] == "test_module.TestSample" 317 + assert stored_data['metadata_url'] == "s3://bucket/metadata.msgpack" 318 + assert stored_data['uuid'] == test_uuid 311 319 312 320 313 - def test_basic_index_entry_round_trip_redis(redis_connection): 321 + def test_basic_index_entry_round_trip_redis(clean_redis): 314 322 """Test writing and reading a BasicIndexEntry from Redis. 315 323 316 324 Should be able to write an entry to Redis and read it back with all fields 317 325 intact and matching the original values. 318 326 """ 319 - redis = redis_connection 320 327 test_uuid = "12345678-1234-1234-1234-123456789abc" 321 328 322 329 original_entry = atlocal.BasicIndexEntry( ··· 326 333 uuid=test_uuid 327 334 ) 328 335 329 - original_entry.write_to(redis) 336 + original_entry.write_to(clean_redis) 330 337 331 338 # Read back from Redis 332 - stored_data = atlocal._decode_bytes_dict(redis.hgetall(f"BasicIndexEntry:{test_uuid}")) 339 + stored_data = atlocal._decode_bytes_dict(clean_redis.hgetall(f"BasicIndexEntry:{test_uuid}")) 333 340 retrieved_entry = atlocal.BasicIndexEntry(**stored_data) 334 341 335 342 assert retrieved_entry.wds_url == original_entry.wds_url 336 343 assert retrieved_entry.sample_kind == original_entry.sample_kind 337 344 assert retrieved_entry.metadata_url == original_entry.metadata_url 338 345 assert retrieved_entry.uuid == original_entry.uuid 339 - 340 - # Clean up 341 - redis.delete(f"BasicIndexEntry:{test_uuid}") 342 346 343 347 344 348 ## ··· 378 382 assert isinstance(index._redis, Redis) 379 383 380 384 381 - def test_index_add_entry_without_uuid(redis_connection): 385 + def test_index_add_entry_without_uuid(clean_redis): 382 386 """Test adding a dataset entry to the index without specifying UUID. 383 387 384 388 Should create a BasicIndexEntry with auto-generated UUID and persist it to Redis. 385 389 """ 386 - redis = redis_connection 387 - index = atlocal.Index(redis=redis) 390 + index = atlocal.Index(redis=clean_redis) 388 391 389 392 ds = atdata.Dataset[SimpleTestSample]( 390 393 url="s3://bucket/dataset.tar", ··· 399 402 assert entry.metadata_url == ds.metadata_url 400 403 401 404 # Verify it was persisted to Redis 402 - stored_data = redis.hgetall(f"BasicIndexEntry:{entry.uuid}") 405 + stored_data = clean_redis.hgetall(f"BasicIndexEntry:{entry.uuid}") 403 406 assert len(stored_data) > 0 404 407 405 - # Clean up 406 - redis.delete(f"BasicIndexEntry:{entry.uuid}") 407 408 408 - 409 - def test_index_add_entry_with_uuid(redis_connection): 409 + def test_index_add_entry_with_uuid(clean_redis): 410 410 """Test adding a dataset entry to the index with a specified UUID. 411 411 412 412 Should create a BasicIndexEntry with the provided UUID and persist it to Redis. 413 413 """ 414 - redis = redis_connection 415 - index = atlocal.Index(redis=redis) 414 + index = atlocal.Index(redis=clean_redis) 416 415 test_uuid = "12345678-1234-1234-1234-123456789abc" 417 416 418 417 ds = atdata.Dataset[SimpleTestSample]( ··· 427 426 assert entry.sample_kind == f"{SimpleTestSample.__module__}.SimpleTestSample" 428 427 assert entry.metadata_url == ds.metadata_url 429 428 430 - # Clean up 431 - redis.delete(f"BasicIndexEntry:{test_uuid}") 432 429 433 - 434 - def test_index_entries_generator_empty(redis_connection): 430 + def test_index_entries_generator_empty(clean_redis): 435 431 """Test iterating over entries in an empty index. 436 432 437 433 Should yield no entries when the index is empty. 438 434 """ 439 - redis = redis_connection 440 - # Clear any existing BasicIndexEntry keys 441 - for key in redis.scan_iter(match='BasicIndexEntry:*'): 442 - redis.delete(key) 443 - 444 - index = atlocal.Index(redis=redis) 435 + index = atlocal.Index(redis=clean_redis) 445 436 446 437 entries = list(index.entries) 447 438 assert len(entries) == 0 448 439 449 440 450 - def test_index_entries_generator_multiple(redis_connection): 441 + def test_index_entries_generator_multiple(clean_redis): 451 442 """Test iterating over multiple entries in the index. 452 443 453 444 Should yield all BasicIndexEntry objects that have been added to the index. 454 445 """ 455 - redis = redis_connection 456 - # Clear any existing BasicIndexEntry keys 457 - for key in redis.scan_iter(match='BasicIndexEntry:*'): 458 - redis.delete(key) 459 - 460 - index = atlocal.Index(redis=redis) 446 + index = atlocal.Index(redis=clean_redis) 461 447 462 448 ds1 = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset1.tar") 463 449 ds2 = atdata.Dataset[ArrayTestSample](url="s3://bucket/dataset2.tar") ··· 472 458 assert entry1.uuid in uuids 473 459 assert entry2.uuid in uuids 474 460 475 - # Clean up 476 - redis.delete(f"BasicIndexEntry:{entry1.uuid}") 477 - redis.delete(f"BasicIndexEntry:{entry2.uuid}") 478 461 479 - 480 - def test_index_all_entries_empty(redis_connection): 462 + def test_index_all_entries_empty(clean_redis): 481 463 """Test getting all entries as a list from an empty index. 482 464 483 465 Should return an empty list when no entries exist. 484 466 """ 485 - redis = redis_connection 486 - # Clear any existing BasicIndexEntry keys 487 - for key in redis.scan_iter(match='BasicIndexEntry:*'): 488 - redis.delete(key) 489 - 490 - index = atlocal.Index(redis=redis) 467 + index = atlocal.Index(redis=clean_redis) 491 468 492 469 entries = index.all_entries 493 470 assert isinstance(entries, list) 494 471 assert len(entries) == 0 495 472 496 473 497 - def test_index_all_entries_multiple(redis_connection): 474 + def test_index_all_entries_multiple(clean_redis): 498 475 """Test getting all entries as a list with multiple entries. 499 476 500 477 Should return a list containing all BasicIndexEntry objects in the index. 501 478 """ 502 - redis = redis_connection 503 - # Clear any existing BasicIndexEntry keys 504 - for key in redis.scan_iter(match='BasicIndexEntry:*'): 505 - redis.delete(key) 506 - 507 - index = atlocal.Index(redis=redis) 479 + index = atlocal.Index(redis=clean_redis) 508 480 509 481 ds1 = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset1.tar") 510 482 ds2 = atdata.Dataset[ArrayTestSample](url="s3://bucket/dataset2.tar") ··· 516 488 assert isinstance(entries, list) 517 489 assert len(entries) == 2 518 490 519 - # Clean up 520 - redis.delete(f"BasicIndexEntry:{entry1.uuid}") 521 - redis.delete(f"BasicIndexEntry:{entry2.uuid}") 522 491 523 - 524 - def test_index_entries_filtering(redis_connection): 492 + def test_index_entries_filtering(clean_redis): 525 493 """Test that index only returns BasicIndexEntry objects. 526 494 527 495 Should only iterate over keys matching 'BasicIndexEntry:*' pattern and 528 496 ignore any other Redis keys. 529 497 """ 530 - redis = redis_connection 531 - # Clear any existing BasicIndexEntry keys 532 - for key in redis.scan_iter(match='BasicIndexEntry:*'): 533 - redis.delete(key) 534 - 535 - index = atlocal.Index(redis=redis) 498 + index = atlocal.Index(redis=clean_redis) 536 499 537 500 # Add a BasicIndexEntry 538 501 ds = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset.tar") 539 502 entry = index.add_entry(ds) 540 503 541 504 # Add some other Redis keys that should be ignored 542 - redis.set("other_key", "value") 543 - redis.hset("other_hash", "field", "value") 505 + clean_redis.set("other_key", "value") 506 + clean_redis.hset("other_hash", "field", "value") 544 507 545 508 entries = list(index.entries) 546 509 assert len(entries) == 1 547 510 assert entries[0].uuid == entry.uuid 548 511 549 - # Clean up 550 - redis.delete(f"BasicIndexEntry:{entry.uuid}") 551 - redis.delete("other_key") 552 - redis.delete("other_hash") 512 + # Clean up non-BasicIndexEntry keys (fixture only cleans BasicIndexEntry:*) 513 + clean_redis.delete("other_key") 514 + clean_redis.delete("other_hash") 553 515 554 516 555 517 ## ··· 669 631 repo.insert(ds) 670 632 671 633 672 - def test_repo_insert_single_shard(mock_s3, redis_connection, sample_dataset): 634 + def test_repo_insert_single_shard(mock_s3, clean_redis, sample_dataset): 673 635 """Test inserting a small dataset that fits in a single shard. 674 636 675 637 Should write the dataset to S3, create metadata, add index entry, and return 676 638 a new Dataset pointing to the stored copy with correct URL format. 677 639 """ 678 - # Clear Redis 679 - for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 680 - redis_connection.delete(key) 681 - 682 - # Create repo with mock S3 683 640 repo = atlocal.Repo( 684 641 s3_credentials=mock_s3['credentials'], 685 642 hive_path=mock_s3['hive_path'], 686 - redis=redis_connection 643 + redis=clean_redis 687 644 ) 688 645 689 - # Insert dataset 690 646 entry, new_ds = repo.insert(sample_dataset, maxcount=100) 691 647 692 - # Verify entry was created 693 648 assert entry.uuid is not None 694 649 assert entry.wds_url is not None 695 650 assert entry.sample_kind == f"{SimpleTestSample.__module__}.SimpleTestSample" 696 - 697 - # Verify index entry exists 698 651 assert len(repo.index.all_entries) == 1 699 - 700 - # Verify URL format is correct (single shard) 701 652 assert '.tar' in new_ds.url 702 653 assert new_ds.url.startswith(mock_s3['hive_path']) 703 654 704 - # Clean up 705 - redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 706 655 707 - 708 - def test_repo_insert_multiple_shards(mock_s3, redis_connection, tmp_path): 656 + def test_repo_insert_multiple_shards(mock_s3, clean_redis, tmp_path): 709 657 """Test inserting a large dataset that spans multiple shards. 710 658 711 659 Should write multiple tar files to S3, use brace notation in returned URL, 712 660 and correctly format the shard range. 713 661 """ 714 - # Clear Redis 715 - for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 716 - redis_connection.delete(key) 717 - 718 - # Create a larger dataset with multiple samples 719 - dataset_path = tmp_path / "large-dataset-000000.tar" 720 - with wds.TarWriter(str(dataset_path)) as sink: 721 - for i in range(50): # More samples to force multiple shards 722 - sample = SimpleTestSample(name=f"sample_{i}", value=i * 10) 723 - sink.write(sample.as_wds) 724 - 725 - ds = atdata.Dataset[SimpleTestSample](url=str(dataset_path)) 726 - 727 - # Create repo with mock S3 662 + ds = make_simple_dataset(tmp_path, num_samples=50, name="large") 728 663 repo = atlocal.Repo( 729 664 s3_credentials=mock_s3['credentials'], 730 665 hive_path=mock_s3['hive_path'], 731 - redis=redis_connection 666 + redis=clean_redis 732 667 ) 733 668 734 - # Insert dataset with small maxcount to force multiple shards 735 669 entry, new_ds = repo.insert(ds, maxcount=10) 736 670 737 - # Verify entry was created 738 671 assert entry.uuid is not None 739 672 assert entry.wds_url is not None 740 - 741 - # Verify URL uses brace notation for multiple shards 742 673 assert '{' in new_ds.url and '}' in new_ds.url 743 674 744 - # Clean up 745 - redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 746 675 747 - 748 - def test_repo_insert_with_metadata(mock_s3, redis_connection, tmp_path): 676 + def test_repo_insert_with_metadata(mock_s3, clean_redis, tmp_path): 749 677 """Test inserting a dataset with metadata. 750 678 751 679 Should write metadata as msgpack to S3 and include metadata_url in the 752 680 returned Dataset and BasicIndexEntry. 753 681 """ 754 - # Clear Redis 755 - for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 756 - redis_connection.delete(key) 757 - 758 - # Create dataset with metadata 759 - dataset_path = tmp_path / "test-dataset-000000.tar" 760 - with wds.TarWriter(str(dataset_path)) as sink: 761 - for i in range(5): 762 - sample = SimpleTestSample(name=f"sample_{i}", value=i * 10) 763 - sink.write(sample.as_wds) 764 - 765 - # Set metadata internally 766 - ds = atdata.Dataset[SimpleTestSample](url=str(dataset_path)) 682 + ds = make_simple_dataset(tmp_path, num_samples=5) 767 683 ds._metadata = {"description": "test dataset", "version": "1.0"} 768 684 769 - # Create repo with mock S3 770 685 repo = atlocal.Repo( 771 686 s3_credentials=mock_s3['credentials'], 772 687 hive_path=mock_s3['hive_path'], 773 - redis=redis_connection 688 + redis=clean_redis 774 689 ) 775 690 776 - # Insert dataset 777 691 entry, new_ds = repo.insert(ds, maxcount=100) 778 692 779 - # Verify metadata_url is set 780 693 assert entry.metadata_url is not None 781 694 assert new_ds.metadata_url is not None 782 695 assert 'metadata' in entry.metadata_url 783 696 784 - # Clean up 785 - redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 786 - 787 697 788 - def test_repo_insert_without_metadata(mock_s3, redis_connection, tmp_path): 698 + def test_repo_insert_without_metadata(mock_s3, clean_redis, tmp_path): 789 699 """Test inserting a dataset without metadata. 790 700 791 701 Should handle None metadata gracefully and not write a metadata file. 792 702 """ 793 - # Clear Redis 794 - for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 795 - redis_connection.delete(key) 796 - 797 - # Create dataset without metadata 798 - dataset_path = tmp_path / "test-dataset-000000.tar" 799 - with wds.TarWriter(str(dataset_path)) as sink: 800 - for i in range(5): 801 - sample = SimpleTestSample(name=f"sample_{i}", value=i * 10) 802 - sink.write(sample.as_wds) 803 - 804 - ds = atdata.Dataset[SimpleTestSample](url=str(dataset_path)) 805 - 806 - # Create repo with mock S3 703 + ds = make_simple_dataset(tmp_path, num_samples=5) 807 704 repo = atlocal.Repo( 808 705 s3_credentials=mock_s3['credentials'], 809 706 hive_path=mock_s3['hive_path'], 810 - redis=redis_connection 707 + redis=clean_redis 811 708 ) 812 709 813 - # Insert dataset 814 710 entry, new_ds = repo.insert(ds, maxcount=100) 815 711 816 - # Verify entry was created and index works 817 712 assert entry.uuid is not None 818 713 assert len(repo.index.all_entries) == 1 819 714 820 - # Clean up 821 - redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 822 715 823 - 824 - def test_repo_insert_cache_local_false(mock_s3, redis_connection, sample_dataset): 716 + def test_repo_insert_cache_local_false(mock_s3, clean_redis, sample_dataset): 825 717 """Test inserting with cache_local=False (direct S3 write). 826 718 827 719 Should write tar shards directly to S3 without local caching. 828 720 """ 829 - # Clear Redis 830 - for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 831 - redis_connection.delete(key) 832 - 833 - # Create repo with mock S3 834 721 repo = atlocal.Repo( 835 722 s3_credentials=mock_s3['credentials'], 836 723 hive_path=mock_s3['hive_path'], 837 - redis=redis_connection 724 + redis=clean_redis 838 725 ) 839 726 840 - # Insert dataset with cache_local=False 841 727 entry, new_ds = repo.insert(sample_dataset, cache_local=False, maxcount=100) 842 728 843 - # Verify entry was created 844 729 assert entry.uuid is not None 845 730 assert entry.wds_url is not None 846 731 847 - # Clean up 848 - redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 849 732 850 - 851 - def test_repo_insert_cache_local_true(mock_s3, redis_connection, sample_dataset): 733 + def test_repo_insert_cache_local_true(mock_s3, clean_redis, sample_dataset): 852 734 """Test inserting with cache_local=True (local cache then copy). 853 735 854 736 Should write to temporary local storage first, then copy to S3, and clean up 855 737 local cache files after copying. 856 738 """ 857 - # Create repository 858 739 repo = atlocal.Repo( 859 740 s3_credentials=mock_s3['credentials'], 860 741 hive_path=mock_s3['hive_path'], 861 - redis=redis_connection 742 + redis=clean_redis 862 743 ) 863 744 864 - # Insert dataset with cache_local=True 865 745 entry, new_ds = repo.insert(sample_dataset, cache_local=True, maxcount=100) 866 746 867 - # Verify entry was created 868 747 assert entry.uuid is not None 869 748 assert entry.wds_url is not None 870 749 871 - # Clean up 872 - redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 873 - 874 750 875 - def test_repo_insert_creates_index_entry(mock_s3, redis_connection, sample_dataset): 751 + def test_repo_insert_creates_index_entry(mock_s3, clean_redis, sample_dataset): 876 752 """Test that insert() creates a valid index entry. 877 753 878 754 Should add a BasicIndexEntry to the index with correct wds_url, sample_kind, 879 755 metadata_url, and UUID. 880 756 """ 881 - # Clear Redis 882 - for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 883 - redis_connection.delete(key) 884 - 885 - # Create repo with mock S3 886 757 repo = atlocal.Repo( 887 758 s3_credentials=mock_s3['credentials'], 888 759 hive_path=mock_s3['hive_path'], 889 - redis=redis_connection 760 + redis=clean_redis 890 761 ) 891 762 892 - # Insert dataset 893 763 entry, new_ds = repo.insert(sample_dataset, maxcount=100) 894 764 895 - # Verify index entry was created with correct fields 896 765 assert entry.uuid is not None 897 766 assert entry.wds_url == new_ds.url 898 767 assert entry.sample_kind == f"{SimpleTestSample.__module__}.SimpleTestSample" 899 768 900 - # Verify it's in the index 901 769 all_entries = repo.index.all_entries 902 770 assert len(all_entries) == 1 903 771 assert all_entries[0].uuid == entry.uuid 904 772 905 - # Clean up 906 - redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 907 773 908 - 909 - def test_repo_insert_uuid_generation(mock_s3, redis_connection, sample_dataset): 774 + def test_repo_insert_uuid_generation(mock_s3, clean_redis, sample_dataset): 910 775 """Test that insert() generates a unique UUID for each dataset. 911 776 912 777 Should create a new UUID for the dataset and use it consistently in filenames, 913 778 index entry, and returned Dataset. 914 779 """ 915 - # Clear Redis 916 - for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 917 - redis_connection.delete(key) 918 - 919 - # Create repo with mock S3 920 780 repo = atlocal.Repo( 921 781 s3_credentials=mock_s3['credentials'], 922 782 hive_path=mock_s3['hive_path'], 923 - redis=redis_connection 783 + redis=clean_redis 924 784 ) 925 785 926 - # Insert two datasets and verify they get different UUIDs 927 786 entry1, new_ds1 = repo.insert(sample_dataset, maxcount=100) 928 787 entry2, new_ds2 = repo.insert(sample_dataset, maxcount=100) 929 788 930 - # Verify UUIDs are different 931 789 assert entry1.uuid != entry2.uuid 932 - 933 - # Verify UUIDs are used in URLs 934 790 assert entry1.uuid in new_ds1.url 935 791 assert entry2.uuid in new_ds2.url 936 - 937 - # Verify both are in index 938 792 assert len(repo.index.all_entries) == 2 939 793 940 - # Clean up 941 - redis_connection.delete(f"BasicIndexEntry:{entry1.uuid}") 942 - redis_connection.delete(f"BasicIndexEntry:{entry2.uuid}") 943 794 944 - 945 - def test_repo_insert_empty_dataset(mock_s3, redis_connection, tmp_path): 946 - """Test that inserting an empty dataset raises RuntimeError. 795 + def test_repo_insert_empty_dataset(mock_s3, clean_redis, tmp_path): 796 + """Test inserting an empty dataset. 947 797 948 - Should raise RuntimeError with message about not writing any shards when 949 - dataset is empty. 798 + WebDataset's ShardWriter creates a shard file even with no samples, 799 + so empty datasets succeed (creating an empty shard) rather than raising 800 + RuntimeError. 950 801 """ 951 - # Clear Redis 952 - for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 953 - redis_connection.delete(key) 954 - 955 - # Create empty dataset 956 802 dataset_path = tmp_path / "empty-dataset-000000.tar" 957 803 with wds.TarWriter(str(dataset_path)) as sink: 958 804 pass # Write no samples 959 805 960 806 ds = atdata.Dataset[SimpleTestSample](url=str(dataset_path)) 961 - 962 - # Create repo with mock S3 963 807 repo = atlocal.Repo( 964 808 s3_credentials=mock_s3['credentials'], 965 809 hive_path=mock_s3['hive_path'], 966 - redis=redis_connection 810 + redis=clean_redis 967 811 ) 968 812 969 - # Note: Empty datasets may still create a shard file, so RuntimeError may not be raised 970 - # This test documents the actual behavior 971 - try: 972 - entry, new_ds = repo.insert(ds, maxcount=100) 973 - # If it succeeds, verify the entry was created 974 - assert entry.uuid is not None 975 - redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 976 - except RuntimeError as e: 977 - # If it raises RuntimeError, verify the message 978 - assert "did not write any shards" in str(e) 813 + # Empty datasets succeed because WebDataset creates a shard file regardless 814 + entry, new_ds = repo.insert(ds, maxcount=100) 815 + assert entry.uuid is not None 816 + assert '.tar' in new_ds.url 979 817 980 818 981 - def test_repo_insert_preserves_sample_type(mock_s3, redis_connection, sample_dataset): 819 + def test_repo_insert_preserves_sample_type(mock_s3, clean_redis, sample_dataset): 982 820 """Test that the returned Dataset preserves the original sample type. 983 821 984 822 Should return a Dataset[T] with the same sample type as the input dataset. 985 823 """ 986 - # Clear Redis 987 - for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 988 - redis_connection.delete(key) 989 - 990 - # Create repo with mock S3 991 824 repo = atlocal.Repo( 992 825 s3_credentials=mock_s3['credentials'], 993 826 hive_path=mock_s3['hive_path'], 994 - redis=redis_connection 827 + redis=clean_redis 995 828 ) 996 829 997 - # Insert dataset 998 830 entry, new_ds = repo.insert(sample_dataset, maxcount=100) 999 831 1000 - # Verify sample type is preserved 1001 832 assert new_ds.sample_type == SimpleTestSample 1002 833 assert entry.sample_kind == f"{SimpleTestSample.__module__}.SimpleTestSample" 1003 834 1004 - # Clean up 1005 - redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 1006 835 1007 - 1008 - def test_repo_insert_round_trip(mock_s3, redis_connection, tmp_path): 836 + def test_repo_insert_round_trip(mock_s3, clean_redis, tmp_path): 1009 837 """Test full round-trip: insert dataset, then load and compare samples. 1010 838 1011 839 Should be able to insert a dataset and then load it back from the returned ··· 1014 842 pytest.skip("Reading from moto-mocked S3 requires additional s3fs/WebDataset configuration") 1015 843 1016 844 1017 - def test_repo_insert_with_shard_writer_kwargs(mock_s3, redis_connection, tmp_path): 845 + def test_repo_insert_with_shard_writer_kwargs(mock_s3, clean_redis, tmp_path): 1018 846 """Test that insert() passes additional kwargs to ShardWriter. 1019 847 1020 848 Should forward kwargs like maxcount, maxsize to the underlying ShardWriter. 1021 849 """ 1022 - # Clear Redis 1023 - for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 1024 - redis_connection.delete(key) 1025 - 1026 - # Create dataset with many samples 1027 - dataset_path = tmp_path / "large-dataset-000000.tar" 1028 - with wds.TarWriter(str(dataset_path)) as sink: 1029 - for i in range(30): 1030 - sample = SimpleTestSample(name=f"sample_{i}", value=i * 10) 1031 - sink.write(sample.as_wds) 1032 - 1033 - ds = atdata.Dataset[SimpleTestSample](url=str(dataset_path)) 1034 - 1035 - # Create repo with mock S3 850 + ds = make_simple_dataset(tmp_path, num_samples=30, name="large") 1036 851 repo = atlocal.Repo( 1037 852 s3_credentials=mock_s3['credentials'], 1038 853 hive_path=mock_s3['hive_path'], 1039 - redis=redis_connection 854 + redis=clean_redis 1040 855 ) 1041 856 1042 - # Insert dataset with small maxcount to force multiple shards 1043 857 entry, new_ds = repo.insert(ds, maxcount=5) 1044 858 1045 - # Verify multiple shards were created (indicated by brace notation) 1046 859 assert '{' in new_ds.url and '}' in new_ds.url 1047 860 1048 - # Clean up 1049 - redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 1050 - 1051 861 1052 - def test_repo_insert_numpy_arrays(mock_s3, redis_connection, tmp_path): 862 + def test_repo_insert_numpy_arrays(mock_s3, clean_redis, tmp_path): 1053 863 """Test inserting a dataset containing samples with numpy arrays. 1054 864 1055 - Should correctly serialize and store numpy arrays, and the returned dataset 1056 - should be able to deserialize them. 865 + Should correctly serialize and store numpy arrays. 1057 866 """ 1058 - # Clear Redis 1059 - for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 1060 - redis_connection.delete(key) 1061 - 1062 - # Create dataset with numpy arrays 1063 - dataset_path = tmp_path / "numpy-dataset-000000.tar" 1064 - with wds.TarWriter(str(dataset_path)) as sink: 1065 - for i in range(3): 1066 - arr = np.random.randn(10, 10) 1067 - sample = ArrayTestSample(label=f"array_{i}", data=arr) 1068 - sink.write(sample.as_wds) 1069 - 1070 - ds = atdata.Dataset[ArrayTestSample](url=str(dataset_path)) 1071 - 1072 - # Create repo with mock S3 867 + ds = make_array_dataset(tmp_path, num_samples=3, array_shape=(10, 10)) 1073 868 repo = atlocal.Repo( 1074 869 s3_credentials=mock_s3['credentials'], 1075 870 hive_path=mock_s3['hive_path'], 1076 - redis=redis_connection 871 + redis=clean_redis 1077 872 ) 1078 873 1079 - # Insert dataset - just verify it works with numpy arrays 1080 874 entry, new_ds = repo.insert(ds, maxcount=100) 1081 875 1082 - # Verify the insert succeeded 1083 876 assert entry.uuid is not None 1084 877 assert entry.sample_kind == f"{ArrayTestSample.__module__}.ArrayTestSample" 1085 878 1086 - # Clean up 1087 - redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 1088 - 1089 879 1090 880 ## 1091 881 # Integration tests 1092 882 1093 - def test_repo_index_integration(mock_s3, redis_connection, sample_dataset): 883 + def test_repo_index_integration(mock_s3, clean_redis, sample_dataset): 1094 884 """Test that Repo and Index work together correctly. 1095 885 1096 886 Should be able to insert datasets into Repo and retrieve their entries 1097 887 from the Index. 1098 888 """ 1099 - # Clear Redis 1100 - for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 1101 - redis_connection.delete(key) 1102 - 1103 - # Create repo with mock S3 1104 889 repo = atlocal.Repo( 1105 890 s3_credentials=mock_s3['credentials'], 1106 891 hive_path=mock_s3['hive_path'], 1107 - redis=redis_connection 892 + redis=clean_redis 1108 893 ) 1109 894 1110 - # Insert dataset 1111 895 entry, new_ds = repo.insert(sample_dataset, maxcount=100) 1112 896 1113 - # Verify we can retrieve from index 1114 897 all_entries = repo.index.all_entries 1115 898 assert len(all_entries) == 1 1116 899 assert all_entries[0].uuid == entry.uuid 1117 900 assert all_entries[0].wds_url == entry.wds_url 1118 901 1119 - # Clean up 1120 - redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 1121 902 1122 - 1123 - def test_multiple_datasets_same_type(mock_s3, redis_connection, sample_dataset): 903 + def test_multiple_datasets_same_type(mock_s3, clean_redis, sample_dataset): 1124 904 """Test inserting multiple datasets of the same sample type. 1125 905 1126 906 Should create separate entries with different UUIDs and all should be 1127 907 retrievable from the index. 1128 908 """ 1129 - # Clear Redis 1130 - for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 1131 - redis_connection.delete(key) 1132 - 1133 - # Create repo with mock S3 1134 909 repo = atlocal.Repo( 1135 910 s3_credentials=mock_s3['credentials'], 1136 911 hive_path=mock_s3['hive_path'], 1137 - redis=redis_connection 912 + redis=clean_redis 1138 913 ) 1139 914 1140 - # Insert same dataset multiple times 1141 - entry1, new_ds1 = repo.insert(sample_dataset, maxcount=100) 1142 - entry2, new_ds2 = repo.insert(sample_dataset, maxcount=100) 1143 - entry3, new_ds3 = repo.insert(sample_dataset, maxcount=100) 915 + entry1, _ = repo.insert(sample_dataset, maxcount=100) 916 + entry2, _ = repo.insert(sample_dataset, maxcount=100) 917 + entry3, _ = repo.insert(sample_dataset, maxcount=100) 1144 918 1145 - # Verify all have different UUIDs 1146 919 uuids = {entry1.uuid, entry2.uuid, entry3.uuid} 1147 920 assert len(uuids) == 3 1148 921 1149 - # Verify all are in index 1150 922 all_entries = repo.index.all_entries 1151 923 assert len(all_entries) == 3 1152 924 1153 - # Verify all have same sample_kind 1154 925 for entry in all_entries: 1155 926 assert entry.sample_kind == f"{SimpleTestSample.__module__}.SimpleTestSample" 1156 927 1157 - # Clean up 1158 - redis_connection.delete(f"BasicIndexEntry:{entry1.uuid}") 1159 - redis_connection.delete(f"BasicIndexEntry:{entry2.uuid}") 1160 - redis_connection.delete(f"BasicIndexEntry:{entry3.uuid}") 1161 928 1162 - 1163 - def test_multiple_datasets_different_types(mock_s3, redis_connection, tmp_path): 929 + def test_multiple_datasets_different_types(mock_s3, clean_redis, tmp_path): 1164 930 """Test inserting datasets with different sample types. 1165 931 1166 932 Should correctly track sample_kind for each dataset and create distinct 1167 933 index entries. 1168 934 """ 1169 - # Clear Redis 1170 - for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 1171 - redis_connection.delete(key) 935 + simple_ds = make_simple_dataset(tmp_path, num_samples=3, name="simple") 936 + array_ds = make_array_dataset(tmp_path, num_samples=3, array_shape=(5, 5)) 1172 937 1173 - # Create dataset with SimpleTestSample 1174 - simple_path = tmp_path / "simple-dataset-000000.tar" 1175 - with wds.TarWriter(str(simple_path)) as sink: 1176 - for i in range(3): 1177 - sample = SimpleTestSample(name=f"sample_{i}", value=i * 10) 1178 - sink.write(sample.as_wds) 1179 - simple_ds = atdata.Dataset[SimpleTestSample](url=str(simple_path)) 1180 - 1181 - # Create dataset with ArrayTestSample 1182 - array_path = tmp_path / "array-dataset-000000.tar" 1183 - with wds.TarWriter(str(array_path)) as sink: 1184 - for i in range(3): 1185 - sample = ArrayTestSample(label=f"array_{i}", data=np.random.randn(5, 5)) 1186 - sink.write(sample.as_wds) 1187 - array_ds = atdata.Dataset[ArrayTestSample](url=str(array_path)) 1188 - 1189 - # Create repo with mock S3 1190 938 repo = atlocal.Repo( 1191 939 s3_credentials=mock_s3['credentials'], 1192 940 hive_path=mock_s3['hive_path'], 1193 - redis=redis_connection 941 + redis=clean_redis 1194 942 ) 1195 943 1196 - # Insert both datasets 1197 - entry1, new_ds1 = repo.insert(simple_ds, maxcount=100) 1198 - entry2, new_ds2 = repo.insert(array_ds, maxcount=100) 944 + entry1, _ = repo.insert(simple_ds, maxcount=100) 945 + entry2, _ = repo.insert(array_ds, maxcount=100) 1199 946 1200 - # Verify different sample_kind values 1201 947 assert entry1.sample_kind == f"{SimpleTestSample.__module__}.SimpleTestSample" 1202 948 assert entry2.sample_kind == f"{ArrayTestSample.__module__}.ArrayTestSample" 1203 949 assert entry1.sample_kind != entry2.sample_kind 950 + assert len(repo.index.all_entries) == 2 1204 951 1205 - # Verify both are in index 1206 - all_entries = repo.index.all_entries 1207 - assert len(all_entries) == 2 1208 952 1209 - # Clean up 1210 - redis_connection.delete(f"BasicIndexEntry:{entry1.uuid}") 1211 - redis_connection.delete(f"BasicIndexEntry:{entry2.uuid}") 1212 - 1213 - 1214 - def test_index_persistence_across_instances(redis_connection): 953 + def test_index_persistence_across_instances(clean_redis): 1215 954 """Test that index entries persist across Index instance recreations. 1216 955 1217 956 Should be able to create an Index, add entries, create a new Index instance 1218 957 with the same Redis connection, and retrieve the same entries. 1219 958 """ 1220 - redis = redis_connection 1221 - # Clear any existing BasicIndexEntry keys 1222 - for key in redis.scan_iter(match='BasicIndexEntry:*'): 1223 - redis.delete(key) 1224 - 1225 - # Create first index instance and add entry 1226 - index1 = atlocal.Index(redis=redis) 959 + index1 = atlocal.Index(redis=clean_redis) 1227 960 ds = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset.tar") 1228 961 entry1 = index1.add_entry(ds) 1229 962 1230 - # Create new index instance with same Redis connection 1231 - index2 = atlocal.Index(redis=redis) 963 + index2 = atlocal.Index(redis=clean_redis) 1232 964 entries = index2.all_entries 1233 965 1234 966 assert len(entries) == 1 1235 967 assert entries[0].uuid == entry1.uuid 1236 968 assert entries[0].wds_url == entry1.wds_url 1237 969 1238 - # Clean up 1239 - redis.delete(f"BasicIndexEntry:{entry1.uuid}") 1240 - 1241 970 1242 - def test_concurrent_index_access(redis_connection): 971 + def test_concurrent_index_access(clean_redis): 1243 972 """Test that multiple Index instances can access the same Redis store. 1244 973 1245 974 Should handle concurrent access to the same Redis index from multiple 1246 975 Index instances. 1247 976 """ 1248 - redis = redis_connection 1249 - # Clear any existing BasicIndexEntry keys 1250 - for key in redis.scan_iter(match='BasicIndexEntry:*'): 1251 - redis.delete(key) 1252 - 1253 - # Create multiple index instances 1254 - index1 = atlocal.Index(redis=redis) 1255 - index2 = atlocal.Index(redis=redis) 977 + index1 = atlocal.Index(redis=clean_redis) 978 + index2 = atlocal.Index(redis=clean_redis) 1256 979 1257 - # Add entries from different instances 1258 980 ds1 = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset1.tar") 1259 981 ds2 = atdata.Dataset[ArrayTestSample](url="s3://bucket/dataset2.tar") 1260 982 1261 983 entry1 = index1.add_entry(ds1) 1262 984 entry2 = index2.add_entry(ds2) 1263 985 1264 - # Both instances should see both entries 1265 986 entries1 = index1.all_entries 1266 987 entries2 = index2.all_entries 1267 988 ··· 1271 992 uuids1 = {e.uuid for e in entries1} 1272 993 uuids2 = {e.uuid for e in entries2} 1273 994 1274 - assert entry1.uuid in uuids1 1275 - assert entry2.uuid in uuids1 1276 - assert entry1.uuid in uuids2 1277 - assert entry2.uuid in uuids2 1278 - 1279 - # Clean up 1280 - redis.delete(f"BasicIndexEntry:{entry1.uuid}") 1281 - redis.delete(f"BasicIndexEntry:{entry2.uuid}") 995 + assert entry1.uuid in uuids1 and entry2.uuid in uuids1 996 + assert entry1.uuid in uuids2 and entry2.uuid in uuids2