A loose federation of distributed, typed datasets
1"""End-to-end integration tests for atdata data flow pipeline.
2
3Tests the complete workflow: Create → Store → Load → Iterate → Verify.
4
5These tests verify:
6- Full pipeline with various sample types
7- Multi-shard datasets with brace notation
8- Large batch handling and memory efficiency
9- Metadata round-trip preservation
10- Parquet export with transformations
11"""
12
13from dataclasses import dataclass
14from pathlib import Path
15
16import numpy as np
17from numpy.typing import NDArray
18import webdataset as wds
19
20import atdata
21
22
23##
24# Test sample types
25
26
27@atdata.packable
28class SimpleSample:
29 """Basic sample with primitive types only."""
30
31 name: str
32 value: int
33 score: float
34 active: bool
35
36
37@atdata.packable
38class NDArraySample:
39 """Sample with multiple NDArray fields of different shapes."""
40
41 label: int
42 image: NDArray
43 features: NDArray
44
45
46@atdata.packable
47class OptionalNDArraySample:
48 """Sample with optional NDArray fields."""
49
50 label: int
51 image: NDArray
52 embeddings: NDArray | None = None
53
54
55@atdata.packable
56class BytesSample:
57 """Sample with bytes field."""
58
59 name: str
60 raw_data: bytes
61
62
63@atdata.packable
64class ListSample:
65 """Sample with list fields."""
66
67 tags: list[str]
68 scores: list[float]
69 ids: list[int]
70
71
72@dataclass
73class InheritanceSample(atdata.PackableSample):
74 """Sample using inheritance syntax instead of decorator."""
75
76 title: str
77 count: int
78 measurements: NDArray
79
80
81##
82# Helper functions
83
84
85def create_simple_samples(n: int) -> list[SimpleSample]:
86 """Create n simple samples with distinct values."""
87 return [
88 SimpleSample(
89 name=f"sample_{i}",
90 value=i * 10,
91 score=float(i) * 0.5,
92 active=(i % 2 == 0),
93 )
94 for i in range(n)
95 ]
96
97
98def create_ndarray_samples(n: int, img_shape: tuple = (64, 64)) -> list[NDArraySample]:
99 """Create n NDArray samples with distinct values."""
100 return [
101 NDArraySample(
102 label=i,
103 image=np.random.randn(*img_shape).astype(np.float32),
104 features=np.random.randn(128).astype(np.float32),
105 )
106 for i in range(n)
107 ]
108
109
110def create_optional_samples(
111 n: int, include_optional: bool
112) -> list[OptionalNDArraySample]:
113 """Create samples with or without optional embeddings."""
114 return [
115 OptionalNDArraySample(
116 label=i,
117 image=np.random.randn(32, 32).astype(np.float32),
118 embeddings=np.random.randn(64).astype(np.float32)
119 if include_optional
120 else None,
121 )
122 for i in range(n)
123 ]
124
125
126def write_single_shard(path: Path, samples: list) -> str:
127 """Write samples to a single tar file, return path."""
128 tar_path = path.as_posix()
129 with wds.writer.TarWriter(tar_path) as sink:
130 for sample in samples:
131 sink.write(sample.as_wds)
132 return tar_path
133
134
135def write_multi_shard(
136 base_path: Path,
137 samples: list,
138 samples_per_shard: int = 10,
139) -> tuple[str, int]:
140 """Write samples to multiple shards, return brace pattern and shard count."""
141 pattern = (base_path / "shard-%06d.tar").as_posix()
142 with wds.writer.ShardWriter(pattern=pattern, maxcount=samples_per_shard) as sink:
143 for sample in samples:
144 sink.write(sample.as_wds)
145
146 n_shards = (len(samples) + samples_per_shard - 1) // samples_per_shard
147 brace_pattern = (base_path / f"shard-{{000000..{n_shards - 1:06d}}}.tar").as_posix()
148 return brace_pattern, n_shards
149
150
151##
152# Full Pipeline Tests
153
154
155class TestFullPipelineSimple:
156 """End-to-end tests with simple primitive-only samples."""
157
158 def test_create_store_load_iterate_single_shard(self, tmp_path):
159 """Full pipeline: create → store → load → iterate (single shard)."""
160 n_samples = 50
161 samples = create_simple_samples(n_samples)
162
163 # Store
164 tar_path = write_single_shard(tmp_path / "simple.tar", samples)
165
166 # Load
167 dataset = atdata.Dataset[SimpleSample](tar_path)
168
169 # Iterate without batching
170 loaded = list(dataset.ordered(batch_size=None))
171
172 # Verify
173 assert len(loaded) == n_samples
174 for i, sample in enumerate(loaded):
175 assert isinstance(sample, SimpleSample)
176 assert sample.name == f"sample_{i}"
177 assert sample.value == i * 10
178 assert sample.score == float(i) * 0.5
179 assert sample.active == (i % 2 == 0)
180
181 def test_create_store_load_iterate_batched(self, tmp_path):
182 """Full pipeline with batching."""
183 n_samples = 100
184 batch_size = 16
185 samples = create_simple_samples(n_samples)
186
187 tar_path = write_single_shard(tmp_path / "batched.tar", samples)
188 dataset = atdata.Dataset[SimpleSample](tar_path)
189
190 # Iterate with batching
191 batches = list(dataset.ordered(batch_size=batch_size))
192
193 # Verify batch structure (WebDataset drops incomplete final batch)
194 total_samples = sum(len(b.samples) for b in batches)
195 assert total_samples >= (n_samples // batch_size) * batch_size
196
197 for batch in batches:
198 assert isinstance(batch, atdata.SampleBatch)
199 assert batch.sample_type == SimpleSample
200 assert len(batch.samples) <= batch_size
201
202 # Verify aggregated attributes
203 names = batch.name
204 values = batch.value
205 assert isinstance(names, list)
206 assert isinstance(values, list)
207 assert len(names) == len(batch.samples)
208 assert len(values) == len(batch.samples)
209
210 def test_inheritance_syntax_pipeline(self, tmp_path):
211 """Full pipeline using inheritance-style sample definition."""
212 n_samples = 25
213 samples = [
214 InheritanceSample(
215 title=f"doc_{i}",
216 count=i * 5,
217 measurements=np.random.randn(10).astype(np.float32),
218 )
219 for i in range(n_samples)
220 ]
221
222 tar_path = write_single_shard(tmp_path / "inheritance.tar", samples)
223 dataset = atdata.Dataset[InheritanceSample](tar_path)
224
225 loaded = list(dataset.ordered(batch_size=None))
226
227 assert len(loaded) == n_samples
228 for i, sample in enumerate(loaded):
229 assert isinstance(sample, InheritanceSample)
230 assert sample.title == f"doc_{i}"
231 assert sample.count == i * 5
232 assert isinstance(sample.measurements, np.ndarray)
233
234
235class TestFullPipelineNDArray:
236 """End-to-end tests with NDArray samples."""
237
238 def test_ndarray_serialization_roundtrip(self, tmp_path):
239 """NDArray fields should serialize and deserialize exactly."""
240 n_samples = 20
241 samples = create_ndarray_samples(n_samples, img_shape=(32, 32))
242
243 tar_path = write_single_shard(tmp_path / "ndarray.tar", samples)
244 dataset = atdata.Dataset[NDArraySample](tar_path)
245
246 loaded = list(dataset.ordered(batch_size=None))
247
248 assert len(loaded) == n_samples
249 for original, loaded_sample in zip(samples, loaded):
250 assert loaded_sample.label == original.label
251 np.testing.assert_array_almost_equal(loaded_sample.image, original.image)
252 np.testing.assert_array_almost_equal(
253 loaded_sample.features, original.features
254 )
255
256 def test_ndarray_batch_stacking(self, tmp_path):
257 """NDArray fields should stack into batch dimension."""
258 n_samples = 32
259 batch_size = 8
260 img_shape = (16, 16)
261 feature_dim = 64
262
263 samples = [
264 NDArraySample(
265 label=i,
266 image=np.full(img_shape, i, dtype=np.float32),
267 features=np.full(feature_dim, i * 0.1, dtype=np.float32),
268 )
269 for i in range(n_samples)
270 ]
271
272 tar_path = write_single_shard(tmp_path / "stacking.tar", samples)
273 dataset = atdata.Dataset[NDArraySample](tar_path)
274
275 batches = list(dataset.ordered(batch_size=batch_size))
276
277 for batch_idx, batch in enumerate(batches):
278 # Check stacked shapes
279 assert batch.image.shape == (batch_size, *img_shape)
280 assert batch.features.shape == (batch_size, feature_dim)
281
282 # Check values
283 for i in range(batch_size):
284 sample_idx = batch_idx * batch_size + i
285 np.testing.assert_array_equal(
286 batch.image[i],
287 np.full(img_shape, sample_idx, dtype=np.float32),
288 )
289
290 def test_optional_ndarray_with_values(self, tmp_path):
291 """Optional NDArray with actual values should roundtrip."""
292 n_samples = 15
293 samples = create_optional_samples(n_samples, include_optional=True)
294
295 tar_path = write_single_shard(tmp_path / "optional_filled.tar", samples)
296 dataset = atdata.Dataset[OptionalNDArraySample](tar_path)
297
298 loaded = list(dataset.ordered(batch_size=None))
299
300 for original, loaded_sample in zip(samples, loaded):
301 assert loaded_sample.embeddings is not None
302 np.testing.assert_array_almost_equal(
303 loaded_sample.embeddings,
304 original.embeddings,
305 )
306
307 def test_optional_ndarray_with_none(self, tmp_path):
308 """Optional NDArray with None should roundtrip."""
309 n_samples = 15
310 samples = create_optional_samples(n_samples, include_optional=False)
311
312 tar_path = write_single_shard(tmp_path / "optional_none.tar", samples)
313 dataset = atdata.Dataset[OptionalNDArraySample](tar_path)
314
315 loaded = list(dataset.ordered(batch_size=None))
316
317 for loaded_sample in loaded:
318 assert loaded_sample.embeddings is None
319
320 def test_mixed_dtypes(self, tmp_path):
321 """Various numpy dtypes should serialize correctly."""
322
323 @atdata.packable
324 class MultiDtypeSample:
325 f32: NDArray
326 f64: NDArray
327 i32: NDArray
328 i64: NDArray
329 u8: NDArray
330
331 samples = [
332 MultiDtypeSample(
333 f32=np.array([1.0, 2.0, 3.0], dtype=np.float32),
334 f64=np.array([1.0, 2.0, 3.0], dtype=np.float64),
335 i32=np.array([1, 2, 3], dtype=np.int32),
336 i64=np.array([1, 2, 3], dtype=np.int64),
337 u8=np.array([255, 128, 0], dtype=np.uint8),
338 )
339 for _ in range(10)
340 ]
341
342 tar_path = write_single_shard(tmp_path / "multidtype.tar", samples)
343 dataset = atdata.Dataset[MultiDtypeSample](tar_path)
344
345 loaded = list(dataset.ordered(batch_size=None))
346
347 for original, loaded_sample in zip(samples, loaded):
348 assert loaded_sample.f32.dtype == np.float32
349 assert loaded_sample.f64.dtype == np.float64
350 assert loaded_sample.i32.dtype == np.int32
351 assert loaded_sample.i64.dtype == np.int64
352 assert loaded_sample.u8.dtype == np.uint8
353 np.testing.assert_array_equal(loaded_sample.f32, original.f32)
354
355
356class TestMultiShardPipeline:
357 """End-to-end tests with multi-shard datasets using brace notation."""
358
359 def test_multi_shard_ordered_iteration(self, tmp_path):
360 """Multi-shard dataset should iterate all samples in order."""
361 n_samples = 100
362 samples_per_shard = 10
363 samples = create_simple_samples(n_samples)
364
365 brace_pattern, n_shards = write_multi_shard(
366 tmp_path,
367 samples,
368 samples_per_shard=samples_per_shard,
369 )
370
371 assert n_shards == 10
372
373 dataset = atdata.Dataset[SimpleSample](brace_pattern)
374 loaded = list(dataset.ordered(batch_size=None))
375
376 assert len(loaded) == n_samples
377
378 # Verify ordering within each shard
379 for i, sample in enumerate(loaded):
380 assert sample.name == f"sample_{i}"
381
382 def test_multi_shard_batched(self, tmp_path):
383 """Multi-shard dataset with batching should work correctly."""
384 n_samples = 120
385 samples_per_shard = 15
386 batch_size = 8
387 samples = create_simple_samples(n_samples)
388
389 brace_pattern, n_shards = write_multi_shard(
390 tmp_path,
391 samples,
392 samples_per_shard=samples_per_shard,
393 )
394
395 dataset = atdata.Dataset[SimpleSample](brace_pattern)
396 batches = list(dataset.ordered(batch_size=batch_size))
397
398 # Total samples should match
399 total_samples = sum(len(b.samples) for b in batches)
400 assert total_samples == (n_samples // batch_size) * batch_size
401
402 def test_multi_shard_shuffled(self, tmp_path):
403 """Multi-shard shuffled iteration should work."""
404 n_samples = 50
405 samples_per_shard = 10
406 samples = create_simple_samples(n_samples)
407
408 brace_pattern, _ = write_multi_shard(
409 tmp_path,
410 samples,
411 samples_per_shard=samples_per_shard,
412 )
413
414 dataset = atdata.Dataset[SimpleSample](brace_pattern)
415
416 # Collect some samples from shuffled iteration
417 shuffled_samples = []
418 for sample in dataset.shuffled(batch_size=None):
419 shuffled_samples.append(sample)
420 if len(shuffled_samples) >= 30:
421 break
422
423 assert len(shuffled_samples) == 30
424
425 # All samples should be valid SimpleSample instances
426 for sample in shuffled_samples:
427 assert isinstance(sample, SimpleSample)
428 assert sample.name.startswith("sample_")
429
430 def test_single_shard_via_brace_pattern(self, tmp_path):
431 """Single shard via brace pattern should work."""
432 n_samples = 25
433 samples = create_simple_samples(n_samples)
434
435 # Create exactly one shard
436 brace_pattern, n_shards = write_multi_shard(
437 tmp_path,
438 samples,
439 samples_per_shard=100, # More than samples, so single shard
440 )
441
442 assert n_shards == 1
443
444 dataset = atdata.Dataset[SimpleSample](brace_pattern)
445 loaded = list(dataset.ordered(batch_size=None))
446
447 assert len(loaded) == n_samples
448
449
450class TestLargeBatchHandling:
451 """Tests for handling large batches and many samples."""
452
453 def test_large_batch_size(self, tmp_path):
454 """Large batch sizes should work correctly."""
455 n_samples = 200
456 batch_size = 64
457 samples = create_simple_samples(n_samples)
458
459 tar_path = write_single_shard(tmp_path / "large_batch.tar", samples)
460 dataset = atdata.Dataset[SimpleSample](tar_path)
461
462 batches = list(dataset.ordered(batch_size=batch_size))
463
464 # Verify we got the expected number of complete batches
465 total_samples = sum(len(b.samples) for b in batches)
466 assert total_samples >= (n_samples // batch_size) * batch_size
467 for batch in batches:
468 assert len(batch.samples) <= batch_size
469
470 def test_many_samples_single_shard(self, tmp_path):
471 """Many samples in single shard should work."""
472 n_samples = 500
473 samples = create_simple_samples(n_samples)
474
475 tar_path = write_single_shard(tmp_path / "many.tar", samples)
476 dataset = atdata.Dataset[SimpleSample](tar_path)
477
478 loaded = list(dataset.ordered(batch_size=None))
479 assert len(loaded) == n_samples
480
481 def test_large_ndarray_samples(self, tmp_path):
482 """Large NDArray fields should serialize correctly."""
483 n_samples = 10
484 large_shape = (256, 256) # Larger images
485
486 samples = create_ndarray_samples(n_samples, img_shape=large_shape)
487
488 tar_path = write_single_shard(tmp_path / "large_ndarray.tar", samples)
489 dataset = atdata.Dataset[NDArraySample](tar_path)
490
491 loaded = list(dataset.ordered(batch_size=None))
492
493 for original, loaded_sample in zip(samples, loaded):
494 assert loaded_sample.image.shape == large_shape
495 np.testing.assert_array_almost_equal(
496 loaded_sample.image,
497 original.image,
498 )
499
500
501class TestBytesAndListSamples:
502 """Tests for bytes and list field types."""
503
504 def test_bytes_field_roundtrip(self, tmp_path):
505 """Bytes fields should roundtrip correctly."""
506 samples = [
507 BytesSample(
508 name=f"item_{i}",
509 raw_data=f"binary_data_{i}".encode("utf-8"),
510 )
511 for i in range(20)
512 ]
513
514 tar_path = write_single_shard(tmp_path / "bytes.tar", samples)
515 dataset = atdata.Dataset[BytesSample](tar_path)
516
517 loaded = list(dataset.ordered(batch_size=None))
518
519 for original, loaded_sample in zip(samples, loaded):
520 assert loaded_sample.name == original.name
521 assert loaded_sample.raw_data == original.raw_data
522
523 def test_list_fields_roundtrip(self, tmp_path):
524 """List fields should roundtrip correctly."""
525 samples = [
526 ListSample(
527 tags=[f"tag_{j}" for j in range(3)],
528 scores=[float(j) * 0.1 for j in range(5)],
529 ids=[i * 10 + j for j in range(4)],
530 )
531 for i in range(15)
532 ]
533
534 tar_path = write_single_shard(tmp_path / "lists.tar", samples)
535 dataset = atdata.Dataset[ListSample](tar_path)
536
537 loaded = list(dataset.ordered(batch_size=None))
538
539 for original, loaded_sample in zip(samples, loaded):
540 assert loaded_sample.tags == original.tags
541 assert loaded_sample.scores == original.scores
542 assert loaded_sample.ids == original.ids
543
544
545class TestMetadataRoundTrip:
546 """Tests for metadata preservation through the pipeline."""
547
548 def test_dataset_with_metadata_url(self, tmp_path):
549 """Dataset with metadata_url should fetch and cache metadata."""
550 from unittest.mock import Mock, patch, MagicMock
551 import msgpack
552
553 samples = create_simple_samples(10)
554 tar_path = write_single_shard(tmp_path / "meta.tar", samples)
555
556 test_metadata = {
557 "version": "1.0.0",
558 "created_by": "test",
559 "sample_count": 10,
560 "nested": {"key": "value"},
561 }
562
563 # Create a proper mock that supports context manager protocol
564 mock_response = MagicMock()
565 mock_response.content = msgpack.packb(test_metadata)
566 mock_response.raise_for_status = Mock()
567 mock_response.__enter__ = Mock(return_value=mock_response)
568 mock_response.__exit__ = Mock(return_value=False)
569
570 with patch("atdata.dataset.requests.get", return_value=mock_response):
571 dataset = atdata.Dataset[SimpleSample](
572 tar_path,
573 metadata_url="http://example.com/meta.msgpack",
574 )
575
576 # Fetch metadata
577 metadata = dataset.metadata
578
579 assert metadata == test_metadata
580 assert metadata["version"] == "1.0.0"
581 assert metadata["nested"]["key"] == "value"
582
583 # Second access should use cache
584 metadata2 = dataset.metadata
585 assert metadata2 == test_metadata
586
587
588class TestParquetExport:
589 """Tests for Parquet export functionality."""
590
591 def test_simple_parquet_export(self, tmp_path):
592 """Simple samples should export to Parquet correctly."""
593 import pandas as pd
594
595 n_samples = 50
596 samples = create_simple_samples(n_samples)
597
598 tar_path = write_single_shard(tmp_path / "for_parquet.tar", samples)
599 dataset = atdata.Dataset[SimpleSample](tar_path)
600
601 parquet_path = tmp_path / "output.parquet"
602 dataset.to_parquet(parquet_path)
603
604 # Verify Parquet file
605 df = pd.read_parquet(parquet_path)
606 assert len(df) == n_samples
607 assert list(df.columns) == ["name", "value", "score", "active"]
608 assert df["name"].iloc[0] == "sample_0"
609 assert df["value"].iloc[0] == 0
610
611 def test_parquet_export_with_maxcount(self, tmp_path):
612 """Parquet export with maxcount should create segments."""
613 import pandas as pd
614
615 n_samples = 45
616 maxcount = 10
617 samples = create_simple_samples(n_samples)
618
619 tar_path = write_single_shard(tmp_path / "segmented.tar", samples)
620 dataset = atdata.Dataset[SimpleSample](tar_path)
621
622 parquet_path = tmp_path / "segments.parquet"
623 dataset.to_parquet(parquet_path, maxcount=maxcount)
624
625 # Should create 5 segment files (45 samples / 10 per file)
626 segment_files = list(tmp_path.glob("segments-*.parquet"))
627 assert len(segment_files) == 5
628
629 # Total rows should match
630 total_rows = sum(len(pd.read_parquet(f)) for f in segment_files)
631 assert total_rows == n_samples
632
633
634class TestIterationModes:
635 """Tests for different iteration modes."""
636
637 def test_ordered_is_deterministic(self, tmp_path):
638 """Ordered iteration should be deterministic across multiple passes."""
639 n_samples = 30
640 samples = create_simple_samples(n_samples)
641
642 tar_path = write_single_shard(tmp_path / "ordered.tar", samples)
643 dataset = atdata.Dataset[SimpleSample](tar_path)
644
645 # Two passes should yield identical results
646 pass1 = [s.name for s in dataset.ordered(batch_size=None)]
647 pass2 = [s.name for s in dataset.ordered(batch_size=None)]
648
649 assert pass1 == pass2
650
651 def test_shuffled_changes_order(self, tmp_path):
652 """Shuffled iteration should change order (with high probability)."""
653 n_samples = 100
654 samples = create_simple_samples(n_samples)
655
656 tar_path = write_single_shard(tmp_path / "shuffle_test.tar", samples)
657 dataset = atdata.Dataset[SimpleSample](tar_path)
658
659 # Collect samples from multiple shuffled passes
660 passes = []
661 for _ in range(3):
662 names = []
663 for sample in dataset.shuffled(batch_size=None):
664 names.append(sample.name)
665 if len(names) >= n_samples:
666 break
667 passes.append(names)
668
669 # At least two passes should differ (very high probability with 100 samples)
670 # Note: This could theoretically fail, but probability is astronomically low
671 assert (
672 passes[0] != passes[1] or passes[1] != passes[2] or passes[0] != passes[2]
673 )
674
675 def test_batch_size_one(self, tmp_path):
676 """batch_size=1 should return single-element batches."""
677 n_samples = 10
678 samples = create_simple_samples(n_samples)
679
680 tar_path = write_single_shard(tmp_path / "batch1.tar", samples)
681 dataset = atdata.Dataset[SimpleSample](tar_path)
682
683 batches = list(dataset.ordered(batch_size=1))
684
685 assert len(batches) == n_samples
686 for batch in batches:
687 assert isinstance(batch, atdata.SampleBatch)
688 assert len(batch.samples) == 1