···11-"""A loose federation of distributed, typed datasets"""
11+"""A loose federation of distributed, typed datasets.
22+33+``atdata`` provides a typed dataset abstraction built on WebDataset, with support
44+for:
55+66+- **Typed samples** with automatic msgpack serialization
77+- **NDArray handling** with transparent bytes conversion
88+- **Lens transformations** for viewing datasets through different type schemas
99+- **Batch aggregation** with automatic numpy array stacking
1010+- **WebDataset integration** for efficient large-scale dataset storage
1111+1212+Quick Start:
1313+ >>> import atdata
1414+ >>> import numpy as np
1515+ >>>
1616+ >>> @atdata.packable
1717+ ... class MyData:
1818+ ... features: np.ndarray
1919+ ... label: str
2020+ >>>
2121+ >>> # Create dataset from WebDataset tar files
2222+ >>> ds = atdata.Dataset[MyData]("path/to/data-{000000..000009}.tar")
2323+ >>>
2424+ >>> # Iterate with automatic batching
2525+ >>> for batch in ds.shuffled(batch_size=32):
2626+ ... features = batch.features # numpy array (32, ...)
2727+ ... labels = batch.label # list of 32 strings
2828+2929+Main Components:
3030+ - ``PackableSample``: Base class for msgpack-serializable samples
3131+ - ``Dataset``: Typed dataset wrapper for WebDataset
3232+ - ``SampleBatch``: Automatic batch aggregation
3333+ - ``Lens``: Bidirectional type transformations
3434+ - ``@packable``: Decorator for creating PackableSample classes
3535+ - ``@lens``: Decorator for creating lens transformations
3636+"""
237338##
439# Expose components
+39-3
src/atdata/_helpers.py
···11-"""Assorted helper methods for `atdata`"""
11+"""Helper utilities for numpy array serialization.
22+33+This module provides utility functions for converting numpy arrays to and from
44+bytes for msgpack serialization. The functions use numpy's native save/load
55+format to preserve array dtype and shape information.
66+77+Functions:
88+ - ``array_to_bytes()``: Serialize numpy array to bytes
99+ - ``bytes_to_array()``: Deserialize bytes to numpy array
1010+1111+These helpers are used internally by ``PackableSample`` to enable transparent
1212+handling of NDArray fields during msgpack packing/unpacking.
1313+"""
214315##
416# Imports
···1123##
12241325def array_to_bytes( x: np.ndarray ) -> bytes:
1414- """Convert `numpy` array to a format suitable for packing"""
2626+ """Convert a numpy array to bytes for msgpack serialization.
2727+2828+ Uses numpy's native ``save()`` format to preserve array dtype and shape.
2929+3030+ Args:
3131+ x: A numpy array to serialize.
3232+3333+ Returns:
3434+ Raw bytes representing the serialized array.
3535+3636+ Note:
3737+ Uses ``allow_pickle=True`` to support object dtypes.
3838+ """
1539 np_bytes = BytesIO()
1640 np.save( np_bytes, x, allow_pickle = True )
1741 return np_bytes.getvalue()
18421943def bytes_to_array( b: bytes ) -> np.ndarray:
2020- """Convert packed bytes back to a `numpy` array"""
4444+ """Convert serialized bytes back to a numpy array.
4545+4646+ Reverses the serialization performed by ``array_to_bytes()``.
4747+4848+ Args:
4949+ b: Raw bytes from a serialized numpy array.
5050+5151+ Returns:
5252+ The deserialized numpy array with original dtype and shape.
5353+5454+ Note:
5555+ Uses ``allow_pickle=True`` to support object dtypes.
5656+ """
2157 np_bytes = BytesIO( b )
2258 return np.load( np_bytes, allow_pickle = True )
+299-37
src/atdata/dataset.py
···11-"""Schematized WebDatasets"""
11+"""Core dataset and sample infrastructure for typed WebDatasets.
22+33+This module provides the core components for working with typed, msgpack-serialized
44+samples in WebDataset format:
55+66+- ``PackableSample``: Base class for msgpack-serializable samples with automatic
77+ NDArray handling
88+- ``SampleBatch``: Automatic batching with attribute aggregation
99+- ``Dataset``: Generic typed dataset wrapper for WebDataset tar files
1010+- ``@packable``: Decorator to convert regular classes into PackableSample subclasses
1111+1212+The implementation handles automatic conversion between numpy arrays and bytes
1313+during serialization, enabling efficient storage of numerical data in WebDataset
1414+archives.
1515+1616+Example:
1717+ >>> @packable
1818+ ... class ImageSample:
1919+ ... image: NDArray
2020+ ... label: str
2121+ ...
2222+ >>> ds = Dataset[ImageSample]("data-{000000..000009}.tar")
2323+ >>> for batch in ds.shuffled(batch_size=32):
2424+ ... images = batch.image # Stacked numpy array (32, H, W, C)
2525+ ... labels = batch.label # List of 32 strings
2626+"""
227328##
429# Imports
···107132# return eh.bytes_to_array( self.raw_bytes )
108133109134def _make_packable( x ):
135135+ """Convert a value to a msgpack-compatible format.
136136+137137+ Args:
138138+ x: A value to convert. If it's a numpy array, converts to bytes.
139139+ Otherwise returns the value unchanged.
140140+141141+ Returns:
142142+ The value in a format suitable for msgpack serialization.
143143+ """
110144 # if isinstance( x, ArrayBytes ):
111145 # return x.raw_bytes
112146 if isinstance( x, np.ndarray ):
···114148 return x
115149116150def _is_possibly_ndarray_type( t ):
117117- """Checks if a type annotation is possibly an NDArray."""
151151+ """Check if a type annotation is or contains NDArray.
152152+153153+ Args:
154154+ t: A type annotation to check.
155155+156156+ Returns:
157157+ ``True`` if the type is ``NDArray`` or a union containing ``NDArray``
158158+ (e.g., ``NDArray | None``), ``False`` otherwise.
159159+ """
118160119161 # Directly an NDArray
120162 if t == NDArray:
···133175134176@dataclass
135177class PackableSample( ABC ):
136136- """A sample that can be packed and unpacked with msgpack"""
178178+ """Base class for samples that can be serialized with msgpack.
179179+180180+ This abstract base class provides automatic serialization/deserialization
181181+ for dataclass-based samples. Fields annotated as ``NDArray`` or
182182+ ``NDArray | None`` are automatically converted between numpy arrays and
183183+ bytes during packing/unpacking.
184184+185185+ Subclasses should be defined either by:
186186+ 1. Direct inheritance with the ``@dataclass`` decorator
187187+ 2. Using the ``@packable`` decorator (recommended)
188188+189189+ Example:
190190+ >>> @packable
191191+ ... class MyData:
192192+ ... name: str
193193+ ... embeddings: NDArray
194194+ ...
195195+ >>> sample = MyData(name="test", embeddings=np.array([1.0, 2.0]))
196196+ >>> packed = sample.packed # Serialize to bytes
197197+ >>> restored = MyData.from_bytes(packed) # Deserialize
198198+ """
137199138200 def _ensure_good( self ):
139139- """TODO Stupid kludge because of __post_init__ nonsense for wrapped classes"""
201201+ """Auto-convert annotated NDArray fields from bytes to numpy arrays.
202202+203203+ This method scans all dataclass fields and for any field annotated as
204204+ ``NDArray`` or ``NDArray | None``, automatically converts bytes values
205205+ to numpy arrays using the helper deserialization function. This enables
206206+ transparent handling of array serialization in msgpack data.
207207+208208+ Note:
209209+ This is called during ``__post_init__`` to ensure proper type
210210+ conversion after deserialization.
211211+ """
140212141213 # Auto-convert known types when annotated
142214 # for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
···173245174246 @classmethod
175247 def from_data( cls, data: MsgpackRawSample ) -> Self:
176176- """Create a sample instance from unpacked msgpack data"""
248248+ """Create a sample instance from unpacked msgpack data.
249249+250250+ Args:
251251+ data: A dictionary of unpacked msgpack data with keys matching
252252+ the sample's field names.
253253+254254+ Returns:
255255+ A new instance of this sample class with fields populated from
256256+ the data dictionary and NDArray fields auto-converted from bytes.
257257+ """
177258 ret = cls( **data )
178259 ret._ensure_good()
179260 return ret
180261181262 @classmethod
182263 def from_bytes( cls, bs: bytes ) -> Self:
183183- """Create a sample instance from raw msgpack bytes"""
264264+ """Create a sample instance from raw msgpack bytes.
265265+266266+ Args:
267267+ bs: Raw bytes from a msgpack-serialized sample.
268268+269269+ Returns:
270270+ A new instance of this sample class deserialized from the bytes.
271271+ """
184272 return cls.from_data( ormsgpack.unpackb( bs ) )
185273186274 @property
187275 def packed( self ) -> bytes:
188188- """Pack this sample's data into msgpack bytes"""
276276+ """Pack this sample's data into msgpack bytes.
277277+278278+ NDArray fields are automatically converted to bytes before packing.
279279+ All other fields are packed as-is if they're msgpack-compatible.
280280+281281+ Returns:
282282+ Raw msgpack bytes representing this sample's data.
283283+284284+ Raises:
285285+ RuntimeError: If msgpack serialization fails.
286286+ """
189287190288 # Make sure that all of our (possibly unpackable) data is in a packable
191289 # format
···204302 # TODO Expand to allow for specifying explicit __key__
205303 @property
206304 def as_wds( self ) -> WDSRawSample:
207207- """Pack this sample's data for writing to webdataset"""
305305+ """Pack this sample's data for writing to WebDataset.
306306+307307+ Returns:
308308+ A dictionary with ``__key__`` (UUID v1 for sortable keys) and
309309+ ``msgpack`` (packed sample data) fields suitable for WebDataset.
310310+311311+ Note:
312312+ TODO: Expand to allow specifying explicit ``__key__`` values.
313313+ """
208314 return {
209315 # Generates a UUID that is timelike-sortable
210316 '__key__': str( uuid.uuid1( 0, 0 ) ),
···212318 }
213319214320def _batch_aggregate( xs: Sequence ):
321321+ """Aggregate a sequence of values into a batch-appropriate format.
322322+323323+ Args:
324324+ xs: A sequence of values to aggregate. If the first element is a numpy
325325+ array, all elements are stacked into a single array. Otherwise,
326326+ returns a list.
327327+328328+ Returns:
329329+ A numpy array (if elements are arrays) or a list (otherwise).
330330+ """
215331216332 if not xs:
217333 # Empty sequence
218334 return []
219335220220- # Aggregate
336336+ # Aggregate
221337 if isinstance( xs[0], np.ndarray ):
222338 return np.array( list( xs ) )
223339224340 return list( xs )
225341226342class SampleBatch( Generic[DT] ):
343343+ """A batch of samples with automatic attribute aggregation.
344344+345345+ This class wraps a sequence of samples and provides magic ``__getattr__``
346346+ access to aggregate sample attributes. When you access an attribute that
347347+ exists on the sample type, it automatically aggregates values across all
348348+ samples in the batch.
349349+350350+ NDArray fields are stacked into a numpy array with a batch dimension.
351351+ Other fields are aggregated into a list.
352352+353353+ Type Parameters:
354354+ DT: The sample type, must derive from ``PackableSample``.
355355+356356+ Attributes:
357357+ samples: The list of sample instances in this batch.
358358+359359+ Example:
360360+ >>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
361361+ >>> batch.embeddings # Returns stacked numpy array of shape (3, ...)
362362+ >>> batch.names # Returns list of names
363363+ """
227364228365 def __init__( self, samples: Sequence[DT] ):
229229- """TODO"""
366366+ """Create a batch from a sequence of samples.
367367+368368+ Args:
369369+ samples: A sequence of sample instances to aggregate into a batch.
370370+ Each sample must be an instance of a type derived from
371371+ ``PackableSample``.
372372+ """
230373 self.samples = list( samples )
231374 self._aggregate_cache = dict()
232375233376 @property
234377 def sample_type( self ) -> Type:
235235- """The type of each sample in this batch"""
378378+ """The type of each sample in this batch.
379379+380380+ Returns:
381381+ The type parameter ``DT`` used when creating this ``SampleBatch[DT]``.
382382+ """
236383 return typing.get_args( self.__orig_class__)[0]
237384238385 def __getattr__( self, name ):
386386+ """Aggregate an attribute across all samples in the batch.
387387+388388+ This magic method enables attribute-style access to aggregated sample
389389+ fields. Results are cached for efficiency.
390390+391391+ Args:
392392+ name: The attribute name to aggregate across samples.
393393+394394+ Returns:
395395+ For NDArray fields: a stacked numpy array with batch dimension.
396396+ For other fields: a list of values from each sample.
397397+398398+ Raises:
399399+ AttributeError: If the attribute doesn't exist on the sample type.
400400+ """
239401 # Aggregate named params of sample type
240402 if name in vars( self.sample_type )['__annotations__']:
241403 if name not in self._aggregate_cache:
···243405 [ getattr( x, name )
244406 for x in self.samples ]
245407 )
246246-408408+247409 return self._aggregate_cache[name]
248248-410410+249411 raise AttributeError( f'No sample attribute named {name}' )
250412251413···268430# IT = TypeVar( 'IT', default = Any )
269431270432class Dataset( Generic[ST] ):
271271- """A dataset that ingests and formats raw samples from a WebDataset
272272-273273- (Abstract base for subclassing)
433433+ """A typed dataset built on WebDataset with lens transformations.
434434+435435+ This class wraps WebDataset tar archives and provides type-safe iteration
436436+ over samples of a specific ``PackableSample`` type. Samples are stored as
437437+ msgpack-serialized data within WebDataset shards.
438438+439439+ The dataset supports:
440440+ - Ordered and shuffled iteration
441441+ - Automatic batching with ``SampleBatch``
442442+ - Type transformations via the lens system (``as_type()``)
443443+ - Export to parquet format
444444+445445+ Type Parameters:
446446+ ST: The sample type for this dataset, must derive from ``PackableSample``.
447447+448448+ Attributes:
449449+ url: WebDataset brace-notation URL for the tar file(s).
450450+451451+ Example:
452452+ >>> ds = Dataset[MyData]("path/to/data-{000000..000009}.tar")
453453+ >>> for sample in ds.ordered(batch_size=32):
454454+ ... # sample is SampleBatch[MyData] with batch_size samples
455455+ ... embeddings = sample.embeddings # shape: (32, ...)
456456+ ...
457457+ >>> # Transform to a different view
458458+ >>> ds_view = ds.as_type(MyDataView)
274459 """
275460276461 # sample_class: Type = get_parameters( )
···280465281466 @property
282467 def sample_type( self ) -> Type:
283283- """The type of each returned sample from this `Dataset`'s iterator"""
284284- # TODO Figure out why linting fails here
468468+ """The type of each returned sample from this dataset's iterator.
469469+470470+ Returns:
471471+ The type parameter ``ST`` used when creating this ``Dataset[ST]``.
472472+473473+ Note:
474474+ Extracts the type parameter at runtime using ``__orig_class__``.
475475+ """
476476+ # NOTE: Linting may fail here due to __orig_class__ being a runtime attribute
285477 return typing.get_args( self.__orig_class__ )[0]
286478 @property
287479 def batch_type( self ) -> Type:
288288- """The type of a batch built from `sample_class`"""
480480+ """The type of batches produced by this dataset.
481481+482482+ Returns:
483483+ ``SampleBatch[ST]`` where ``ST`` is this dataset's sample type.
484484+ """
289485 # return self.__orig_class__.__args__[1]
290486 return SampleBatch[self.sample_type]
291487···296492 #
297493298494 def __init__( self, url: str ) -> None:
299299- """TODO"""
495495+ """Create a dataset from a WebDataset URL.
496496+497497+ Args:
498498+ url: WebDataset brace-notation URL pointing to tar files, e.g.,
499499+ ``"path/to/file-{000000..000009}.tar"`` for multiple shards or
500500+ ``"path/to/file-000000.tar"`` for a single shard.
501501+ """
300502 super().__init__()
301503 self.url = url
302504···304506 self._output_lens: Lens | None = None
305507306508 def as_type( self, other: Type[RT] ) -> 'Dataset[RT]':
307307- """TODO"""
509509+ """View this dataset through a different sample type using a registered lens.
510510+511511+ Args:
512512+ other: The target sample type to transform into. Must be a type
513513+ derived from ``PackableSample``.
514514+515515+ Returns:
516516+ A new ``Dataset`` instance that yields samples of type ``other``
517517+ by applying the appropriate lens transformation from the global
518518+ ``LensNetwork`` registry.
519519+520520+ Raises:
521521+ ValueError: If no registered lens exists between the current
522522+ sample type and the target type.
523523+ """
308524 ret = Dataset[other]( self.url )
309525 # Get the singleton lens registry
310526 lenses = LensNetwork()
···384600 buffer_samples: int = 10_000,
385601 batch_size: int | None = 1,
386602 ) -> Iterable[ST]:
387387- """Iterate over the dataset in random order
388388-603603+ """Iterate over the dataset in random order.
604604+389605 Args:
390390- buffer_shards (int): Asdf
391391- batch_size (:obj:`int`, optional) The size of iterated batches.
392392- Default: 1. If ``None``, iterates over one sample at a time
393393- with no batch dimension.
394394-606606+ buffer_shards: Number of shards to buffer for shuffling at the
607607+ shard level. Larger values increase randomness but use more
608608+ memory. Default: 100.
609609+ buffer_samples: Number of samples to buffer for shuffling within
610610+ shards. Larger values increase randomness but use more memory.
611611+ Default: 10,000.
612612+ batch_size: The size of iterated batches. Default: 1. If ``None``,
613613+ iterates over one sample at a time with no batch dimension.
614614+395615 Returns:
396396- :obj:`webdataset.DataPipeline` A data pipeline that iterates over
397397- the dataset in its original sample order
398398-616616+ A WebDataset data pipeline that iterates over the dataset in
617617+ randomized order. If ``batch_size`` is not ``None``, yields
618618+ ``SampleBatch[ST]`` instances; otherwise yields individual ``ST``
619619+ samples.
399620 """
400621401622 if batch_size is None:
···500721 # @classmethod
501722 # TODO replace Any with IT
502723 def wrap( self, sample: MsgpackRawSample ) -> ST:
503503- """Wrap a `sample` into the appropriate dataset-specific type"""
724724+ """Wrap a raw msgpack sample into the appropriate dataset-specific type.
725725+726726+ Args:
727727+ sample: A dictionary containing at minimum a ``'msgpack'`` key with
728728+ serialized sample bytes.
729729+730730+ Returns:
731731+ A deserialized sample of type ``ST``, optionally transformed through
732732+ a lens if ``as_type()`` was called.
733733+ """
504734 assert 'msgpack' in sample
505735 assert type( sample['msgpack'] ) == bytes
506736···524754 # )
525755526756 def wrap_batch( self, batch: WDSRawBatch ) -> SampleBatch[ST]:
527527- """Wrap a `batch` of samples into the appropriate dataset-specific type
528528-529529- This default implementation simply creates a list one sample at a time
757757+ """Wrap a batch of raw msgpack samples into a typed SampleBatch.
758758+759759+ Args:
760760+ batch: A dictionary containing a ``'msgpack'`` key with a list of
761761+ serialized sample bytes.
762762+763763+ Returns:
764764+ A ``SampleBatch[ST]`` containing deserialized samples, optionally
765765+ transformed through a lens if ``as_type()`` was called.
766766+767767+ Note:
768768+ This implementation deserializes samples one at a time, then
769769+ aggregates them into a batch.
530770 """
531771532772 assert 'msgpack' in batch
···572812# return decorator
573813574814def packable( cls ):
575575- """TODO"""
576576-815815+ """Decorator to convert a regular class into a ``PackableSample``.
816816+817817+ This decorator transforms a class into a dataclass that inherits from
818818+ ``PackableSample``, enabling automatic msgpack serialization/deserialization
819819+ with special handling for NDArray fields.
820820+821821+ Args:
822822+ cls: The class to convert. Should have type annotations for its fields.
823823+824824+ Returns:
825825+ A new dataclass that inherits from ``PackableSample`` with the same
826826+ name and annotations as the original class.
827827+828828+ Example:
829829+ >>> @packable
830830+ ... class MyData:
831831+ ... name: str
832832+ ... values: NDArray
833833+ ...
834834+ >>> sample = MyData(name="test", values=np.array([1, 2, 3]))
835835+ >>> bytes_data = sample.packed
836836+ >>> restored = MyData.from_bytes(bytes_data)
837837+ """
838838+577839 ##
578840579841 class_name = cls.__name__
+177-24
src/atdata/lens.py
···11-"""Lenses between typed datasets"""
11+"""Lens-based type transformations for datasets.
22+33+This module implements a lens system for bidirectional transformations between
44+different sample types. Lenses enable viewing a dataset through different type
55+schemas without duplicating the underlying data.
66+77+Key components:
88+99+- ``Lens``: Bidirectional transformation with getter (S -> V) and optional
1010+ putter (V, S -> S)
1111+- ``LensNetwork``: Global singleton registry for lens transformations
1212+- ``@lens``: Decorator to create and register lens transformations
1313+1414+Lenses support the functional programming concept of composable, well-behaved
1515+transformations that satisfy lens laws (GetPut and PutGet).
1616+1717+Example:
1818+ >>> @packable
1919+ ... class FullData:
2020+ ... name: str
2121+ ... age: int
2222+ ... embedding: NDArray
2323+ ...
2424+ >>> @packable
2525+ ... class NameOnly:
2626+ ... name: str
2727+ ...
2828+ >>> @lens
2929+ ... def name_view(full: FullData) -> NameOnly:
3030+ ... return NameOnly(name=full.name)
3131+ ...
3232+ >>> @name_view.putter
3333+ ... def name_view_put(view: NameOnly, source: FullData) -> FullData:
3434+ ... return FullData(name=view.name, age=source.age,
3535+ ... embedding=source.embedding)
3636+ ...
3737+ >>> ds = Dataset[FullData]("data.tar")
3838+ >>> ds_names = ds.as_type(NameOnly) # Uses registered lens
3939+"""
240341##
442# Imports
···3977# Shortcut decorators
40784179class Lens( Generic[S, V] ):
4242- """TODO"""
8080+ """A bidirectional transformation between two sample types.
8181+8282+ A lens provides a way to view and update data of type ``S`` (source) as if
8383+ it were type ``V`` (view). It consists of a getter that transforms ``S -> V``
8484+ and an optional putter that transforms ``(V, S) -> S``, enabling updates to
8585+ the view to be reflected back in the source.
43864444- # @property
4545- # def source_type( self ) -> Type[S]:
4646- # """The source type (S) for the lens; what is put to"""
4747- # # TODO Figure out why linting fails here
4848- # return self.__orig_class__.__args__[0]
8787+ Type Parameters:
8888+ S: The source type, must derive from ``PackableSample``.
8989+ V: The view type, must derive from ``PackableSample``.
49905050- # @property
5151- # def view_type( self ) -> Type[V]:
5252- # """The view type (V) for the lens; what is get'd from"""
5353- # # TODO FIgure out why linting fails here
5454- # return self.__orig_class__.__args__[1]
9191+ Example:
9292+ >>> @lens
9393+ ... def name_lens(full: FullData) -> NameOnly:
9494+ ... return NameOnly(name=full.name)
9595+ ...
9696+ >>> @name_lens.putter
9797+ ... def name_lens_put(view: NameOnly, source: FullData) -> FullData:
9898+ ... return FullData(name=view.name, age=source.age)
9999+ """
5510056101 def __init__( self, get: LensGetter[S, V],
57102 put: Optional[LensPutter[S, V]] = None
58103 ) -> None:
5959- """TODO"""
104104+ """Initialize a lens with a getter and optional putter function.
105105+106106+ Args:
107107+ get: A function that transforms from source type ``S`` to view type
108108+ ``V``. Must accept exactly one parameter annotated with the
109109+ source type.
110110+ put: An optional function that updates the source based on a modified
111111+ view. Takes a view of type ``V`` and original source of type ``S``,
112112+ and returns an updated source of type ``S``. If not provided, a
113113+ trivial putter is used that ignores updates to the view.
114114+115115+ Raises:
116116+ AssertionError: If the getter function doesn't have exactly one
117117+ parameter.
118118+ """
60119 ##
6112062121 # Check argument validity
···70129 functools.update_wrapper( self, get )
7113072131 self.source_type: Type[PackableSample] = input_types[0].annotation
7373- self.view_type = sig.return_annotation
132132+ self.view_type: Type[PackableSample] = sig.return_annotation
7413375134 # Store the getter
76135 self._getter = get
7777-136136+78137 # Determine and store the putter
79138 if put is None:
80139 # Trivial putter does not update the source
···86145 #
8714688147 def putter( self, put: LensPutter[S, V] ) -> LensPutter[S, V]:
8989- """TODO"""
148148+ """Decorator to register a putter function for this lens.
149149+150150+ Args:
151151+ put: A function that takes a view of type ``V`` and source of type
152152+ ``S``, and returns an updated source of type ``S``.
153153+154154+ Returns:
155155+ The putter function, allowing this to be used as a decorator.
156156+157157+ Example:
158158+ >>> @my_lens.putter
159159+ ... def my_lens_put(view: ViewType, source: SourceType) -> SourceType:
160160+ ... return SourceType(...)
161161+ """
90162 ##
91163 self._putter = put
92164 return put
···94166 # Methods to actually execute transformations
9516796168 def put( self, v: V, s: S ) -> S:
9797- """TODO"""
169169+ """Update the source based on a modified view.
170170+171171+ Args:
172172+ v: The modified view of type ``V``.
173173+ s: The original source of type ``S``.
174174+175175+ Returns:
176176+ An updated source of type ``S`` that reflects changes from the view.
177177+ """
98178 return self._putter( v, s )
99179100180 def get( self, s: S ) -> V:
101101- """TODO"""
181181+ """Transform the source into the view type.
182182+183183+ Args:
184184+ s: The source sample of type ``S``.
185185+186186+ Returns:
187187+ A view of the source as type ``V``.
188188+ """
102189 return self( s )
103190104191 # Convenience to enable calling the lens as its getter
105105-192192+106193 def __call__( self, s: S ) -> V:
194194+ """Apply the lens transformation (same as ``get()``).
195195+196196+ Args:
197197+ s: The source sample of type ``S``.
198198+199199+ Returns:
200200+ A view of the source as type ``V``.
201201+ """
107202 return self._getter( s )
108203109204# TODO Figure out how to properly parameterize this
···124219# lens = _lens_factory
125220126221def lens( f: LensGetter[S, V] ) -> Lens[S, V]:
222222+ """Decorator to create and register a lens transformation.
223223+224224+ This decorator converts a getter function into a ``Lens`` object and
225225+ automatically registers it in the global ``LensNetwork`` registry.
226226+227227+ Args:
228228+ f: A getter function that transforms from source type ``S`` to view
229229+ type ``V``. Must have exactly one parameter with a type annotation.
230230+231231+ Returns:
232232+ A ``Lens[S, V]`` object that can be called to apply the transformation
233233+ or decorated with ``@lens_name.putter`` to add a putter function.
234234+235235+ Example:
236236+ >>> @lens
237237+ ... def extract_name(full: FullData) -> NameOnly:
238238+ ... return NameOnly(name=full.name)
239239+ ...
240240+ >>> @extract_name.putter
241241+ ... def extract_name_put(view: NameOnly, source: FullData) -> FullData:
242242+ ... return FullData(name=view.name, age=source.age)
243243+ """
127244 ret = Lens[S, V]( f )
128245 _network.register( ret )
129246 return ret
···136253# """TODO"""
137254138255class LensNetwork:
139139- """TODO"""
256256+ """Global registry for lens transformations between sample types.
257257+258258+ This class implements a singleton pattern to maintain a global registry of
259259+ all lenses decorated with ``@lens``. It enables looking up transformations
260260+ between different ``PackableSample`` types.
261261+262262+ Attributes:
263263+ _instance: The singleton instance of this class.
264264+ _registry: Dictionary mapping ``(source_type, view_type)`` tuples to
265265+ their corresponding ``Lens`` objects.
266266+ """
140267141268 _instance = None
142269 """The singleton instance"""
143270144271 def __new__(cls, *args, **kwargs):
272272+ """Ensure only one instance of LensNetwork exists (singleton pattern)."""
145273 if cls._instance is None:
146274 # If no instance exists, create a new one
147275 cls._instance = super().__new__(cls)
148276 return cls._instance # Return the existing (or newly created) instance
149277150278 def __init__(self):
279279+ """Initialize the lens registry (only on first instantiation)."""
151280 if not hasattr(self, '_initialized'): # Check if already initialized
152281 self._registry: Dict[LensSignature, Lens] = dict()
153282 self._initialized = True
154283155284 def register( self, _lens: Lens ):
156156- """Set `lens` as the canonical view between its source and view types"""
157157-285285+ """Register a lens as the canonical transformation between two types.
286286+287287+ Args:
288288+ _lens: The lens to register. Will be stored in the registry under
289289+ the key ``(_lens.source_type, _lens.view_type)``.
290290+291291+ Note:
292292+ If a lens already exists for the same type pair, it will be
293293+ overwritten.
294294+ """
295295+158296 # sig = inspect.signature( _lens.get )
159297 # input_types = list( sig.parameters.values() )
160298 # assert len( input_types ) == 1, \
···169307 self._registry[_lens.source_type, _lens.view_type] = _lens
170308171309 def transform( self, source: DatasetType, view: DatasetType ) -> Lens:
172172- """TODO"""
310310+ """Look up the lens transformation between two sample types.
311311+312312+ Args:
313313+ source: The source sample type (must derive from ``PackableSample``).
314314+ view: The target view type (must derive from ``PackableSample``).
315315+316316+ Returns:
317317+ The registered ``Lens`` that transforms from ``source`` to ``view``.
318318+319319+ Raises:
320320+ ValueError: If no lens has been registered for the given type pair.
321321+322322+ Note:
323323+ Currently only supports direct transformations. Compositional
324324+ transformations (chaining multiple lenses) are not yet implemented.
325325+ """
173326174327 # TODO Handle compositional closure
175328 ret = self._registry.get( (source, view), None )
176329 if ret is None:
177330 raise ValueError( f'No registered lens from source {source} to view {view}' )
178178-331331+179332 return ret
180333181334