A loose federation of distributed, typed datasets
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

Merge pull request #25 from foundation-ac/release/v0.1.3b1

Release/v0.1.3b1

authored by

Maxine Levesque and committed by
GitHub
867a9c70 bee2b657

+77 -20
+1 -1
pyproject.toml
··· 1 1 [project] 2 2 name = "atdata" 3 - version = "0.1.3a3" 3 + version = "0.1.3b1" 4 4 description = "A loose federation of distributed, typed datasets" 5 5 readme = "README.md" 6 6 authors = [
+48 -19
src/atdata/dataset.py
··· 8 8 from pathlib import Path 9 9 import uuid 10 10 import functools 11 + 12 + import dataclasses 13 + import types 11 14 from dataclasses import ( 12 15 dataclass, 13 16 asdict, ··· 21 24 import numpy as np 22 25 import pandas as pd 23 26 27 + import typing 24 28 from typing import ( 25 29 Any, 26 30 Optional, ··· 28 32 Sequence, 29 33 Iterable, 30 34 Callable, 35 + Union, 31 36 # 32 37 Self, 33 38 Generic, ··· 108 113 return eh.array_to_bytes( x ) 109 114 return x 110 115 116 + def _is_possibly_ndarray_type( t ): 117 + """Checks if a type annotation is possibly an NDArray.""" 118 + 119 + # Directly an NDArray 120 + if t == NDArray: 121 + print( 'is an NDArray' ) 122 + return True 123 + 124 + # Check for Optionals (i.e., NDArray | None) 125 + if isinstance( t, types.UnionType ): 126 + t_parts = t.__args__ 127 + if any( x == NDArray 128 + for x in t_parts ): 129 + return True 130 + 131 + # Not an NDArray 132 + return False 133 + 111 134 @dataclass 112 135 class PackableSample( ABC ): 113 136 """A sample that can be packed and unpacked with msgpack""" ··· 116 139 """TODO Stupid kludge because of __post_init__ nonsense for wrapped classes""" 117 140 118 141 # Auto-convert known types when annotated 119 - for var_name, var_type in vars( self.__class__ )['__annotations__'].items(): 142 + # for var_name, var_type in vars( self.__class__ )['__annotations__'].items(): 143 + for field in dataclasses.fields( self ): 144 + var_name = field.name 145 + var_type = field.type 120 146 121 147 # Annotation for this variable is to be an NDArray 122 - if var_type == NDArray: 148 + if _is_possibly_ndarray_type( var_type ): 123 149 # ... so, we'll always auto-convert to numpy 124 150 125 151 var_cur_value = getattr( self, var_name ) ··· 135 161 # setattr( self, var_name, var_cur_value.to_numpy ) 136 162 137 163 elif isinstance( var_cur_value, bytes ): 164 + # TODO This does create a constraint that serialized bytes 165 + # in a field that might be an NDArray are always interpreted 166 + # as being the NDArray interpretation 138 167 setattr( self, var_name, eh.bytes_to_array( var_cur_value ) ) 139 168 140 169 def __post_init__( self ): ··· 204 233 @property 205 234 def sample_type( self ) -> Type: 206 235 """The type of each sample in this batch""" 207 - return self.__orig_class__.__args__[0] 236 + return typing.get_args( self.__orig_class__)[0] 208 237 209 238 def __getattr__( self, name ): 210 239 # Aggregate named params of sample type ··· 253 282 def sample_type( self ) -> Type: 254 283 """The type of each returned sample from this `Dataset`'s iterator""" 255 284 # TODO Figure out why linting fails here 256 - return self.__orig_class__.__args__[0] 285 + return typing.get_args( self.__orig_class__ )[0] 257 286 @property 258 287 def batch_type( self ) -> Type: 259 288 """The type of a batch built from `sample_class`""" ··· 371 400 372 401 if batch_size is None: 373 402 # TODO Duplication here 374 - return wds.DataPipeline( 375 - wds.SimpleShardList( self.url ), 376 - wds.shuffle( buffer_shards ), 377 - wds.split_by_worker, 403 + return wds.pipeline.DataPipeline( 404 + wds.shardlists.SimpleShardList( self.url ), 405 + wds.filters.shuffle( buffer_shards ), 406 + wds.shardlists.split_by_worker, 378 407 # 379 - wds.tarfile_to_samples(), 408 + wds.tariterators.tarfile_to_samples(), 380 409 # wds.shuffle( buffer_samples ), 381 410 # wds.map( self.preprocess ), 382 - wds.shuffle( buffer_samples ), 383 - wds.map( self.wrap ), 411 + wds.filters.shuffle( buffer_samples ), 412 + wds.filters.map( self.wrap ), 384 413 ) 385 414 386 - return wds.DataPipeline( 387 - wds.SimpleShardList( self.url ), 388 - wds.shuffle( buffer_shards ), 389 - wds.split_by_worker, 415 + return wds.pipeline.DataPipeline( 416 + wds.shardlists.SimpleShardList( self.url ), 417 + wds.filters.shuffle( buffer_shards ), 418 + wds.shardlists.split_by_worker, 390 419 # 391 - wds.tarfile_to_samples(), 420 + wds.tariterators.tarfile_to_samples(), 392 421 # wds.shuffle( buffer_samples ), 393 422 # wds.map( self.preprocess ), 394 - wds.shuffle( buffer_samples ), 395 - wds.batched( batch_size ), 396 - wds.map( self.wrap_batch ), 423 + wds.filters.shuffle( buffer_samples ), 424 + wds.filters.batched( batch_size ), 425 + wds.filters.map( self.wrap_batch ), 397 426 ) 398 427 399 428 # TODO Rewrite to eliminate `pandas` dependency directly calling
+28
tests/test_dataset.py
··· 50 50 label: int 51 51 image: NDArray 52 52 53 + @atdata.packable 54 + class NumpyOptionalSampleDecorated: 55 + label: int 56 + image: NDArray 57 + embeddings: NDArray | None = None 58 + 53 59 test_cases = [ 54 60 { 55 61 'SampleType': BasicTestSample, ··· 89 95 'image': np.random.randn( 1024, 1024 ), 90 96 }, 91 97 'sample_wds_stem': 'numpy_test_decorated', 98 + 'test_parquet': False, 99 + }, 100 + { 101 + 'SampleType': NumpyOptionalSampleDecorated, 102 + 'sample_data': 103 + { 104 + 'label': 9_001, 105 + 'image': np.random.randn( 1024, 1024 ), 106 + 'embeddings': np.random.randn( 512 ), 107 + }, 108 + 'sample_wds_stem': 'numpy_optional_decorated', 109 + 'test_parquet': False, 110 + }, 111 + { 112 + 'SampleType': NumpyOptionalSampleDecorated, 113 + 'sample_data': 114 + { 115 + 'label': 9_001, 116 + 'image': np.random.randn( 1024, 1024 ), 117 + 'embeddings': None, 118 + }, 119 + 'sample_wds_stem': 'numpy_optional_decorated_none', 92 120 'test_parquet': False, 93 121 }, 94 122 ]