A loose federation of distributed, typed datasets
1"""Test dataaset functionality."""
2
3##
4# Imports
5
6# Tests
7import pytest
8
9# System
10from dataclasses import dataclass
11
12# External
13import numpy as np
14import webdataset as wds
15
16# Local
17import atdata
18import atdata.dataset as atds
19
20# Typing
21from numpy.typing import NDArray
22from typing import (
23 Type,
24)
25
26
27##
28# Sample test cases
29
30
31@dataclass
32class BasicTestSample(atdata.PackableSample):
33 name: str
34 position: int
35 value: float
36
37
38@dataclass
39class NumpyTestSample(atdata.PackableSample):
40 label: int
41 image: NDArray
42
43
44@atdata.packable
45class BasicTestSampleDecorated:
46 name: str
47 position: int
48 value: float
49
50
51@atdata.packable
52class NumpyTestSampleDecorated:
53 label: int
54 image: NDArray
55
56
57@atdata.packable
58class NumpyOptionalSampleDecorated:
59 label: int
60 image: NDArray
61 embeddings: NDArray | None = None
62
63
64test_cases = [
65 {
66 "SampleType": BasicTestSample,
67 "sample_data": {
68 "name": "Hello, world!",
69 "position": 42,
70 "value": 1024.768,
71 },
72 "sample_wds_stem": "basic_test",
73 "test_parquet": True,
74 },
75 {
76 "SampleType": NumpyTestSample,
77 "sample_data": {
78 "label": 9_001,
79 "image": np.random.randn(1024, 1024),
80 },
81 "sample_wds_stem": "numpy_test",
82 "test_parquet": False,
83 },
84 {
85 "SampleType": BasicTestSampleDecorated,
86 "sample_data": {
87 "name": "Hello, world!",
88 "position": 42,
89 "value": 1024.768,
90 },
91 "sample_wds_stem": "basic_test_decorated",
92 "test_parquet": True,
93 },
94 {
95 "SampleType": NumpyTestSampleDecorated,
96 "sample_data": {
97 "label": 9_001,
98 "image": np.random.randn(1024, 1024),
99 },
100 "sample_wds_stem": "numpy_test_decorated",
101 "test_parquet": False,
102 },
103 {
104 "SampleType": NumpyOptionalSampleDecorated,
105 "sample_data": {
106 "label": 9_001,
107 "image": np.random.randn(1024, 1024),
108 "embeddings": np.random.randn(512),
109 },
110 "sample_wds_stem": "numpy_optional_decorated",
111 "test_parquet": False,
112 },
113 {
114 "SampleType": NumpyOptionalSampleDecorated,
115 "sample_data": {
116 "label": 9_001,
117 "image": np.random.randn(1024, 1024),
118 "embeddings": None,
119 },
120 "sample_wds_stem": "numpy_optional_decorated_none",
121 "test_parquet": False,
122 },
123]
124
125
126## Tests
127
128
129@pytest.mark.parametrize(
130 ("SampleType", "sample_data"),
131 [(case["SampleType"], case["sample_data"]) for case in test_cases],
132)
133def test_create_sample(
134 SampleType: Type[atdata.PackableSample],
135 sample_data: atds.WDSRawSample,
136):
137 """Test our ability to create samples from semi-structured data"""
138
139 sample = SampleType.from_data(sample_data)
140 assert isinstance(sample, SampleType), (
141 f"Did not properly form sample for test type {SampleType}"
142 )
143
144 for k, v in sample_data.items():
145 cur_assertion: bool
146 if isinstance(v, np.ndarray):
147 cur_assertion = np.all(getattr(sample, k) == v)
148 else:
149 cur_assertion = getattr(sample, k) == v
150 assert cur_assertion, (
151 f"Did not properly incorporate property {k} of test type {SampleType}"
152 )
153
154
155@pytest.mark.parametrize(
156 ("SampleType", "sample_data", "sample_wds_stem"),
157 [
158 (case["SampleType"], case["sample_data"], case["sample_wds_stem"])
159 for case in test_cases
160 ],
161)
162def test_wds(
163 SampleType: Type[atdata.PackableSample],
164 sample_data: atds.WDSRawSample,
165 sample_wds_stem: str,
166 tmp_path,
167):
168 """Test our ability to write samples as `WebDatasets` to disk"""
169
170 ## Testing hyperparameters
171
172 n_copies = 100
173 shard_maxcount = 10
174 batch_size = 4
175 n_iterate = 10
176
177 ## Write sharded dataset
178
179 file_pattern = (tmp_path / (f"{sample_wds_stem}" + "-{shard_id}.tar")).as_posix()
180 file_wds_pattern = file_pattern.format(shard_id="%06d")
181
182 with wds.writer.ShardWriter(
183 pattern=file_wds_pattern,
184 maxcount=shard_maxcount,
185 ) as sink:
186 for i_sample in range(n_copies):
187 new_sample = SampleType.from_data(sample_data)
188 assert isinstance(new_sample, SampleType), (
189 f"Did not properly form sample for test type {SampleType}"
190 )
191
192 sink.write(new_sample.as_wds)
193
194 ## Ordered
195
196 # Read first shard, no batches
197
198 first_filename = file_pattern.format(shard_id=f"{0:06d}")
199 dataset = atdata.Dataset[SampleType](first_filename)
200
201 iterations_run = 0
202 for i_iterate, cur_sample in enumerate(dataset.ordered(batch_size=None)):
203 assert isinstance(cur_sample, SampleType), (
204 f"Single sample for {SampleType} written to `wds` is of wrong type"
205 )
206
207 # Check sample values
208
209 for k, v in sample_data.items():
210 if isinstance(v, np.ndarray):
211 is_correct = np.all(getattr(cur_sample, k) == v)
212 else:
213 is_correct = getattr(cur_sample, k) == v
214 assert is_correct, (
215 f"{SampleType}: Incorrect sample value found for {k} - {type(getattr(cur_sample, k))}"
216 )
217
218 iterations_run += 1
219 if iterations_run >= n_iterate:
220 break
221
222 assert iterations_run == n_iterate, (
223 f"Only found {iterations_run} samples, not {n_iterate}"
224 )
225
226 # Read all shards, batches
227
228 start_id = f"{0:06d}"
229 end_id = f"{9:06d}"
230 first_filename = file_pattern.format(shard_id="{" + start_id + ".." + end_id + "}")
231 dataset = atdata.Dataset[SampleType](first_filename)
232
233 iterations_run = 0
234 for i_iterate, cur_batch in enumerate(dataset.ordered(batch_size=batch_size)):
235 assert isinstance(cur_batch, atdata.SampleBatch), (
236 f"{SampleType}: Batch sample is not correctly a batch"
237 )
238
239 assert cur_batch.sample_type == SampleType, (
240 f"{SampleType}: Batch `sample_type` is incorrect type"
241 )
242
243 if i_iterate == 0:
244 cur_n = len(cur_batch.samples)
245 assert cur_n == batch_size, (
246 f"{SampleType}: Batch has {cur_n} samples, not {batch_size}"
247 )
248
249 assert isinstance(cur_batch.samples[0], SampleType), (
250 f"{SampleType}: Batch sample of wrong type ({type(cur_batch.samples[0])})"
251 )
252
253 # Check batch values
254 for k, v in sample_data.items():
255 cur_batch_data = getattr(cur_batch, k)
256
257 if isinstance(v, np.ndarray):
258 assert isinstance(cur_batch_data, np.ndarray), (
259 f"{SampleType}: `NDArray` not carried through to batch"
260 )
261
262 is_correct = all(
263 [
264 np.all(cur_batch_data[i] == v)
265 for i in range(cur_batch_data.shape[0])
266 ]
267 )
268
269 else:
270 is_correct = all(
271 [cur_batch_data[i] == v for i in range(len(cur_batch_data))]
272 )
273
274 assert is_correct, f"{SampleType}: Incorrect sample value found for {k}"
275
276 iterations_run += 1
277 if iterations_run >= n_iterate:
278 break
279
280 assert iterations_run == n_iterate, (
281 f"Only found {iterations_run} samples, not {n_iterate}"
282 )
283
284 ## Shuffled
285
286 # Read first shard, no batches
287
288 first_filename = file_pattern.format(shard_id=f"{0:06d}")
289 dataset = atdata.Dataset[SampleType](first_filename)
290
291 iterations_run = 0
292 for i_iterate, cur_sample in enumerate(dataset.shuffled(batch_size=None)):
293 assert isinstance(cur_sample, SampleType), (
294 f"Single sample for {SampleType} written to `wds` is of wrong type"
295 )
296
297 iterations_run += 1
298 if iterations_run >= n_iterate:
299 break
300
301 assert iterations_run == n_iterate, (
302 f"Only found {iterations_run} samples, not {n_iterate}"
303 )
304
305 # Read all shards, batches
306
307 start_id = f"{0:06d}"
308 end_id = f"{9:06d}"
309 first_filename = file_pattern.format(shard_id="{" + start_id + ".." + end_id + "}")
310 dataset = atdata.Dataset[SampleType](first_filename)
311
312 iterations_run = 0
313 for i_iterate, cur_sample in enumerate(dataset.shuffled(batch_size=batch_size)):
314 assert isinstance(cur_sample, atdata.SampleBatch), (
315 f"{SampleType}: Batch sample is not correctly a batch"
316 )
317
318 assert cur_sample.sample_type == SampleType, (
319 f"{SampleType}: Batch `sample_type` is incorrect type"
320 )
321
322 if i_iterate == 0:
323 cur_n = len(cur_sample.samples)
324 assert cur_n == batch_size, (
325 f"{SampleType}: Batch has {cur_n} samples, not {batch_size}"
326 )
327
328 assert isinstance(cur_sample.samples[0], SampleType), (
329 f"{SampleType}: Batch sample of wrong type ({type(cur_sample.samples[0])})"
330 )
331
332 iterations_run += 1
333 if iterations_run >= n_iterate:
334 break
335
336 assert iterations_run == n_iterate, (
337 f"Only found {iterations_run} samples, not {n_iterate}"
338 )
339
340
341#
342
343
344@pytest.mark.parametrize(
345 ("SampleType", "sample_data", "sample_wds_stem", "test_parquet"),
346 [
347 (
348 case["SampleType"],
349 case["sample_data"],
350 case["sample_wds_stem"],
351 case["test_parquet"],
352 )
353 for case in test_cases
354 ],
355)
356def test_parquet_export(
357 SampleType: Type[atdata.PackableSample],
358 sample_data: atds.WDSRawSample,
359 sample_wds_stem: str,
360 test_parquet: bool,
361 tmp_path,
362):
363 """Test our ability to export a dataset to `parquet` format"""
364
365 # Skip irrelevant test cases
366 if not test_parquet:
367 return
368
369 ## Testing hyperparameters
370
371 n_copies_dataset = 1_000
372 n_per_file = 100
373
374 ## Start out by writing tar dataset
375
376 wds_filename = (tmp_path / f"{sample_wds_stem}.tar").as_posix()
377 with wds.writer.TarWriter(wds_filename) as sink:
378 for _ in range(n_copies_dataset):
379 new_sample = SampleType.from_data(sample_data)
380 sink.write(new_sample.as_wds)
381
382 ## Now export to `parquet`
383
384 dataset = atdata.Dataset[SampleType](wds_filename)
385 parquet_filename = tmp_path / f"{sample_wds_stem}.parquet"
386 dataset.to_parquet(parquet_filename)
387
388 parquet_filename = tmp_path / f"{sample_wds_stem}-segments.parquet"
389 dataset.to_parquet(parquet_filename, maxcount=n_per_file)
390
391
392##
393# Edge case tests for coverage
394
395
396def test_batch_aggregate_empty():
397 """Test _batch_aggregate with empty list returns empty list."""
398 result = atds._batch_aggregate([])
399 assert result == [], "Empty input should return empty list"
400
401
402def test_sample_batch_attribute_error():
403 """Test SampleBatch raises AttributeError for non-existent attributes."""
404
405 @atdata.packable
406 class SimpleSample:
407 name: str
408 value: int
409
410 samples = [SimpleSample(name="test", value=1)]
411 batch = atdata.SampleBatch[SimpleSample](samples)
412
413 with pytest.raises(AttributeError, match="No sample attribute named"):
414 _ = batch.nonexistent_attribute
415
416
417def test_sample_batch_type_property():
418 """Test SampleBatch.sample_type property."""
419
420 @atdata.packable
421 class TypedSample:
422 data: str
423
424 samples = [TypedSample(data="hello")]
425 batch = atdata.SampleBatch[TypedSample](samples)
426
427 assert batch.sample_type == TypedSample
428
429
430def test_dataset_batch_type_property(tmp_path):
431 """Test Dataset.batch_type property."""
432
433 @atdata.packable
434 class BatchTypeSample:
435 value: int
436
437 # Create a simple dataset
438 wds_filename = (tmp_path / "batch_type_test.tar").as_posix()
439 with wds.writer.TarWriter(wds_filename) as sink:
440 sample = BatchTypeSample(value=42)
441 sink.write(sample.as_wds)
442
443 dataset = atdata.Dataset[BatchTypeSample](wds_filename)
444 batch_type = dataset.batch_type
445
446 # batch_type should be SampleBatch parameterized with the sample type
447 assert batch_type.__origin__ == atdata.SampleBatch
448
449
450def test_dataset_shard_list_property(tmp_path):
451 """Test Dataset.shard_list property returns list of shard URLs."""
452
453 @atdata.packable
454 class ShardListSample:
455 value: int
456
457 # Create multiple shards
458 file_pattern = (tmp_path / "shards_test-%06d.tar").as_posix()
459 with wds.writer.ShardWriter(pattern=file_pattern, maxcount=5) as sink:
460 for i in range(15):
461 sample = ShardListSample(value=i)
462 sink.write(sample.as_wds)
463
464 # Read with brace pattern
465 brace_pattern = (tmp_path / "shards_test-{000000..000002}.tar").as_posix()
466 dataset = atdata.Dataset[ShardListSample](brace_pattern)
467
468 shard_list = dataset.shard_list
469 assert isinstance(shard_list, list)
470 assert len(shard_list) == 3
471
472
473def test_dataset_metadata_property(tmp_path):
474 """Test Dataset.metadata property fetches and caches metadata from URL."""
475 from unittest.mock import patch, Mock
476 import msgpack
477
478 @atdata.packable
479 class MetadataSample:
480 value: int
481
482 # Create a simple dataset
483 wds_filename = (tmp_path / "metadata_test.tar").as_posix()
484 with wds.writer.TarWriter(wds_filename) as sink:
485 sample = MetadataSample(value=42)
486 sink.write(sample.as_wds)
487
488 # Mock the requests.get call
489 mock_metadata = {"key": "value", "count": 100}
490 mock_response = Mock()
491 mock_response.content = msgpack.packb(mock_metadata)
492 mock_response.raise_for_status = Mock()
493 mock_response.__enter__ = Mock(return_value=mock_response)
494 mock_response.__exit__ = Mock(return_value=False)
495
496 with patch("atdata.dataset.requests.get", return_value=mock_response) as mock_get:
497 dataset = atdata.Dataset[MetadataSample](
498 wds_filename, metadata_url="http://example.com/metadata.msgpack"
499 )
500
501 # First call should fetch
502 metadata = dataset.metadata
503 assert metadata == mock_metadata
504 mock_get.assert_called_once_with(
505 "http://example.com/metadata.msgpack", stream=True
506 )
507
508 # Second call should use cache
509 metadata2 = dataset.metadata
510 assert metadata2 == mock_metadata
511 assert mock_get.call_count == 1 # Still only one call
512
513
514def test_dataset_metadata_property_none(tmp_path):
515 """Test Dataset.metadata returns None when no metadata_url is set."""
516
517 @atdata.packable
518 class NoMetadataSample:
519 value: int
520
521 wds_filename = (tmp_path / "no_metadata_test.tar").as_posix()
522 with wds.writer.TarWriter(wds_filename) as sink:
523 sample = NoMetadataSample(value=42)
524 sink.write(sample.as_wds)
525
526 dataset = atdata.Dataset[NoMetadataSample](wds_filename)
527 assert dataset.metadata is None
528
529
530def test_parquet_export_with_remainder(tmp_path):
531 """Test parquet export with maxcount that doesn't divide evenly."""
532
533 @atdata.packable
534 class RemainderSample:
535 name: str
536 value: int
537
538 # Create dataset with 25 samples
539 n_samples = 25
540 maxcount = 10 # Will create 3 segments: 10, 10, 5
541
542 wds_filename = (tmp_path / "remainder_test.tar").as_posix()
543 with wds.writer.TarWriter(wds_filename) as sink:
544 for i in range(n_samples):
545 sample = RemainderSample(name=f"sample_{i}", value=i)
546 sink.write(sample.as_wds)
547
548 dataset = atdata.Dataset[RemainderSample](wds_filename)
549 parquet_path = tmp_path / "remainder_output.parquet"
550 dataset.to_parquet(parquet_path, maxcount=maxcount)
551
552 # Should have created 3 segment files
553 import pandas as pd
554
555 segment_files = list(tmp_path.glob("remainder_output-*.parquet"))
556 assert len(segment_files) == 3
557
558 # Check total row count
559 total_rows = sum(len(pd.read_parquet(f)) for f in segment_files)
560 assert total_rows == n_samples
561
562
563def test_dataset_with_lens_batched(tmp_path):
564 """Test dataset iteration with lens transformation in batch mode."""
565 from dataclasses import dataclass
566
567 @dataclass
568 class SourceSample(atdata.PackableSample):
569 name: str
570 age: int
571 score: float
572
573 @dataclass
574 class ViewSample(atdata.PackableSample):
575 name: str
576 score: float
577
578 @atdata.lens
579 def extract_view(s: SourceSample) -> ViewSample:
580 return ViewSample(name=s.name, score=s.score)
581
582 # Create dataset
583 n_samples = 20
584 batch_size = 4
585 wds_filename = (tmp_path / "lens_batch_test.tar").as_posix()
586
587 with wds.writer.TarWriter(wds_filename) as sink:
588 for i in range(n_samples):
589 sample = SourceSample(name=f"person_{i}", age=20 + i, score=float(i) * 1.5)
590 sink.write(sample.as_wds)
591
592 # Read with lens transformation in batch mode
593 dataset = atdata.Dataset[SourceSample](wds_filename).as_type(ViewSample)
594
595 batches_seen = 0
596 for batch in dataset.ordered(batch_size=batch_size):
597 assert isinstance(batch, atdata.SampleBatch)
598 assert batch.sample_type == ViewSample
599
600 # Check that samples are ViewSample type (not SourceSample)
601 for sample in batch.samples:
602 assert isinstance(sample, ViewSample)
603 assert hasattr(sample, "name")
604 assert hasattr(sample, "score")
605 assert not hasattr(sample, "age") # age is not in ViewSample
606
607 batches_seen += 1
608
609 assert batches_seen == n_samples // batch_size
610
611
612def test_from_bytes_invalid_msgpack():
613 """Test from_bytes raises on invalid msgpack data."""
614
615 @atdata.packable
616 class SimpleSample:
617 value: int
618
619 with pytest.raises(Exception): # ormsgpack raises on invalid data
620 SimpleSample.from_bytes(b"not valid msgpack data")
621
622
623def test_from_bytes_missing_field():
624 """Test from_bytes raises when required field is missing."""
625
626 @atdata.packable
627 class RequiredFieldSample:
628 name: str
629 count: int
630
631 import ormsgpack
632
633 # Only provide 'name', missing 'count'
634 incomplete_data = ormsgpack.packb({"name": "test"})
635
636 with pytest.raises(TypeError): # Missing required argument
637 RequiredFieldSample.from_bytes(incomplete_data)
638
639
640def test_wrap_missing_msgpack_key(tmp_path):
641 """Test wrap raises ValueError on sample missing msgpack key."""
642
643 @atdata.packable
644 class WrapTestSample:
645 value: int
646
647 wds_filename = (tmp_path / "wrap_test.tar").as_posix()
648 with wds.writer.TarWriter(wds_filename) as sink:
649 sample = WrapTestSample(value=42)
650 sink.write(sample.as_wds)
651
652 dataset = atdata.Dataset[WrapTestSample](wds_filename)
653
654 # Directly call wrap with missing key
655 with pytest.raises(ValueError, match="missing 'msgpack' key"):
656 dataset.wrap({"__key__": "test"}) # Missing 'msgpack' key
657
658
659def test_wrap_wrong_msgpack_type(tmp_path):
660 """Test wrap raises ValueError when msgpack value is not bytes."""
661
662 @atdata.packable
663 class WrapTypeSample:
664 value: int
665
666 wds_filename = (tmp_path / "wrap_type_test.tar").as_posix()
667 with wds.writer.TarWriter(wds_filename) as sink:
668 sample = WrapTypeSample(value=42)
669 sink.write(sample.as_wds)
670
671 dataset = atdata.Dataset[WrapTypeSample](wds_filename)
672
673 # Directly call wrap with wrong type
674 with pytest.raises(ValueError, match="to be bytes"):
675 dataset.wrap({"__key__": "test", "msgpack": "not bytes"})
676
677
678def test_wrap_corrupted_msgpack(tmp_path):
679 """Test wrap raises on corrupted msgpack bytes."""
680
681 @atdata.packable
682 class CorruptedSample:
683 value: int
684
685 wds_filename = (tmp_path / "corrupted_test.tar").as_posix()
686 with wds.writer.TarWriter(wds_filename) as sink:
687 sample = CorruptedSample(value=42)
688 sink.write(sample.as_wds)
689
690 dataset = atdata.Dataset[CorruptedSample](wds_filename)
691
692 # Corrupted msgpack bytes should raise during deserialization
693 with pytest.raises(Exception): # ormsgpack raises on corrupted data
694 dataset.wrap({"__key__": "test", "msgpack": b"\xff\xfe\x00\x01invalid"})
695
696
697def test_dataset_nonexistent_file():
698 """Test Dataset raises on nonexistent tar file during iteration."""
699
700 @atdata.packable
701 class NonexistentSample:
702 value: int
703
704 dataset = atdata.Dataset[NonexistentSample]("/nonexistent/path/data.tar")
705
706 # Dataset creation succeeds (lazy loading)
707 assert dataset is not None
708
709 # Iteration fails when file doesn't exist
710 with pytest.raises(Exception): # FileNotFoundError or similar
711 list(dataset.ordered(batch_size=None))
712
713
714def test_dataset_invalid_batch_size(tmp_path):
715 """Test Dataset raises on invalid batch_size values."""
716
717 @atdata.packable
718 class BatchSizeSample:
719 value: int
720
721 wds_filename = (tmp_path / "batch_test.tar").as_posix()
722 with wds.writer.TarWriter(wds_filename) as sink:
723 sample = BatchSizeSample(value=42)
724 sink.write(sample.as_wds)
725
726 dataset = atdata.Dataset[BatchSizeSample](wds_filename)
727
728 # batch_size=0 produces empty batches, causing IndexError in webdataset
729 with pytest.raises((ValueError, AssertionError, IndexError)):
730 list(dataset.ordered(batch_size=0))
731
732
733##
734# DictSample tests
735
736
737def test_dictsample_creation():
738 """Test DictSample can be created with keyword args or dict."""
739 # From keyword args
740 ds1 = atdata.DictSample(name="test", value=42)
741 assert ds1.name == "test"
742 assert ds1.value == 42
743
744 # From _data dict
745 ds2 = atdata.DictSample(_data={"name": "test2", "value": 100})
746 assert ds2.name == "test2"
747 assert ds2.value == 100
748
749
750def test_dictsample_getattr():
751 """Test DictSample attribute access."""
752 sample = atdata.DictSample(text="hello", label=1)
753
754 assert sample.text == "hello"
755 assert sample.label == 1
756
757 # Non-existent attribute raises AttributeError
758 with pytest.raises(AttributeError, match="has no field"):
759 _ = sample.nonexistent
760
761
762def test_dictsample_getitem():
763 """Test DictSample dict-style access."""
764 sample = atdata.DictSample(text="hello", label=1)
765
766 assert sample["text"] == "hello"
767 assert sample["label"] == 1
768
769 # Non-existent key raises KeyError
770 with pytest.raises(KeyError):
771 _ = sample["nonexistent"]
772
773
774def test_dictsample_dict_methods():
775 """Test DictSample dict-like methods."""
776 sample = atdata.DictSample(a=1, b=2, c=3)
777
778 assert set(sample.keys()) == {"a", "b", "c"}
779 assert set(sample.values()) == {1, 2, 3}
780 assert set(sample.items()) == {("a", 1), ("b", 2), ("c", 3)}
781 assert "a" in sample
782 assert "x" not in sample
783 assert sample.get("a") == 1
784 assert sample.get("x", "default") == "default"
785
786
787def test_dictsample_to_dict():
788 """Test DictSample.to_dict returns a copy."""
789 sample = atdata.DictSample(name="test", value=42)
790 d = sample.to_dict()
791
792 assert d == {"name": "test", "value": 42}
793 # Should be a copy
794 d["name"] = "modified"
795 assert sample.name == "test"
796
797
798def test_dictsample_serialization():
799 """Test DictSample can be serialized and deserialized."""
800 original = atdata.DictSample(text="hello", count=42)
801
802 # Serialize
803 packed = original.packed
804
805 # Deserialize
806 restored = atdata.DictSample.from_bytes(packed)
807
808 assert restored.text == "hello"
809 assert restored.count == 42
810
811
812def test_dictsample_as_wds():
813 """Test DictSample.as_wds produces valid WebDataset format."""
814 sample = atdata.DictSample(name="test", value=123)
815 wds_dict = sample.as_wds
816
817 assert "__key__" in wds_dict
818 assert "msgpack" in wds_dict
819 assert isinstance(wds_dict["msgpack"], bytes)
820
821
822def test_dictsample_repr():
823 """Test DictSample has a useful repr."""
824 sample = atdata.DictSample(name="test", value=42)
825 repr_str = repr(sample)
826
827 assert "DictSample" in repr_str
828 assert "name" in repr_str
829 assert "value" in repr_str
830
831
832def test_dictsample_dataset_iteration(tmp_path):
833 """Test Dataset[DictSample] can iterate over data."""
834
835 # Create typed sample data
836 @atdata.packable
837 class SourceSample:
838 text: str
839 label: int
840
841 wds_filename = (tmp_path / "dictsample_test.tar").as_posix()
842 with wds.writer.TarWriter(wds_filename) as sink:
843 for i in range(5):
844 sample = SourceSample(text=f"item_{i}", label=i)
845 sink.write(sample.as_wds)
846
847 # Read as DictSample
848 dataset = atdata.Dataset[atdata.DictSample](wds_filename)
849
850 samples = list(dataset.ordered())
851 assert len(samples) == 5
852
853 for i, sample in enumerate(samples):
854 assert isinstance(sample, atdata.DictSample)
855 assert sample.text == f"item_{i}"
856 assert sample["label"] == i
857
858
859def test_dictsample_to_typed_via_as_type(tmp_path):
860 """Test converting DictSample dataset to typed via as_type."""
861
862 @atdata.packable
863 class TypedSample:
864 text: str
865 label: int
866
867 # Create data using typed sample
868 wds_filename = (tmp_path / "astype_test.tar").as_posix()
869 with wds.writer.TarWriter(wds_filename) as sink:
870 for i in range(5):
871 sample = TypedSample(text=f"item_{i}", label=i)
872 sink.write(sample.as_wds)
873
874 # Load as DictSample first
875 ds_dict = atdata.Dataset[atdata.DictSample](wds_filename)
876
877 # Convert to typed
878 ds_typed = ds_dict.as_type(TypedSample)
879
880 # Verify typed iteration works
881 samples = list(ds_typed.ordered())
882 assert len(samples) == 5
883
884 for i, sample in enumerate(samples):
885 assert isinstance(sample, TypedSample)
886 assert sample.text == f"item_{i}"
887 assert sample.label == i
888
889
890def test_packable_auto_registers_dictsample_lens():
891 """Test @packable decorator auto-registers lens from DictSample."""
892
893 @atdata.packable
894 class AutoLensSample:
895 name: str
896 value: int
897
898 # The lens should be registered automatically
899 network = atdata.LensNetwork()
900 lens = network.transform(atdata.DictSample, AutoLensSample)
901
902 # Test the lens works
903 dict_sample = atdata.DictSample(name="test", value=42)
904 typed_sample = lens(dict_sample)
905
906 assert isinstance(typed_sample, AutoLensSample)
907 assert typed_sample.name == "test"
908 assert typed_sample.value == 42
909
910
911def test_dictsample_batched_iteration(tmp_path):
912 """Test Dataset[DictSample] works with batched iteration."""
913
914 @atdata.packable
915 class BatchSource:
916 text: str
917 value: int
918
919 wds_filename = (tmp_path / "batch_dictsample_test.tar").as_posix()
920 with wds.writer.TarWriter(wds_filename) as sink:
921 for i in range(10):
922 sample = BatchSource(text=f"item_{i}", value=i)
923 sink.write(sample.as_wds)
924
925 # Read as DictSample with batching
926 dataset = atdata.Dataset[atdata.DictSample](wds_filename)
927
928 batch_count = 0
929 for batch in dataset.ordered(batch_size=4):
930 assert isinstance(batch, atdata.SampleBatch)
931 assert len(batch.samples) <= 4
932 for sample in batch.samples:
933 assert isinstance(sample, atdata.DictSample)
934 batch_count += 1
935
936 assert batch_count == 3 # 10 samples / 4 per batch = 2 full + 1 partial
937
938
939##