···2525- **Comprehensive integration test suite**: 593 tests covering E2E flows, error handling, edge cases
26262727### Changed
2828+- Fix type signatures for Dataset.ordered and Dataset.shuffled (GH#28) (#404)
2829- Investigate quartodoc Example section rendering - missing CSS classes on pre/code tags (#401)
2930- Update all docstrings from Example: to Examples: format (#403)
3031- Create GitHub issues for v0.3 roadmap feature domains (#402)
+52-12
src/atdata/dataset.py
···6464 TypeVar,
6565 TypeAlias,
6666 dataclass_transform,
6767+ overload,
6868+ Literal,
6769)
6870from numpy.typing import NDArray
6971···721723 # Use our cached values
722724 return self._metadata
723725726726+ @overload
727727+ def ordered( self,
728728+ batch_size: None = None,
729729+ ) -> Iterable[ST]: ...
730730+731731+ @overload
732732+ def ordered( self,
733733+ batch_size: int,
734734+ ) -> Iterable[SampleBatch[ST]]: ...
735735+724736 def ordered( self,
725737 batch_size: int | None = None,
726726- ) -> Iterable[ST]:
727727- """Iterate over the dataset in order
738738+ ) -> Iterable[ST] | Iterable[SampleBatch[ST]]:
739739+ """Iterate over the dataset in order.
728740729741 Args:
730730- batch_size (:obj:`int`, optional): The size of iterated batches.
731731- Default: None (unbatched). If ``None``, iterates over one
732732- sample at a time with no batch dimension.
742742+ batch_size: The size of iterated batches. Default: None (unbatched).
743743+ If ``None``, iterates over one sample at a time with no batch
744744+ dimension.
733745734746 Returns:
735735- :obj:`webdataset.DataPipeline` A data pipeline that iterates over
736736- the dataset in its original sample order
747747+ A data pipeline that iterates over the dataset in its original
748748+ sample order. When ``batch_size`` is ``None``, yields individual
749749+ samples of type ``ST``. When ``batch_size`` is an integer, yields
750750+ ``SampleBatch[ST]`` instances containing that many samples.
737751752752+ Examples:
753753+ >>> for sample in ds.ordered():
754754+ ... process(sample) # sample is ST
755755+ >>> for batch in ds.ordered(batch_size=32):
756756+ ... process(batch) # batch is SampleBatch[ST]
738757 """
739758 if batch_size is None:
740759 return wds.pipeline.DataPipeline(
···756775 wds.filters.map( self.wrap_batch ),
757776 )
758777778778+ @overload
779779+ def shuffled( self,
780780+ buffer_shards: int = 100,
781781+ buffer_samples: int = 10_000,
782782+ batch_size: None = None,
783783+ ) -> Iterable[ST]: ...
784784+785785+ @overload
786786+ def shuffled( self,
787787+ buffer_shards: int = 100,
788788+ buffer_samples: int = 10_000,
789789+ *,
790790+ batch_size: int,
791791+ ) -> Iterable[SampleBatch[ST]]: ...
792792+759793 def shuffled( self,
760794 buffer_shards: int = 100,
761795 buffer_samples: int = 10_000,
762796 batch_size: int | None = None,
763763- ) -> Iterable[ST]:
797797+ ) -> Iterable[ST] | Iterable[SampleBatch[ST]]:
764798 """Iterate over the dataset in random order.
765799766800 Args:
···775809 dimension.
776810777811 Returns:
778778- A WebDataset data pipeline that iterates over the dataset in
779779- randomized order. If ``batch_size`` is not ``None``, yields
780780- ``SampleBatch[ST]`` instances; otherwise yields individual ``ST``
781781- samples.
812812+ A data pipeline that iterates over the dataset in randomized order.
813813+ When ``batch_size`` is ``None``, yields individual samples of type
814814+ ``ST``. When ``batch_size`` is an integer, yields ``SampleBatch[ST]``
815815+ instances containing that many samples.
816816+817817+ Examples:
818818+ >>> for sample in ds.shuffled():
819819+ ... process(sample) # sample is ST
820820+ >>> for batch in ds.shuffled(batch_size=32):
821821+ ... process(batch) # batch is SampleBatch[ST]
782822 """
783823 if batch_size is None:
784824 return wds.pipeline.DataPipeline(