A loose federation of distributed, typed datasets
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

Merge branch 'feature/gh-28-dataset-iterator-overloading' into release/v0.2.3b1

+53 -12
.chainlink/issues.db

This is a binary file and will not be displayed.

+1
CHANGELOG.md
··· 25 25 - **Comprehensive integration test suite**: 593 tests covering E2E flows, error handling, edge cases 26 26 27 27 ### Changed 28 + - Fix type signatures for Dataset.ordered and Dataset.shuffled (GH#28) (#404) 28 29 - Investigate quartodoc Example section rendering - missing CSS classes on pre/code tags (#401) 29 30 - Update all docstrings from Example: to Examples: format (#403) 30 31 - Create GitHub issues for v0.3 roadmap feature domains (#402)
+52 -12
src/atdata/dataset.py
··· 64 64 TypeVar, 65 65 TypeAlias, 66 66 dataclass_transform, 67 + overload, 68 + Literal, 67 69 ) 68 70 from numpy.typing import NDArray 69 71 ··· 721 723 # Use our cached values 722 724 return self._metadata 723 725 726 + @overload 727 + def ordered( self, 728 + batch_size: None = None, 729 + ) -> Iterable[ST]: ... 730 + 731 + @overload 732 + def ordered( self, 733 + batch_size: int, 734 + ) -> Iterable[SampleBatch[ST]]: ... 735 + 724 736 def ordered( self, 725 737 batch_size: int | None = None, 726 - ) -> Iterable[ST]: 727 - """Iterate over the dataset in order 738 + ) -> Iterable[ST] | Iterable[SampleBatch[ST]]: 739 + """Iterate over the dataset in order. 728 740 729 741 Args: 730 - batch_size (:obj:`int`, optional): The size of iterated batches. 731 - Default: None (unbatched). If ``None``, iterates over one 732 - sample at a time with no batch dimension. 742 + batch_size: The size of iterated batches. Default: None (unbatched). 743 + If ``None``, iterates over one sample at a time with no batch 744 + dimension. 733 745 734 746 Returns: 735 - :obj:`webdataset.DataPipeline` A data pipeline that iterates over 736 - the dataset in its original sample order 747 + A data pipeline that iterates over the dataset in its original 748 + sample order. When ``batch_size`` is ``None``, yields individual 749 + samples of type ``ST``. When ``batch_size`` is an integer, yields 750 + ``SampleBatch[ST]`` instances containing that many samples. 737 751 752 + Examples: 753 + >>> for sample in ds.ordered(): 754 + ... process(sample) # sample is ST 755 + >>> for batch in ds.ordered(batch_size=32): 756 + ... process(batch) # batch is SampleBatch[ST] 738 757 """ 739 758 if batch_size is None: 740 759 return wds.pipeline.DataPipeline( ··· 756 775 wds.filters.map( self.wrap_batch ), 757 776 ) 758 777 778 + @overload 779 + def shuffled( self, 780 + buffer_shards: int = 100, 781 + buffer_samples: int = 10_000, 782 + batch_size: None = None, 783 + ) -> Iterable[ST]: ... 784 + 785 + @overload 786 + def shuffled( self, 787 + buffer_shards: int = 100, 788 + buffer_samples: int = 10_000, 789 + *, 790 + batch_size: int, 791 + ) -> Iterable[SampleBatch[ST]]: ... 792 + 759 793 def shuffled( self, 760 794 buffer_shards: int = 100, 761 795 buffer_samples: int = 10_000, 762 796 batch_size: int | None = None, 763 - ) -> Iterable[ST]: 797 + ) -> Iterable[ST] | Iterable[SampleBatch[ST]]: 764 798 """Iterate over the dataset in random order. 765 799 766 800 Args: ··· 775 809 dimension. 776 810 777 811 Returns: 778 - A WebDataset data pipeline that iterates over the dataset in 779 - randomized order. If ``batch_size`` is not ``None``, yields 780 - ``SampleBatch[ST]`` instances; otherwise yields individual ``ST`` 781 - samples. 812 + A data pipeline that iterates over the dataset in randomized order. 813 + When ``batch_size`` is ``None``, yields individual samples of type 814 + ``ST``. When ``batch_size`` is an integer, yields ``SampleBatch[ST]`` 815 + instances containing that many samples. 816 + 817 + Examples: 818 + >>> for sample in ds.shuffled(): 819 + ... process(sample) # sample is ST 820 + >>> for batch in ds.shuffled(batch_size=32): 821 + ... process(batch) # batch is SampleBatch[ST] 782 822 """ 783 823 if batch_size is None: 784 824 return wds.pipeline.DataPipeline(