A loose federation of distributed, typed datasets
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

Merge pull request #30 from foundation-ac/release/v0.1.3b4

Release/v0.1.3b4

authored by

Maxine Levesque and committed by
GitHub
a5120c44 84a0b958

+551 -65
+36 -1
src/atdata/__init__.py
··· 1 - """A loose federation of distributed, typed datasets""" 1 + """A loose federation of distributed, typed datasets. 2 + 3 + ``atdata`` provides a typed dataset abstraction built on WebDataset, with support 4 + for: 5 + 6 + - **Typed samples** with automatic msgpack serialization 7 + - **NDArray handling** with transparent bytes conversion 8 + - **Lens transformations** for viewing datasets through different type schemas 9 + - **Batch aggregation** with automatic numpy array stacking 10 + - **WebDataset integration** for efficient large-scale dataset storage 11 + 12 + Quick Start: 13 + >>> import atdata 14 + >>> import numpy as np 15 + >>> 16 + >>> @atdata.packable 17 + ... class MyData: 18 + ... features: np.ndarray 19 + ... label: str 20 + >>> 21 + >>> # Create dataset from WebDataset tar files 22 + >>> ds = atdata.Dataset[MyData]("path/to/data-{000000..000009}.tar") 23 + >>> 24 + >>> # Iterate with automatic batching 25 + >>> for batch in ds.shuffled(batch_size=32): 26 + ... features = batch.features # numpy array (32, ...) 27 + ... labels = batch.label # list of 32 strings 28 + 29 + Main Components: 30 + - ``PackableSample``: Base class for msgpack-serializable samples 31 + - ``Dataset``: Typed dataset wrapper for WebDataset 32 + - ``SampleBatch``: Automatic batch aggregation 33 + - ``Lens``: Bidirectional type transformations 34 + - ``@packable``: Decorator for creating PackableSample classes 35 + - ``@lens``: Decorator for creating lens transformations 36 + """ 2 37 3 38 ## 4 39 # Expose components
+39 -3
src/atdata/_helpers.py
··· 1 - """Assorted helper methods for `atdata`""" 1 + """Helper utilities for numpy array serialization. 2 + 3 + This module provides utility functions for converting numpy arrays to and from 4 + bytes for msgpack serialization. The functions use numpy's native save/load 5 + format to preserve array dtype and shape information. 6 + 7 + Functions: 8 + - ``array_to_bytes()``: Serialize numpy array to bytes 9 + - ``bytes_to_array()``: Deserialize bytes to numpy array 10 + 11 + These helpers are used internally by ``PackableSample`` to enable transparent 12 + handling of NDArray fields during msgpack packing/unpacking. 13 + """ 2 14 3 15 ## 4 16 # Imports ··· 11 23 ## 12 24 13 25 def array_to_bytes( x: np.ndarray ) -> bytes: 14 - """Convert `numpy` array to a format suitable for packing""" 26 + """Convert a numpy array to bytes for msgpack serialization. 27 + 28 + Uses numpy's native ``save()`` format to preserve array dtype and shape. 29 + 30 + Args: 31 + x: A numpy array to serialize. 32 + 33 + Returns: 34 + Raw bytes representing the serialized array. 35 + 36 + Note: 37 + Uses ``allow_pickle=True`` to support object dtypes. 38 + """ 15 39 np_bytes = BytesIO() 16 40 np.save( np_bytes, x, allow_pickle = True ) 17 41 return np_bytes.getvalue() 18 42 19 43 def bytes_to_array( b: bytes ) -> np.ndarray: 20 - """Convert packed bytes back to a `numpy` array""" 44 + """Convert serialized bytes back to a numpy array. 45 + 46 + Reverses the serialization performed by ``array_to_bytes()``. 47 + 48 + Args: 49 + b: Raw bytes from a serialized numpy array. 50 + 51 + Returns: 52 + The deserialized numpy array with original dtype and shape. 53 + 54 + Note: 55 + Uses ``allow_pickle=True`` to support object dtypes. 56 + """ 21 57 np_bytes = BytesIO( b ) 22 58 return np.load( np_bytes, allow_pickle = True )
+299 -37
src/atdata/dataset.py
··· 1 - """Schematized WebDatasets""" 1 + """Core dataset and sample infrastructure for typed WebDatasets. 2 + 3 + This module provides the core components for working with typed, msgpack-serialized 4 + samples in WebDataset format: 5 + 6 + - ``PackableSample``: Base class for msgpack-serializable samples with automatic 7 + NDArray handling 8 + - ``SampleBatch``: Automatic batching with attribute aggregation 9 + - ``Dataset``: Generic typed dataset wrapper for WebDataset tar files 10 + - ``@packable``: Decorator to convert regular classes into PackableSample subclasses 11 + 12 + The implementation handles automatic conversion between numpy arrays and bytes 13 + during serialization, enabling efficient storage of numerical data in WebDataset 14 + archives. 15 + 16 + Example: 17 + >>> @packable 18 + ... class ImageSample: 19 + ... image: NDArray 20 + ... label: str 21 + ... 22 + >>> ds = Dataset[ImageSample]("data-{000000..000009}.tar") 23 + >>> for batch in ds.shuffled(batch_size=32): 24 + ... images = batch.image # Stacked numpy array (32, H, W, C) 25 + ... labels = batch.label # List of 32 strings 26 + """ 2 27 3 28 ## 4 29 # Imports ··· 107 132 # return eh.bytes_to_array( self.raw_bytes ) 108 133 109 134 def _make_packable( x ): 135 + """Convert a value to a msgpack-compatible format. 136 + 137 + Args: 138 + x: A value to convert. If it's a numpy array, converts to bytes. 139 + Otherwise returns the value unchanged. 140 + 141 + Returns: 142 + The value in a format suitable for msgpack serialization. 143 + """ 110 144 # if isinstance( x, ArrayBytes ): 111 145 # return x.raw_bytes 112 146 if isinstance( x, np.ndarray ): ··· 114 148 return x 115 149 116 150 def _is_possibly_ndarray_type( t ): 117 - """Checks if a type annotation is possibly an NDArray.""" 151 + """Check if a type annotation is or contains NDArray. 152 + 153 + Args: 154 + t: A type annotation to check. 155 + 156 + Returns: 157 + ``True`` if the type is ``NDArray`` or a union containing ``NDArray`` 158 + (e.g., ``NDArray | None``), ``False`` otherwise. 159 + """ 118 160 119 161 # Directly an NDArray 120 162 if t == NDArray: ··· 133 175 134 176 @dataclass 135 177 class PackableSample( ABC ): 136 - """A sample that can be packed and unpacked with msgpack""" 178 + """Base class for samples that can be serialized with msgpack. 179 + 180 + This abstract base class provides automatic serialization/deserialization 181 + for dataclass-based samples. Fields annotated as ``NDArray`` or 182 + ``NDArray | None`` are automatically converted between numpy arrays and 183 + bytes during packing/unpacking. 184 + 185 + Subclasses should be defined either by: 186 + 1. Direct inheritance with the ``@dataclass`` decorator 187 + 2. Using the ``@packable`` decorator (recommended) 188 + 189 + Example: 190 + >>> @packable 191 + ... class MyData: 192 + ... name: str 193 + ... embeddings: NDArray 194 + ... 195 + >>> sample = MyData(name="test", embeddings=np.array([1.0, 2.0])) 196 + >>> packed = sample.packed # Serialize to bytes 197 + >>> restored = MyData.from_bytes(packed) # Deserialize 198 + """ 137 199 138 200 def _ensure_good( self ): 139 - """TODO Stupid kludge because of __post_init__ nonsense for wrapped classes""" 201 + """Auto-convert annotated NDArray fields from bytes to numpy arrays. 202 + 203 + This method scans all dataclass fields and for any field annotated as 204 + ``NDArray`` or ``NDArray | None``, automatically converts bytes values 205 + to numpy arrays using the helper deserialization function. This enables 206 + transparent handling of array serialization in msgpack data. 207 + 208 + Note: 209 + This is called during ``__post_init__`` to ensure proper type 210 + conversion after deserialization. 211 + """ 140 212 141 213 # Auto-convert known types when annotated 142 214 # for var_name, var_type in vars( self.__class__ )['__annotations__'].items(): ··· 173 245 174 246 @classmethod 175 247 def from_data( cls, data: MsgpackRawSample ) -> Self: 176 - """Create a sample instance from unpacked msgpack data""" 248 + """Create a sample instance from unpacked msgpack data. 249 + 250 + Args: 251 + data: A dictionary of unpacked msgpack data with keys matching 252 + the sample's field names. 253 + 254 + Returns: 255 + A new instance of this sample class with fields populated from 256 + the data dictionary and NDArray fields auto-converted from bytes. 257 + """ 177 258 ret = cls( **data ) 178 259 ret._ensure_good() 179 260 return ret 180 261 181 262 @classmethod 182 263 def from_bytes( cls, bs: bytes ) -> Self: 183 - """Create a sample instance from raw msgpack bytes""" 264 + """Create a sample instance from raw msgpack bytes. 265 + 266 + Args: 267 + bs: Raw bytes from a msgpack-serialized sample. 268 + 269 + Returns: 270 + A new instance of this sample class deserialized from the bytes. 271 + """ 184 272 return cls.from_data( ormsgpack.unpackb( bs ) ) 185 273 186 274 @property 187 275 def packed( self ) -> bytes: 188 - """Pack this sample's data into msgpack bytes""" 276 + """Pack this sample's data into msgpack bytes. 277 + 278 + NDArray fields are automatically converted to bytes before packing. 279 + All other fields are packed as-is if they're msgpack-compatible. 280 + 281 + Returns: 282 + Raw msgpack bytes representing this sample's data. 283 + 284 + Raises: 285 + RuntimeError: If msgpack serialization fails. 286 + """ 189 287 190 288 # Make sure that all of our (possibly unpackable) data is in a packable 191 289 # format ··· 204 302 # TODO Expand to allow for specifying explicit __key__ 205 303 @property 206 304 def as_wds( self ) -> WDSRawSample: 207 - """Pack this sample's data for writing to webdataset""" 305 + """Pack this sample's data for writing to WebDataset. 306 + 307 + Returns: 308 + A dictionary with ``__key__`` (UUID v1 for sortable keys) and 309 + ``msgpack`` (packed sample data) fields suitable for WebDataset. 310 + 311 + Note: 312 + TODO: Expand to allow specifying explicit ``__key__`` values. 313 + """ 208 314 return { 209 315 # Generates a UUID that is timelike-sortable 210 316 '__key__': str( uuid.uuid1( 0, 0 ) ), ··· 212 318 } 213 319 214 320 def _batch_aggregate( xs: Sequence ): 321 + """Aggregate a sequence of values into a batch-appropriate format. 322 + 323 + Args: 324 + xs: A sequence of values to aggregate. If the first element is a numpy 325 + array, all elements are stacked into a single array. Otherwise, 326 + returns a list. 327 + 328 + Returns: 329 + A numpy array (if elements are arrays) or a list (otherwise). 330 + """ 215 331 216 332 if not xs: 217 333 # Empty sequence 218 334 return [] 219 335 220 - # Aggregate 336 + # Aggregate 221 337 if isinstance( xs[0], np.ndarray ): 222 338 return np.array( list( xs ) ) 223 339 224 340 return list( xs ) 225 341 226 342 class SampleBatch( Generic[DT] ): 343 + """A batch of samples with automatic attribute aggregation. 344 + 345 + This class wraps a sequence of samples and provides magic ``__getattr__`` 346 + access to aggregate sample attributes. When you access an attribute that 347 + exists on the sample type, it automatically aggregates values across all 348 + samples in the batch. 349 + 350 + NDArray fields are stacked into a numpy array with a batch dimension. 351 + Other fields are aggregated into a list. 352 + 353 + Type Parameters: 354 + DT: The sample type, must derive from ``PackableSample``. 355 + 356 + Attributes: 357 + samples: The list of sample instances in this batch. 358 + 359 + Example: 360 + >>> batch = SampleBatch[MyData]([sample1, sample2, sample3]) 361 + >>> batch.embeddings # Returns stacked numpy array of shape (3, ...) 362 + >>> batch.names # Returns list of names 363 + """ 227 364 228 365 def __init__( self, samples: Sequence[DT] ): 229 - """TODO""" 366 + """Create a batch from a sequence of samples. 367 + 368 + Args: 369 + samples: A sequence of sample instances to aggregate into a batch. 370 + Each sample must be an instance of a type derived from 371 + ``PackableSample``. 372 + """ 230 373 self.samples = list( samples ) 231 374 self._aggregate_cache = dict() 232 375 233 376 @property 234 377 def sample_type( self ) -> Type: 235 - """The type of each sample in this batch""" 378 + """The type of each sample in this batch. 379 + 380 + Returns: 381 + The type parameter ``DT`` used when creating this ``SampleBatch[DT]``. 382 + """ 236 383 return typing.get_args( self.__orig_class__)[0] 237 384 238 385 def __getattr__( self, name ): 386 + """Aggregate an attribute across all samples in the batch. 387 + 388 + This magic method enables attribute-style access to aggregated sample 389 + fields. Results are cached for efficiency. 390 + 391 + Args: 392 + name: The attribute name to aggregate across samples. 393 + 394 + Returns: 395 + For NDArray fields: a stacked numpy array with batch dimension. 396 + For other fields: a list of values from each sample. 397 + 398 + Raises: 399 + AttributeError: If the attribute doesn't exist on the sample type. 400 + """ 239 401 # Aggregate named params of sample type 240 402 if name in vars( self.sample_type )['__annotations__']: 241 403 if name not in self._aggregate_cache: ··· 243 405 [ getattr( x, name ) 244 406 for x in self.samples ] 245 407 ) 246 - 408 + 247 409 return self._aggregate_cache[name] 248 - 410 + 249 411 raise AttributeError( f'No sample attribute named {name}' ) 250 412 251 413 ··· 268 430 # IT = TypeVar( 'IT', default = Any ) 269 431 270 432 class Dataset( Generic[ST] ): 271 - """A dataset that ingests and formats raw samples from a WebDataset 272 - 273 - (Abstract base for subclassing) 433 + """A typed dataset built on WebDataset with lens transformations. 434 + 435 + This class wraps WebDataset tar archives and provides type-safe iteration 436 + over samples of a specific ``PackableSample`` type. Samples are stored as 437 + msgpack-serialized data within WebDataset shards. 438 + 439 + The dataset supports: 440 + - Ordered and shuffled iteration 441 + - Automatic batching with ``SampleBatch`` 442 + - Type transformations via the lens system (``as_type()``) 443 + - Export to parquet format 444 + 445 + Type Parameters: 446 + ST: The sample type for this dataset, must derive from ``PackableSample``. 447 + 448 + Attributes: 449 + url: WebDataset brace-notation URL for the tar file(s). 450 + 451 + Example: 452 + >>> ds = Dataset[MyData]("path/to/data-{000000..000009}.tar") 453 + >>> for sample in ds.ordered(batch_size=32): 454 + ... # sample is SampleBatch[MyData] with batch_size samples 455 + ... embeddings = sample.embeddings # shape: (32, ...) 456 + ... 457 + >>> # Transform to a different view 458 + >>> ds_view = ds.as_type(MyDataView) 274 459 """ 275 460 276 461 # sample_class: Type = get_parameters( ) ··· 280 465 281 466 @property 282 467 def sample_type( self ) -> Type: 283 - """The type of each returned sample from this `Dataset`'s iterator""" 284 - # TODO Figure out why linting fails here 468 + """The type of each returned sample from this dataset's iterator. 469 + 470 + Returns: 471 + The type parameter ``ST`` used when creating this ``Dataset[ST]``. 472 + 473 + Note: 474 + Extracts the type parameter at runtime using ``__orig_class__``. 475 + """ 476 + # NOTE: Linting may fail here due to __orig_class__ being a runtime attribute 285 477 return typing.get_args( self.__orig_class__ )[0] 286 478 @property 287 479 def batch_type( self ) -> Type: 288 - """The type of a batch built from `sample_class`""" 480 + """The type of batches produced by this dataset. 481 + 482 + Returns: 483 + ``SampleBatch[ST]`` where ``ST`` is this dataset's sample type. 484 + """ 289 485 # return self.__orig_class__.__args__[1] 290 486 return SampleBatch[self.sample_type] 291 487 ··· 296 492 # 297 493 298 494 def __init__( self, url: str ) -> None: 299 - """TODO""" 495 + """Create a dataset from a WebDataset URL. 496 + 497 + Args: 498 + url: WebDataset brace-notation URL pointing to tar files, e.g., 499 + ``"path/to/file-{000000..000009}.tar"`` for multiple shards or 500 + ``"path/to/file-000000.tar"`` for a single shard. 501 + """ 300 502 super().__init__() 301 503 self.url = url 302 504 ··· 304 506 self._output_lens: Lens | None = None 305 507 306 508 def as_type( self, other: Type[RT] ) -> 'Dataset[RT]': 307 - """TODO""" 509 + """View this dataset through a different sample type using a registered lens. 510 + 511 + Args: 512 + other: The target sample type to transform into. Must be a type 513 + derived from ``PackableSample``. 514 + 515 + Returns: 516 + A new ``Dataset`` instance that yields samples of type ``other`` 517 + by applying the appropriate lens transformation from the global 518 + ``LensNetwork`` registry. 519 + 520 + Raises: 521 + ValueError: If no registered lens exists between the current 522 + sample type and the target type. 523 + """ 308 524 ret = Dataset[other]( self.url ) 309 525 # Get the singleton lens registry 310 526 lenses = LensNetwork() ··· 384 600 buffer_samples: int = 10_000, 385 601 batch_size: int | None = 1, 386 602 ) -> Iterable[ST]: 387 - """Iterate over the dataset in random order 388 - 603 + """Iterate over the dataset in random order. 604 + 389 605 Args: 390 - buffer_shards (int): Asdf 391 - batch_size (:obj:`int`, optional) The size of iterated batches. 392 - Default: 1. If ``None``, iterates over one sample at a time 393 - with no batch dimension. 394 - 606 + buffer_shards: Number of shards to buffer for shuffling at the 607 + shard level. Larger values increase randomness but use more 608 + memory. Default: 100. 609 + buffer_samples: Number of samples to buffer for shuffling within 610 + shards. Larger values increase randomness but use more memory. 611 + Default: 10,000. 612 + batch_size: The size of iterated batches. Default: 1. If ``None``, 613 + iterates over one sample at a time with no batch dimension. 614 + 395 615 Returns: 396 - :obj:`webdataset.DataPipeline` A data pipeline that iterates over 397 - the dataset in its original sample order 398 - 616 + A WebDataset data pipeline that iterates over the dataset in 617 + randomized order. If ``batch_size`` is not ``None``, yields 618 + ``SampleBatch[ST]`` instances; otherwise yields individual ``ST`` 619 + samples. 399 620 """ 400 621 401 622 if batch_size is None: ··· 500 721 # @classmethod 501 722 # TODO replace Any with IT 502 723 def wrap( self, sample: MsgpackRawSample ) -> ST: 503 - """Wrap a `sample` into the appropriate dataset-specific type""" 724 + """Wrap a raw msgpack sample into the appropriate dataset-specific type. 725 + 726 + Args: 727 + sample: A dictionary containing at minimum a ``'msgpack'`` key with 728 + serialized sample bytes. 729 + 730 + Returns: 731 + A deserialized sample of type ``ST``, optionally transformed through 732 + a lens if ``as_type()`` was called. 733 + """ 504 734 assert 'msgpack' in sample 505 735 assert type( sample['msgpack'] ) == bytes 506 736 ··· 524 754 # ) 525 755 526 756 def wrap_batch( self, batch: WDSRawBatch ) -> SampleBatch[ST]: 527 - """Wrap a `batch` of samples into the appropriate dataset-specific type 528 - 529 - This default implementation simply creates a list one sample at a time 757 + """Wrap a batch of raw msgpack samples into a typed SampleBatch. 758 + 759 + Args: 760 + batch: A dictionary containing a ``'msgpack'`` key with a list of 761 + serialized sample bytes. 762 + 763 + Returns: 764 + A ``SampleBatch[ST]`` containing deserialized samples, optionally 765 + transformed through a lens if ``as_type()`` was called. 766 + 767 + Note: 768 + This implementation deserializes samples one at a time, then 769 + aggregates them into a batch. 530 770 """ 531 771 532 772 assert 'msgpack' in batch ··· 572 812 # return decorator 573 813 574 814 def packable( cls ): 575 - """TODO""" 576 - 815 + """Decorator to convert a regular class into a ``PackableSample``. 816 + 817 + This decorator transforms a class into a dataclass that inherits from 818 + ``PackableSample``, enabling automatic msgpack serialization/deserialization 819 + with special handling for NDArray fields. 820 + 821 + Args: 822 + cls: The class to convert. Should have type annotations for its fields. 823 + 824 + Returns: 825 + A new dataclass that inherits from ``PackableSample`` with the same 826 + name and annotations as the original class. 827 + 828 + Example: 829 + >>> @packable 830 + ... class MyData: 831 + ... name: str 832 + ... values: NDArray 833 + ... 834 + >>> sample = MyData(name="test", values=np.array([1, 2, 3])) 835 + >>> bytes_data = sample.packed 836 + >>> restored = MyData.from_bytes(bytes_data) 837 + """ 838 + 577 839 ## 578 840 579 841 class_name = cls.__name__
+177 -24
src/atdata/lens.py
··· 1 - """Lenses between typed datasets""" 1 + """Lens-based type transformations for datasets. 2 + 3 + This module implements a lens system for bidirectional transformations between 4 + different sample types. Lenses enable viewing a dataset through different type 5 + schemas without duplicating the underlying data. 6 + 7 + Key components: 8 + 9 + - ``Lens``: Bidirectional transformation with getter (S -> V) and optional 10 + putter (V, S -> S) 11 + - ``LensNetwork``: Global singleton registry for lens transformations 12 + - ``@lens``: Decorator to create and register lens transformations 13 + 14 + Lenses support the functional programming concept of composable, well-behaved 15 + transformations that satisfy lens laws (GetPut and PutGet). 16 + 17 + Example: 18 + >>> @packable 19 + ... class FullData: 20 + ... name: str 21 + ... age: int 22 + ... embedding: NDArray 23 + ... 24 + >>> @packable 25 + ... class NameOnly: 26 + ... name: str 27 + ... 28 + >>> @lens 29 + ... def name_view(full: FullData) -> NameOnly: 30 + ... return NameOnly(name=full.name) 31 + ... 32 + >>> @name_view.putter 33 + ... def name_view_put(view: NameOnly, source: FullData) -> FullData: 34 + ... return FullData(name=view.name, age=source.age, 35 + ... embedding=source.embedding) 36 + ... 37 + >>> ds = Dataset[FullData]("data.tar") 38 + >>> ds_names = ds.as_type(NameOnly) # Uses registered lens 39 + """ 2 40 3 41 ## 4 42 # Imports ··· 39 77 # Shortcut decorators 40 78 41 79 class Lens( Generic[S, V] ): 42 - """TODO""" 80 + """A bidirectional transformation between two sample types. 81 + 82 + A lens provides a way to view and update data of type ``S`` (source) as if 83 + it were type ``V`` (view). It consists of a getter that transforms ``S -> V`` 84 + and an optional putter that transforms ``(V, S) -> S``, enabling updates to 85 + the view to be reflected back in the source. 43 86 44 - # @property 45 - # def source_type( self ) -> Type[S]: 46 - # """The source type (S) for the lens; what is put to""" 47 - # # TODO Figure out why linting fails here 48 - # return self.__orig_class__.__args__[0] 87 + Type Parameters: 88 + S: The source type, must derive from ``PackableSample``. 89 + V: The view type, must derive from ``PackableSample``. 49 90 50 - # @property 51 - # def view_type( self ) -> Type[V]: 52 - # """The view type (V) for the lens; what is get'd from""" 53 - # # TODO FIgure out why linting fails here 54 - # return self.__orig_class__.__args__[1] 91 + Example: 92 + >>> @lens 93 + ... def name_lens(full: FullData) -> NameOnly: 94 + ... return NameOnly(name=full.name) 95 + ... 96 + >>> @name_lens.putter 97 + ... def name_lens_put(view: NameOnly, source: FullData) -> FullData: 98 + ... return FullData(name=view.name, age=source.age) 99 + """ 55 100 56 101 def __init__( self, get: LensGetter[S, V], 57 102 put: Optional[LensPutter[S, V]] = None 58 103 ) -> None: 59 - """TODO""" 104 + """Initialize a lens with a getter and optional putter function. 105 + 106 + Args: 107 + get: A function that transforms from source type ``S`` to view type 108 + ``V``. Must accept exactly one parameter annotated with the 109 + source type. 110 + put: An optional function that updates the source based on a modified 111 + view. Takes a view of type ``V`` and original source of type ``S``, 112 + and returns an updated source of type ``S``. If not provided, a 113 + trivial putter is used that ignores updates to the view. 114 + 115 + Raises: 116 + AssertionError: If the getter function doesn't have exactly one 117 + parameter. 118 + """ 60 119 ## 61 120 62 121 # Check argument validity ··· 70 129 functools.update_wrapper( self, get ) 71 130 72 131 self.source_type: Type[PackableSample] = input_types[0].annotation 73 - self.view_type = sig.return_annotation 132 + self.view_type: Type[PackableSample] = sig.return_annotation 74 133 75 134 # Store the getter 76 135 self._getter = get 77 - 136 + 78 137 # Determine and store the putter 79 138 if put is None: 80 139 # Trivial putter does not update the source ··· 86 145 # 87 146 88 147 def putter( self, put: LensPutter[S, V] ) -> LensPutter[S, V]: 89 - """TODO""" 148 + """Decorator to register a putter function for this lens. 149 + 150 + Args: 151 + put: A function that takes a view of type ``V`` and source of type 152 + ``S``, and returns an updated source of type ``S``. 153 + 154 + Returns: 155 + The putter function, allowing this to be used as a decorator. 156 + 157 + Example: 158 + >>> @my_lens.putter 159 + ... def my_lens_put(view: ViewType, source: SourceType) -> SourceType: 160 + ... return SourceType(...) 161 + """ 90 162 ## 91 163 self._putter = put 92 164 return put ··· 94 166 # Methods to actually execute transformations 95 167 96 168 def put( self, v: V, s: S ) -> S: 97 - """TODO""" 169 + """Update the source based on a modified view. 170 + 171 + Args: 172 + v: The modified view of type ``V``. 173 + s: The original source of type ``S``. 174 + 175 + Returns: 176 + An updated source of type ``S`` that reflects changes from the view. 177 + """ 98 178 return self._putter( v, s ) 99 179 100 180 def get( self, s: S ) -> V: 101 - """TODO""" 181 + """Transform the source into the view type. 182 + 183 + Args: 184 + s: The source sample of type ``S``. 185 + 186 + Returns: 187 + A view of the source as type ``V``. 188 + """ 102 189 return self( s ) 103 190 104 191 # Convenience to enable calling the lens as its getter 105 - 192 + 106 193 def __call__( self, s: S ) -> V: 194 + """Apply the lens transformation (same as ``get()``). 195 + 196 + Args: 197 + s: The source sample of type ``S``. 198 + 199 + Returns: 200 + A view of the source as type ``V``. 201 + """ 107 202 return self._getter( s ) 108 203 109 204 # TODO Figure out how to properly parameterize this ··· 124 219 # lens = _lens_factory 125 220 126 221 def lens( f: LensGetter[S, V] ) -> Lens[S, V]: 222 + """Decorator to create and register a lens transformation. 223 + 224 + This decorator converts a getter function into a ``Lens`` object and 225 + automatically registers it in the global ``LensNetwork`` registry. 226 + 227 + Args: 228 + f: A getter function that transforms from source type ``S`` to view 229 + type ``V``. Must have exactly one parameter with a type annotation. 230 + 231 + Returns: 232 + A ``Lens[S, V]`` object that can be called to apply the transformation 233 + or decorated with ``@lens_name.putter`` to add a putter function. 234 + 235 + Example: 236 + >>> @lens 237 + ... def extract_name(full: FullData) -> NameOnly: 238 + ... return NameOnly(name=full.name) 239 + ... 240 + >>> @extract_name.putter 241 + ... def extract_name_put(view: NameOnly, source: FullData) -> FullData: 242 + ... return FullData(name=view.name, age=source.age) 243 + """ 127 244 ret = Lens[S, V]( f ) 128 245 _network.register( ret ) 129 246 return ret ··· 136 253 # """TODO""" 137 254 138 255 class LensNetwork: 139 - """TODO""" 256 + """Global registry for lens transformations between sample types. 257 + 258 + This class implements a singleton pattern to maintain a global registry of 259 + all lenses decorated with ``@lens``. It enables looking up transformations 260 + between different ``PackableSample`` types. 261 + 262 + Attributes: 263 + _instance: The singleton instance of this class. 264 + _registry: Dictionary mapping ``(source_type, view_type)`` tuples to 265 + their corresponding ``Lens`` objects. 266 + """ 140 267 141 268 _instance = None 142 269 """The singleton instance""" 143 270 144 271 def __new__(cls, *args, **kwargs): 272 + """Ensure only one instance of LensNetwork exists (singleton pattern).""" 145 273 if cls._instance is None: 146 274 # If no instance exists, create a new one 147 275 cls._instance = super().__new__(cls) 148 276 return cls._instance # Return the existing (or newly created) instance 149 277 150 278 def __init__(self): 279 + """Initialize the lens registry (only on first instantiation).""" 151 280 if not hasattr(self, '_initialized'): # Check if already initialized 152 281 self._registry: Dict[LensSignature, Lens] = dict() 153 282 self._initialized = True 154 283 155 284 def register( self, _lens: Lens ): 156 - """Set `lens` as the canonical view between its source and view types""" 157 - 285 + """Register a lens as the canonical transformation between two types. 286 + 287 + Args: 288 + _lens: The lens to register. Will be stored in the registry under 289 + the key ``(_lens.source_type, _lens.view_type)``. 290 + 291 + Note: 292 + If a lens already exists for the same type pair, it will be 293 + overwritten. 294 + """ 295 + 158 296 # sig = inspect.signature( _lens.get ) 159 297 # input_types = list( sig.parameters.values() ) 160 298 # assert len( input_types ) == 1, \ ··· 169 307 self._registry[_lens.source_type, _lens.view_type] = _lens 170 308 171 309 def transform( self, source: DatasetType, view: DatasetType ) -> Lens: 172 - """TODO""" 310 + """Look up the lens transformation between two sample types. 311 + 312 + Args: 313 + source: The source sample type (must derive from ``PackableSample``). 314 + view: The target view type (must derive from ``PackableSample``). 315 + 316 + Returns: 317 + The registered ``Lens`` that transforms from ``source`` to ``view``. 318 + 319 + Raises: 320 + ValueError: If no lens has been registered for the given type pair. 321 + 322 + Note: 323 + Currently only supports direct transformations. Compositional 324 + transformations (chaining multiple lenses) are not yet implemented. 325 + """ 173 326 174 327 # TODO Handle compositional closure 175 328 ret = self._registry.get( (source, view), None ) 176 329 if ret is None: 177 330 raise ValueError( f'No registered lens from source {source} to view {view}' ) 178 - 331 + 179 332 return ret 180 333 181 334