A loose federation of distributed, typed datasets
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 688 lines 23 kB view raw
1"""End-to-end integration tests for atdata data flow pipeline. 2 3Tests the complete workflow: Create → Store → Load → Iterate → Verify. 4 5These tests verify: 6- Full pipeline with various sample types 7- Multi-shard datasets with brace notation 8- Large batch handling and memory efficiency 9- Metadata round-trip preservation 10- Parquet export with transformations 11""" 12 13from dataclasses import dataclass 14from pathlib import Path 15 16import numpy as np 17from numpy.typing import NDArray 18import webdataset as wds 19 20import atdata 21 22 23## 24# Test sample types 25 26 27@atdata.packable 28class SimpleSample: 29 """Basic sample with primitive types only.""" 30 31 name: str 32 value: int 33 score: float 34 active: bool 35 36 37@atdata.packable 38class NDArraySample: 39 """Sample with multiple NDArray fields of different shapes.""" 40 41 label: int 42 image: NDArray 43 features: NDArray 44 45 46@atdata.packable 47class OptionalNDArraySample: 48 """Sample with optional NDArray fields.""" 49 50 label: int 51 image: NDArray 52 embeddings: NDArray | None = None 53 54 55@atdata.packable 56class BytesSample: 57 """Sample with bytes field.""" 58 59 name: str 60 raw_data: bytes 61 62 63@atdata.packable 64class ListSample: 65 """Sample with list fields.""" 66 67 tags: list[str] 68 scores: list[float] 69 ids: list[int] 70 71 72@dataclass 73class InheritanceSample(atdata.PackableSample): 74 """Sample using inheritance syntax instead of decorator.""" 75 76 title: str 77 count: int 78 measurements: NDArray 79 80 81## 82# Helper functions 83 84 85def create_simple_samples(n: int) -> list[SimpleSample]: 86 """Create n simple samples with distinct values.""" 87 return [ 88 SimpleSample( 89 name=f"sample_{i}", 90 value=i * 10, 91 score=float(i) * 0.5, 92 active=(i % 2 == 0), 93 ) 94 for i in range(n) 95 ] 96 97 98def create_ndarray_samples(n: int, img_shape: tuple = (64, 64)) -> list[NDArraySample]: 99 """Create n NDArray samples with distinct values.""" 100 return [ 101 NDArraySample( 102 label=i, 103 image=np.random.randn(*img_shape).astype(np.float32), 104 features=np.random.randn(128).astype(np.float32), 105 ) 106 for i in range(n) 107 ] 108 109 110def create_optional_samples( 111 n: int, include_optional: bool 112) -> list[OptionalNDArraySample]: 113 """Create samples with or without optional embeddings.""" 114 return [ 115 OptionalNDArraySample( 116 label=i, 117 image=np.random.randn(32, 32).astype(np.float32), 118 embeddings=np.random.randn(64).astype(np.float32) 119 if include_optional 120 else None, 121 ) 122 for i in range(n) 123 ] 124 125 126def write_single_shard(path: Path, samples: list) -> str: 127 """Write samples to a single tar file, return path.""" 128 tar_path = path.as_posix() 129 with wds.writer.TarWriter(tar_path) as sink: 130 for sample in samples: 131 sink.write(sample.as_wds) 132 return tar_path 133 134 135def write_multi_shard( 136 base_path: Path, 137 samples: list, 138 samples_per_shard: int = 10, 139) -> tuple[str, int]: 140 """Write samples to multiple shards, return brace pattern and shard count.""" 141 pattern = (base_path / "shard-%06d.tar").as_posix() 142 with wds.writer.ShardWriter(pattern=pattern, maxcount=samples_per_shard) as sink: 143 for sample in samples: 144 sink.write(sample.as_wds) 145 146 n_shards = (len(samples) + samples_per_shard - 1) // samples_per_shard 147 brace_pattern = (base_path / f"shard-{{000000..{n_shards - 1:06d}}}.tar").as_posix() 148 return brace_pattern, n_shards 149 150 151## 152# Full Pipeline Tests 153 154 155class TestFullPipelineSimple: 156 """End-to-end tests with simple primitive-only samples.""" 157 158 def test_create_store_load_iterate_single_shard(self, tmp_path): 159 """Full pipeline: create → store → load → iterate (single shard).""" 160 n_samples = 50 161 samples = create_simple_samples(n_samples) 162 163 # Store 164 tar_path = write_single_shard(tmp_path / "simple.tar", samples) 165 166 # Load 167 dataset = atdata.Dataset[SimpleSample](tar_path) 168 169 # Iterate without batching 170 loaded = list(dataset.ordered(batch_size=None)) 171 172 # Verify 173 assert len(loaded) == n_samples 174 for i, sample in enumerate(loaded): 175 assert isinstance(sample, SimpleSample) 176 assert sample.name == f"sample_{i}" 177 assert sample.value == i * 10 178 assert sample.score == float(i) * 0.5 179 assert sample.active == (i % 2 == 0) 180 181 def test_create_store_load_iterate_batched(self, tmp_path): 182 """Full pipeline with batching.""" 183 n_samples = 100 184 batch_size = 16 185 samples = create_simple_samples(n_samples) 186 187 tar_path = write_single_shard(tmp_path / "batched.tar", samples) 188 dataset = atdata.Dataset[SimpleSample](tar_path) 189 190 # Iterate with batching 191 batches = list(dataset.ordered(batch_size=batch_size)) 192 193 # Verify batch structure (WebDataset drops incomplete final batch) 194 total_samples = sum(len(b.samples) for b in batches) 195 assert total_samples >= (n_samples // batch_size) * batch_size 196 197 for batch in batches: 198 assert isinstance(batch, atdata.SampleBatch) 199 assert batch.sample_type == SimpleSample 200 assert len(batch.samples) <= batch_size 201 202 # Verify aggregated attributes 203 names = batch.name 204 values = batch.value 205 assert isinstance(names, list) 206 assert isinstance(values, list) 207 assert len(names) == len(batch.samples) 208 assert len(values) == len(batch.samples) 209 210 def test_inheritance_syntax_pipeline(self, tmp_path): 211 """Full pipeline using inheritance-style sample definition.""" 212 n_samples = 25 213 samples = [ 214 InheritanceSample( 215 title=f"doc_{i}", 216 count=i * 5, 217 measurements=np.random.randn(10).astype(np.float32), 218 ) 219 for i in range(n_samples) 220 ] 221 222 tar_path = write_single_shard(tmp_path / "inheritance.tar", samples) 223 dataset = atdata.Dataset[InheritanceSample](tar_path) 224 225 loaded = list(dataset.ordered(batch_size=None)) 226 227 assert len(loaded) == n_samples 228 for i, sample in enumerate(loaded): 229 assert isinstance(sample, InheritanceSample) 230 assert sample.title == f"doc_{i}" 231 assert sample.count == i * 5 232 assert isinstance(sample.measurements, np.ndarray) 233 234 235class TestFullPipelineNDArray: 236 """End-to-end tests with NDArray samples.""" 237 238 def test_ndarray_serialization_roundtrip(self, tmp_path): 239 """NDArray fields should serialize and deserialize exactly.""" 240 n_samples = 20 241 samples = create_ndarray_samples(n_samples, img_shape=(32, 32)) 242 243 tar_path = write_single_shard(tmp_path / "ndarray.tar", samples) 244 dataset = atdata.Dataset[NDArraySample](tar_path) 245 246 loaded = list(dataset.ordered(batch_size=None)) 247 248 assert len(loaded) == n_samples 249 for original, loaded_sample in zip(samples, loaded): 250 assert loaded_sample.label == original.label 251 np.testing.assert_array_almost_equal(loaded_sample.image, original.image) 252 np.testing.assert_array_almost_equal( 253 loaded_sample.features, original.features 254 ) 255 256 def test_ndarray_batch_stacking(self, tmp_path): 257 """NDArray fields should stack into batch dimension.""" 258 n_samples = 32 259 batch_size = 8 260 img_shape = (16, 16) 261 feature_dim = 64 262 263 samples = [ 264 NDArraySample( 265 label=i, 266 image=np.full(img_shape, i, dtype=np.float32), 267 features=np.full(feature_dim, i * 0.1, dtype=np.float32), 268 ) 269 for i in range(n_samples) 270 ] 271 272 tar_path = write_single_shard(tmp_path / "stacking.tar", samples) 273 dataset = atdata.Dataset[NDArraySample](tar_path) 274 275 batches = list(dataset.ordered(batch_size=batch_size)) 276 277 for batch_idx, batch in enumerate(batches): 278 # Check stacked shapes 279 assert batch.image.shape == (batch_size, *img_shape) 280 assert batch.features.shape == (batch_size, feature_dim) 281 282 # Check values 283 for i in range(batch_size): 284 sample_idx = batch_idx * batch_size + i 285 np.testing.assert_array_equal( 286 batch.image[i], 287 np.full(img_shape, sample_idx, dtype=np.float32), 288 ) 289 290 def test_optional_ndarray_with_values(self, tmp_path): 291 """Optional NDArray with actual values should roundtrip.""" 292 n_samples = 15 293 samples = create_optional_samples(n_samples, include_optional=True) 294 295 tar_path = write_single_shard(tmp_path / "optional_filled.tar", samples) 296 dataset = atdata.Dataset[OptionalNDArraySample](tar_path) 297 298 loaded = list(dataset.ordered(batch_size=None)) 299 300 for original, loaded_sample in zip(samples, loaded): 301 assert loaded_sample.embeddings is not None 302 np.testing.assert_array_almost_equal( 303 loaded_sample.embeddings, 304 original.embeddings, 305 ) 306 307 def test_optional_ndarray_with_none(self, tmp_path): 308 """Optional NDArray with None should roundtrip.""" 309 n_samples = 15 310 samples = create_optional_samples(n_samples, include_optional=False) 311 312 tar_path = write_single_shard(tmp_path / "optional_none.tar", samples) 313 dataset = atdata.Dataset[OptionalNDArraySample](tar_path) 314 315 loaded = list(dataset.ordered(batch_size=None)) 316 317 for loaded_sample in loaded: 318 assert loaded_sample.embeddings is None 319 320 def test_mixed_dtypes(self, tmp_path): 321 """Various numpy dtypes should serialize correctly.""" 322 323 @atdata.packable 324 class MultiDtypeSample: 325 f32: NDArray 326 f64: NDArray 327 i32: NDArray 328 i64: NDArray 329 u8: NDArray 330 331 samples = [ 332 MultiDtypeSample( 333 f32=np.array([1.0, 2.0, 3.0], dtype=np.float32), 334 f64=np.array([1.0, 2.0, 3.0], dtype=np.float64), 335 i32=np.array([1, 2, 3], dtype=np.int32), 336 i64=np.array([1, 2, 3], dtype=np.int64), 337 u8=np.array([255, 128, 0], dtype=np.uint8), 338 ) 339 for _ in range(10) 340 ] 341 342 tar_path = write_single_shard(tmp_path / "multidtype.tar", samples) 343 dataset = atdata.Dataset[MultiDtypeSample](tar_path) 344 345 loaded = list(dataset.ordered(batch_size=None)) 346 347 for original, loaded_sample in zip(samples, loaded): 348 assert loaded_sample.f32.dtype == np.float32 349 assert loaded_sample.f64.dtype == np.float64 350 assert loaded_sample.i32.dtype == np.int32 351 assert loaded_sample.i64.dtype == np.int64 352 assert loaded_sample.u8.dtype == np.uint8 353 np.testing.assert_array_equal(loaded_sample.f32, original.f32) 354 355 356class TestMultiShardPipeline: 357 """End-to-end tests with multi-shard datasets using brace notation.""" 358 359 def test_multi_shard_ordered_iteration(self, tmp_path): 360 """Multi-shard dataset should iterate all samples in order.""" 361 n_samples = 100 362 samples_per_shard = 10 363 samples = create_simple_samples(n_samples) 364 365 brace_pattern, n_shards = write_multi_shard( 366 tmp_path, 367 samples, 368 samples_per_shard=samples_per_shard, 369 ) 370 371 assert n_shards == 10 372 373 dataset = atdata.Dataset[SimpleSample](brace_pattern) 374 loaded = list(dataset.ordered(batch_size=None)) 375 376 assert len(loaded) == n_samples 377 378 # Verify ordering within each shard 379 for i, sample in enumerate(loaded): 380 assert sample.name == f"sample_{i}" 381 382 def test_multi_shard_batched(self, tmp_path): 383 """Multi-shard dataset with batching should work correctly.""" 384 n_samples = 120 385 samples_per_shard = 15 386 batch_size = 8 387 samples = create_simple_samples(n_samples) 388 389 brace_pattern, n_shards = write_multi_shard( 390 tmp_path, 391 samples, 392 samples_per_shard=samples_per_shard, 393 ) 394 395 dataset = atdata.Dataset[SimpleSample](brace_pattern) 396 batches = list(dataset.ordered(batch_size=batch_size)) 397 398 # Total samples should match 399 total_samples = sum(len(b.samples) for b in batches) 400 assert total_samples == (n_samples // batch_size) * batch_size 401 402 def test_multi_shard_shuffled(self, tmp_path): 403 """Multi-shard shuffled iteration should work.""" 404 n_samples = 50 405 samples_per_shard = 10 406 samples = create_simple_samples(n_samples) 407 408 brace_pattern, _ = write_multi_shard( 409 tmp_path, 410 samples, 411 samples_per_shard=samples_per_shard, 412 ) 413 414 dataset = atdata.Dataset[SimpleSample](brace_pattern) 415 416 # Collect some samples from shuffled iteration 417 shuffled_samples = [] 418 for sample in dataset.shuffled(batch_size=None): 419 shuffled_samples.append(sample) 420 if len(shuffled_samples) >= 30: 421 break 422 423 assert len(shuffled_samples) == 30 424 425 # All samples should be valid SimpleSample instances 426 for sample in shuffled_samples: 427 assert isinstance(sample, SimpleSample) 428 assert sample.name.startswith("sample_") 429 430 def test_single_shard_via_brace_pattern(self, tmp_path): 431 """Single shard via brace pattern should work.""" 432 n_samples = 25 433 samples = create_simple_samples(n_samples) 434 435 # Create exactly one shard 436 brace_pattern, n_shards = write_multi_shard( 437 tmp_path, 438 samples, 439 samples_per_shard=100, # More than samples, so single shard 440 ) 441 442 assert n_shards == 1 443 444 dataset = atdata.Dataset[SimpleSample](brace_pattern) 445 loaded = list(dataset.ordered(batch_size=None)) 446 447 assert len(loaded) == n_samples 448 449 450class TestLargeBatchHandling: 451 """Tests for handling large batches and many samples.""" 452 453 def test_large_batch_size(self, tmp_path): 454 """Large batch sizes should work correctly.""" 455 n_samples = 200 456 batch_size = 64 457 samples = create_simple_samples(n_samples) 458 459 tar_path = write_single_shard(tmp_path / "large_batch.tar", samples) 460 dataset = atdata.Dataset[SimpleSample](tar_path) 461 462 batches = list(dataset.ordered(batch_size=batch_size)) 463 464 # Verify we got the expected number of complete batches 465 total_samples = sum(len(b.samples) for b in batches) 466 assert total_samples >= (n_samples // batch_size) * batch_size 467 for batch in batches: 468 assert len(batch.samples) <= batch_size 469 470 def test_many_samples_single_shard(self, tmp_path): 471 """Many samples in single shard should work.""" 472 n_samples = 500 473 samples = create_simple_samples(n_samples) 474 475 tar_path = write_single_shard(tmp_path / "many.tar", samples) 476 dataset = atdata.Dataset[SimpleSample](tar_path) 477 478 loaded = list(dataset.ordered(batch_size=None)) 479 assert len(loaded) == n_samples 480 481 def test_large_ndarray_samples(self, tmp_path): 482 """Large NDArray fields should serialize correctly.""" 483 n_samples = 10 484 large_shape = (256, 256) # Larger images 485 486 samples = create_ndarray_samples(n_samples, img_shape=large_shape) 487 488 tar_path = write_single_shard(tmp_path / "large_ndarray.tar", samples) 489 dataset = atdata.Dataset[NDArraySample](tar_path) 490 491 loaded = list(dataset.ordered(batch_size=None)) 492 493 for original, loaded_sample in zip(samples, loaded): 494 assert loaded_sample.image.shape == large_shape 495 np.testing.assert_array_almost_equal( 496 loaded_sample.image, 497 original.image, 498 ) 499 500 501class TestBytesAndListSamples: 502 """Tests for bytes and list field types.""" 503 504 def test_bytes_field_roundtrip(self, tmp_path): 505 """Bytes fields should roundtrip correctly.""" 506 samples = [ 507 BytesSample( 508 name=f"item_{i}", 509 raw_data=f"binary_data_{i}".encode("utf-8"), 510 ) 511 for i in range(20) 512 ] 513 514 tar_path = write_single_shard(tmp_path / "bytes.tar", samples) 515 dataset = atdata.Dataset[BytesSample](tar_path) 516 517 loaded = list(dataset.ordered(batch_size=None)) 518 519 for original, loaded_sample in zip(samples, loaded): 520 assert loaded_sample.name == original.name 521 assert loaded_sample.raw_data == original.raw_data 522 523 def test_list_fields_roundtrip(self, tmp_path): 524 """List fields should roundtrip correctly.""" 525 samples = [ 526 ListSample( 527 tags=[f"tag_{j}" for j in range(3)], 528 scores=[float(j) * 0.1 for j in range(5)], 529 ids=[i * 10 + j for j in range(4)], 530 ) 531 for i in range(15) 532 ] 533 534 tar_path = write_single_shard(tmp_path / "lists.tar", samples) 535 dataset = atdata.Dataset[ListSample](tar_path) 536 537 loaded = list(dataset.ordered(batch_size=None)) 538 539 for original, loaded_sample in zip(samples, loaded): 540 assert loaded_sample.tags == original.tags 541 assert loaded_sample.scores == original.scores 542 assert loaded_sample.ids == original.ids 543 544 545class TestMetadataRoundTrip: 546 """Tests for metadata preservation through the pipeline.""" 547 548 def test_dataset_with_metadata_url(self, tmp_path): 549 """Dataset with metadata_url should fetch and cache metadata.""" 550 from unittest.mock import Mock, patch, MagicMock 551 import msgpack 552 553 samples = create_simple_samples(10) 554 tar_path = write_single_shard(tmp_path / "meta.tar", samples) 555 556 test_metadata = { 557 "version": "1.0.0", 558 "created_by": "test", 559 "sample_count": 10, 560 "nested": {"key": "value"}, 561 } 562 563 # Create a proper mock that supports context manager protocol 564 mock_response = MagicMock() 565 mock_response.content = msgpack.packb(test_metadata) 566 mock_response.raise_for_status = Mock() 567 mock_response.__enter__ = Mock(return_value=mock_response) 568 mock_response.__exit__ = Mock(return_value=False) 569 570 with patch("atdata.dataset.requests.get", return_value=mock_response): 571 dataset = atdata.Dataset[SimpleSample]( 572 tar_path, 573 metadata_url="http://example.com/meta.msgpack", 574 ) 575 576 # Fetch metadata 577 metadata = dataset.metadata 578 579 assert metadata == test_metadata 580 assert metadata["version"] == "1.0.0" 581 assert metadata["nested"]["key"] == "value" 582 583 # Second access should use cache 584 metadata2 = dataset.metadata 585 assert metadata2 == test_metadata 586 587 588class TestParquetExport: 589 """Tests for Parquet export functionality.""" 590 591 def test_simple_parquet_export(self, tmp_path): 592 """Simple samples should export to Parquet correctly.""" 593 import pandas as pd 594 595 n_samples = 50 596 samples = create_simple_samples(n_samples) 597 598 tar_path = write_single_shard(tmp_path / "for_parquet.tar", samples) 599 dataset = atdata.Dataset[SimpleSample](tar_path) 600 601 parquet_path = tmp_path / "output.parquet" 602 dataset.to_parquet(parquet_path) 603 604 # Verify Parquet file 605 df = pd.read_parquet(parquet_path) 606 assert len(df) == n_samples 607 assert list(df.columns) == ["name", "value", "score", "active"] 608 assert df["name"].iloc[0] == "sample_0" 609 assert df["value"].iloc[0] == 0 610 611 def test_parquet_export_with_maxcount(self, tmp_path): 612 """Parquet export with maxcount should create segments.""" 613 import pandas as pd 614 615 n_samples = 45 616 maxcount = 10 617 samples = create_simple_samples(n_samples) 618 619 tar_path = write_single_shard(tmp_path / "segmented.tar", samples) 620 dataset = atdata.Dataset[SimpleSample](tar_path) 621 622 parquet_path = tmp_path / "segments.parquet" 623 dataset.to_parquet(parquet_path, maxcount=maxcount) 624 625 # Should create 5 segment files (45 samples / 10 per file) 626 segment_files = list(tmp_path.glob("segments-*.parquet")) 627 assert len(segment_files) == 5 628 629 # Total rows should match 630 total_rows = sum(len(pd.read_parquet(f)) for f in segment_files) 631 assert total_rows == n_samples 632 633 634class TestIterationModes: 635 """Tests for different iteration modes.""" 636 637 def test_ordered_is_deterministic(self, tmp_path): 638 """Ordered iteration should be deterministic across multiple passes.""" 639 n_samples = 30 640 samples = create_simple_samples(n_samples) 641 642 tar_path = write_single_shard(tmp_path / "ordered.tar", samples) 643 dataset = atdata.Dataset[SimpleSample](tar_path) 644 645 # Two passes should yield identical results 646 pass1 = [s.name for s in dataset.ordered(batch_size=None)] 647 pass2 = [s.name for s in dataset.ordered(batch_size=None)] 648 649 assert pass1 == pass2 650 651 def test_shuffled_changes_order(self, tmp_path): 652 """Shuffled iteration should change order (with high probability).""" 653 n_samples = 100 654 samples = create_simple_samples(n_samples) 655 656 tar_path = write_single_shard(tmp_path / "shuffle_test.tar", samples) 657 dataset = atdata.Dataset[SimpleSample](tar_path) 658 659 # Collect samples from multiple shuffled passes 660 passes = [] 661 for _ in range(3): 662 names = [] 663 for sample in dataset.shuffled(batch_size=None): 664 names.append(sample.name) 665 if len(names) >= n_samples: 666 break 667 passes.append(names) 668 669 # At least two passes should differ (very high probability with 100 samples) 670 # Note: This could theoretically fail, but probability is astronomically low 671 assert ( 672 passes[0] != passes[1] or passes[1] != passes[2] or passes[0] != passes[2] 673 ) 674 675 def test_batch_size_one(self, tmp_path): 676 """batch_size=1 should return single-element batches.""" 677 n_samples = 10 678 samples = create_simple_samples(n_samples) 679 680 tar_path = write_single_shard(tmp_path / "batch1.tar", samples) 681 dataset = atdata.Dataset[SimpleSample](tar_path) 682 683 batches = list(dataset.ordered(batch_size=1)) 684 685 assert len(batches) == n_samples 686 for batch in batches: 687 assert isinstance(batch, atdata.SampleBatch) 688 assert len(batch.samples) == 1