A loose federation of distributed, typed datasets
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 939 lines 27 kB view raw
1"""Test dataaset functionality.""" 2 3## 4# Imports 5 6# Tests 7import pytest 8 9# System 10from dataclasses import dataclass 11 12# External 13import numpy as np 14import webdataset as wds 15 16# Local 17import atdata 18import atdata.dataset as atds 19 20# Typing 21from numpy.typing import NDArray 22from typing import ( 23 Type, 24) 25 26 27## 28# Sample test cases 29 30 31@dataclass 32class BasicTestSample(atdata.PackableSample): 33 name: str 34 position: int 35 value: float 36 37 38@dataclass 39class NumpyTestSample(atdata.PackableSample): 40 label: int 41 image: NDArray 42 43 44@atdata.packable 45class BasicTestSampleDecorated: 46 name: str 47 position: int 48 value: float 49 50 51@atdata.packable 52class NumpyTestSampleDecorated: 53 label: int 54 image: NDArray 55 56 57@atdata.packable 58class NumpyOptionalSampleDecorated: 59 label: int 60 image: NDArray 61 embeddings: NDArray | None = None 62 63 64test_cases = [ 65 { 66 "SampleType": BasicTestSample, 67 "sample_data": { 68 "name": "Hello, world!", 69 "position": 42, 70 "value": 1024.768, 71 }, 72 "sample_wds_stem": "basic_test", 73 "test_parquet": True, 74 }, 75 { 76 "SampleType": NumpyTestSample, 77 "sample_data": { 78 "label": 9_001, 79 "image": np.random.randn(1024, 1024), 80 }, 81 "sample_wds_stem": "numpy_test", 82 "test_parquet": False, 83 }, 84 { 85 "SampleType": BasicTestSampleDecorated, 86 "sample_data": { 87 "name": "Hello, world!", 88 "position": 42, 89 "value": 1024.768, 90 }, 91 "sample_wds_stem": "basic_test_decorated", 92 "test_parquet": True, 93 }, 94 { 95 "SampleType": NumpyTestSampleDecorated, 96 "sample_data": { 97 "label": 9_001, 98 "image": np.random.randn(1024, 1024), 99 }, 100 "sample_wds_stem": "numpy_test_decorated", 101 "test_parquet": False, 102 }, 103 { 104 "SampleType": NumpyOptionalSampleDecorated, 105 "sample_data": { 106 "label": 9_001, 107 "image": np.random.randn(1024, 1024), 108 "embeddings": np.random.randn(512), 109 }, 110 "sample_wds_stem": "numpy_optional_decorated", 111 "test_parquet": False, 112 }, 113 { 114 "SampleType": NumpyOptionalSampleDecorated, 115 "sample_data": { 116 "label": 9_001, 117 "image": np.random.randn(1024, 1024), 118 "embeddings": None, 119 }, 120 "sample_wds_stem": "numpy_optional_decorated_none", 121 "test_parquet": False, 122 }, 123] 124 125 126## Tests 127 128 129@pytest.mark.parametrize( 130 ("SampleType", "sample_data"), 131 [(case["SampleType"], case["sample_data"]) for case in test_cases], 132) 133def test_create_sample( 134 SampleType: Type[atdata.PackableSample], 135 sample_data: atds.WDSRawSample, 136): 137 """Test our ability to create samples from semi-structured data""" 138 139 sample = SampleType.from_data(sample_data) 140 assert isinstance(sample, SampleType), ( 141 f"Did not properly form sample for test type {SampleType}" 142 ) 143 144 for k, v in sample_data.items(): 145 cur_assertion: bool 146 if isinstance(v, np.ndarray): 147 cur_assertion = np.all(getattr(sample, k) == v) 148 else: 149 cur_assertion = getattr(sample, k) == v 150 assert cur_assertion, ( 151 f"Did not properly incorporate property {k} of test type {SampleType}" 152 ) 153 154 155@pytest.mark.parametrize( 156 ("SampleType", "sample_data", "sample_wds_stem"), 157 [ 158 (case["SampleType"], case["sample_data"], case["sample_wds_stem"]) 159 for case in test_cases 160 ], 161) 162def test_wds( 163 SampleType: Type[atdata.PackableSample], 164 sample_data: atds.WDSRawSample, 165 sample_wds_stem: str, 166 tmp_path, 167): 168 """Test our ability to write samples as `WebDatasets` to disk""" 169 170 ## Testing hyperparameters 171 172 n_copies = 100 173 shard_maxcount = 10 174 batch_size = 4 175 n_iterate = 10 176 177 ## Write sharded dataset 178 179 file_pattern = (tmp_path / (f"{sample_wds_stem}" + "-{shard_id}.tar")).as_posix() 180 file_wds_pattern = file_pattern.format(shard_id="%06d") 181 182 with wds.writer.ShardWriter( 183 pattern=file_wds_pattern, 184 maxcount=shard_maxcount, 185 ) as sink: 186 for i_sample in range(n_copies): 187 new_sample = SampleType.from_data(sample_data) 188 assert isinstance(new_sample, SampleType), ( 189 f"Did not properly form sample for test type {SampleType}" 190 ) 191 192 sink.write(new_sample.as_wds) 193 194 ## Ordered 195 196 # Read first shard, no batches 197 198 first_filename = file_pattern.format(shard_id=f"{0:06d}") 199 dataset = atdata.Dataset[SampleType](first_filename) 200 201 iterations_run = 0 202 for i_iterate, cur_sample in enumerate(dataset.ordered(batch_size=None)): 203 assert isinstance(cur_sample, SampleType), ( 204 f"Single sample for {SampleType} written to `wds` is of wrong type" 205 ) 206 207 # Check sample values 208 209 for k, v in sample_data.items(): 210 if isinstance(v, np.ndarray): 211 is_correct = np.all(getattr(cur_sample, k) == v) 212 else: 213 is_correct = getattr(cur_sample, k) == v 214 assert is_correct, ( 215 f"{SampleType}: Incorrect sample value found for {k} - {type(getattr(cur_sample, k))}" 216 ) 217 218 iterations_run += 1 219 if iterations_run >= n_iterate: 220 break 221 222 assert iterations_run == n_iterate, ( 223 f"Only found {iterations_run} samples, not {n_iterate}" 224 ) 225 226 # Read all shards, batches 227 228 start_id = f"{0:06d}" 229 end_id = f"{9:06d}" 230 first_filename = file_pattern.format(shard_id="{" + start_id + ".." + end_id + "}") 231 dataset = atdata.Dataset[SampleType](first_filename) 232 233 iterations_run = 0 234 for i_iterate, cur_batch in enumerate(dataset.ordered(batch_size=batch_size)): 235 assert isinstance(cur_batch, atdata.SampleBatch), ( 236 f"{SampleType}: Batch sample is not correctly a batch" 237 ) 238 239 assert cur_batch.sample_type == SampleType, ( 240 f"{SampleType}: Batch `sample_type` is incorrect type" 241 ) 242 243 if i_iterate == 0: 244 cur_n = len(cur_batch.samples) 245 assert cur_n == batch_size, ( 246 f"{SampleType}: Batch has {cur_n} samples, not {batch_size}" 247 ) 248 249 assert isinstance(cur_batch.samples[0], SampleType), ( 250 f"{SampleType}: Batch sample of wrong type ({type(cur_batch.samples[0])})" 251 ) 252 253 # Check batch values 254 for k, v in sample_data.items(): 255 cur_batch_data = getattr(cur_batch, k) 256 257 if isinstance(v, np.ndarray): 258 assert isinstance(cur_batch_data, np.ndarray), ( 259 f"{SampleType}: `NDArray` not carried through to batch" 260 ) 261 262 is_correct = all( 263 [ 264 np.all(cur_batch_data[i] == v) 265 for i in range(cur_batch_data.shape[0]) 266 ] 267 ) 268 269 else: 270 is_correct = all( 271 [cur_batch_data[i] == v for i in range(len(cur_batch_data))] 272 ) 273 274 assert is_correct, f"{SampleType}: Incorrect sample value found for {k}" 275 276 iterations_run += 1 277 if iterations_run >= n_iterate: 278 break 279 280 assert iterations_run == n_iterate, ( 281 f"Only found {iterations_run} samples, not {n_iterate}" 282 ) 283 284 ## Shuffled 285 286 # Read first shard, no batches 287 288 first_filename = file_pattern.format(shard_id=f"{0:06d}") 289 dataset = atdata.Dataset[SampleType](first_filename) 290 291 iterations_run = 0 292 for i_iterate, cur_sample in enumerate(dataset.shuffled(batch_size=None)): 293 assert isinstance(cur_sample, SampleType), ( 294 f"Single sample for {SampleType} written to `wds` is of wrong type" 295 ) 296 297 iterations_run += 1 298 if iterations_run >= n_iterate: 299 break 300 301 assert iterations_run == n_iterate, ( 302 f"Only found {iterations_run} samples, not {n_iterate}" 303 ) 304 305 # Read all shards, batches 306 307 start_id = f"{0:06d}" 308 end_id = f"{9:06d}" 309 first_filename = file_pattern.format(shard_id="{" + start_id + ".." + end_id + "}") 310 dataset = atdata.Dataset[SampleType](first_filename) 311 312 iterations_run = 0 313 for i_iterate, cur_sample in enumerate(dataset.shuffled(batch_size=batch_size)): 314 assert isinstance(cur_sample, atdata.SampleBatch), ( 315 f"{SampleType}: Batch sample is not correctly a batch" 316 ) 317 318 assert cur_sample.sample_type == SampleType, ( 319 f"{SampleType}: Batch `sample_type` is incorrect type" 320 ) 321 322 if i_iterate == 0: 323 cur_n = len(cur_sample.samples) 324 assert cur_n == batch_size, ( 325 f"{SampleType}: Batch has {cur_n} samples, not {batch_size}" 326 ) 327 328 assert isinstance(cur_sample.samples[0], SampleType), ( 329 f"{SampleType}: Batch sample of wrong type ({type(cur_sample.samples[0])})" 330 ) 331 332 iterations_run += 1 333 if iterations_run >= n_iterate: 334 break 335 336 assert iterations_run == n_iterate, ( 337 f"Only found {iterations_run} samples, not {n_iterate}" 338 ) 339 340 341# 342 343 344@pytest.mark.parametrize( 345 ("SampleType", "sample_data", "sample_wds_stem", "test_parquet"), 346 [ 347 ( 348 case["SampleType"], 349 case["sample_data"], 350 case["sample_wds_stem"], 351 case["test_parquet"], 352 ) 353 for case in test_cases 354 ], 355) 356def test_parquet_export( 357 SampleType: Type[atdata.PackableSample], 358 sample_data: atds.WDSRawSample, 359 sample_wds_stem: str, 360 test_parquet: bool, 361 tmp_path, 362): 363 """Test our ability to export a dataset to `parquet` format""" 364 365 # Skip irrelevant test cases 366 if not test_parquet: 367 return 368 369 ## Testing hyperparameters 370 371 n_copies_dataset = 1_000 372 n_per_file = 100 373 374 ## Start out by writing tar dataset 375 376 wds_filename = (tmp_path / f"{sample_wds_stem}.tar").as_posix() 377 with wds.writer.TarWriter(wds_filename) as sink: 378 for _ in range(n_copies_dataset): 379 new_sample = SampleType.from_data(sample_data) 380 sink.write(new_sample.as_wds) 381 382 ## Now export to `parquet` 383 384 dataset = atdata.Dataset[SampleType](wds_filename) 385 parquet_filename = tmp_path / f"{sample_wds_stem}.parquet" 386 dataset.to_parquet(parquet_filename) 387 388 parquet_filename = tmp_path / f"{sample_wds_stem}-segments.parquet" 389 dataset.to_parquet(parquet_filename, maxcount=n_per_file) 390 391 392## 393# Edge case tests for coverage 394 395 396def test_batch_aggregate_empty(): 397 """Test _batch_aggregate with empty list returns empty list.""" 398 result = atds._batch_aggregate([]) 399 assert result == [], "Empty input should return empty list" 400 401 402def test_sample_batch_attribute_error(): 403 """Test SampleBatch raises AttributeError for non-existent attributes.""" 404 405 @atdata.packable 406 class SimpleSample: 407 name: str 408 value: int 409 410 samples = [SimpleSample(name="test", value=1)] 411 batch = atdata.SampleBatch[SimpleSample](samples) 412 413 with pytest.raises(AttributeError, match="No sample attribute named"): 414 _ = batch.nonexistent_attribute 415 416 417def test_sample_batch_type_property(): 418 """Test SampleBatch.sample_type property.""" 419 420 @atdata.packable 421 class TypedSample: 422 data: str 423 424 samples = [TypedSample(data="hello")] 425 batch = atdata.SampleBatch[TypedSample](samples) 426 427 assert batch.sample_type == TypedSample 428 429 430def test_dataset_batch_type_property(tmp_path): 431 """Test Dataset.batch_type property.""" 432 433 @atdata.packable 434 class BatchTypeSample: 435 value: int 436 437 # Create a simple dataset 438 wds_filename = (tmp_path / "batch_type_test.tar").as_posix() 439 with wds.writer.TarWriter(wds_filename) as sink: 440 sample = BatchTypeSample(value=42) 441 sink.write(sample.as_wds) 442 443 dataset = atdata.Dataset[BatchTypeSample](wds_filename) 444 batch_type = dataset.batch_type 445 446 # batch_type should be SampleBatch parameterized with the sample type 447 assert batch_type.__origin__ == atdata.SampleBatch 448 449 450def test_dataset_shard_list_property(tmp_path): 451 """Test Dataset.shard_list property returns list of shard URLs.""" 452 453 @atdata.packable 454 class ShardListSample: 455 value: int 456 457 # Create multiple shards 458 file_pattern = (tmp_path / "shards_test-%06d.tar").as_posix() 459 with wds.writer.ShardWriter(pattern=file_pattern, maxcount=5) as sink: 460 for i in range(15): 461 sample = ShardListSample(value=i) 462 sink.write(sample.as_wds) 463 464 # Read with brace pattern 465 brace_pattern = (tmp_path / "shards_test-{000000..000002}.tar").as_posix() 466 dataset = atdata.Dataset[ShardListSample](brace_pattern) 467 468 shard_list = dataset.shard_list 469 assert isinstance(shard_list, list) 470 assert len(shard_list) == 3 471 472 473def test_dataset_metadata_property(tmp_path): 474 """Test Dataset.metadata property fetches and caches metadata from URL.""" 475 from unittest.mock import patch, Mock 476 import msgpack 477 478 @atdata.packable 479 class MetadataSample: 480 value: int 481 482 # Create a simple dataset 483 wds_filename = (tmp_path / "metadata_test.tar").as_posix() 484 with wds.writer.TarWriter(wds_filename) as sink: 485 sample = MetadataSample(value=42) 486 sink.write(sample.as_wds) 487 488 # Mock the requests.get call 489 mock_metadata = {"key": "value", "count": 100} 490 mock_response = Mock() 491 mock_response.content = msgpack.packb(mock_metadata) 492 mock_response.raise_for_status = Mock() 493 mock_response.__enter__ = Mock(return_value=mock_response) 494 mock_response.__exit__ = Mock(return_value=False) 495 496 with patch("atdata.dataset.requests.get", return_value=mock_response) as mock_get: 497 dataset = atdata.Dataset[MetadataSample]( 498 wds_filename, metadata_url="http://example.com/metadata.msgpack" 499 ) 500 501 # First call should fetch 502 metadata = dataset.metadata 503 assert metadata == mock_metadata 504 mock_get.assert_called_once_with( 505 "http://example.com/metadata.msgpack", stream=True 506 ) 507 508 # Second call should use cache 509 metadata2 = dataset.metadata 510 assert metadata2 == mock_metadata 511 assert mock_get.call_count == 1 # Still only one call 512 513 514def test_dataset_metadata_property_none(tmp_path): 515 """Test Dataset.metadata returns None when no metadata_url is set.""" 516 517 @atdata.packable 518 class NoMetadataSample: 519 value: int 520 521 wds_filename = (tmp_path / "no_metadata_test.tar").as_posix() 522 with wds.writer.TarWriter(wds_filename) as sink: 523 sample = NoMetadataSample(value=42) 524 sink.write(sample.as_wds) 525 526 dataset = atdata.Dataset[NoMetadataSample](wds_filename) 527 assert dataset.metadata is None 528 529 530def test_parquet_export_with_remainder(tmp_path): 531 """Test parquet export with maxcount that doesn't divide evenly.""" 532 533 @atdata.packable 534 class RemainderSample: 535 name: str 536 value: int 537 538 # Create dataset with 25 samples 539 n_samples = 25 540 maxcount = 10 # Will create 3 segments: 10, 10, 5 541 542 wds_filename = (tmp_path / "remainder_test.tar").as_posix() 543 with wds.writer.TarWriter(wds_filename) as sink: 544 for i in range(n_samples): 545 sample = RemainderSample(name=f"sample_{i}", value=i) 546 sink.write(sample.as_wds) 547 548 dataset = atdata.Dataset[RemainderSample](wds_filename) 549 parquet_path = tmp_path / "remainder_output.parquet" 550 dataset.to_parquet(parquet_path, maxcount=maxcount) 551 552 # Should have created 3 segment files 553 import pandas as pd 554 555 segment_files = list(tmp_path.glob("remainder_output-*.parquet")) 556 assert len(segment_files) == 3 557 558 # Check total row count 559 total_rows = sum(len(pd.read_parquet(f)) for f in segment_files) 560 assert total_rows == n_samples 561 562 563def test_dataset_with_lens_batched(tmp_path): 564 """Test dataset iteration with lens transformation in batch mode.""" 565 from dataclasses import dataclass 566 567 @dataclass 568 class SourceSample(atdata.PackableSample): 569 name: str 570 age: int 571 score: float 572 573 @dataclass 574 class ViewSample(atdata.PackableSample): 575 name: str 576 score: float 577 578 @atdata.lens 579 def extract_view(s: SourceSample) -> ViewSample: 580 return ViewSample(name=s.name, score=s.score) 581 582 # Create dataset 583 n_samples = 20 584 batch_size = 4 585 wds_filename = (tmp_path / "lens_batch_test.tar").as_posix() 586 587 with wds.writer.TarWriter(wds_filename) as sink: 588 for i in range(n_samples): 589 sample = SourceSample(name=f"person_{i}", age=20 + i, score=float(i) * 1.5) 590 sink.write(sample.as_wds) 591 592 # Read with lens transformation in batch mode 593 dataset = atdata.Dataset[SourceSample](wds_filename).as_type(ViewSample) 594 595 batches_seen = 0 596 for batch in dataset.ordered(batch_size=batch_size): 597 assert isinstance(batch, atdata.SampleBatch) 598 assert batch.sample_type == ViewSample 599 600 # Check that samples are ViewSample type (not SourceSample) 601 for sample in batch.samples: 602 assert isinstance(sample, ViewSample) 603 assert hasattr(sample, "name") 604 assert hasattr(sample, "score") 605 assert not hasattr(sample, "age") # age is not in ViewSample 606 607 batches_seen += 1 608 609 assert batches_seen == n_samples // batch_size 610 611 612def test_from_bytes_invalid_msgpack(): 613 """Test from_bytes raises on invalid msgpack data.""" 614 615 @atdata.packable 616 class SimpleSample: 617 value: int 618 619 with pytest.raises(Exception): # ormsgpack raises on invalid data 620 SimpleSample.from_bytes(b"not valid msgpack data") 621 622 623def test_from_bytes_missing_field(): 624 """Test from_bytes raises when required field is missing.""" 625 626 @atdata.packable 627 class RequiredFieldSample: 628 name: str 629 count: int 630 631 import ormsgpack 632 633 # Only provide 'name', missing 'count' 634 incomplete_data = ormsgpack.packb({"name": "test"}) 635 636 with pytest.raises(TypeError): # Missing required argument 637 RequiredFieldSample.from_bytes(incomplete_data) 638 639 640def test_wrap_missing_msgpack_key(tmp_path): 641 """Test wrap raises ValueError on sample missing msgpack key.""" 642 643 @atdata.packable 644 class WrapTestSample: 645 value: int 646 647 wds_filename = (tmp_path / "wrap_test.tar").as_posix() 648 with wds.writer.TarWriter(wds_filename) as sink: 649 sample = WrapTestSample(value=42) 650 sink.write(sample.as_wds) 651 652 dataset = atdata.Dataset[WrapTestSample](wds_filename) 653 654 # Directly call wrap with missing key 655 with pytest.raises(ValueError, match="missing 'msgpack' key"): 656 dataset.wrap({"__key__": "test"}) # Missing 'msgpack' key 657 658 659def test_wrap_wrong_msgpack_type(tmp_path): 660 """Test wrap raises ValueError when msgpack value is not bytes.""" 661 662 @atdata.packable 663 class WrapTypeSample: 664 value: int 665 666 wds_filename = (tmp_path / "wrap_type_test.tar").as_posix() 667 with wds.writer.TarWriter(wds_filename) as sink: 668 sample = WrapTypeSample(value=42) 669 sink.write(sample.as_wds) 670 671 dataset = atdata.Dataset[WrapTypeSample](wds_filename) 672 673 # Directly call wrap with wrong type 674 with pytest.raises(ValueError, match="to be bytes"): 675 dataset.wrap({"__key__": "test", "msgpack": "not bytes"}) 676 677 678def test_wrap_corrupted_msgpack(tmp_path): 679 """Test wrap raises on corrupted msgpack bytes.""" 680 681 @atdata.packable 682 class CorruptedSample: 683 value: int 684 685 wds_filename = (tmp_path / "corrupted_test.tar").as_posix() 686 with wds.writer.TarWriter(wds_filename) as sink: 687 sample = CorruptedSample(value=42) 688 sink.write(sample.as_wds) 689 690 dataset = atdata.Dataset[CorruptedSample](wds_filename) 691 692 # Corrupted msgpack bytes should raise during deserialization 693 with pytest.raises(Exception): # ormsgpack raises on corrupted data 694 dataset.wrap({"__key__": "test", "msgpack": b"\xff\xfe\x00\x01invalid"}) 695 696 697def test_dataset_nonexistent_file(): 698 """Test Dataset raises on nonexistent tar file during iteration.""" 699 700 @atdata.packable 701 class NonexistentSample: 702 value: int 703 704 dataset = atdata.Dataset[NonexistentSample]("/nonexistent/path/data.tar") 705 706 # Dataset creation succeeds (lazy loading) 707 assert dataset is not None 708 709 # Iteration fails when file doesn't exist 710 with pytest.raises(Exception): # FileNotFoundError or similar 711 list(dataset.ordered(batch_size=None)) 712 713 714def test_dataset_invalid_batch_size(tmp_path): 715 """Test Dataset raises on invalid batch_size values.""" 716 717 @atdata.packable 718 class BatchSizeSample: 719 value: int 720 721 wds_filename = (tmp_path / "batch_test.tar").as_posix() 722 with wds.writer.TarWriter(wds_filename) as sink: 723 sample = BatchSizeSample(value=42) 724 sink.write(sample.as_wds) 725 726 dataset = atdata.Dataset[BatchSizeSample](wds_filename) 727 728 # batch_size=0 produces empty batches, causing IndexError in webdataset 729 with pytest.raises((ValueError, AssertionError, IndexError)): 730 list(dataset.ordered(batch_size=0)) 731 732 733## 734# DictSample tests 735 736 737def test_dictsample_creation(): 738 """Test DictSample can be created with keyword args or dict.""" 739 # From keyword args 740 ds1 = atdata.DictSample(name="test", value=42) 741 assert ds1.name == "test" 742 assert ds1.value == 42 743 744 # From _data dict 745 ds2 = atdata.DictSample(_data={"name": "test2", "value": 100}) 746 assert ds2.name == "test2" 747 assert ds2.value == 100 748 749 750def test_dictsample_getattr(): 751 """Test DictSample attribute access.""" 752 sample = atdata.DictSample(text="hello", label=1) 753 754 assert sample.text == "hello" 755 assert sample.label == 1 756 757 # Non-existent attribute raises AttributeError 758 with pytest.raises(AttributeError, match="has no field"): 759 _ = sample.nonexistent 760 761 762def test_dictsample_getitem(): 763 """Test DictSample dict-style access.""" 764 sample = atdata.DictSample(text="hello", label=1) 765 766 assert sample["text"] == "hello" 767 assert sample["label"] == 1 768 769 # Non-existent key raises KeyError 770 with pytest.raises(KeyError): 771 _ = sample["nonexistent"] 772 773 774def test_dictsample_dict_methods(): 775 """Test DictSample dict-like methods.""" 776 sample = atdata.DictSample(a=1, b=2, c=3) 777 778 assert set(sample.keys()) == {"a", "b", "c"} 779 assert set(sample.values()) == {1, 2, 3} 780 assert set(sample.items()) == {("a", 1), ("b", 2), ("c", 3)} 781 assert "a" in sample 782 assert "x" not in sample 783 assert sample.get("a") == 1 784 assert sample.get("x", "default") == "default" 785 786 787def test_dictsample_to_dict(): 788 """Test DictSample.to_dict returns a copy.""" 789 sample = atdata.DictSample(name="test", value=42) 790 d = sample.to_dict() 791 792 assert d == {"name": "test", "value": 42} 793 # Should be a copy 794 d["name"] = "modified" 795 assert sample.name == "test" 796 797 798def test_dictsample_serialization(): 799 """Test DictSample can be serialized and deserialized.""" 800 original = atdata.DictSample(text="hello", count=42) 801 802 # Serialize 803 packed = original.packed 804 805 # Deserialize 806 restored = atdata.DictSample.from_bytes(packed) 807 808 assert restored.text == "hello" 809 assert restored.count == 42 810 811 812def test_dictsample_as_wds(): 813 """Test DictSample.as_wds produces valid WebDataset format.""" 814 sample = atdata.DictSample(name="test", value=123) 815 wds_dict = sample.as_wds 816 817 assert "__key__" in wds_dict 818 assert "msgpack" in wds_dict 819 assert isinstance(wds_dict["msgpack"], bytes) 820 821 822def test_dictsample_repr(): 823 """Test DictSample has a useful repr.""" 824 sample = atdata.DictSample(name="test", value=42) 825 repr_str = repr(sample) 826 827 assert "DictSample" in repr_str 828 assert "name" in repr_str 829 assert "value" in repr_str 830 831 832def test_dictsample_dataset_iteration(tmp_path): 833 """Test Dataset[DictSample] can iterate over data.""" 834 835 # Create typed sample data 836 @atdata.packable 837 class SourceSample: 838 text: str 839 label: int 840 841 wds_filename = (tmp_path / "dictsample_test.tar").as_posix() 842 with wds.writer.TarWriter(wds_filename) as sink: 843 for i in range(5): 844 sample = SourceSample(text=f"item_{i}", label=i) 845 sink.write(sample.as_wds) 846 847 # Read as DictSample 848 dataset = atdata.Dataset[atdata.DictSample](wds_filename) 849 850 samples = list(dataset.ordered()) 851 assert len(samples) == 5 852 853 for i, sample in enumerate(samples): 854 assert isinstance(sample, atdata.DictSample) 855 assert sample.text == f"item_{i}" 856 assert sample["label"] == i 857 858 859def test_dictsample_to_typed_via_as_type(tmp_path): 860 """Test converting DictSample dataset to typed via as_type.""" 861 862 @atdata.packable 863 class TypedSample: 864 text: str 865 label: int 866 867 # Create data using typed sample 868 wds_filename = (tmp_path / "astype_test.tar").as_posix() 869 with wds.writer.TarWriter(wds_filename) as sink: 870 for i in range(5): 871 sample = TypedSample(text=f"item_{i}", label=i) 872 sink.write(sample.as_wds) 873 874 # Load as DictSample first 875 ds_dict = atdata.Dataset[atdata.DictSample](wds_filename) 876 877 # Convert to typed 878 ds_typed = ds_dict.as_type(TypedSample) 879 880 # Verify typed iteration works 881 samples = list(ds_typed.ordered()) 882 assert len(samples) == 5 883 884 for i, sample in enumerate(samples): 885 assert isinstance(sample, TypedSample) 886 assert sample.text == f"item_{i}" 887 assert sample.label == i 888 889 890def test_packable_auto_registers_dictsample_lens(): 891 """Test @packable decorator auto-registers lens from DictSample.""" 892 893 @atdata.packable 894 class AutoLensSample: 895 name: str 896 value: int 897 898 # The lens should be registered automatically 899 network = atdata.LensNetwork() 900 lens = network.transform(atdata.DictSample, AutoLensSample) 901 902 # Test the lens works 903 dict_sample = atdata.DictSample(name="test", value=42) 904 typed_sample = lens(dict_sample) 905 906 assert isinstance(typed_sample, AutoLensSample) 907 assert typed_sample.name == "test" 908 assert typed_sample.value == 42 909 910 911def test_dictsample_batched_iteration(tmp_path): 912 """Test Dataset[DictSample] works with batched iteration.""" 913 914 @atdata.packable 915 class BatchSource: 916 text: str 917 value: int 918 919 wds_filename = (tmp_path / "batch_dictsample_test.tar").as_posix() 920 with wds.writer.TarWriter(wds_filename) as sink: 921 for i in range(10): 922 sample = BatchSource(text=f"item_{i}", value=i) 923 sink.write(sample.as_wds) 924 925 # Read as DictSample with batching 926 dataset = atdata.Dataset[atdata.DictSample](wds_filename) 927 928 batch_count = 0 929 for batch in dataset.ordered(batch_size=4): 930 assert isinstance(batch, atdata.SampleBatch) 931 assert len(batch.samples) <= 4 932 for sample in batch.samples: 933 assert isinstance(sample, atdata.DictSample) 934 batch_count += 1 935 936 assert batch_count == 3 # 10 samples / 4 per batch = 2 full + 1 partial 937 938 939##