···1111### Fixed
12121313### Changed
1414+- Adversarial review: Test suite and codebase comprehensive assessment (#181)
1515+- Consolidate test sample type definitions into conftest.py (#184)
1616+- Trim verbose docstrings that restate function signatures (#189)
1717+- Replace assertions with explicit ValueError in Repo.insert (#187)
1818+- Add Redis key prefix constants to eliminate magic strings (#186)
1919+- Convert TODO comments to tracked issues or design notes (#185)
2020+- Add tests for schema_to_type with malformed/edge-case schemas (#183)
2121+- Remove duplicate shard writing logic between Repo.insert and S3DataStore (#182)
2222+- Remove unused Lens import from dataset.py (#188)
2323+- Build comprehensive markdown documentation for atdata (#171)
2424+- Write docs/protocols.md - Abstract protocols reference (#180)
2525+- Write docs/load-dataset.md - HuggingFace-style API (#179)
2626+- Write docs/promotion.md - Local to atmosphere workflow (#178)
2727+- Write docs/atmosphere.md - ATProto publishing and loading (#177)
2828+- Write docs/local-storage.md - LocalIndex, Repo, S3DataStore (#176)
2929+- Write docs/lenses.md - Lens transformations (#175)
3030+- Write docs/datasets.md - Dataset iteration and batching (#174)
3131+- Write docs/packable-samples.md - PackableSample and @packable (#173)
3232+- Write docs/index.md - overview and quick start guide (#172)
3333+- Adversarial review: Post Local-ATProto Reconciliation (#165)
3434+- Add error path tests for Dataset with invalid tar files (#170)
3535+- Convert TODO comment to design note in dataset.py (#169)
3636+- Replace O(n²) string prefix extraction with os.path.commonprefix (#168)
3737+- Remove unused Lens import from dataset.py (#167)
3838+- Extract shared dtype/type conversion to _type_utils.py (#166)
3939+- Local-ATProto Reconciliation Refactor (#111)
4040+- Phase 7: Documentation and Examples (#118)
4141+- Review and update docstrings for new public API (#164)
4242+- Create examples/promote_workflow.py demonstration (#163)
4343+- Create examples/local_workflow.py demonstration (#162)
4444+- Phase 6: Testing (protocols, integration, property tests) (#117)
4545+- Add schema deduplication integration test (#161)
4646+- Add integration test for local to atmosphere round-trip (#160)
4747+- Create test_protocols.py for protocol compliance tests (#159)
4848+- Phase 5: Local to Atmosphere Promotion Workflow (#116)
4949+- Add tests for promote.py (#158)
5050+- Implement schema deduplication helper (#157)
5151+- Create promote.py module with promote_to_atmosphere function (#156)
5252+- Adversarial review: Phase 4 code contraction (#148)
5353+- Simplify _python_type_to_field_type in local.py (#155)
5454+- Clean up unused imports and type: ignore comments (#153)
5555+- Precompile split detection regex patterns (#151)
5656+- Add missing error path tests for invalid msgpack data (#154)
5757+- Remove duplicate S3 write logic between Repo.insert and S3DataStore (#152)
5858+- Remove verbose docstrings that restate function signatures (#150)
5959+- Remove redundant _ensure_good() call in PackableSample.from_data() (#149)
6060+- Phase 4: Integrate with load_dataset() (@handle/dataset resolution) (#115)
6161+- Support auto-type resolution from index schema (#147)
6262+- Add @handle/dataset path resolution (#146)
6363+- Extend load_dataset signature with index parameter (#145)
6464+- Adversarial review: Phase 1 and Phase 2 implementation (#133)
6565+- Phase 3: Implement Concrete Classes (LocalIndex, AtmosphereIndex, S3DataStore) (#114)
6666+- Implement AtmosphereIndex wrapper (#144)
6767+- Implement S3DataStore class (#143)
6868+- Add AbstractIndex protocol methods to Index class (#142)
6969+- Adversarial review: Phase 1 & 2 implementation (#135)
7070+- Add test coverage for _schema_codec utility functions (#141)
7171+- Add missing test for parse_cid with malformed input (#140)
7272+- Merge clean_redis and clean_redis_schemas fixtures (#139)
7373+- Remove redundant return statement in Index.entries (#137)
7474+- DRY: Consolidate Redis deserialization in Index.entries vs LocalDatasetEntry.from_redis (#138)
7575+- Remove unused _decode_bytes_dict function from local.py (#136)
7676+- Adversarial review: Phase 1 & 2 implementation (#134)
7777+- Phase 2: Align Local with ATProto Record Formats (#113)
7878+- Add schema storage to local (LocalSchemaRecord in Redis) (#128)
7979+- Rename BasicIndexEntry to LocalDatasetEntry and implement IndexEntry protocol (#127)
8080+- Rename Index to LocalIndex and implement AbstractIndex protocol (#129)
8181+- Update Repo to use new LocalIndex API (#130)
8282+- Update test_local.py for renamed classes and new API (#131)
8383+- Add libipld dependency for CID generation (#132)
8484+- Phase 2: Refactor local.py to use new protocols (#113)
8585+- Add CID utilities module (_cid.py) with ATProto-compatible CID generation (#132)
8686+- Rename BasicIndexEntry to LocalDatasetEntry with CID + name dual identity (#127)
8787+- Add LocalIndex alias for Index class (#129)
8888+- Update Repo.insert() to require name parameter (#130)
8989+- Update test_local.py for new LocalDatasetEntry API (#131)
9090+- Revise AbstractIndex: Remove single-type generic, add schema decoding (#123)
9191+- Implement dynamic PackableSample class generation from schema (#126)
9292+- Add decode_schema() method to AbstractIndex (#125)
9393+- Remove generic type parameter from AbstractIndex (#124)
9494+- Phase 1: Define Abstract Protocols (_protocols.py) (#112)
9595+- Export protocols from __init__.py (#122)
9696+- Define AbstractDataStore protocol (#121)
9797+- Define AbstractIndex protocol (#120)
9898+- Define IndexEntry protocol (#119)
9999+- Review ATProto vs Local integration architecture convergence (#110)
14100- Add HuggingFace Datasets-style API to atdata (#103)
15101- Support streaming mode parameter (#108)
16102- Add split parameter handling (train/test/validation) (#107)
+374
docs/atmosphere.md
···11+# Atmosphere (ATProto Integration)
22+33+The atmosphere module enables publishing and discovering datasets on the ATProto network, creating a federated ecosystem for typed datasets.
44+55+## Installation
66+77+```bash
88+pip install atdata[atmosphere]
99+# or
1010+pip install atproto
1111+```
1212+1313+## Overview
1414+1515+ATProto integration publishes datasets, schemas, and lenses as records in the `ac.foundation.dataset.*` namespace. This enables:
1616+1717+- **Discovery** through the ATProto network
1818+- **Federation** across different hosts
1919+- **Verifiability** through content-addressable records
2020+2121+## AtmosphereClient
2222+2323+The client handles authentication and record operations:
2424+2525+```python
2626+from atdata.atmosphere import AtmosphereClient
2727+2828+client = AtmosphereClient()
2929+3030+# Login with app-specific password (not your main password!)
3131+client.login("alice.bsky.social", "app-password")
3232+3333+print(client.did) # 'did:plc:...'
3434+print(client.handle) # 'alice.bsky.social'
3535+```
3636+3737+### Session Management
3838+3939+Save and restore sessions to avoid re-authentication:
4040+4141+```python
4242+# Export session for later
4343+session_string = client.export_session()
4444+4545+# Later: restore session
4646+new_client = AtmosphereClient()
4747+new_client.login_with_session(session_string)
4848+```
4949+5050+### Custom PDS
5151+5252+Connect to a custom PDS instead of bsky.social:
5353+5454+```python
5555+client = AtmosphereClient(base_url="https://pds.example.com")
5656+```
5757+5858+## AtmosphereIndex
5959+6060+The unified interface for ATProto operations, implementing the AbstractIndex protocol:
6161+6262+```python
6363+from atdata.atmosphere import AtmosphereClient, AtmosphereIndex
6464+6565+client = AtmosphereClient()
6666+client.login("handle.bsky.social", "app-password")
6767+6868+index = AtmosphereIndex(client)
6969+```
7070+7171+### Publishing Schemas
7272+7373+```python
7474+@atdata.packable
7575+class ImageSample:
7676+ image: NDArray
7777+ label: str
7878+ confidence: float
7979+8080+# Publish schema
8181+schema_uri = index.publish_schema(
8282+ ImageSample,
8383+ version="1.0.0",
8484+ description="Image classification sample",
8585+)
8686+# Returns: "at://did:plc:.../ac.foundation.dataset.sampleSchema/..."
8787+```
8888+8989+### Publishing Datasets
9090+9191+```python
9292+dataset = atdata.Dataset[ImageSample]("data-{000000..000009}.tar")
9393+9494+entry = index.insert_dataset(
9595+ dataset,
9696+ name="imagenet-subset",
9797+ schema_ref=schema_uri, # Optional - auto-publishes if omitted
9898+ description="ImageNet subset",
9999+ tags=["images", "classification"],
100100+ license="MIT",
101101+)
102102+103103+print(entry.uri) # AT URI of the record
104104+print(entry.data_urls) # WebDataset URLs
105105+```
106106+107107+### Listing and Retrieving
108108+109109+```python
110110+# List your datasets
111111+for entry in index.list_datasets():
112112+ print(f"{entry.name}: {entry.schema_ref}")
113113+114114+# List from another user
115115+for entry in index.list_datasets(repo="did:plc:other-user"):
116116+ print(entry.name)
117117+118118+# Get specific dataset
119119+entry = index.get_dataset("at://did:plc:.../ac.foundation.dataset.record/...")
120120+121121+# List schemas
122122+for schema in index.list_schemas():
123123+ print(f"{schema['name']} v{schema['version']}")
124124+125125+# Decode schema to Python type
126126+SampleType = index.decode_schema(schema_uri)
127127+```
128128+129129+## Lower-Level Publishers
130130+131131+For more control, use the individual publisher classes:
132132+133133+### SchemaPublisher
134134+135135+```python
136136+from atdata.atmosphere import SchemaPublisher
137137+138138+publisher = SchemaPublisher(client)
139139+140140+uri = publisher.publish(
141141+ ImageSample,
142142+ name="ImageSample",
143143+ version="1.0.0",
144144+ description="Image with label",
145145+ metadata={"source": "training"},
146146+)
147147+```
148148+149149+### DatasetPublisher
150150+151151+```python
152152+from atdata.atmosphere import DatasetPublisher
153153+154154+publisher = DatasetPublisher(client)
155155+156156+uri = publisher.publish(
157157+ dataset,
158158+ name="training-images",
159159+ schema_uri=schema_uri, # Required if auto_publish_schema=False
160160+ auto_publish_schema=True, # Publish schema automatically
161161+ description="Training images",
162162+ tags=["training", "images"],
163163+ license="MIT",
164164+)
165165+```
166166+167167+### LensPublisher
168168+169169+```python
170170+from atdata.atmosphere import LensPublisher
171171+172172+publisher = LensPublisher(client)
173173+174174+# With code references
175175+uri = publisher.publish(
176176+ name="simplify",
177177+ source_schema=full_schema_uri,
178178+ target_schema=simple_schema_uri,
179179+ description="Extract label only",
180180+ getter_code={
181181+ "repository": "https://github.com/org/repo",
182182+ "commit": "abc123def...",
183183+ "path": "transforms/simplify.py:simplify_getter",
184184+ },
185185+ putter_code={
186186+ "repository": "https://github.com/org/repo",
187187+ "commit": "abc123def...",
188188+ "path": "transforms/simplify.py:simplify_putter",
189189+ },
190190+)
191191+192192+# Or publish from a Lens object
193193+from atdata.lens import lens
194194+195195+@lens
196196+def simplify(src: FullSample) -> SimpleSample:
197197+ return SimpleSample(label=src.label)
198198+199199+uri = publisher.publish_from_lens(
200200+ simplify,
201201+ source_schema=full_schema_uri,
202202+ target_schema=simple_schema_uri,
203203+)
204204+```
205205+206206+## AT URIs
207207+208208+ATProto records are identified by AT URIs:
209209+210210+```python
211211+from atdata.atmosphere import AtUri
212212+213213+# Parse an AT URI
214214+uri = AtUri.parse("at://did:plc:abc123/ac.foundation.dataset.sampleSchema/xyz")
215215+216216+print(uri.authority) # 'did:plc:abc123'
217217+print(uri.collection) # 'ac.foundation.dataset.sampleSchema'
218218+print(uri.rkey) # 'xyz'
219219+220220+# Format back to string
221221+print(str(uri)) # 'at://did:plc:abc123/ac.foundation.dataset.sampleSchema/xyz'
222222+```
223223+224224+## Record Types
225225+226226+### SchemaRecord
227227+228228+```python
229229+from atdata.atmosphere import SchemaRecord, FieldDef, FieldType
230230+231231+schema = SchemaRecord(
232232+ name="ImageSample",
233233+ version="1.0.0",
234234+ fields=[
235235+ FieldDef(
236236+ name="image",
237237+ field_type=FieldType(kind="ndarray", dtype="float32"),
238238+ ),
239239+ FieldDef(
240240+ name="label",
241241+ field_type=FieldType(kind="primitive", primitive="str"),
242242+ ),
243243+ ],
244244+ description="Image with label",
245245+)
246246+247247+record_dict = schema.to_record()
248248+```
249249+250250+### DatasetRecord
251251+252252+```python
253253+from atdata.atmosphere import DatasetRecord, StorageLocation
254254+255255+dataset_record = DatasetRecord(
256256+ name="training-images",
257257+ schema_ref="at://did:plc:.../...",
258258+ storage=StorageLocation(
259259+ kind="external",
260260+ urls=["s3://bucket/data-{000000..000009}.tar"],
261261+ ),
262262+ tags=["training"],
263263+ license="MIT",
264264+)
265265+```
266266+267267+### LensRecord
268268+269269+```python
270270+from atdata.atmosphere import LensRecord, CodeReference
271271+272272+lens_record = LensRecord(
273273+ name="simplify",
274274+ source_schema="at://did:plc:.../.../source",
275275+ target_schema="at://did:plc:.../.../target",
276276+ description="Simplify sample",
277277+ getter_code=CodeReference(
278278+ repository="https://github.com/org/repo",
279279+ commit="abc123",
280280+ path="transforms.py:simplify",
281281+ ),
282282+)
283283+```
284284+285285+## Supported Field Types
286286+287287+Schemas support these field types:
288288+289289+| Python Type | ATProto Type |
290290+|-------------|--------------|
291291+| `str` | `primitive/str` |
292292+| `int` | `primitive/int` |
293293+| `float` | `primitive/float` |
294294+| `bool` | `primitive/bool` |
295295+| `bytes` | `primitive/bytes` |
296296+| `NDArray` | `ndarray` (default dtype: float32) |
297297+| `NDArray[np.float64]` | `ndarray` (dtype: float64) |
298298+| `list[str]` | `array` with items |
299299+| `T \| None` | Optional field |
300300+301301+## Complete Example
302302+303303+```python
304304+import numpy as np
305305+from numpy.typing import NDArray
306306+import atdata
307307+from atdata.atmosphere import AtmosphereClient, AtmosphereIndex
308308+import webdataset as wds
309309+310310+# 1. Define and create samples
311311+@atdata.packable
312312+class FeatureSample:
313313+ features: NDArray
314314+ label: int
315315+ source: str
316316+317317+samples = [
318318+ FeatureSample(
319319+ features=np.random.randn(128).astype(np.float32),
320320+ label=i % 10,
321321+ source="synthetic",
322322+ )
323323+ for i in range(1000)
324324+]
325325+326326+# 2. Write to tar
327327+with wds.writer.TarWriter("features.tar") as sink:
328328+ for i, s in enumerate(samples):
329329+ sink.write({**s.as_wds, "__key__": f"{i:06d}"})
330330+331331+# 3. Authenticate
332332+client = AtmosphereClient()
333333+client.login("myhandle.bsky.social", "app-password")
334334+335335+index = AtmosphereIndex(client)
336336+337337+# 4. Publish schema
338338+schema_uri = index.publish_schema(
339339+ FeatureSample,
340340+ version="1.0.0",
341341+ description="Feature vectors with labels",
342342+)
343343+344344+# 5. Publish dataset
345345+dataset = atdata.Dataset[FeatureSample]("features.tar")
346346+entry = index.insert_dataset(
347347+ dataset,
348348+ name="synthetic-features-v1",
349349+ schema_ref=schema_uri,
350350+ tags=["features", "synthetic"],
351351+)
352352+353353+print(f"Published: {entry.uri}")
354354+355355+# 6. Later: discover and load
356356+for dataset_entry in index.list_datasets():
357357+ print(f"Found: {dataset_entry.name}")
358358+359359+ # Reconstruct type from schema
360360+ SampleType = index.decode_schema(dataset_entry.schema_ref)
361361+362362+ # Load dataset
363363+ ds = atdata.Dataset[SampleType](dataset_entry.data_urls[0])
364364+ for batch in ds.ordered(batch_size=32):
365365+ print(batch.features.shape)
366366+ break
367367+```
368368+369369+## Related
370370+371371+- [Local Storage](local-storage.md) - Redis + S3 backend
372372+- [Promotion](promotion.md) - Promoting local datasets to ATProto
373373+- [Protocols](protocols.md) - AbstractIndex interface
374374+- [Packable Samples](packable-samples.md) - Defining sample types
+193
docs/datasets.md
···11+# Datasets
22+33+The `Dataset` class provides typed iteration over WebDataset tar files with automatic batching and lens transformations.
44+55+## Creating a Dataset
66+77+```python
88+import atdata
99+1010+@atdata.packable
1111+class ImageSample:
1212+ image: NDArray
1313+ label: str
1414+1515+# Single shard
1616+dataset = atdata.Dataset[ImageSample]("data-000000.tar")
1717+1818+# Multiple shards with brace notation
1919+dataset = atdata.Dataset[ImageSample]("data-{000000..000009}.tar")
2020+```
2121+2222+The type parameter `[ImageSample]` specifies what sample type the dataset contains. This enables type-safe iteration and automatic deserialization.
2323+2424+## Iteration Modes
2525+2626+### Ordered Iteration
2727+2828+Iterate through samples in their original order:
2929+3030+```python
3131+# With batching (default batch_size=1)
3232+for batch in dataset.ordered(batch_size=32):
3333+ images = batch.image # numpy array (32, H, W, C)
3434+ labels = batch.label # list of 32 strings
3535+3636+# Without batching (raw samples)
3737+for sample in dataset.ordered(batch_size=None):
3838+ print(sample.label)
3939+```
4040+4141+### Shuffled Iteration
4242+4343+Iterate with randomized order at both shard and sample levels:
4444+4545+```python
4646+for batch in dataset.shuffled(batch_size=32):
4747+ # Samples are shuffled
4848+ process(batch)
4949+5050+# Control shuffle buffer sizes
5151+for batch in dataset.shuffled(
5252+ buffer_shards=100, # Shards to buffer (default: 100)
5353+ buffer_samples=10000, # Samples to buffer (default: 10,000)
5454+ batch_size=32,
5555+):
5656+ process(batch)
5757+```
5858+5959+Larger buffer sizes increase randomness but use more memory.
6060+6161+## SampleBatch
6262+6363+When iterating with a `batch_size`, each iteration yields a `SampleBatch` with automatic attribute aggregation.
6464+6565+```python
6666+@atdata.packable
6767+class Sample:
6868+ features: NDArray # shape (256,)
6969+ label: str
7070+ score: float
7171+7272+for batch in dataset.ordered(batch_size=16):
7373+ # NDArray fields are stacked with a batch dimension
7474+ features = batch.features # numpy array (16, 256)
7575+7676+ # Other fields become lists
7777+ labels = batch.label # list of 16 strings
7878+ scores = batch.score # list of 16 floats
7979+```
8080+8181+Results are cached, so accessing the same attribute multiple times is efficient.
8282+8383+## Type Transformations with Lenses
8484+8585+View a dataset through a different sample type using registered lenses:
8686+8787+```python
8888+@atdata.packable
8989+class SimplifiedSample:
9090+ label: str
9191+9292+@atdata.lens
9393+def simplify(src: ImageSample) -> SimplifiedSample:
9494+ return SimplifiedSample(label=src.label)
9595+9696+# Transform dataset to different type
9797+simple_ds = dataset.as_type(SimplifiedSample)
9898+9999+for batch in simple_ds.ordered(batch_size=16):
100100+ print(batch.label) # Only label field available
101101+```
102102+103103+See [Lenses](lenses.md) for details on defining transformations.
104104+105105+## Dataset Properties
106106+107107+### Shard List
108108+109109+Get the list of individual tar files:
110110+111111+```python
112112+dataset = atdata.Dataset[Sample]("data-{000000..000009}.tar")
113113+shards = dataset.shard_list
114114+# ['data-000000.tar', 'data-000001.tar', ..., 'data-000009.tar']
115115+```
116116+117117+### Metadata
118118+119119+Datasets can have associated metadata from a URL:
120120+121121+```python
122122+dataset = atdata.Dataset[Sample](
123123+ "data-{000000..000009}.tar",
124124+ metadata_url="https://example.com/metadata.msgpack"
125125+)
126126+127127+# Fetched and cached on first access
128128+metadata = dataset.metadata # dict or None
129129+```
130130+131131+## Writing Datasets
132132+133133+Use WebDataset's `TarWriter` or `ShardWriter` to create datasets:
134134+135135+```python
136136+import webdataset as wds
137137+138138+samples = [
139139+ ImageSample(image=np.random.rand(224, 224, 3).astype(np.float32), label="cat")
140140+ for _ in range(100)
141141+]
142142+143143+# Single tar file
144144+with wds.writer.TarWriter("data-000000.tar") as sink:
145145+ for i, sample in enumerate(samples):
146146+ sink.write({**sample.as_wds, "__key__": f"sample_{i:06d}"})
147147+148148+# Multiple shards with automatic splitting
149149+with wds.writer.ShardWriter("data-%06d.tar", maxcount=1000) as sink:
150150+ for i, sample in enumerate(samples):
151151+ sink.write({**sample.as_wds, "__key__": f"sample_{i:06d}"})
152152+```
153153+154154+## Parquet Export
155155+156156+Export dataset contents to parquet format:
157157+158158+```python
159159+# Export entire dataset
160160+dataset.to_parquet("output.parquet")
161161+162162+# Export with custom field mapping
163163+def extract_fields(sample):
164164+ return {"label": sample.label, "score": sample.confidence}
165165+166166+dataset.to_parquet("output.parquet", sample_map=extract_fields)
167167+168168+# Export in segments
169169+dataset.to_parquet("output.parquet", maxcount=10000)
170170+# Creates output-000000.parquet, output-000001.parquet, etc.
171171+```
172172+173173+## URL Formats
174174+175175+WebDataset supports various URL formats:
176176+177177+```python
178178+# Local files
179179+dataset = atdata.Dataset[Sample]("./data/file.tar")
180180+dataset = atdata.Dataset[Sample]("/absolute/path/file-{000000..000009}.tar")
181181+182182+# S3 (requires s3fs)
183183+dataset = atdata.Dataset[Sample]("s3://bucket/path/file-{000000..000009}.tar")
184184+185185+# HTTP/HTTPS
186186+dataset = atdata.Dataset[Sample]("https://example.com/data-{000000..000009}.tar")
187187+```
188188+189189+## Related
190190+191191+- [Packable Samples](packable-samples.md) - Defining typed samples
192192+- [Lenses](lenses.md) - Type transformations
193193+- [load_dataset](load-dataset.md) - HuggingFace-style loading API
+150
docs/index.md
···11+# atdata
22+33+A loose federation of distributed, typed datasets built on WebDataset.
44+55+## What is atdata?
66+77+atdata provides a typed dataset abstraction for machine learning workflows with:
88+99+- **Typed samples** with automatic msgpack serialization
1010+- **NDArray handling** with transparent numpy array conversion
1111+- **Lens transformations** for viewing datasets through different schemas
1212+- **Batch aggregation** with automatic numpy stacking
1313+- **WebDataset integration** for efficient large-scale storage
1414+- **ATProto federation** for publishing and discovering datasets
1515+1616+## Installation
1717+1818+```bash
1919+pip install atdata
2020+2121+# With ATProto support
2222+pip install atdata[atmosphere]
2323+```
2424+2525+## Quick Start
2626+2727+### Define a Sample Type
2828+2929+```python
3030+import numpy as np
3131+from numpy.typing import NDArray
3232+import atdata
3333+3434+@atdata.packable
3535+class ImageSample:
3636+ image: NDArray
3737+ label: str
3838+ confidence: float
3939+```
4040+4141+### Create and Write Samples
4242+4343+```python
4444+import webdataset as wds
4545+4646+samples = [
4747+ ImageSample(
4848+ image=np.random.rand(224, 224, 3).astype(np.float32),
4949+ label="cat",
5050+ confidence=0.95,
5151+ )
5252+ for _ in range(100)
5353+]
5454+5555+with wds.writer.TarWriter("data-000000.tar") as sink:
5656+ for i, sample in enumerate(samples):
5757+ sink.write({**sample.as_wds, "__key__": f"sample_{i:06d}"})
5858+```
5959+6060+### Load and Iterate
6161+6262+```python
6363+dataset = atdata.Dataset[ImageSample]("data-000000.tar")
6464+6565+# Iterate with batching
6666+for batch in dataset.shuffled(batch_size=32):
6767+ images = batch.image # numpy array (32, 224, 224, 3)
6868+ labels = batch.label # list of 32 strings
6969+ confs = batch.confidence # list of 32 floats
7070+```
7171+7272+### Use Lenses for Type Transformations
7373+7474+```python
7575+@atdata.packable
7676+class SimplifiedSample:
7777+ label: str
7878+7979+@atdata.lens
8080+def simplify(src: ImageSample) -> SimplifiedSample:
8181+ return SimplifiedSample(label=src.label)
8282+8383+# View dataset through a different type
8484+simple_ds = dataset.as_type(SimplifiedSample)
8585+for batch in simple_ds.ordered(batch_size=16):
8686+ print(batch.label)
8787+```
8888+8989+## HuggingFace-Style Loading
9090+9191+```python
9292+# Load from local path
9393+ds = atdata.load_dataset("path/to/data-{000000..000009}.tar", split="train")
9494+9595+# Load with split detection
9696+ds_dict = atdata.load_dataset("path/to/data/")
9797+train_ds = ds_dict["train"]
9898+test_ds = ds_dict["test"]
9999+```
100100+101101+## Local Storage with Redis + S3
102102+103103+```python
104104+from atdata.local import LocalIndex, Repo
105105+106106+# Set up local index
107107+index = LocalIndex() # Connects to Redis
108108+109109+# Create repo with S3 storage
110110+repo = Repo(
111111+ s3_credentials={"AWS_ENDPOINT": "http://localhost:9000", ...},
112112+ bucket="my-bucket",
113113+ index=index,
114114+)
115115+116116+# Insert dataset
117117+entry = repo.insert(samples, name="my-dataset")
118118+print(f"Stored at: {entry.data_urls}")
119119+```
120120+121121+## Publish to ATProto Federation
122122+123123+```python
124124+from atdata.atmosphere import AtmosphereClient
125125+from atdata.promote import promote_to_atmosphere
126126+127127+# Authenticate
128128+client = AtmosphereClient()
129129+client.login("handle.bsky.social", "app-password")
130130+131131+# Promote local dataset to federation
132132+entry = index.get_dataset("my-dataset")
133133+at_uri = promote_to_atmosphere(entry, index, client)
134134+print(f"Published at: {at_uri}")
135135+```
136136+137137+## Documentation
138138+139139+- [Packable Samples](packable-samples.md) - Defining typed samples
140140+- [Datasets](datasets.md) - Loading and iterating datasets
141141+- [Lenses](lenses.md) - Type transformations
142142+- [Local Storage](local-storage.md) - Redis + S3 backend
143143+- [Atmosphere](atmosphere.md) - ATProto federation
144144+- [Promotion](promotion.md) - Local to atmosphere workflow
145145+- [load_dataset](load-dataset.md) - HuggingFace-style API
146146+- [Protocols](protocols.md) - Abstract interfaces
147147+148148+## License
149149+150150+MIT
+195
docs/lenses.md
···11+# Lenses
22+33+Lenses provide bidirectional transformations between sample types, enabling datasets to be viewed through different schemas without duplicating data.
44+55+## Overview
66+77+A lens consists of:
88+- **Getter**: Transforms source type `S` to view type `V`
99+- **Putter**: Updates source based on a modified view (optional)
1010+1111+## Creating a Lens
1212+1313+Use the `@lens` decorator to define a getter:
1414+1515+```python
1616+import atdata
1717+from numpy.typing import NDArray
1818+1919+@atdata.packable
2020+class FullSample:
2121+ image: NDArray
2222+ label: str
2323+ confidence: float
2424+ metadata: dict
2525+2626+@atdata.packable
2727+class SimpleSample:
2828+ label: str
2929+ confidence: float
3030+3131+@atdata.lens
3232+def simplify(src: FullSample) -> SimpleSample:
3333+ return SimpleSample(label=src.label, confidence=src.confidence)
3434+```
3535+3636+The decorator:
3737+1. Creates a `Lens` object from the getter function
3838+2. Registers it in the global `LensNetwork` registry
3939+3. Extracts source/view types from annotations
4040+4141+## Adding a Putter
4242+4343+To enable bidirectional updates, add a putter:
4444+4545+```python
4646+@simplify.putter
4747+def simplify_put(view: SimpleSample, source: FullSample) -> FullSample:
4848+ return FullSample(
4949+ image=source.image,
5050+ label=view.label,
5151+ confidence=view.confidence,
5252+ metadata=source.metadata,
5353+ )
5454+```
5555+5656+The putter receives:
5757+- `view`: The modified view value
5858+- `source`: The original source value
5959+6060+It returns an updated source that reflects changes from the view.
6161+6262+## Using Lenses with Datasets
6363+6464+Lenses integrate with `Dataset.as_type()`:
6565+6666+```python
6767+dataset = atdata.Dataset[FullSample]("data-{000000..000009}.tar")
6868+6969+# View through a different type
7070+simple_ds = dataset.as_type(SimpleSample)
7171+7272+for batch in simple_ds.ordered(batch_size=32):
7373+ # Only SimpleSample fields available
7474+ labels = batch.label
7575+ scores = batch.confidence
7676+```
7777+7878+## Direct Lens Usage
7979+8080+Lenses can also be called directly:
8181+8282+```python
8383+full = FullSample(
8484+ image=np.zeros((224, 224, 3)),
8585+ label="cat",
8686+ confidence=0.95,
8787+ metadata={"source": "training"}
8888+)
8989+9090+# Apply getter
9191+simple = simplify(full)
9292+# Or: simple = simplify.get(full)
9393+9494+# Apply putter
9595+modified_simple = SimpleSample(label="dog", confidence=0.87)
9696+updated_full = simplify.put(modified_simple, full)
9797+# updated_full has label="dog", confidence=0.87, but retains
9898+# original image and metadata
9999+```
100100+101101+## Lens Laws
102102+103103+Well-behaved lenses should satisfy these properties:
104104+105105+### GetPut Law
106106+If you get a view and immediately put it back, the source is unchanged:
107107+```python
108108+view = lens.get(source)
109109+assert lens.put(view, source) == source
110110+```
111111+112112+### PutGet Law
113113+If you put a view, getting it back yields that view:
114114+```python
115115+updated = lens.put(view, source)
116116+assert lens.get(updated) == view
117117+```
118118+119119+### PutPut Law
120120+Putting twice is equivalent to putting once with the final value:
121121+```python
122122+result1 = lens.put(v2, lens.put(v1, source))
123123+result2 = lens.put(v2, source)
124124+assert result1 == result2
125125+```
126126+127127+## Trivial Putter
128128+129129+If no putter is defined, a trivial putter is used that ignores view updates:
130130+131131+```python
132132+@atdata.lens
133133+def extract_label(src: FullSample) -> SimpleSample:
134134+ return SimpleSample(label=src.label, confidence=src.confidence)
135135+136136+# Without a putter, put() returns the original source unchanged
137137+view = SimpleSample(label="modified", confidence=0.5)
138138+updated = extract_label.put(view, original)
139139+assert updated == original # No changes applied
140140+```
141141+142142+## LensNetwork Registry
143143+144144+The `LensNetwork` is a singleton that stores all registered lenses:
145145+146146+```python
147147+from atdata.lens import LensNetwork
148148+149149+network = LensNetwork()
150150+151151+# Look up a specific lens
152152+lens = network.transform(FullSample, SimpleSample)
153153+154154+# Raises ValueError if no lens exists
155155+try:
156156+ lens = network.transform(TypeA, TypeB)
157157+except ValueError:
158158+ print("No lens registered for TypeA -> TypeB")
159159+```
160160+161161+## Example: Feature Extraction
162162+163163+```python
164164+@atdata.packable
165165+class RawSample:
166166+ audio: NDArray
167167+ text: str
168168+ speaker_id: int
169169+170170+@atdata.packable
171171+class TextFeatures:
172172+ text: str
173173+ word_count: int
174174+175175+@atdata.lens
176176+def extract_text(src: RawSample) -> TextFeatures:
177177+ return TextFeatures(
178178+ text=src.text,
179179+ word_count=len(src.text.split())
180180+ )
181181+182182+@extract_text.putter
183183+def extract_text_put(view: TextFeatures, source: RawSample) -> RawSample:
184184+ return RawSample(
185185+ audio=source.audio,
186186+ text=view.text,
187187+ speaker_id=source.speaker_id
188188+ )
189189+```
190190+191191+## Related
192192+193193+- [Datasets](datasets.md) - Using lenses with Dataset.as_type()
194194+- [Packable Samples](packable-samples.md) - Defining sample types
195195+- [Atmosphere](atmosphere.md) - Publishing lenses to ATProto federation
+248
docs/load-dataset.md
···11+# load_dataset API
22+33+The `load_dataset()` function provides a HuggingFace Datasets-style interface for loading typed datasets.
44+55+## Overview
66+77+Key differences from HuggingFace Datasets:
88+- Requires explicit `sample_type` parameter (typed dataclass) unless using index
99+- Returns `atdata.Dataset[ST]` instead of HF Dataset
1010+- Built on WebDataset for efficient streaming
1111+- No Arrow caching layer
1212+1313+## Basic Usage
1414+1515+```python
1616+import atdata
1717+from atdata import load_dataset
1818+1919+@atdata.packable
2020+class TextSample:
2121+ text: str
2222+ label: int
2323+2424+# Load a specific split
2525+train_ds = load_dataset("path/to/data.tar", TextSample, split="train")
2626+2727+# Load all splits (returns DatasetDict)
2828+ds_dict = load_dataset("path/to/data/", TextSample)
2929+train_ds = ds_dict["train"]
3030+test_ds = ds_dict["test"]
3131+```
3232+3333+## Path Formats
3434+3535+### WebDataset Brace Notation
3636+3737+```python
3838+# Range notation
3939+ds = load_dataset("data-{000000..000099}.tar", MySample, split="train")
4040+4141+# List notation
4242+ds = load_dataset("data-{train,test,val}.tar", MySample, split="train")
4343+```
4444+4545+### Glob Patterns
4646+4747+```python
4848+# Match all tar files
4949+ds = load_dataset("path/to/*.tar", MySample)
5050+5151+# Match pattern
5252+ds = load_dataset("path/to/train-*.tar", MySample, split="train")
5353+```
5454+5555+### Local Directory
5656+5757+```python
5858+# Scans for .tar files
5959+ds = load_dataset("./my-dataset/", MySample)
6060+```
6161+6262+### Remote URLs
6363+6464+```python
6565+# S3
6666+ds = load_dataset("s3://bucket/data-{000..099}.tar", MySample, split="train")
6767+6868+# HTTP/HTTPS
6969+ds = load_dataset("https://example.com/data.tar", MySample, split="train")
7070+7171+# Google Cloud Storage
7272+ds = load_dataset("gs://bucket/data.tar", MySample, split="train")
7373+```
7474+7575+### Index Lookup
7676+7777+```python
7878+from atdata.local import LocalIndex
7979+8080+index = LocalIndex()
8181+8282+# Load from local index (auto-resolves type from schema)
8383+ds = load_dataset("@local/my-dataset", index=index, split="train")
8484+8585+# With explicit type
8686+ds = load_dataset("@local/my-dataset", MySample, index=index, split="train")
8787+```
8888+8989+## Split Detection
9090+9191+Splits are automatically detected from filenames and directories:
9292+9393+| Pattern | Detected Split |
9494+|---------|---------------|
9595+| `train-*.tar`, `training-*.tar` | train |
9696+| `test-*.tar`, `testing-*.tar` | test |
9797+| `val-*.tar`, `valid-*.tar`, `validation-*.tar` | validation |
9898+| `dev-*.tar`, `development-*.tar` | validation |
9999+| `train/*.tar` (directory) | train |
100100+| `test/*.tar` (directory) | test |
101101+102102+Files without a detected split default to "train".
103103+104104+## DatasetDict
105105+106106+When loading without `split=`, returns a `DatasetDict`:
107107+108108+```python
109109+ds_dict = load_dataset("path/to/data/", MySample)
110110+111111+# Access splits
112112+train_ds = ds_dict["train"]
113113+test_ds = ds_dict["test"]
114114+115115+# Iterate splits
116116+for name, dataset in ds_dict.items():
117117+ print(f"{name}: {len(dataset.shard_list)} shards")
118118+119119+# Properties
120120+print(ds_dict.num_shards) # {'train': 10, 'test': 2}
121121+print(ds_dict.sample_type) # <class 'MySample'>
122122+print(ds_dict.streaming) # False
123123+```
124124+125125+## Explicit Data Files
126126+127127+Override automatic detection with `data_files`:
128128+129129+```python
130130+# Single pattern
131131+ds = load_dataset(
132132+ "path/to/",
133133+ MySample,
134134+ data_files="custom-*.tar",
135135+)
136136+137137+# List of patterns
138138+ds = load_dataset(
139139+ "path/to/",
140140+ MySample,
141141+ data_files=["shard-000.tar", "shard-001.tar"],
142142+)
143143+144144+# Explicit split mapping
145145+ds = load_dataset(
146146+ "path/to/",
147147+ MySample,
148148+ data_files={
149149+ "train": "training-shards-*.tar",
150150+ "test": "eval-data.tar",
151151+ },
152152+)
153153+```
154154+155155+## Streaming Mode
156156+157157+The `streaming` parameter signals intent for streaming mode:
158158+159159+```python
160160+# Mark as streaming
161161+ds_dict = load_dataset("path/to/data.tar", MySample, streaming=True)
162162+163163+# Check streaming status
164164+if ds_dict.streaming:
165165+ print("Streaming mode")
166166+```
167167+168168+Note: atdata datasets are always lazy/streaming via WebDataset pipelines. This parameter primarily signals intent.
169169+170170+## Auto Type Resolution
171171+172172+When using index lookup, the sample type can be resolved automatically:
173173+174174+```python
175175+from atdata.local import LocalIndex
176176+177177+index = LocalIndex()
178178+179179+# No sample_type needed - resolved from schema
180180+ds = load_dataset("@local/my-dataset", index=index, split="train")
181181+182182+# Type is inferred from the stored schema
183183+sample_type = ds.sample_type
184184+```
185185+186186+## Error Handling
187187+188188+```python
189189+try:
190190+ ds = load_dataset("path/to/data.tar", MySample, split="train")
191191+except FileNotFoundError:
192192+ print("No data files found")
193193+except ValueError as e:
194194+ if "Split" in str(e):
195195+ print("Requested split not found")
196196+ else:
197197+ print(f"Invalid configuration: {e}")
198198+except KeyError:
199199+ print("Dataset not found in index")
200200+```
201201+202202+## Complete Example
203203+204204+```python
205205+import numpy as np
206206+from numpy.typing import NDArray
207207+import atdata
208208+from atdata import load_dataset
209209+import webdataset as wds
210210+211211+# 1. Define sample type
212212+@atdata.packable
213213+class ImageSample:
214214+ image: NDArray
215215+ label: str
216216+217217+# 2. Create dataset files
218218+for split in ["train", "test"]:
219219+ with wds.writer.TarWriter(f"{split}-000.tar") as sink:
220220+ for i in range(100):
221221+ sample = ImageSample(
222222+ image=np.random.rand(64, 64, 3).astype(np.float32),
223223+ label=f"sample_{i}",
224224+ )
225225+ sink.write({**sample.as_wds, "__key__": f"{i:06d}"})
226226+227227+# 3. Load with split detection
228228+ds_dict = load_dataset("./", ImageSample)
229229+print(ds_dict.keys()) # dict_keys(['train', 'test'])
230230+231231+# 4. Iterate
232232+for batch in ds_dict["train"].ordered(batch_size=16):
233233+ print(batch.image.shape) # (16, 64, 64, 3)
234234+ print(batch.label) # ['sample_0', 'sample_1', ...]
235235+ break
236236+237237+# 5. Load specific split
238238+train_ds = load_dataset("./", ImageSample, split="train")
239239+for batch in train_ds.ordered(batch_size=32):
240240+ process(batch)
241241+```
242242+243243+## Related
244244+245245+- [Datasets](datasets.md) - Dataset iteration and batching
246246+- [Packable Samples](packable-samples.md) - Defining sample types
247247+- [Local Storage](local-storage.md) - LocalIndex for index lookup
248248+- [Protocols](protocols.md) - AbstractIndex interface
+279
docs/local-storage.md
···11+# Local Storage
22+33+The local storage module provides a Redis + S3 backend for storing and managing datasets before publishing to the ATProto federation.
44+55+## Overview
66+77+Local storage uses:
88+- **Redis** for indexing and tracking dataset metadata
99+- **S3-compatible storage** for dataset tar files
1010+1111+This enables development and small-scale deployment before promoting to the full ATProto infrastructure.
1212+1313+## LocalIndex
1414+1515+The index tracks datasets in Redis:
1616+1717+```python
1818+from atdata.local import LocalIndex
1919+2020+# Default connection (localhost:6379)
2121+index = LocalIndex()
2222+2323+# Custom Redis connection
2424+import redis
2525+r = redis.Redis(host='custom-host', port=6379)
2626+index = LocalIndex(redis=r)
2727+2828+# With connection kwargs
2929+index = LocalIndex(host='custom-host', port=6379, db=1)
3030+```
3131+3232+### Adding Entries
3333+3434+```python
3535+dataset = atdata.Dataset[ImageSample]("data-{000000..000009}.tar")
3636+3737+entry = index.add_entry(
3838+ dataset,
3939+ name="my-dataset",
4040+ schema_ref="local://schemas/mymodule.ImageSample@1.0.0", # optional
4141+ metadata={"description": "Training images"}, # optional
4242+)
4343+4444+print(entry.cid) # Content identifier
4545+print(entry.name) # "my-dataset"
4646+print(entry.data_urls) # ["data-{000000..000009}.tar"]
4747+```
4848+4949+### Listing and Retrieving
5050+5151+```python
5252+# Iterate all entries
5353+for entry in index.entries:
5454+ print(f"{entry.name}: {entry.cid}")
5555+5656+# Get as list
5757+all_entries = index.all_entries
5858+5959+# Get by name
6060+entry = index.get_entry_by_name("my-dataset")
6161+6262+# Get by CID
6363+entry = index.get_entry("bafyrei...")
6464+```
6565+6666+## Repo
6767+6868+The Repo class combines S3 storage with Redis indexing:
6969+7070+```python
7171+from atdata.local import Repo
7272+7373+# From credentials file
7474+repo = Repo(
7575+ s3_credentials="path/to/.env",
7676+ hive_path="my-bucket/datasets",
7777+)
7878+7979+# From credentials dict
8080+repo = Repo(
8181+ s3_credentials={
8282+ "AWS_ENDPOINT": "http://localhost:9000",
8383+ "AWS_ACCESS_KEY_ID": "minioadmin",
8484+ "AWS_SECRET_ACCESS_KEY": "minioadmin",
8585+ },
8686+ hive_path="my-bucket/datasets",
8787+)
8888+```
8989+9090+### Credentials File Format
9191+9292+The `.env` file should contain:
9393+9494+```
9595+AWS_ENDPOINT=http://localhost:9000
9696+AWS_ACCESS_KEY_ID=your-access-key
9797+AWS_SECRET_ACCESS_KEY=your-secret-key
9898+```
9999+100100+For AWS S3, omit `AWS_ENDPOINT` to use the default endpoint.
101101+102102+### Inserting Datasets
103103+104104+```python
105105+@atdata.packable
106106+class ImageSample:
107107+ image: NDArray
108108+ label: str
109109+110110+# Create dataset from samples
111111+samples = [ImageSample(...) for _ in range(1000)]
112112+with wds.writer.TarWriter("temp.tar") as sink:
113113+ for i, s in enumerate(samples):
114114+ sink.write({**s.as_wds, "__key__": f"{i:06d}"})
115115+116116+dataset = atdata.Dataset[ImageSample]("temp.tar")
117117+118118+# Insert into repo (writes to S3 + indexes in Redis)
119119+entry, stored_dataset = repo.insert(
120120+ dataset,
121121+ name="training-images-v1",
122122+ cache_local=False, # Stream directly to S3
123123+)
124124+125125+print(entry.cid) # Content identifier
126126+print(stored_dataset.url) # S3 URL for the stored data
127127+print(stored_dataset.shard_list) # Individual shard URLs
128128+```
129129+130130+### Insert Options
131131+132132+```python
133133+entry, ds = repo.insert(
134134+ dataset,
135135+ name="my-dataset",
136136+ cache_local=True, # Write locally first, then copy (faster for some workloads)
137137+ maxcount=10000, # Samples per shard
138138+ maxsize=100_000_000, # Max shard size in bytes
139139+)
140140+```
141141+142142+## LocalDatasetEntry
143143+144144+Index entries provide content-addressable identification:
145145+146146+```python
147147+entry = index.get_entry_by_name("my-dataset")
148148+149149+# Core properties (IndexEntry protocol)
150150+entry.name # Human-readable name
151151+entry.schema_ref # Schema reference
152152+entry.data_urls # WebDataset URLs
153153+entry.metadata # Arbitrary metadata dict or None
154154+155155+# Content addressing
156156+entry.cid # ATProto-compatible CID (content identifier)
157157+158158+# Legacy compatibility
159159+entry.wds_url # First data URL
160160+entry.sample_kind # Same as schema_ref
161161+```
162162+163163+The CID is generated from the entry's content (schema_ref + data_urls), ensuring identical data produces identical CIDs whether stored locally or in the atmosphere.
164164+165165+## Schema Storage
166166+167167+Schemas can be stored and retrieved from the index:
168168+169169+```python
170170+# Publish a schema
171171+schema_ref = index.publish_schema(
172172+ ImageSample,
173173+ version="1.0.0",
174174+ description="Image with label annotation",
175175+)
176176+# Returns: "local://schemas/mymodule.ImageSample@1.0.0"
177177+178178+# Retrieve schema record
179179+schema = index.get_schema(schema_ref)
180180+# {
181181+# "name": "ImageSample",
182182+# "version": "1.0.0",
183183+# "fields": [...],
184184+# "description": "...",
185185+# "createdAt": "...",
186186+# }
187187+188188+# List all schemas
189189+for schema in index.list_schemas():
190190+ print(f"{schema['name']}@{schema['version']}")
191191+192192+# Reconstruct sample type from schema
193193+SampleType = index.decode_schema(schema_ref)
194194+dataset = atdata.Dataset[SampleType](entry.data_urls[0])
195195+```
196196+197197+## S3DataStore
198198+199199+For direct S3 operations without Redis indexing:
200200+201201+```python
202202+from atdata.local import S3DataStore
203203+204204+store = S3DataStore(
205205+ credentials="path/to/.env",
206206+ bucket="my-bucket",
207207+)
208208+209209+# Write dataset shards
210210+urls = store.write_shards(
211211+ dataset,
212212+ prefix="datasets/v1",
213213+ maxcount=10000,
214214+)
215215+# Returns: ["s3://my-bucket/datasets/v1/data--uuid--000000.tar", ...]
216216+217217+# Check capabilities
218218+store.supports_streaming() # True
219219+```
220220+221221+## Complete Workflow Example
222222+223223+```python
224224+import numpy as np
225225+from numpy.typing import NDArray
226226+import atdata
227227+from atdata.local import Repo, LocalIndex
228228+import webdataset as wds
229229+230230+# 1. Define sample type
231231+@atdata.packable
232232+class TrainingSample:
233233+ features: NDArray
234234+ label: int
235235+ source: str
236236+237237+# 2. Create samples
238238+samples = [
239239+ TrainingSample(
240240+ features=np.random.randn(128).astype(np.float32),
241241+ label=i % 10,
242242+ source="synthetic",
243243+ )
244244+ for i in range(10000)
245245+]
246246+247247+# 3. Write to local tar
248248+with wds.writer.TarWriter("local-data.tar") as sink:
249249+ for i, sample in enumerate(samples):
250250+ sink.write({**sample.as_wds, "__key__": f"{i:06d}"})
251251+252252+# 4. Create repo and insert
253253+repo = Repo(
254254+ s3_credentials={
255255+ "AWS_ENDPOINT": "http://localhost:9000",
256256+ "AWS_ACCESS_KEY_ID": "minioadmin",
257257+ "AWS_SECRET_ACCESS_KEY": "minioadmin",
258258+ },
259259+ hive_path="datasets-bucket/training",
260260+)
261261+262262+local_ds = atdata.Dataset[TrainingSample]("local-data.tar")
263263+entry, stored_ds = repo.insert(local_ds, name="training-v1")
264264+265265+# 5. Retrieve later
266266+index = LocalIndex()
267267+entry = index.get_entry_by_name("training-v1")
268268+dataset = atdata.Dataset[TrainingSample](entry.data_urls[0])
269269+270270+for batch in dataset.ordered(batch_size=32):
271271+ print(batch.features.shape) # (32, 128)
272272+```
273273+274274+## Related
275275+276276+- [Datasets](datasets.md) - Dataset iteration and batching
277277+- [Protocols](protocols.md) - AbstractIndex and IndexEntry interfaces
278278+- [Promotion](promotion.md) - Promoting local datasets to ATProto
279279+- [Atmosphere](atmosphere.md) - ATProto federation
+183
docs/packable-samples.md
···11+# Packable Samples
22+33+Packable samples are typed dataclasses that can be serialized with msgpack for storage in WebDataset tar files.
44+55+## The `@packable` Decorator
66+77+The recommended way to define a sample type is with the `@packable` decorator:
88+99+```python
1010+import numpy as np
1111+from numpy.typing import NDArray
1212+import atdata
1313+1414+@atdata.packable
1515+class ImageSample:
1616+ image: NDArray
1717+ label: str
1818+ confidence: float
1919+```
2020+2121+This creates a dataclass that:
2222+- Inherits from `PackableSample`
2323+- Has automatic msgpack serialization
2424+- Handles NDArray conversion to/from bytes
2525+2626+## Supported Field Types
2727+2828+### Primitives
2929+3030+```python
3131+@atdata.packable
3232+class PrimitiveSample:
3333+ name: str
3434+ count: int
3535+ score: float
3636+ active: bool
3737+ data: bytes
3838+```
3939+4040+### NumPy Arrays
4141+4242+Fields annotated as `NDArray` are automatically converted:
4343+4444+```python
4545+@atdata.packable
4646+class ArraySample:
4747+ features: NDArray # Required array
4848+ embeddings: NDArray | None # Optional array
4949+```
5050+5151+**Note**: Bytes in NDArray-typed fields are always interpreted as serialized arrays. Don't use `NDArray` for raw binary data.
5252+5353+### Lists
5454+5555+```python
5656+@atdata.packable
5757+class ListSample:
5858+ tags: list[str]
5959+ scores: list[float]
6060+```
6161+6262+## Serialization
6363+6464+### Packing to Bytes
6565+6666+```python
6767+sample = ImageSample(
6868+ image=np.random.rand(224, 224, 3).astype(np.float32),
6969+ label="cat",
7070+ confidence=0.95,
7171+)
7272+7373+# Serialize to msgpack bytes
7474+packed_bytes = sample.packed
7575+print(f"Size: {len(packed_bytes)} bytes")
7676+```
7777+7878+### Unpacking from Bytes
7979+8080+```python
8181+# Deserialize from bytes
8282+restored = ImageSample.from_bytes(packed_bytes)
8383+8484+# Arrays are automatically restored
8585+assert np.array_equal(sample.image, restored.image)
8686+assert sample.label == restored.label
8787+```
8888+8989+### WebDataset Format
9090+9191+The `as_wds` property returns a dict ready for WebDataset:
9292+9393+```python
9494+wds_dict = sample.as_wds
9595+# {'__key__': '1234...', 'msgpack': b'...'}
9696+```
9797+9898+Write samples to a tar file:
9999+100100+```python
101101+import webdataset as wds
102102+103103+with wds.writer.TarWriter("data-000000.tar") as sink:
104104+ for i, sample in enumerate(samples):
105105+ # Use custom key or let as_wds generate one
106106+ sink.write({**sample.as_wds, "__key__": f"sample_{i:06d}"})
107107+```
108108+109109+## Direct Inheritance (Alternative)
110110+111111+You can also inherit directly from `PackableSample`:
112112+113113+```python
114114+from dataclasses import dataclass
115115+116116+@dataclass
117117+class DirectSample(atdata.PackableSample):
118118+ name: str
119119+ values: NDArray
120120+```
121121+122122+This is equivalent to using `@packable` but more verbose.
123123+124124+## How It Works
125125+126126+### Serialization Flow
127127+128128+1. **Packing** (`sample.packed`):
129129+ - NDArray fields → converted to bytes via `array_to_bytes()`
130130+ - Other fields → passed through
131131+ - All fields → packed with msgpack
132132+133133+2. **Unpacking** (`Sample.from_bytes()`):
134134+ - Bytes → unpacked with ormsgpack
135135+ - Dict → passed to `__init__`
136136+ - `__post_init__` → calls `_ensure_good()`
137137+ - NDArray fields → bytes converted back to arrays
138138+139139+### The `_ensure_good()` Method
140140+141141+This method runs automatically after construction and handles NDArray conversion:
142142+143143+```python
144144+def _ensure_good(self):
145145+ for field in dataclasses.fields(self):
146146+ if _is_possibly_ndarray_type(field.type):
147147+ value = getattr(self, field.name)
148148+ if isinstance(value, bytes):
149149+ setattr(self, field.name, bytes_to_array(value))
150150+```
151151+152152+## Best Practices
153153+154154+### Do
155155+156156+```python
157157+@atdata.packable
158158+class GoodSample:
159159+ features: NDArray # Clear type annotation
160160+ label: str # Simple primitives
161161+ metadata: dict # Msgpack-compatible dicts
162162+ scores: list[float] # Typed lists
163163+```
164164+165165+### Don't
166166+167167+```python
168168+@atdata.packable
169169+class BadSample:
170170+ # DON'T: Nested dataclasses not supported
171171+ nested: OtherSample
172172+173173+ # DON'T: Complex objects that aren't msgpack-serializable
174174+ callback: Callable
175175+176176+ # DON'T: Use NDArray for raw bytes
177177+ raw_data: NDArray # Use 'bytes' type instead
178178+```
179179+180180+## Related
181181+182182+- [Datasets](datasets.md) - Loading and iterating samples
183183+- [Lenses](lenses.md) - Transforming between sample types
+191
docs/promotion.md
···11+# Promotion Workflow
22+33+The promotion workflow migrates datasets from local storage (Redis + S3) to the ATProto atmosphere network, enabling federation and discovery.
44+55+## Overview
66+77+Promotion handles:
88+- **Schema deduplication**: Avoids publishing duplicate schemas
99+- **Data URL preservation**: Keeps existing S3 URLs or copies to new storage
1010+- **Metadata transfer**: Preserves tags, descriptions, and custom metadata
1111+1212+## Basic Usage
1313+1414+```python
1515+from atdata.local import LocalIndex
1616+from atdata.atmosphere import AtmosphereClient
1717+from atdata.promote import promote_to_atmosphere
1818+1919+# Setup
2020+local_index = LocalIndex()
2121+client = AtmosphereClient()
2222+client.login("handle.bsky.social", "app-password")
2323+2424+# Get local entry
2525+entry = local_index.get_entry_by_name("my-dataset")
2626+2727+# Promote to atmosphere
2828+at_uri = promote_to_atmosphere(entry, local_index, client)
2929+print(f"Published: {at_uri}")
3030+```
3131+3232+## With Metadata
3333+3434+```python
3535+at_uri = promote_to_atmosphere(
3636+ entry,
3737+ local_index,
3838+ client,
3939+ name="my-dataset-v2", # Override name
4040+ description="Training images", # Add description
4141+ tags=["images", "training"], # Add discovery tags
4242+ license="MIT", # Specify license
4343+)
4444+```
4545+4646+## Schema Deduplication
4747+4848+The promotion workflow automatically checks for existing schemas:
4949+5050+```python
5151+# First promotion: publishes schema
5252+uri1 = promote_to_atmosphere(entry1, local_index, client)
5353+5454+# Second promotion with same schema type + version: reuses existing schema
5555+uri2 = promote_to_atmosphere(entry2, local_index, client)
5656+```
5757+5858+Schema matching is based on:
5959+- `{module}.{class_name}` (e.g., `mymodule.ImageSample`)
6060+- Version string (e.g., `1.0.0`)
6161+6262+## Data Storage Options
6363+6464+### Use Existing URLs (Default)
6565+6666+By default, promotion keeps the original data URLs:
6767+6868+```python
6969+# Data stays in original S3 location
7070+at_uri = promote_to_atmosphere(entry, local_index, client)
7171+```
7272+7373+### Copy to New Storage
7474+7575+To copy data to a different storage location:
7676+7777+```python
7878+from atdata.local import S3DataStore
7979+8080+# Create new data store
8181+new_store = S3DataStore(
8282+ credentials="new-s3-creds.env",
8383+ bucket="public-datasets",
8484+)
8585+8686+# Promote with data copy
8787+at_uri = promote_to_atmosphere(
8888+ entry,
8989+ local_index,
9090+ client,
9191+ data_store=new_store, # Copy data to new storage
9292+)
9393+```
9494+9595+## Complete Workflow Example
9696+9797+```python
9898+import numpy as np
9999+from numpy.typing import NDArray
100100+import atdata
101101+from atdata.local import LocalIndex, Repo
102102+from atdata.atmosphere import AtmosphereClient
103103+from atdata.promote import promote_to_atmosphere
104104+import webdataset as wds
105105+106106+# 1. Define sample type
107107+@atdata.packable
108108+class FeatureSample:
109109+ features: NDArray
110110+ label: int
111111+112112+# 2. Create local dataset
113113+samples = [
114114+ FeatureSample(
115115+ features=np.random.randn(128).astype(np.float32),
116116+ label=i % 10,
117117+ )
118118+ for i in range(1000)
119119+]
120120+121121+with wds.writer.TarWriter("features.tar") as sink:
122122+ for i, s in enumerate(samples):
123123+ sink.write({**s.as_wds, "__key__": f"{i:06d}"})
124124+125125+# 3. Store in local repo
126126+repo = Repo(
127127+ s3_credentials={
128128+ "AWS_ENDPOINT": "http://localhost:9000",
129129+ "AWS_ACCESS_KEY_ID": "minioadmin",
130130+ "AWS_SECRET_ACCESS_KEY": "minioadmin",
131131+ },
132132+ hive_path="datasets-bucket/features",
133133+)
134134+135135+dataset = atdata.Dataset[FeatureSample]("features.tar")
136136+local_entry, _ = repo.insert(dataset, name="feature-vectors-v1")
137137+138138+# 4. Publish schema to local index
139139+local_index = LocalIndex()
140140+local_index.publish_schema(FeatureSample, version="1.0.0")
141141+142142+# 5. Promote to atmosphere
143143+client = AtmosphereClient()
144144+client.login("myhandle.bsky.social", "app-password")
145145+146146+at_uri = promote_to_atmosphere(
147147+ local_entry,
148148+ local_index,
149149+ client,
150150+ description="Feature vectors for classification",
151151+ tags=["features", "embeddings"],
152152+ license="MIT",
153153+)
154154+155155+print(f"Dataset published: {at_uri}")
156156+157157+# 6. Verify on atmosphere
158158+from atdata.atmosphere import AtmosphereIndex
159159+160160+atm_index = AtmosphereIndex(client)
161161+entry = atm_index.get_dataset(at_uri)
162162+print(f"Name: {entry.name}")
163163+print(f"Schema: {entry.schema_ref}")
164164+print(f"URLs: {entry.data_urls}")
165165+```
166166+167167+## Error Handling
168168+169169+```python
170170+try:
171171+ at_uri = promote_to_atmosphere(entry, local_index, client)
172172+except KeyError as e:
173173+ # Schema not found in local index
174174+ print(f"Missing schema: {e}")
175175+except ValueError as e:
176176+ # Entry has no data URLs
177177+ print(f"Invalid entry: {e}")
178178+```
179179+180180+## Requirements
181181+182182+Before promotion:
183183+1. Dataset must be in local index (via `Repo.insert()` or `Index.add_entry()`)
184184+2. Schema must be published to local index (via `Index.publish_schema()`)
185185+3. AtmosphereClient must be authenticated
186186+187187+## Related
188188+189189+- [Local Storage](local-storage.md) - Setting up local datasets
190190+- [Atmosphere](atmosphere.md) - ATProto integration
191191+- [Protocols](protocols.md) - AbstractIndex and AbstractDataStore
+243
docs/protocols.md
···11+# Protocols
22+33+The protocols module defines abstract interfaces that enable interchangeable index backends (local Redis vs ATProto) and data stores (S3 vs PDS blobs).
44+55+## Overview
66+77+Both local and atmosphere implementations solve the same problem: indexed dataset storage with external data URLs. These protocols formalize that common interface:
88+99+- **IndexEntry**: Common interface for dataset index entries
1010+- **AbstractIndex**: Protocol for index operations
1111+- **AbstractDataStore**: Protocol for data storage operations
1212+1313+## IndexEntry Protocol
1414+1515+Represents a dataset entry in any index:
1616+1717+```python
1818+from atdata._protocols import IndexEntry
1919+2020+def process_entry(entry: IndexEntry) -> None:
2121+ print(f"Name: {entry.name}")
2222+ print(f"Schema: {entry.schema_ref}")
2323+ print(f"URLs: {entry.data_urls}")
2424+ print(f"Metadata: {entry.metadata}")
2525+```
2626+2727+### Properties
2828+2929+| Property | Type | Description |
3030+|----------|------|-------------|
3131+| `name` | `str` | Human-readable dataset name |
3232+| `schema_ref` | `str` | Schema reference (local:// or at://) |
3333+| `data_urls` | `list[str]` | WebDataset URLs for the data |
3434+| `metadata` | `dict \| None` | Arbitrary metadata dictionary |
3535+3636+### Implementations
3737+3838+- `LocalDatasetEntry` (from `atdata.local`)
3939+- `AtmosphereIndexEntry` (from `atdata.atmosphere`)
4040+4141+## AbstractIndex Protocol
4242+4343+Defines operations for managing schemas and datasets:
4444+4545+```python
4646+from atdata._protocols import AbstractIndex
4747+4848+def list_all_datasets(index: AbstractIndex) -> None:
4949+ """Works with LocalIndex or AtmosphereIndex."""
5050+ for entry in index.list_datasets():
5151+ print(f"{entry.name}: {entry.schema_ref}")
5252+```
5353+5454+### Dataset Operations
5555+5656+```python
5757+# Insert a dataset
5858+entry = index.insert_dataset(
5959+ dataset,
6060+ name="my-dataset",
6161+ schema_ref="local://schemas/MySample@1.0.0", # optional
6262+)
6363+6464+# Get by name/reference
6565+entry = index.get_dataset("my-dataset")
6666+6767+# List all datasets
6868+for entry in index.list_datasets():
6969+ print(entry.name)
7070+```
7171+7272+### Schema Operations
7373+7474+```python
7575+# Publish a schema
7676+schema_ref = index.publish_schema(
7777+ MySample,
7878+ version="1.0.0",
7979+)
8080+8181+# Get schema record
8282+schema = index.get_schema(schema_ref)
8383+print(schema["name"], schema["version"])
8484+8585+# List all schemas
8686+for schema in index.list_schemas():
8787+ print(f"{schema['name']}@{schema['version']}")
8888+8989+# Decode schema to Python type
9090+SampleType = index.decode_schema(schema_ref)
9191+dataset = atdata.Dataset[SampleType](entry.data_urls[0])
9292+```
9393+9494+### Implementations
9595+9696+- `LocalIndex` / `Index` (from `atdata.local`)
9797+- `AtmosphereIndex` (from `atdata.atmosphere`)
9898+9999+## AbstractDataStore Protocol
100100+101101+Abstracts over different storage backends:
102102+103103+```python
104104+from atdata._protocols import AbstractDataStore
105105+106106+def write_dataset(store: AbstractDataStore, dataset) -> list[str]:
107107+ """Works with S3DataStore or future PDS blob store."""
108108+ urls = store.write_shards(dataset, prefix="datasets/v1")
109109+ return urls
110110+```
111111+112112+### Methods
113113+114114+```python
115115+# Write dataset shards
116116+urls = store.write_shards(
117117+ dataset,
118118+ prefix="datasets/mnist/v1",
119119+ maxcount=10000, # samples per shard
120120+)
121121+122122+# Resolve URL for reading
123123+readable_url = store.read_url("s3://bucket/path.tar")
124124+125125+# Check streaming support
126126+if store.supports_streaming():
127127+ # Can stream directly
128128+ pass
129129+```
130130+131131+### Implementations
132132+133133+- `S3DataStore` (from `atdata.local`)
134134+135135+## Using Protocols for Polymorphism
136136+137137+Write code that works with any backend:
138138+139139+```python
140140+from atdata._protocols import AbstractIndex, IndexEntry
141141+from atdata import Dataset
142142+143143+def backup_all_datasets(
144144+ source: AbstractIndex,
145145+ target: AbstractIndex,
146146+) -> None:
147147+ """Copy all datasets from source index to target."""
148148+ for entry in source.list_datasets():
149149+ # Decode schema from source
150150+ SampleType = source.decode_schema(entry.schema_ref)
151151+152152+ # Publish schema to target
153153+ target_schema = target.publish_schema(SampleType)
154154+155155+ # Load and re-insert dataset
156156+ ds = Dataset[SampleType](entry.data_urls[0])
157157+ target.insert_dataset(
158158+ ds,
159159+ name=entry.name,
160160+ schema_ref=target_schema,
161161+ )
162162+```
163163+164164+## Schema Reference Formats
165165+166166+Schema references vary by backend:
167167+168168+| Backend | Format | Example |
169169+|---------|--------|---------|
170170+| Local | `local://schemas/{module.Class}@{version}` | `local://schemas/myapp.ImageSample@1.0.0` |
171171+| Atmosphere | `at://{did}/{collection}/{rkey}` | `at://did:plc:abc123/ac.foundation.dataset.sampleSchema/xyz` |
172172+173173+## Type Checking
174174+175175+Protocols are runtime-checkable:
176176+177177+```python
178178+from atdata._protocols import IndexEntry, AbstractIndex
179179+180180+# Check if object implements protocol
181181+entry = index.get_dataset("test")
182182+assert isinstance(entry, IndexEntry)
183183+184184+# Type hints work with protocols
185185+def process(index: AbstractIndex) -> None:
186186+ ... # IDE provides autocomplete
187187+```
188188+189189+## Complete Example
190190+191191+```python
192192+import atdata
193193+from atdata.local import LocalIndex, S3DataStore
194194+from atdata.atmosphere import AtmosphereClient, AtmosphereIndex
195195+from atdata._protocols import AbstractIndex
196196+import numpy as np
197197+from numpy.typing import NDArray
198198+199199+# Define sample type
200200+@atdata.packable
201201+class FeatureSample:
202202+ features: NDArray
203203+ label: int
204204+205205+# Function works with any index
206206+def count_datasets(index: AbstractIndex) -> int:
207207+ return sum(1 for _ in index.list_datasets())
208208+209209+# Use with local index
210210+local_index = LocalIndex()
211211+print(f"Local datasets: {count_datasets(local_index)}")
212212+213213+# Use with atmosphere index
214214+client = AtmosphereClient()
215215+client.login("handle.bsky.social", "app-password")
216216+atm_index = AtmosphereIndex(client)
217217+print(f"Atmosphere datasets: {count_datasets(atm_index)}")
218218+219219+# Migrate from local to atmosphere
220220+def migrate_dataset(
221221+ name: str,
222222+ source: AbstractIndex,
223223+ target: AbstractIndex,
224224+) -> None:
225225+ entry = source.get_dataset(name)
226226+ SampleType = source.decode_schema(entry.schema_ref)
227227+228228+ # Publish schema
229229+ schema_ref = target.publish_schema(SampleType)
230230+231231+ # Create dataset and insert
232232+ ds = atdata.Dataset[SampleType](entry.data_urls[0])
233233+ target.insert_dataset(ds, name=name, schema_ref=schema_ref)
234234+235235+migrate_dataset("my-features", local_index, atm_index)
236236+```
237237+238238+## Related
239239+240240+- [Local Storage](local-storage.md) - LocalIndex and S3DataStore
241241+- [Atmosphere](atmosphere.md) - AtmosphereIndex
242242+- [Promotion](promotion.md) - Local to atmosphere migration
243243+- [load_dataset](load-dataset.md) - Using indexes with load_dataset()
+312
examples/local_workflow.py
···11+#!/usr/bin/env python3
22+"""Demonstration of atdata local storage workflow.
33+44+This script demonstrates how to use the local module to store and index
55+datasets using Redis and S3-compatible storage.
66+77+Usage:
88+ # Dry run with mocks (no Redis/S3 required):
99+ python local_workflow.py
1010+1111+ # With actual Redis (requires redis-server running):
1212+ python local_workflow.py --redis
1313+1414+ # With Redis and S3 (requires MinIO or AWS):
1515+ python local_workflow.py --redis --s3-endpoint http://localhost:9000
1616+1717+Requirements:
1818+ pip install atdata redis
1919+2020+Note:
2121+ For S3 storage, you can use MinIO for local development:
2222+ docker run -p 9000:9000 minio/minio server /data
2323+"""
2424+2525+import argparse
2626+import tempfile
2727+from datetime import datetime
2828+from pathlib import Path
2929+3030+import numpy as np
3131+from numpy.typing import NDArray
3232+3333+import atdata
3434+from atdata.local import LocalIndex, LocalDatasetEntry, Repo, S3DataStore
3535+3636+3737+# =============================================================================
3838+# Define sample types
3939+# =============================================================================
4040+4141+@atdata.packable
4242+class TrainingSample:
4343+ """A sample containing features and label for training."""
4444+ features: NDArray
4545+ label: int
4646+4747+4848+@atdata.packable
4949+class TextSample:
5050+ """A sample containing text data."""
5151+ text: str
5252+ category: str
5353+5454+5555+# =============================================================================
5656+# Demo functions
5757+# =============================================================================
5858+5959+def demo_local_dataset_entry():
6060+ """Demonstrate LocalDatasetEntry creation and CID generation."""
6161+ print("\n" + "=" * 60)
6262+ print("LocalDatasetEntry Demo")
6363+ print("=" * 60)
6464+6565+ # Create an entry
6666+ entry = LocalDatasetEntry(
6767+ _name="my-dataset",
6868+ _schema_ref="local://schemas/examples.TrainingSample@1.0.0",
6969+ _data_urls=["s3://bucket/data-000000.tar", "s3://bucket/data-000001.tar"],
7070+ _metadata={"source": "example", "samples": 10000},
7171+ )
7272+7373+ print(f"\nEntry name: {entry.name}")
7474+ print(f"Schema ref: {entry.schema_ref}")
7575+ print(f"Data URLs: {entry.data_urls}")
7676+ print(f"Metadata: {entry.metadata}")
7777+ print(f"CID: {entry.cid}")
7878+7979+ # Demonstrate CID determinism
8080+ entry2 = LocalDatasetEntry(
8181+ _name="different-name", # Name doesn't affect CID
8282+ _schema_ref="local://schemas/examples.TrainingSample@1.0.0",
8383+ _data_urls=["s3://bucket/data-000000.tar", "s3://bucket/data-000001.tar"],
8484+ )
8585+8686+ print(f"\nCID comparison (same content, different name):")
8787+ print(f" Entry 1 CID: {entry.cid}")
8888+ print(f" Entry 2 CID: {entry2.cid}")
8989+ print(f" Match: {entry.cid == entry2.cid}")
9090+9191+9292+def demo_local_index_mock():
9393+ """Demonstrate LocalIndex operations with mock data."""
9494+ print("\n" + "=" * 60)
9595+ print("LocalIndex Demo (mock)")
9696+ print("=" * 60)
9797+9898+ # LocalIndex without Redis connection works for read operations
9999+ index = LocalIndex()
100100+101101+ print("\nLocalIndex created (no Redis connection)")
102102+ print("Methods available:")
103103+ print(" - index.insert_dataset(dataset, name='...')")
104104+ print(" - index.get_dataset(name_or_cid)")
105105+ print(" - index.list_datasets()")
106106+ print(" - index.publish_schema(sample_type, version='1.0.0')")
107107+ print(" - index.get_schema(ref)")
108108+ print(" - index.list_schemas()")
109109+ print(" - index.decode_schema(ref) # Returns PackableSample class")
110110+111111+112112+def demo_local_index_redis(redis_host: str = "localhost", redis_port: int = 6379):
113113+ """Demonstrate LocalIndex with actual Redis."""
114114+ print("\n" + "=" * 60)
115115+ print("LocalIndex Demo (Redis)")
116116+ print("=" * 60)
117117+118118+ from redis import Redis
119119+120120+ # Connect to Redis
121121+ try:
122122+ redis = Redis(host=redis_host, port=redis_port)
123123+ redis.ping()
124124+ except Exception as e:
125125+ print(f"Could not connect to Redis: {e}")
126126+ print("Skipping Redis demo.")
127127+ return
128128+129129+ # Create index with Redis
130130+ index = LocalIndex(redis=redis)
131131+ print(f"\nConnected to Redis at {redis_host}:{redis_port}")
132132+133133+ # Publish a schema
134134+ print("\nPublishing TrainingSample schema...")
135135+ schema_ref = index.publish_schema(TrainingSample, version="1.0.0")
136136+ print(f" Schema ref: {schema_ref}")
137137+138138+ # List schemas
139139+ print("\nListing schemas:")
140140+ for schema in index.list_schemas():
141141+ print(f" - {schema.get('name', 'Unknown')} v{schema.get('version', '?')}")
142142+143143+ # Get schema and decode to type
144144+ schema_record = index.get_schema(schema_ref)
145145+ print(f"\nSchema record: {schema_record.get('name')}")
146146+ print(f" Fields: {[f['name'] for f in schema_record.get('fields', [])]}")
147147+148148+ # Decode schema back to a PackableSample class
149149+ decoded_type = index.decode_schema(schema_ref)
150150+ print(f"\nDecoded type: {decoded_type.__name__}")
151151+152152+ # Clean up test data
153153+ for key in redis.scan_iter(match="LocalSchema:*"):
154154+ redis.delete(key)
155155+ print("\nCleaned up test schemas")
156156+157157+158158+def demo_s3_datastore():
159159+ """Demonstrate S3DataStore interface."""
160160+ print("\n" + "=" * 60)
161161+ print("S3DataStore Demo")
162162+ print("=" * 60)
163163+164164+ # S3DataStore with mock credentials (won't actually connect)
165165+ creds = {
166166+ "AWS_ENDPOINT": "http://localhost:9000",
167167+ "AWS_ACCESS_KEY_ID": "minioadmin",
168168+ "AWS_SECRET_ACCESS_KEY": "minioadmin",
169169+ }
170170+171171+ store = S3DataStore(creds, bucket="my-bucket")
172172+173173+ print(f"\nS3DataStore created:")
174174+ print(f" Bucket: {store.bucket}")
175175+ print(f" Supports streaming: {store.supports_streaming()}")
176176+177177+ # read_url returns the URL unchanged (passthrough for WDS)
178178+ url = "s3://my-bucket/data.tar"
179179+ print(f"\nread_url passthrough: {store.read_url(url)}")
180180+181181+182182+def demo_repo_workflow(tmp_path: Path):
183183+ """Demonstrate full Repo workflow with local files."""
184184+ import webdataset as wds
185185+186186+ print("\n" + "=" * 60)
187187+ print("Repo Workflow Demo (local files)")
188188+ print("=" * 60)
189189+190190+ # Create sample data
191191+ samples = [
192192+ TrainingSample(features=np.random.randn(10).astype(np.float32), label=i % 3)
193193+ for i in range(100)
194194+ ]
195195+196196+ print(f"\nCreated {len(samples)} training samples")
197197+198198+ # Create a Dataset and write to local tar file
199199+ tar_path = tmp_path / "local-data-000000.tar"
200200+ with wds.writer.TarWriter(str(tar_path)) as sink:
201201+ for i, sample in enumerate(samples):
202202+ sink.write({**sample.as_wds, "__key__": f"sample_{i:06d}"})
203203+204204+ print(f"Wrote samples to: {tar_path}")
205205+206206+ # Load the dataset back
207207+ ds = atdata.Dataset[TrainingSample](str(tar_path))
208208+ loaded = list(ds.ordered(batch_size=None))
209209+ print(f"Loaded {len(loaded)} samples back")
210210+211211+ # Verify round-trip
212212+ assert len(loaded) == len(samples)
213213+ assert np.allclose(loaded[0].features, samples[0].features)
214214+ print("Round-trip verification: PASSED")
215215+216216+217217+def demo_load_dataset_with_index():
218218+ """Demonstrate load_dataset with index parameter."""
219219+ print("\n" + "=" * 60)
220220+ print("load_dataset with Index Demo")
221221+ print("=" * 60)
222222+223223+ print("""
224224+The load_dataset() function supports an index parameter for both local
225225+and atmosphere backends:
226226+227227+ # Local index lookup
228228+ from atdata import load_dataset
229229+ from atdata.local import LocalIndex
230230+231231+ index = LocalIndex()
232232+ ds = load_dataset('@local/my-dataset', index=index, split='train')
233233+234234+ # The index resolves the dataset name to URLs and schema
235235+ for batch in ds.shuffled(batch_size=32):
236236+ process(batch)
237237+238238+ # Atmosphere lookup (via @handle/dataset syntax)
239239+ ds = load_dataset('@alice.science/mnist', split='train')
240240+241241+ # This automatically:
242242+ # 1. Resolves the handle to a DID
243243+ # 2. Fetches the dataset record from the user's repository
244244+ # 3. Gets the data URLs from the record
245245+ # 4. Resolves the schema for type information
246246+""")
247247+248248+249249+# =============================================================================
250250+# Main
251251+# =============================================================================
252252+253253+def main():
254254+ parser = argparse.ArgumentParser(
255255+ description="Demonstrate atdata local storage workflow",
256256+ formatter_class=argparse.RawDescriptionHelpFormatter,
257257+ epilog=__doc__,
258258+ )
259259+ parser.add_argument(
260260+ "--redis",
261261+ action="store_true",
262262+ help="Run demos that require Redis",
263263+ )
264264+ parser.add_argument(
265265+ "--redis-host",
266266+ default="localhost",
267267+ help="Redis host (default: localhost)",
268268+ )
269269+ parser.add_argument(
270270+ "--redis-port",
271271+ type=int,
272272+ default=6379,
273273+ help="Redis port (default: 6379)",
274274+ )
275275+ parser.add_argument(
276276+ "--s3-endpoint",
277277+ help="S3 endpoint URL for live S3 demo",
278278+ )
279279+280280+ args = parser.parse_args()
281281+282282+ print("=" * 60)
283283+ print("atdata.local Demo")
284284+ print("=" * 60)
285285+ print(f"\nTime: {datetime.now().isoformat()}")
286286+287287+ # Always run these demos (no external services required)
288288+ demo_local_dataset_entry()
289289+ demo_local_index_mock()
290290+ demo_s3_datastore()
291291+ demo_load_dataset_with_index()
292292+293293+ # Run with temp directory for file-based demos
294294+ with tempfile.TemporaryDirectory() as tmp:
295295+ demo_repo_workflow(Path(tmp))
296296+297297+ # Run Redis demo if requested
298298+ if args.redis:
299299+ demo_local_index_redis(args.redis_host, args.redis_port)
300300+ else:
301301+ print("\n" + "=" * 60)
302302+ print("Redis Demo Skipped")
303303+ print("=" * 60)
304304+ print("\nTo run with Redis: python local_workflow.py --redis")
305305+306306+ print("\n" + "=" * 60)
307307+ print("Demo Complete!")
308308+ print("=" * 60)
309309+310310+311311+if __name__ == "__main__":
312312+ main()
+394
examples/promote_workflow.py
···11+#!/usr/bin/env python3
22+"""Demonstration of promoting local datasets to the atmosphere network.
33+44+This script demonstrates the workflow for migrating datasets from local
55+Redis/S3 storage to the federated ATProto atmosphere network.
66+77+Usage:
88+ # Dry run with mocks (no external services required):
99+ python promote_workflow.py
1010+1111+ # With actual ATProto connection:
1212+ python promote_workflow.py --handle your.handle --password your-app-password
1313+1414+Requirements:
1515+ pip install atdata[atmosphere]
1616+1717+Note:
1818+ Use an app-specific password, not your main Bluesky password.
1919+ Create app passwords at: https://bsky.app/settings/app-passwords
2020+"""
2121+2222+import argparse
2323+from datetime import datetime
2424+from unittest.mock import Mock, MagicMock
2525+2626+import numpy as np
2727+from numpy.typing import NDArray
2828+2929+import atdata
3030+from atdata.promote import promote_to_atmosphere
3131+3232+3333+# =============================================================================
3434+# Define sample types
3535+# =============================================================================
3636+3737+@atdata.packable
3838+class ExperimentSample:
3939+ """A sample from a scientific experiment."""
4040+ measurement: NDArray
4141+ timestamp: float
4242+ sensor_id: str
4343+4444+4545+# =============================================================================
4646+# Demo functions
4747+# =============================================================================
4848+4949+def demo_promotion_concept():
5050+ """Explain the promotion workflow concept."""
5151+ print("\n" + "=" * 60)
5252+ print("Promotion Workflow Overview")
5353+ print("=" * 60)
5454+5555+ print("""
5656+The promotion workflow moves datasets from local storage to the atmosphere:
5757+5858+ LOCAL ATMOSPHERE
5959+ ----- ----------
6060+ Redis Index ATProto PDS
6161+ S3 Storage --> (same S3 or new location)
6262+ local://schemas/... at://did:plc:.../schema/...
6363+6464+Steps:
6565+1. Retrieve dataset entry from LocalIndex
6666+2. Get schema from local index
6767+3. Find or publish schema on atmosphere (deduplication)
6868+4. Optionally copy data to new storage location
6969+5. Create dataset record on atmosphere
7070+6. Return AT URI for the published dataset
7171+7272+Key features:
7373+- Schema deduplication: Won't republish identical schemas
7474+- Flexible data handling: Keep existing URLs or copy to new storage
7575+- Metadata preservation: Local metadata carries over to atmosphere
7676+""")
7777+7878+7979+def demo_mock_promotion():
8080+ """Demonstrate promotion with mocked services."""
8181+ print("\n" + "=" * 60)
8282+ print("Mock Promotion Demo")
8383+ print("=" * 60)
8484+8585+ from atdata.local import LocalDatasetEntry
8686+8787+ # Create a mock local entry
8888+ local_entry = LocalDatasetEntry(
8989+ _name="experiment-2024-001",
9090+ _schema_ref="local://schemas/__main__.ExperimentSample@1.0.0",
9191+ _data_urls=[
9292+ "s3://research-bucket/experiments/exp-2024-001/shard-000000.tar",
9393+ "s3://research-bucket/experiments/exp-2024-001/shard-000001.tar",
9494+ ],
9595+ _metadata={
9696+ "experiment_date": "2024-01-15",
9797+ "lab": "Physics Building Room 302",
9898+ "principal_investigator": "Dr. Smith",
9999+ },
100100+ )
101101+102102+ print(f"\nLocal entry to promote:")
103103+ print(f" Name: {local_entry.name}")
104104+ print(f" Schema: {local_entry.schema_ref}")
105105+ print(f" URLs: {len(local_entry.data_urls)} shards")
106106+ print(f" Metadata: {local_entry.metadata}")
107107+108108+ # Create mock local index
109109+ mock_index = Mock()
110110+ mock_index.get_schema.return_value = {
111111+ "name": "__main__.ExperimentSample",
112112+ "version": "1.0.0",
113113+ "description": "A sample from a scientific experiment",
114114+ "fields": [
115115+ {"name": "measurement", "fieldType": {"$type": "local#ndarray", "dtype": "float32"}, "optional": False},
116116+ {"name": "timestamp", "fieldType": {"$type": "local#primitive", "primitive": "float"}, "optional": False},
117117+ {"name": "sensor_id", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False},
118118+ ],
119119+ }
120120+121121+ # Create mock atmosphere client
122122+ mock_client = Mock()
123123+ mock_client.did = "did:plc:demo123456789"
124124+125125+ # Mock the atmosphere modules
126126+ from unittest.mock import patch
127127+128128+ with patch("atdata.promote._find_existing_schema") as mock_find:
129129+ mock_find.return_value = None # No existing schema
130130+131131+ with patch("atdata.atmosphere.SchemaPublisher") as MockSchemaPublisher:
132132+ mock_schema_pub = MockSchemaPublisher.return_value
133133+ mock_schema_uri = Mock(__str__=lambda s: "at://did:plc:demo123456789/ac.foundation.dataset.sampleSchema/exp001")
134134+ mock_schema_pub.publish.return_value = mock_schema_uri
135135+136136+ with patch("atdata.atmosphere.DatasetPublisher") as MockDatasetPublisher:
137137+ mock_ds_pub = MockDatasetPublisher.return_value
138138+ mock_ds_uri = Mock(__str__=lambda s: "at://did:plc:demo123456789/ac.foundation.dataset.datasetIndex/exp2024001")
139139+ mock_ds_pub.publish_with_urls.return_value = mock_ds_uri
140140+141141+ # Perform the promotion
142142+ result = promote_to_atmosphere(
143143+ local_entry,
144144+ mock_index,
145145+ mock_client,
146146+ tags=["experiment", "physics", "2024"],
147147+ license="CC-BY-4.0",
148148+ )
149149+150150+ print(f"\nPromotion result:")
151151+ print(f" AT URI: {result}")
152152+ print(f"\nPublished:")
153153+ print(f" Schema: at://did:plc:demo123456789/.../exp001")
154154+ print(f" Dataset: at://did:plc:demo123456789/.../exp2024001")
155155+156156+157157+def demo_schema_deduplication():
158158+ """Demonstrate schema deduplication during promotion."""
159159+ print("\n" + "=" * 60)
160160+ print("Schema Deduplication Demo")
161161+ print("=" * 60)
162162+163163+ from atdata.promote import _find_existing_schema
164164+ from unittest.mock import patch
165165+166166+ mock_client = Mock()
167167+168168+ # Scenario 1: Schema already exists
169169+ print("\nScenario 1: Schema already exists on atmosphere")
170170+ with patch("atdata.atmosphere.SchemaLoader") as MockLoader:
171171+ mock_loader = MockLoader.return_value
172172+ mock_loader.list_all.return_value = [
173173+ {
174174+ "uri": "at://did:plc:abc/schema/existing",
175175+ "value": {
176176+ "name": "mymodule.MySample",
177177+ "version": "1.0.0",
178178+ }
179179+ }
180180+ ]
181181+182182+ result = _find_existing_schema(mock_client, "mymodule.MySample", "1.0.0")
183183+ print(f" Looking for: mymodule.MySample@1.0.0")
184184+ print(f" Found: {result}")
185185+ print(f" Action: Reuse existing schema (no republish)")
186186+187187+ # Scenario 2: Different version
188188+ print("\nScenario 2: Same name but different version")
189189+ with patch("atdata.atmosphere.SchemaLoader") as MockLoader:
190190+ mock_loader = MockLoader.return_value
191191+ mock_loader.list_all.return_value = [
192192+ {
193193+ "uri": "at://did:plc:abc/schema/v1",
194194+ "value": {
195195+ "name": "mymodule.MySample",
196196+ "version": "1.0.0", # v1.0.0 exists
197197+ }
198198+ }
199199+ ]
200200+201201+ result = _find_existing_schema(mock_client, "mymodule.MySample", "2.0.0") # Looking for v2.0.0
202202+ print(f" Looking for: mymodule.MySample@2.0.0")
203203+ print(f" Found: {result}")
204204+ print(f" Action: Publish new schema record")
205205+206206+207207+def demo_data_migration_options():
208208+ """Explain data migration options during promotion."""
209209+ print("\n" + "=" * 60)
210210+ print("Data Migration Options")
211211+ print("=" * 60)
212212+213213+ print("""
214214+When promoting, you can choose how to handle the data files:
215215+216216+Option A: Keep existing URLs (default)
217217+-----------------------------------------
218218+ promote_to_atmosphere(entry, index, client)
219219+220220+ - Data stays in original S3 location
221221+ - Dataset record points to existing URLs
222222+ - Fastest option, no data copying
223223+ - Requires original storage to remain accessible
224224+225225+Option B: Copy to new S3 location
226226+-----------------------------------------
227227+ new_store = S3DataStore(creds, bucket='public-bucket')
228228+ promote_to_atmosphere(entry, index, client, data_store=new_store)
229229+230230+ - Data is copied to new bucket
231231+ - Dataset record points to new URLs
232232+ - Good for moving from private to public storage
233233+234234+Option C: Use ATProto blobs (future)
235235+-----------------------------------------
236236+ # Not yet implemented
237237+ promote_to_atmosphere(entry, index, client, data_store='pds-blobs')
238238+239239+ - Data uploaded as ATProto blobs
240240+ - Self-contained in the PDS
241241+ - Size limits apply (ATProto blob limits)
242242+""")
243243+244244+245245+def demo_live_promotion(handle: str, password: str):
246246+ """Demonstrate actual promotion to ATProto."""
247247+ print("\n" + "=" * 60)
248248+ print("Live Promotion Demo")
249249+ print("=" * 60)
250250+251251+ from atdata.atmosphere import AtmosphereClient
252252+ from atdata.local import LocalDatasetEntry
253253+254254+ # Connect to atmosphere
255255+ print(f"\nConnecting as {handle}...")
256256+ client = AtmosphereClient()
257257+ client.login(handle, password)
258258+ print(f"Authenticated! DID: {client.did}")
259259+260260+ # Create a demo local entry (simulating a real local dataset)
261261+ local_entry = LocalDatasetEntry(
262262+ _name="demo-promoted-dataset",
263263+ _schema_ref="local://schemas/__main__.ExperimentSample@1.0.0",
264264+ _data_urls=["s3://example-bucket/demo-data-{000000..000004}.tar"],
265265+ _metadata={"promoted_from": "local_demo", "demo": True},
266266+ )
267267+268268+ # Create a mock local index with our schema
269269+ mock_index = Mock()
270270+ mock_index.get_schema.return_value = {
271271+ "name": "__main__.ExperimentSample",
272272+ "version": "1.0.0",
273273+ "fields": [
274274+ {"name": "measurement", "fieldType": {"$type": "local#ndarray", "dtype": "float32"}, "optional": False},
275275+ {"name": "timestamp", "fieldType": {"$type": "local#primitive", "primitive": "float"}, "optional": False},
276276+ {"name": "sensor_id", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False},
277277+ ],
278278+ }
279279+280280+ print("\nPromoting dataset to atmosphere...")
281281+ result = promote_to_atmosphere(
282282+ local_entry,
283283+ mock_index,
284284+ client,
285285+ tags=["demo", "atdata"],
286286+ license="MIT",
287287+ )
288288+289289+ print(f"\nPromotion successful!")
290290+ print(f" AT URI: {result}")
291291+ print(f"\nYou can now discover this dataset via:")
292292+ print(f" atdata.load_dataset('@{handle}/demo-promoted-dataset')")
293293+294294+295295+def demo_full_workflow():
296296+ """Show the complete local-to-atmosphere workflow."""
297297+ print("\n" + "=" * 60)
298298+ print("Complete Workflow Example")
299299+ print("=" * 60)
300300+301301+ print("""
302302+Here's a complete example of the local-to-atmosphere workflow:
303303+304304+ import atdata
305305+ from atdata.local import LocalIndex, Repo
306306+ from atdata.atmosphere import AtmosphereClient
307307+ from atdata.promote import promote_to_atmosphere
308308+309309+ # 1. Define your sample type
310310+ @atdata.packable
311311+ class MySample:
312312+ features: NDArray
313313+ label: str
314314+315315+ # 2. Create and index local dataset
316316+ local_index = LocalIndex() # Connects to Redis
317317+ repo = Repo(s3_creds, bucket='my-bucket', index=local_index)
318318+319319+ # Insert dataset (writes to S3, indexes in Redis)
320320+ samples = [MySample(features=..., label=...) for ...]
321321+ entry = repo.insert(samples, name='my-dataset')
322322+323323+ print(f"Local CID: {entry.cid}")
324324+ print(f"Local URLs: {entry.data_urls}")
325325+326326+ # 3. When ready to share, promote to atmosphere
327327+ client = AtmosphereClient()
328328+ client.login('myhandle.bsky.social', 'app-password')
329329+330330+ at_uri = promote_to_atmosphere(
331331+ entry,
332332+ local_index,
333333+ client,
334334+ tags=['ml', 'vision'],
335335+ license='MIT',
336336+ )
337337+338338+ print(f"Published at: {at_uri}")
339339+340340+ # 4. Others can now discover and load your dataset
341341+ # ds = atdata.load_dataset('@myhandle.bsky.social/my-dataset')
342342+""")
343343+344344+345345+# =============================================================================
346346+# Main
347347+# =============================================================================
348348+349349+def main():
350350+ parser = argparse.ArgumentParser(
351351+ description="Demonstrate local to atmosphere promotion workflow",
352352+ formatter_class=argparse.RawDescriptionHelpFormatter,
353353+ epilog=__doc__,
354354+ )
355355+ parser.add_argument(
356356+ "--handle",
357357+ help="Bluesky handle for live demo",
358358+ )
359359+ parser.add_argument(
360360+ "--password",
361361+ help="App-specific password for live demo",
362362+ )
363363+364364+ args = parser.parse_args()
365365+366366+ print("=" * 60)
367367+ print("atdata Promotion Workflow Demo")
368368+ print("=" * 60)
369369+ print(f"\nTime: {datetime.now().isoformat()}")
370370+371371+ # Always run these demos (no external services required)
372372+ demo_promotion_concept()
373373+ demo_mock_promotion()
374374+ demo_schema_deduplication()
375375+ demo_data_migration_options()
376376+ demo_full_workflow()
377377+378378+ # Run live demo if credentials provided
379379+ if args.handle and args.password:
380380+ demo_live_promotion(args.handle, args.password)
381381+ else:
382382+ print("\n" + "=" * 60)
383383+ print("Live Demo Skipped")
384384+ print("=" * 60)
385385+ print("\nTo run with actual ATProto connection:")
386386+ print(" python promote_workflow.py --handle your.handle --password your-app-password")
387387+388388+ print("\n" + "=" * 60)
389389+ print("Demo Complete!")
390390+ print("=" * 60)
391391+392392+393393+if __name__ == "__main__":
394394+ main()
···11+"""CID (Content Identifier) utilities for atdata.
22+33+This module provides utilities for generating ATProto-compatible CIDs from
44+data. CIDs are content-addressable identifiers that can be used to uniquely
55+identify schemas, datasets, and other records.
66+77+The CIDs generated here use:
88+- CIDv1 format
99+- dag-cbor codec (0x71)
1010+- SHA-256 hash (0x12)
1111+1212+This ensures compatibility with ATProto's CID requirements and enables
1313+seamless promotion from local storage to atmosphere (ATProto network).
1414+1515+Example:
1616+ >>> schema = {"name": "ImageSample", "version": "1.0.0", "fields": [...]}
1717+ >>> cid = generate_cid(schema)
1818+ >>> print(cid)
1919+ bafyreihffx5a2e7k6r5zqgp5iwpjqr2gfyheqhzqtlxagvqjqyxzqpzqaa
2020+"""
2121+2222+import hashlib
2323+from typing import Any
2424+2525+import libipld
2626+2727+2828+# CID constants
2929+CID_VERSION_1 = 0x01
3030+CODEC_DAG_CBOR = 0x71
3131+HASH_SHA256 = 0x12
3232+SHA256_SIZE = 0x20
3333+3434+3535+def generate_cid(data: Any) -> str:
3636+ """Generate an ATProto-compatible CID from arbitrary data.
3737+3838+ The data is first encoded as DAG-CBOR, then hashed with SHA-256,
3939+ and finally formatted as a CIDv1 string (base32 multibase).
4040+4141+ Args:
4242+ data: Any data structure that can be encoded as DAG-CBOR.
4343+ This includes dicts, lists, strings, numbers, bytes, etc.
4444+4545+ Returns:
4646+ CIDv1 string in base32 multibase format (starts with 'bafy').
4747+4848+ Raises:
4949+ ValueError: If the data cannot be encoded as DAG-CBOR.
5050+5151+ Example:
5252+ >>> generate_cid({"name": "test", "value": 42})
5353+ 'bafyrei...'
5454+ """
5555+ # Encode data as DAG-CBOR
5656+ try:
5757+ cbor_bytes = libipld.encode_dag_cbor(data)
5858+ except Exception as e:
5959+ raise ValueError(f"Failed to encode data as DAG-CBOR: {e}") from e
6060+6161+ # Hash with SHA-256
6262+ sha256_hash = hashlib.sha256(cbor_bytes).digest()
6363+6464+ # Build raw CID bytes:
6565+ # CIDv1 = version(1) + codec(dag-cbor) + multihash
6666+ # Multihash = code(sha256) + size(32) + digest
6767+ raw_cid_bytes = bytes([CID_VERSION_1, CODEC_DAG_CBOR, HASH_SHA256, SHA256_SIZE]) + sha256_hash
6868+6969+ # Encode to base32 multibase string
7070+ return libipld.encode_cid(raw_cid_bytes)
7171+7272+7373+def generate_cid_from_bytes(data_bytes: bytes) -> str:
7474+ """Generate a CID from raw bytes (already encoded data).
7575+7676+ Use this when you have pre-encoded data (e.g., DAG-CBOR bytes from
7777+ another source) and want to generate its CID without re-encoding.
7878+7979+ Args:
8080+ data_bytes: Raw bytes to hash (treated as opaque blob).
8181+8282+ Returns:
8383+ CIDv1 string in base32 multibase format.
8484+8585+ Example:
8686+ >>> cbor_bytes = libipld.encode_dag_cbor({"key": "value"})
8787+ >>> cid = generate_cid_from_bytes(cbor_bytes)
8888+ """
8989+ sha256_hash = hashlib.sha256(data_bytes).digest()
9090+ raw_cid_bytes = bytes([CID_VERSION_1, CODEC_DAG_CBOR, HASH_SHA256, SHA256_SIZE]) + sha256_hash
9191+ return libipld.encode_cid(raw_cid_bytes)
9292+9393+9494+def verify_cid(cid: str, data: Any) -> bool:
9595+ """Verify that a CID matches the given data.
9696+9797+ Args:
9898+ cid: CID string to verify.
9999+ data: Data that should correspond to the CID.
100100+101101+ Returns:
102102+ True if the CID matches the data, False otherwise.
103103+104104+ Example:
105105+ >>> cid = generate_cid({"name": "test"})
106106+ >>> verify_cid(cid, {"name": "test"})
107107+ True
108108+ >>> verify_cid(cid, {"name": "different"})
109109+ False
110110+ """
111111+ expected_cid = generate_cid(data)
112112+ return cid == expected_cid
113113+114114+115115+def parse_cid(cid: str) -> dict:
116116+ """Parse a CID string into its components.
117117+118118+ Args:
119119+ cid: CID string to parse.
120120+121121+ Returns:
122122+ Dictionary with 'version', 'codec', and 'hash' keys.
123123+ The 'hash' value is itself a dict with 'code', 'size', and 'digest'.
124124+125125+ Example:
126126+ >>> info = parse_cid('bafyrei...')
127127+ >>> info['version']
128128+ 1
129129+ >>> info['codec']
130130+ 113 # 0x71 = dag-cbor
131131+ """
132132+ return libipld.decode_cid(cid)
133133+134134+135135+__all__ = [
136136+ "generate_cid",
137137+ "generate_cid_from_bytes",
138138+ "verify_cid",
139139+ "parse_cid",
140140+]
+192-99
src/atdata/_hf_api.py
···3131import re
3232from pathlib import Path
3333from typing import (
3434+ TYPE_CHECKING,
3435 Any,
3536 Generic,
3637 Iterator,
3738 Mapping,
3939+ Optional,
3840 Type,
3941 TypeVar,
4042 Union,
···4244)
43454446from .dataset import Dataset, PackableSample
4747+4848+if TYPE_CHECKING:
4949+ from ._protocols import AbstractIndex
45504651##
4752# Type variables
···134139135140136141def _is_brace_pattern(path: str) -> bool:
137137- """Check if path contains WebDataset brace expansion notation.
138138-139139- Examples:
140140- >>> _is_brace_pattern("data-{000000..000099}.tar")
141141- True
142142- >>> _is_brace_pattern("data-{train,test}.tar")
143143- True
144144- >>> _is_brace_pattern("data-000000.tar")
145145- False
146146- """
142142+ """Check if path contains WebDataset brace expansion notation like {000..099}."""
147143 return bool(re.search(r"\{[^}]+\}", path))
148144149145150146def _is_glob_pattern(path: str) -> bool:
151151- """Check if path contains glob wildcards.
152152-153153- Examples:
154154- >>> _is_glob_pattern("data-*.tar")
155155- True
156156- >>> _is_glob_pattern("data-000000.tar")
157157- False
158158- """
147147+ """Check if path contains glob wildcards (* or ?)."""
159148 return "*" in path or "?" in path
160149161150162151def _is_remote_url(path: str) -> bool:
163163- """Check if path is a remote URL (s3, http, etc.).
164164-165165- Examples:
166166- >>> _is_remote_url("s3://bucket/path")
167167- True
168168- >>> _is_remote_url("https://example.com/data.tar")
169169- True
170170- >>> _is_remote_url("/local/path/data.tar")
171171- False
172172- """
152152+ """Check if path is a remote URL (s3://, gs://, http://, https://, az://)."""
173153 return path.startswith(("s3://", "gs://", "http://", "https://", "az://"))
174154175155176156def _expand_local_glob(pattern: str) -> list[str]:
177177- """Expand a local glob pattern to list of paths.
178178-179179- Args:
180180- pattern: Glob pattern like "path/to/*.tar"
181181-182182- Returns:
183183- Sorted list of matching file paths.
184184- """
157157+ """Expand local glob pattern to sorted list of matching file paths."""
185158 base_path = Path(pattern).parent
186159 glob_part = Path(pattern).name
187160···192165 return [str(p) for p in matches if p.is_file()]
193166194167195195-# Common split name patterns in filenames
196196-_SPLIT_PATTERNS = [
168168+# Pre-compiled split name patterns (pattern, split_name)
169169+_SPLIT_PATTERNS: list[tuple[re.Pattern[str], str]] = [
197170 # Patterns like "dataset-train-000000.tar" (split in middle with delimiters)
198198- (r"[_-](train|training)[_-]", "train"),
199199- (r"[_-](test|testing)[_-]", "test"),
200200- (r"[_-](val|valid|validation)[_-]", "validation"),
201201- (r"[_-](dev|development)[_-]", "validation"),
171171+ (re.compile(r"[_-](train|training)[_-]"), "train"),
172172+ (re.compile(r"[_-](test|testing)[_-]"), "test"),
173173+ (re.compile(r"[_-](val|valid|validation)[_-]"), "validation"),
174174+ (re.compile(r"[_-](dev|development)[_-]"), "validation"),
202175 # Patterns at start of filename like "train-000.tar" or "test_data.tar"
203203- (r"^(train|training)[_-]", "train"),
204204- (r"^(test|testing)[_-]", "test"),
205205- (r"^(val|valid|validation)[_-]", "validation"),
206206- (r"^(dev|development)[_-]", "validation"),
176176+ (re.compile(r"^(train|training)[_-]"), "train"),
177177+ (re.compile(r"^(test|testing)[_-]"), "test"),
178178+ (re.compile(r"^(val|valid|validation)[_-]"), "validation"),
179179+ (re.compile(r"^(dev|development)[_-]"), "validation"),
207180 # Patterns in directory path like "/path/train/shard-000.tar"
208208- (r"[/\\](train|training)[/\\]", "train"),
209209- (r"[/\\](test|testing)[/\\]", "test"),
210210- (r"[/\\](val|valid|validation)[/\\]", "validation"),
211211- (r"[/\\](dev|development)[/\\]", "validation"),
181181+ (re.compile(r"[/\\](train|training)[/\\]"), "train"),
182182+ (re.compile(r"[/\\](test|testing)[/\\]"), "test"),
183183+ (re.compile(r"[/\\](val|valid|validation)[/\\]"), "validation"),
184184+ (re.compile(r"[/\\](dev|development)[/\\]"), "validation"),
212185 # Patterns at start of path like "train/shard-000.tar"
213213- (r"^(train|training)[/\\]", "train"),
214214- (r"^(test|testing)[/\\]", "test"),
215215- (r"^(val|valid|validation)[/\\]", "validation"),
216216- (r"^(dev|development)[/\\]", "validation"),
186186+ (re.compile(r"^(train|training)[/\\]"), "train"),
187187+ (re.compile(r"^(test|testing)[/\\]"), "test"),
188188+ (re.compile(r"^(val|valid|validation)[/\\]"), "validation"),
189189+ (re.compile(r"^(dev|development)[/\\]"), "validation"),
217190]
218191219192220193def _detect_split_from_path(path: str) -> str | None:
221221- """Attempt to detect split name from a file path.
222222-223223- Args:
224224- path: File path to analyze.
225225-226226- Returns:
227227- Detected split name ("train", "test", "validation") or None.
228228- """
229229- # Extract just the filename for pattern matching on full paths
194194+ """Detect split name (train/test/validation) from file path."""
230195 filename = Path(path).name
231196 path_lower = path.lower()
232197 filename_lower = filename.lower()
233198234199 # Check filename first (more specific)
235200 for pattern, split_name in _SPLIT_PATTERNS:
236236- if re.search(pattern, filename_lower):
201201+ if pattern.search(filename_lower):
237202 return split_name
238203239239- # Fall back to full path (catches directory patterns like "train/...")
204204+ # Fall back to full path (catches directory patterns)
240205 for pattern, split_name in _SPLIT_PATTERNS:
241241- if re.search(pattern, path_lower):
206206+ if pattern.search(path_lower):
242207 return split_name
243208244209 return None
···356321 >>> _shards_to_wds_url(["train.tar"])
357322 "train.tar"
358323 """
324324+ import os.path
325325+359326 if len(shards) == 0:
360327 raise ValueError("Cannot create URL from empty shard list")
361328362329 if len(shards) == 1:
363330 return shards[0]
364331365365- # Find common prefix across ALL shards
366366- prefix = shards[0]
367367- for s in shards[1:]:
368368- # Shorten prefix until it matches
369369- while not s.startswith(prefix) and prefix:
370370- prefix = prefix[:-1]
332332+ # Find common prefix using os.path.commonprefix (O(n) vs O(n²))
333333+ prefix = os.path.commonprefix(shards)
371334372372- # Find common suffix across ALL shards
373373- suffix = shards[0]
374374- for s in shards[1:]:
375375- # Shorten suffix until it matches
376376- while not s.endswith(suffix) and suffix:
377377- suffix = suffix[1:]
335335+ # Find common suffix by reversing strings
336336+ reversed_shards = [s[::-1] for s in shards]
337337+ suffix = os.path.commonprefix(reversed_shards)[::-1]
378338379339 prefix_len = len(prefix)
380340 suffix_len = len(suffix)
···427387428388429389##
390390+# Index-based path resolution
391391+392392+393393+def _is_indexed_path(path: str) -> bool:
394394+ """Check if path uses @handle/dataset notation for index lookup.
395395+396396+ Examples:
397397+ >>> _is_indexed_path("@maxine.science/mnist")
398398+ True
399399+ >>> _is_indexed_path("@did:plc:abc123/my-dataset")
400400+ True
401401+ >>> _is_indexed_path("s3://bucket/data.tar")
402402+ False
403403+ """
404404+ return path.startswith("@")
405405+406406+407407+def _parse_indexed_path(path: str) -> tuple[str, str]:
408408+ """Parse @handle/dataset path into (handle_or_did, dataset_name).
409409+410410+ Args:
411411+ path: Path in format "@handle/dataset" or "@did:plc:xxx/dataset"
412412+413413+ Returns:
414414+ Tuple of (handle_or_did, dataset_name)
415415+416416+ Raises:
417417+ ValueError: If path format is invalid.
418418+ """
419419+ if not path.startswith("@"):
420420+ raise ValueError(f"Not an indexed path: {path}")
421421+422422+ # Remove leading @
423423+ rest = path[1:]
424424+425425+ # Split on first / (handle can contain . but dataset name is after /)
426426+ if "/" not in rest:
427427+ raise ValueError(
428428+ f"Invalid indexed path format: {path}. "
429429+ "Expected @handle/dataset or @did:plc:xxx/dataset"
430430+ )
431431+432432+ # Find the split point - for DIDs, the format is did:plc:xxx/dataset
433433+ # For handles, it's handle.domain/dataset
434434+ parts = rest.split("/", 1)
435435+ if len(parts) != 2 or not parts[0] or not parts[1]:
436436+ raise ValueError(f"Invalid indexed path: {path}")
437437+438438+ return parts[0], parts[1]
439439+440440+441441+def _resolve_indexed_path(
442442+ path: str,
443443+ index: "AbstractIndex",
444444+) -> tuple[list[str], str]:
445445+ """Resolve @handle/dataset path to URLs and schema_ref via index lookup.
446446+447447+ Args:
448448+ path: Path in @handle/dataset format.
449449+ index: Index to use for lookup.
450450+451451+ Returns:
452452+ Tuple of (data_urls, schema_ref).
453453+454454+ Raises:
455455+ KeyError: If dataset not found in index.
456456+ """
457457+ handle_or_did, dataset_name = _parse_indexed_path(path)
458458+459459+ # For AtmosphereIndex, we need to resolve handle to DID first
460460+ # For LocalIndex, the handle is ignored and we just look up by name
461461+ entry = index.get_dataset(dataset_name)
462462+463463+ return entry.data_urls, entry.schema_ref
464464+465465+466466+##
430467# Main load_dataset function
431468432469···438475 split: str,
439476 data_files: str | list[str] | dict[str, str | list[str]] | None = None,
440477 streaming: bool = False,
478478+ index: Optional["AbstractIndex"] = None,
441479) -> Dataset[ST]: ...
442480443481···449487 split: None = None,
450488 data_files: str | list[str] | dict[str, str | list[str]] | None = None,
451489 streaming: bool = False,
490490+ index: Optional["AbstractIndex"] = None,
452491) -> DatasetDict[ST]: ...
492492+493493+494494+@overload
495495+def load_dataset(
496496+ path: str,
497497+ sample_type: None = None,
498498+ *,
499499+ split: str,
500500+ data_files: str | list[str] | dict[str, str | list[str]] | None = None,
501501+ streaming: bool = False,
502502+ index: "AbstractIndex",
503503+) -> Dataset[PackableSample]: ...
504504+505505+506506+@overload
507507+def load_dataset(
508508+ path: str,
509509+ sample_type: None = None,
510510+ *,
511511+ split: None = None,
512512+ data_files: str | list[str] | dict[str, str | list[str]] | None = None,
513513+ streaming: bool = False,
514514+ index: "AbstractIndex",
515515+) -> DatasetDict[PackableSample]: ...
453516454517455518def load_dataset(
456519 path: str,
457457- sample_type: Type[ST],
520520+ sample_type: Type[ST] | None = None,
458521 *,
459522 split: str | None = None,
460523 data_files: str | list[str] | dict[str, str | list[str]] | None = None,
461524 streaming: bool = False,
525525+ index: Optional["AbstractIndex"] = None,
462526) -> Dataset[ST] | DatasetDict[ST]:
463463- """Load a dataset from local files or remote URLs.
527527+ """Load a dataset from local files, remote URLs, or an index.
464528465529 This function provides a HuggingFace Datasets-style interface for loading
466530 atdata typed datasets. It handles path resolution, split detection, and
···469533470534 Args:
471535 path: Path to dataset. Can be:
536536+ - Index lookup: "@handle/dataset-name" or "@local/dataset-name"
472537 - WebDataset brace notation: "path/to/{train,test}-{000..099}.tar"
473538 - Local directory: "./data/" (scans for .tar files)
474539 - Glob pattern: "path/to/*.tar"
475540 - Remote URL: "s3://bucket/path/data-*.tar"
476541 - Single file: "path/to/data.tar"
477542478478- sample_type: The PackableSample subclass defining the schema for
479479- samples in this dataset. This is required (unlike HF Datasets)
480480- because atdata uses typed dataclasses.
543543+ sample_type: The PackableSample subclass defining the schema. Can be
544544+ None if index is provided - the type will be resolved from the
545545+ schema stored in the index.
481546482547 split: Which split to load. If None, returns a DatasetDict with all
483548 detected splits. If specified (e.g., "train", "test"), returns
···490555491556 streaming: If True, explicitly marks the dataset for streaming mode.
492557 Note: atdata Datasets are already lazy/streaming via WebDataset
493493- pipelines, so this parameter primarily signals intent. When True,
494494- shard list precomputation is skipped. Default False.
558558+ pipelines, so this parameter primarily signals intent.
559559+560560+ index: Optional AbstractIndex for dataset lookup. Required when using
561561+ @handle/dataset syntax or when sample_type is None. Can be a
562562+ LocalIndex or AtmosphereIndex.
495563496564 Returns:
497565 If split is None: DatasetDict[ST] with all detected splits.
498566 If split is specified: Dataset[ST] for that split.
499567500568 Raises:
501501- ValueError: If the specified split is not found.
569569+ ValueError: If the specified split is not found, or if sample_type
570570+ is None without an index.
502571 FileNotFoundError: If no data files are found at the path.
572572+ KeyError: If dataset not found in index.
503573504574 Example:
505505- >>> @atdata.packable
506506- ... class TextData:
507507- ... text: str
508508- ... label: int
575575+ >>> # Load from local path with explicit type
576576+ >>> train_ds = load_dataset("./data/train-*.tar", TextData, split="train")
509577 >>>
510510- >>> # Load single split
511511- >>> train_ds = load_dataset("./data/train-*.tar", TextData, split="train")
578578+ >>> # Load from index with auto-type resolution
579579+ >>> index = LocalIndex()
580580+ >>> ds = load_dataset("@local/my-dataset", index=index, split="train")
512581 >>>
513582 >>> # Load all splits
514583 >>> ds_dict = load_dataset("./data/", TextData)
515584 >>> train_ds = ds_dict["train"]
516516- >>> test_ds = ds_dict["test"]
517517- >>>
518518- >>> # Explicit data files
519519- >>> ds_dict = load_dataset("./data/", TextData, data_files={
520520- ... "train": "train-*.tar",
521521- ... "test": "test-*.tar",
522522- ... })
523585 """
586586+ # Handle @handle/dataset indexed path resolution
587587+ if _is_indexed_path(path):
588588+ if index is None:
589589+ raise ValueError(
590590+ f"Index required for indexed path: {path}. "
591591+ "Pass index=LocalIndex() or index=AtmosphereIndex(client)."
592592+ )
593593+594594+ data_urls, schema_ref = _resolve_indexed_path(path, index)
595595+596596+ # Resolve sample_type from schema if not provided
597597+ if sample_type is None:
598598+ sample_type = index.decode_schema(schema_ref)
599599+600600+ # For indexed datasets, we treat all URLs as a single "train" split
601601+ url = _shards_to_wds_url(data_urls)
602602+ ds = Dataset[sample_type](url)
603603+604604+ if split is not None:
605605+ # Indexed datasets are single-split by default
606606+ return ds
607607+608608+ return DatasetDict({"train": ds}, sample_type=sample_type, streaming=streaming)
609609+610610+ # Validate sample_type for non-indexed paths
611611+ if sample_type is None:
612612+ raise ValueError(
613613+ "sample_type is required for non-indexed paths. "
614614+ "Use @handle/dataset with an index for auto-type resolution."
615615+ )
616616+524617 # Resolve path to split -> shard URL mapping
525618 splits_shards = _resolve_shards(path, data_files)
526619
+319
src/atdata/_protocols.py
···11+"""Protocol definitions for atdata index and storage abstractions.
22+33+This module defines the abstract protocols that enable interchangeable
44+index backends (local Redis vs ATProto PDS) and data stores (S3 vs PDS blobs).
55+66+The key insight is that both local and atmosphere implementations solve the
77+same problem: indexed dataset storage with external data URLs. These protocols
88+formalize that common interface.
99+1010+Note:
1111+ Protocol methods use ``...`` (Ellipsis) as the body per PEP 544. This is
1212+ the standard Python syntax for Protocol definitions - these are interface
1313+ specifications, not stub implementations. Concrete classes (LocalIndex,
1414+ AtmosphereIndex, etc.) provide the actual implementations.
1515+1616+Protocols:
1717+ IndexEntry: Common interface for dataset index entries
1818+ AbstractIndex: Protocol for index operations (schemas, datasets, lenses)
1919+ AbstractDataStore: Protocol for data storage operations
2020+2121+Example:
2222+ >>> def process_datasets(index: AbstractIndex) -> None:
2323+ ... for entry in index.list_datasets():
2424+ ... print(f"{entry.name}: {entry.data_urls}")
2525+ ...
2626+ >>> # Works with either LocalIndex or AtmosphereIndex
2727+ >>> process_datasets(local_index)
2828+ >>> process_datasets(atmosphere_index)
2929+"""
3030+3131+from typing import (
3232+ Iterator,
3333+ Optional,
3434+ Protocol,
3535+ Type,
3636+ TYPE_CHECKING,
3737+ runtime_checkable,
3838+)
3939+4040+if TYPE_CHECKING:
4141+ from .dataset import PackableSample, Dataset
4242+4343+4444+##
4545+# IndexEntry Protocol
4646+4747+4848+@runtime_checkable
4949+class IndexEntry(Protocol):
5050+ """Common interface for index entries (local or atmosphere).
5151+5252+ Both LocalDatasetEntry and atmosphere DatasetRecord-based entries
5353+ should satisfy this protocol, enabling code that works with either.
5454+5555+ Properties:
5656+ name: Human-readable dataset name
5757+ schema_ref: Reference to schema (local:// path or AT URI)
5858+ data_urls: WebDataset URLs for the data
5959+ metadata: Arbitrary metadata dict, or None
6060+ """
6161+6262+ @property
6363+ def name(self) -> str:
6464+ """Human-readable dataset name."""
6565+ ...
6666+6767+ @property
6868+ def schema_ref(self) -> str:
6969+ """Reference to the schema for this dataset.
7070+7171+ For local: 'local://schemas/{module.Class}@{version}'
7272+ For atmosphere: 'at://did:plc:.../ac.foundation.dataset.sampleSchema/...'
7373+ """
7474+ ...
7575+7676+ @property
7777+ def data_urls(self) -> list[str]:
7878+ """WebDataset URLs for the data.
7979+8080+ These are the URLs that can be passed to atdata.Dataset() or
8181+ used with WebDataset directly. May use brace notation for shards.
8282+ """
8383+ ...
8484+8585+ @property
8686+ def metadata(self) -> Optional[dict]:
8787+ """Arbitrary metadata dictionary, or None if not set."""
8888+ ...
8989+9090+9191+##
9292+# AbstractIndex Protocol
9393+9494+9595+class AbstractIndex(Protocol):
9696+ """Protocol for index operations - implemented by LocalIndex and AtmosphereIndex.
9797+9898+ This protocol defines the common interface for managing dataset metadata:
9999+ - Publishing and retrieving schemas
100100+ - Inserting and listing datasets
101101+ - (Future) Publishing and retrieving lenses
102102+103103+ A single index can hold datasets of many different sample types. The sample
104104+ type is tracked via schema references, not as a generic parameter on the index.
105105+106106+ Example:
107107+ >>> def publish_and_list(index: AbstractIndex) -> None:
108108+ ... # Publish schemas for different types
109109+ ... schema1 = index.publish_schema(ImageSample, version="1.0.0")
110110+ ... schema2 = index.publish_schema(TextSample, version="1.0.0")
111111+ ...
112112+ ... # Insert datasets of different types
113113+ ... index.insert_dataset(image_ds, name="images")
114114+ ... index.insert_dataset(text_ds, name="texts")
115115+ ...
116116+ ... # List all datasets (mixed types)
117117+ ... for entry in index.list_datasets():
118118+ ... print(f"{entry.name} -> {entry.schema_ref}")
119119+ """
120120+121121+ # Dataset operations
122122+123123+ def insert_dataset(
124124+ self,
125125+ ds: "Dataset",
126126+ *,
127127+ name: str,
128128+ schema_ref: Optional[str] = None,
129129+ **kwargs,
130130+ ) -> IndexEntry:
131131+ """Insert a dataset into the index.
132132+133133+ The sample type is inferred from ``ds.sample_type``. If schema_ref is not
134134+ provided, the schema may be auto-published based on the sample type.
135135+136136+ Args:
137137+ ds: The Dataset to register in the index (any sample type).
138138+ name: Human-readable name for the dataset.
139139+ schema_ref: Optional explicit schema reference. If not provided,
140140+ the schema may be auto-published or inferred from ds.sample_type.
141141+ **kwargs: Additional backend-specific options.
142142+143143+ Returns:
144144+ IndexEntry for the inserted dataset.
145145+ """
146146+ ...
147147+148148+ def get_dataset(self, ref: str) -> IndexEntry:
149149+ """Get a dataset entry by name or reference.
150150+151151+ Args:
152152+ ref: Dataset name, path, or full reference string.
153153+154154+ Returns:
155155+ IndexEntry for the dataset.
156156+157157+ Raises:
158158+ KeyError: If dataset not found.
159159+ """
160160+ ...
161161+162162+ def list_datasets(self) -> Iterator[IndexEntry]:
163163+ """List all dataset entries in this index.
164164+165165+ Yields:
166166+ IndexEntry for each dataset (may be of different sample types).
167167+ """
168168+ ...
169169+170170+ # Schema operations
171171+172172+ def publish_schema(
173173+ self,
174174+ sample_type: "Type[PackableSample]",
175175+ *,
176176+ version: str = "1.0.0",
177177+ **kwargs,
178178+ ) -> str:
179179+ """Publish a schema for a sample type.
180180+181181+ Args:
182182+ sample_type: The PackableSample subclass to publish.
183183+ version: Semantic version string for the schema.
184184+ **kwargs: Additional backend-specific options.
185185+186186+ Returns:
187187+ Schema reference string:
188188+ - Local: 'local://schemas/{module.Class}@{version}'
189189+ - Atmosphere: 'at://did:plc:.../ac.foundation.dataset.sampleSchema/...'
190190+ """
191191+ ...
192192+193193+ def get_schema(self, ref: str) -> dict:
194194+ """Get a schema record by reference.
195195+196196+ Args:
197197+ ref: Schema reference string (local:// or at://).
198198+199199+ Returns:
200200+ Schema record as a dictionary with fields like 'name', 'version',
201201+ 'fields', etc.
202202+203203+ Raises:
204204+ KeyError: If schema not found.
205205+ """
206206+ ...
207207+208208+ def list_schemas(self) -> Iterator[dict]:
209209+ """List all schema records in this index.
210210+211211+ Yields:
212212+ Schema records as dictionaries.
213213+ """
214214+ ...
215215+216216+ def decode_schema(self, ref: str) -> "Type[PackableSample]":
217217+ """Reconstruct a Python PackableSample type from a stored schema.
218218+219219+ This method enables loading datasets without knowing the sample type
220220+ ahead of time. The index retrieves the schema record and dynamically
221221+ generates a PackableSample subclass matching the schema definition.
222222+223223+ Args:
224224+ ref: Schema reference string (local:// or at://).
225225+226226+ Returns:
227227+ A dynamically generated PackableSample subclass with fields
228228+ matching the schema definition. The class can be used with
229229+ ``Dataset[T]`` to load and iterate over samples.
230230+231231+ Raises:
232232+ KeyError: If schema not found.
233233+ ValueError: If schema cannot be decoded (unsupported field types).
234234+235235+ Example:
236236+ >>> entry = index.get_dataset("my-dataset")
237237+ >>> SampleType = index.decode_schema(entry.schema_ref)
238238+ >>> ds = Dataset[SampleType](entry.data_urls[0])
239239+ >>> for sample in ds.ordered():
240240+ ... print(sample) # sample is instance of SampleType
241241+ """
242242+ ...
243243+244244+245245+##
246246+# AbstractDataStore Protocol
247247+248248+249249+class AbstractDataStore(Protocol):
250250+ """Protocol for data storage operations.
251251+252252+ This protocol abstracts over different storage backends for dataset data:
253253+ - S3DataStore: S3-compatible object storage
254254+ - PDSBlobStore: ATProto PDS blob storage (future)
255255+256256+ The separation of index (metadata) from data store (actual files) allows
257257+ flexible deployment: local index with S3 storage, atmosphere index with
258258+ S3 storage, or atmosphere index with PDS blobs.
259259+260260+ Example:
261261+ >>> store = S3DataStore(credentials, bucket="my-bucket")
262262+ >>> urls = store.write_shards(dataset, prefix="training/v1")
263263+ >>> print(urls)
264264+ ['s3://my-bucket/training/v1/shard-000000.tar', ...]
265265+ """
266266+267267+ def write_shards(
268268+ self,
269269+ ds: "Dataset",
270270+ *,
271271+ prefix: str,
272272+ **kwargs,
273273+ ) -> list[str]:
274274+ """Write dataset shards to storage.
275275+276276+ Args:
277277+ ds: The Dataset to write.
278278+ prefix: Path prefix for the shards (e.g., 'datasets/mnist/v1').
279279+ **kwargs: Backend-specific options (e.g., maxcount for shard size).
280280+281281+ Returns:
282282+ List of URLs for the written shards, suitable for use with
283283+ WebDataset or atdata.Dataset().
284284+ """
285285+ ...
286286+287287+ def read_url(self, url: str) -> str:
288288+ """Resolve a storage URL for reading.
289289+290290+ Some storage backends may need to transform URLs (e.g., signing S3 URLs
291291+ or resolving blob references). This method returns a URL that can be
292292+ used directly with WebDataset.
293293+294294+ Args:
295295+ url: Storage URL to resolve.
296296+297297+ Returns:
298298+ WebDataset-compatible URL for reading.
299299+ """
300300+ ...
301301+302302+ def supports_streaming(self) -> bool:
303303+ """Whether this store supports streaming reads.
304304+305305+ Returns:
306306+ True if the store supports efficient streaming (like S3),
307307+ False if data must be fully downloaded first.
308308+ """
309309+ ...
310310+311311+312312+##
313313+# Module exports
314314+315315+__all__ = [
316316+ "IndexEntry",
317317+ "AbstractIndex",
318318+ "AbstractDataStore",
319319+]
+237
src/atdata/_schema_codec.py
···11+"""Schema codec for dynamic PackableSample type generation.
22+33+This module provides functionality to reconstruct Python PackableSample types
44+from schema records. This enables loading datasets without knowing the sample
55+type ahead of time - the type can be dynamically generated from stored schema
66+metadata.
77+88+The schema format follows the ATProto record structure defined in
99+``atmosphere/_types.py``, with field types supporting primitives, ndarrays,
1010+arrays, and schema references.
1111+1212+Example:
1313+ >>> schema = {
1414+ ... "name": "ImageSample",
1515+ ... "version": "1.0.0",
1616+ ... "fields": [
1717+ ... {"name": "image", "fieldType": {"$type": "...#ndarray", "dtype": "float32"}, "optional": False},
1818+ ... {"name": "label", "fieldType": {"$type": "...#primitive", "primitive": "str"}, "optional": False},
1919+ ... ]
2020+ ... }
2121+ >>> ImageSample = schema_to_type(schema)
2222+ >>> sample = ImageSample(image=np.zeros((64, 64)), label="cat")
2323+"""
2424+2525+from dataclasses import dataclass, field, make_dataclass
2626+from typing import Any, Optional, Type, Union, get_origin
2727+import hashlib
2828+2929+from numpy.typing import NDArray
3030+3131+# Import PackableSample for inheritance
3232+from .dataset import PackableSample
3333+3434+3535+# Type cache to avoid regenerating identical types
3636+_type_cache: dict[str, Type[PackableSample]] = {}
3737+3838+3939+def _schema_cache_key(schema: dict) -> str:
4040+ """Generate a cache key for a schema.
4141+4242+ Uses name + version + field signature to identify unique schemas.
4343+ """
4444+ name = schema.get("name", "Unknown")
4545+ version = schema.get("version", "0.0.0")
4646+ fields = schema.get("fields", [])
4747+4848+ # Create a stable string representation of fields
4949+ field_sig = ";".join(
5050+ f"{f['name']}:{f['fieldType'].get('$type', '')}:{f.get('optional', False)}"
5151+ for f in fields
5252+ )
5353+5454+ # Hash for compactness
5555+ sig_hash = hashlib.md5(field_sig.encode()).hexdigest()[:8]
5656+ return f"{name}@{version}#{sig_hash}"
5757+5858+5959+def _field_type_to_python(field_type: dict, optional: bool = False) -> Any:
6060+ """Convert a schema field type to a Python type annotation.
6161+6262+ Args:
6363+ field_type: Field type dict with '$type' and type-specific fields.
6464+ optional: Whether this field is optional (can be None).
6565+6666+ Returns:
6767+ Python type annotation suitable for dataclass field.
6868+6969+ Raises:
7070+ ValueError: If field type is not supported.
7171+ """
7272+ type_str = field_type.get("$type", "")
7373+7474+ # Extract kind from $type (e.g., "ac.foundation.dataset.schemaType#primitive" -> "primitive")
7575+ if "#" in type_str:
7676+ kind = type_str.split("#")[-1]
7777+ else:
7878+ # Fallback for simplified format
7979+ kind = field_type.get("kind", "")
8080+8181+ python_type: Any
8282+8383+ if kind == "primitive":
8484+ primitive = field_type.get("primitive", "str")
8585+ primitive_map = {
8686+ "str": str,
8787+ "int": int,
8888+ "float": float,
8989+ "bool": bool,
9090+ "bytes": bytes,
9191+ }
9292+ python_type = primitive_map.get(primitive)
9393+ if python_type is None:
9494+ raise ValueError(f"Unknown primitive type: {primitive}")
9595+9696+ elif kind == "ndarray":
9797+ # NDArray type - dtype info is available but we use generic NDArray
9898+ # The dtype is handled at runtime by PackableSample serialization
9999+ python_type = NDArray
100100+101101+ elif kind == "array":
102102+ # List type - recursively resolve item type
103103+ items = field_type.get("items")
104104+ if items:
105105+ item_type = _field_type_to_python(items, optional=False)
106106+ python_type = list[item_type]
107107+ else:
108108+ python_type = list
109109+110110+ elif kind == "ref":
111111+ # Reference to another schema - not yet supported for dynamic generation
112112+ raise ValueError(
113113+ f"Schema references ('ref') are not yet supported for dynamic type generation. "
114114+ f"Referenced schema: {field_type.get('ref')}"
115115+ )
116116+117117+ else:
118118+ raise ValueError(f"Unknown field type kind: {kind}")
119119+120120+ # Wrap in Optional if needed
121121+ if optional:
122122+ python_type = Optional[python_type]
123123+124124+ return python_type
125125+126126+127127+def schema_to_type(
128128+ schema: dict,
129129+ *,
130130+ use_cache: bool = True,
131131+) -> Type[PackableSample]:
132132+ """Generate a PackableSample subclass from a schema record.
133133+134134+ This function dynamically creates a dataclass that inherits from PackableSample,
135135+ with fields matching the schema definition. The generated class can be used
136136+ with ``Dataset[T]`` to load and process samples.
137137+138138+ Args:
139139+ schema: Schema record dict with 'name', 'version', 'fields', etc.
140140+ Fields should have 'name', 'fieldType', and 'optional' keys.
141141+ use_cache: If True, cache and reuse generated types for identical schemas.
142142+ Defaults to True.
143143+144144+ Returns:
145145+ A dynamically generated PackableSample subclass.
146146+147147+ Raises:
148148+ ValueError: If schema is malformed or contains unsupported types.
149149+150150+ Example:
151151+ >>> schema = index.get_schema("local://schemas/MySample@1.0.0")
152152+ >>> MySample = schema_to_type(schema)
153153+ >>> ds = Dataset[MySample]("data.tar")
154154+ >>> for sample in ds.ordered():
155155+ ... print(sample)
156156+ """
157157+ # Check cache first
158158+ if use_cache:
159159+ cache_key = _schema_cache_key(schema)
160160+ if cache_key in _type_cache:
161161+ return _type_cache[cache_key]
162162+163163+ # Extract schema metadata
164164+ name = schema.get("name")
165165+ if not name:
166166+ raise ValueError("Schema must have a 'name' field")
167167+168168+ version = schema.get("version", "1.0.0")
169169+ fields_data = schema.get("fields", [])
170170+171171+ if not fields_data:
172172+ raise ValueError("Schema must have at least one field")
173173+174174+ # Build field definitions for make_dataclass
175175+ # Format: (name, type) or (name, type, field())
176176+ dataclass_fields: list[tuple[str, Any] | tuple[str, Any, Any]] = []
177177+178178+ for field_def in fields_data:
179179+ field_name = field_def.get("name")
180180+ if not field_name:
181181+ raise ValueError("Each field must have a 'name'")
182182+183183+ field_type_dict = field_def.get("fieldType", {})
184184+ is_optional = field_def.get("optional", False)
185185+186186+ # Convert to Python type
187187+ python_type = _field_type_to_python(field_type_dict, optional=is_optional)
188188+189189+ # Optional fields need a default value of None
190190+ if is_optional:
191191+ dataclass_fields.append((field_name, python_type, field(default=None)))
192192+ else:
193193+ dataclass_fields.append((field_name, python_type))
194194+195195+ # Create the dataclass dynamically
196196+ # We need to make it inherit from PackableSample and call __post_init__
197197+ generated_class = make_dataclass(
198198+ name,
199199+ dataclass_fields,
200200+ bases=(PackableSample,),
201201+ namespace={
202202+ "__post_init__": lambda self: PackableSample.__post_init__(self),
203203+ "__schema_version__": version,
204204+ "__schema_ref__": schema.get("$ref", None), # Store original ref if available
205205+ },
206206+ )
207207+208208+ # Cache the generated type
209209+ if use_cache:
210210+ cache_key = _schema_cache_key(schema)
211211+ _type_cache[cache_key] = generated_class
212212+213213+ return generated_class
214214+215215+216216+def clear_type_cache() -> None:
217217+ """Clear the cached generated types.
218218+219219+ Useful for testing or when schema definitions change.
220220+ """
221221+ _type_cache.clear()
222222+223223+224224+def get_cached_types() -> dict[str, Type[PackableSample]]:
225225+ """Get a copy of the current type cache.
226226+227227+ Returns:
228228+ Dictionary mapping cache keys to generated types.
229229+ """
230230+ return dict(_type_cache)
231231+232232+233233+__all__ = [
234234+ "schema_to_type",
235235+ "clear_type_cache",
236236+ "get_cached_types",
237237+]
+36
src/atdata/_type_utils.py
···11+"""Shared type conversion utilities for schema handling.
22+33+This module provides common type mapping functions used by both local.py
44+and atmosphere/schema.py to avoid code duplication.
55+"""
66+77+from typing import Any
88+99+# Mapping from numpy dtype strings to schema dtype names
1010+NUMPY_DTYPE_MAP = {
1111+ "float16": "float16", "float32": "float32", "float64": "float64",
1212+ "int8": "int8", "int16": "int16", "int32": "int32", "int64": "int64",
1313+ "uint8": "uint8", "uint16": "uint16", "uint32": "uint32", "uint64": "uint64",
1414+ "bool": "bool", "complex64": "complex64", "complex128": "complex128",
1515+}
1616+1717+# Mapping from Python primitive types to schema type names
1818+PRIMITIVE_TYPE_MAP = {
1919+ str: "str", int: "int", float: "float", bool: "bool", bytes: "bytes",
2020+}
2121+2222+2323+def numpy_dtype_to_string(dtype: Any) -> str:
2424+ """Convert a numpy dtype annotation to a schema dtype string.
2525+2626+ Args:
2727+ dtype: A numpy dtype or type annotation containing dtype info.
2828+2929+ Returns:
3030+ Schema dtype string (e.g., "float32", "int64"). Defaults to "float32".
3131+ """
3232+ dtype_str = str(dtype)
3333+ for key, value in NUMPY_DTYPE_MAP.items():
3434+ if key in dtype_str:
3535+ return value
3636+ return "float32"
+216
src/atdata/atmosphere/__init__.py
···3030 pip install atproto
3131"""
32323333+from typing import Iterator, Optional, Type, TYPE_CHECKING
3434+3335from .client import AtmosphereClient
3436from .schema import SchemaPublisher, SchemaLoader
3537from .records import DatasetPublisher, DatasetLoader
···4143 LensRecord,
4244)
43454646+if TYPE_CHECKING:
4747+ from ..dataset import PackableSample, Dataset
4848+4949+5050+class AtmosphereIndexEntry:
5151+ """Entry wrapper for ATProto dataset records implementing IndexEntry protocol.
5252+5353+ Attributes:
5454+ _uri: AT URI of the record.
5555+ _record: Raw record dictionary.
5656+ """
5757+5858+ def __init__(self, uri: str, record: dict):
5959+ self._uri = uri
6060+ self._record = record
6161+6262+ @property
6363+ def name(self) -> str:
6464+ """Human-readable dataset name."""
6565+ return self._record.get("name", "")
6666+6767+ @property
6868+ def schema_ref(self) -> str:
6969+ """AT URI of the schema record."""
7070+ return self._record.get("schemaRef", "")
7171+7272+ @property
7373+ def data_urls(self) -> list[str]:
7474+ """WebDataset URLs from external storage."""
7575+ storage = self._record.get("storage", {})
7676+ storage_type = storage.get("$type", "")
7777+ if "storageExternal" in storage_type:
7878+ return storage.get("urls", [])
7979+ return []
8080+8181+ @property
8282+ def metadata(self) -> Optional[dict]:
8383+ """Metadata from the record, if any."""
8484+ import msgpack
8585+ metadata_bytes = self._record.get("metadata")
8686+ if metadata_bytes is None:
8787+ return None
8888+ return msgpack.unpackb(metadata_bytes, raw=False)
8989+9090+ @property
9191+ def uri(self) -> str:
9292+ """AT URI of this record."""
9393+ return self._uri
9494+9595+9696+class AtmosphereIndex:
9797+ """ATProto index implementing AbstractIndex protocol.
9898+9999+ Wraps SchemaPublisher/Loader and DatasetPublisher/Loader to provide
100100+ a unified interface compatible with LocalIndex.
101101+102102+ Example:
103103+ >>> client = AtmosphereClient()
104104+ >>> client.login("handle.bsky.social", "app-password")
105105+ >>>
106106+ >>> index = AtmosphereIndex(client)
107107+ >>> schema_ref = index.publish_schema(MySample, version="1.0.0")
108108+ >>> entry = index.insert_dataset(dataset, name="my-data")
109109+ """
110110+111111+ def __init__(self, client: AtmosphereClient):
112112+ """Initialize the atmosphere index.
113113+114114+ Args:
115115+ client: Authenticated AtmosphereClient instance.
116116+ """
117117+ self.client = client
118118+ self._schema_publisher = SchemaPublisher(client)
119119+ self._schema_loader = SchemaLoader(client)
120120+ self._dataset_publisher = DatasetPublisher(client)
121121+ self._dataset_loader = DatasetLoader(client)
122122+123123+ # Dataset operations
124124+125125+ def insert_dataset(
126126+ self,
127127+ ds: "Dataset",
128128+ *,
129129+ name: str,
130130+ schema_ref: Optional[str] = None,
131131+ **kwargs,
132132+ ) -> AtmosphereIndexEntry:
133133+ """Insert a dataset into ATProto.
134134+135135+ Args:
136136+ ds: The Dataset to publish.
137137+ name: Human-readable name.
138138+ schema_ref: Optional schema AT URI. If None, auto-publishes schema.
139139+ **kwargs: Additional options (description, tags, license).
140140+141141+ Returns:
142142+ AtmosphereIndexEntry for the inserted dataset.
143143+ """
144144+ uri = self._dataset_publisher.publish(
145145+ ds,
146146+ name=name,
147147+ schema_uri=schema_ref,
148148+ description=kwargs.get("description"),
149149+ tags=kwargs.get("tags"),
150150+ license=kwargs.get("license"),
151151+ auto_publish_schema=(schema_ref is None),
152152+ )
153153+ record = self._dataset_loader.get(uri)
154154+ return AtmosphereIndexEntry(str(uri), record)
155155+156156+ def get_dataset(self, ref: str) -> AtmosphereIndexEntry:
157157+ """Get a dataset by AT URI.
158158+159159+ Args:
160160+ ref: AT URI of the dataset record.
161161+162162+ Returns:
163163+ AtmosphereIndexEntry for the dataset.
164164+165165+ Raises:
166166+ ValueError: If record is not a dataset.
167167+ """
168168+ record = self._dataset_loader.get(ref)
169169+ return AtmosphereIndexEntry(ref, record)
170170+171171+ def list_datasets(self, repo: Optional[str] = None) -> Iterator[AtmosphereIndexEntry]:
172172+ """List dataset entries from a repository.
173173+174174+ Args:
175175+ repo: DID of repository. Defaults to authenticated user.
176176+177177+ Yields:
178178+ AtmosphereIndexEntry for each dataset.
179179+ """
180180+ records = self._dataset_loader.list_all(repo=repo)
181181+ for rec in records:
182182+ uri = rec.get("uri", "")
183183+ yield AtmosphereIndexEntry(uri, rec.get("value", rec))
184184+185185+ # Schema operations
186186+187187+ def publish_schema(
188188+ self,
189189+ sample_type: "Type[PackableSample]",
190190+ *,
191191+ version: str = "1.0.0",
192192+ **kwargs,
193193+ ) -> str:
194194+ """Publish a schema to ATProto.
195195+196196+ Args:
197197+ sample_type: The PackableSample subclass to publish.
198198+ version: Semantic version string.
199199+ **kwargs: Additional options (description, metadata).
200200+201201+ Returns:
202202+ AT URI of the schema record.
203203+ """
204204+ uri = self._schema_publisher.publish(
205205+ sample_type,
206206+ version=version,
207207+ description=kwargs.get("description"),
208208+ metadata=kwargs.get("metadata"),
209209+ )
210210+ return str(uri)
211211+212212+ def get_schema(self, ref: str) -> dict:
213213+ """Get a schema record by AT URI.
214214+215215+ Args:
216216+ ref: AT URI of the schema record.
217217+218218+ Returns:
219219+ Schema record dictionary.
220220+221221+ Raises:
222222+ ValueError: If record is not a schema.
223223+ """
224224+ return self._schema_loader.get(ref)
225225+226226+ def list_schemas(self, repo: Optional[str] = None) -> Iterator[dict]:
227227+ """List schema records from a repository.
228228+229229+ Args:
230230+ repo: DID of repository. Defaults to authenticated user.
231231+232232+ Yields:
233233+ Schema records.
234234+ """
235235+ records = self._schema_loader.list_all(repo=repo)
236236+ for rec in records:
237237+ yield rec.get("value", rec)
238238+239239+ def decode_schema(self, ref: str) -> "Type[PackableSample]":
240240+ """Reconstruct a Python type from a schema record.
241241+242242+ Args:
243243+ ref: AT URI of the schema record.
244244+245245+ Returns:
246246+ Dynamically generated PackableSample subclass.
247247+248248+ Raises:
249249+ ValueError: If schema cannot be decoded.
250250+ """
251251+ from .._schema_codec import schema_to_type
252252+253253+ schema = self.get_schema(ref)
254254+ return schema_to_type(schema)
255255+256256+44257__all__ = [
45258 # Client
46259 "AtmosphereClient",
260260+ # Unified index (AbstractIndex protocol)
261261+ "AtmosphereIndex",
262262+ "AtmosphereIndexEntry",
47263 # Schema operations
48264 "SchemaPublisher",
49265 "SchemaLoader",
+2-24
src/atdata/atmosphere/schema.py
···1717 FieldType,
1818 LEXICON_NAMESPACE,
1919)
2020+from .._type_utils import numpy_dtype_to_string
20212122# Import for type checking only to avoid circular imports
2223from typing import TYPE_CHECKING
···205206206207 def _numpy_dtype_to_string(self, dtype) -> str:
207208 """Convert a numpy dtype annotation to a string."""
208208- dtype_str = str(dtype)
209209- # Handle common numpy dtypes
210210- dtype_map = {
211211- "float16": "float16",
212212- "float32": "float32",
213213- "float64": "float64",
214214- "int8": "int8",
215215- "int16": "int16",
216216- "int32": "int32",
217217- "int64": "int64",
218218- "uint8": "uint8",
219219- "uint16": "uint16",
220220- "uint32": "uint32",
221221- "uint64": "uint64",
222222- "bool": "bool",
223223- "complex64": "complex64",
224224- "complex128": "complex128",
225225- }
226226-227227- for key, value in dtype_map.items():
228228- if key in dtype_str:
229229- return value
230230-231231- return "float32" # Default fallback
209209+ return numpy_dtype_to_string(dtype)
232210233211234212class SchemaLoader:
+11-17
src/atdata/dataset.py
···5454 Sequence,
5555 Iterable,
5656 Callable,
5757- Union,
5858- #
5957 Self,
6058 Generic,
6159 Type,
···187185 continue
188186189187 elif isinstance( var_cur_value, bytes ):
190190- # TODO This does create a constraint that serialized bytes
191191- # in a field that might be an NDArray are always interpreted
192192- # as being the NDArray interpretation
188188+ # Design note: bytes in NDArray-typed fields are always interpreted
189189+ # as serialized arrays. This means raw bytes fields must not be
190190+ # annotated as NDArray.
193191 setattr( self, var_name, eh.bytes_to_array( var_cur_value ) )
194192195193 def __post_init__( self ):
···202200 """Create a sample instance from unpacked msgpack data.
203201204202 Args:
205205- data: A dictionary of unpacked msgpack data with keys matching
206206- the sample's field names.
203203+ data: Dictionary with keys matching the sample's field names.
207204208205 Returns:
209209- A new instance of this sample class with fields populated from
210210- the data dictionary and NDArray fields auto-converted from bytes.
206206+ New instance with NDArray fields auto-converted from bytes.
211207 """
212212- ret = cls( **data )
213213- ret._ensure_good()
214214- return ret
208208+ return cls( **data )
215209216210 @classmethod
217211 def from_bytes( cls, bs: bytes ) -> Self:
···253247254248 return ret
255249256256- # TODO Expand to allow for specifying explicit __key__
257250 @property
258251 def as_wds( self ) -> WDSRawSample:
259252 """Pack this sample's data for writing to WebDataset.
···263256 ``msgpack`` (packed sample data) fields suitable for WebDataset.
264257265258 Note:
266266- TODO: Expand to allow specifying explicit ``__key__`` values.
259259+ Keys are auto-generated as UUID v1 for time-sortable ordering.
260260+ Custom key specification is not currently supported.
267261 """
268262 return {
269263 # Generates a UUID that is timelike-sortable
···575569 wds.filters.map( self.wrap_batch ),
576570 )
577571578578- # TODO Rewrite to eliminate `pandas` dependency directly calling
579579- # `fastparquet`
572572+ # Design note: Uses pandas for parquet export. Could be replaced with
573573+ # direct fastparquet calls to reduce dependencies if needed.
580574 def to_parquet( self, path: Pathlike,
581575 sample_map: Optional[SampleExportMap] = None,
582576 maxcount: Optional[int] = None,
···721715 def __post_init__( self ):
722716 return PackableSample.__post_init__( self )
723717724724- # TODO This doesn't properly carry over the original
718718+ # Restore original class identity for better repr/debugging
725719 as_packable.__name__ = class_name
726720 as_packable.__annotations__ = class_annotations
727721
+715-163
src/atdata/local.py
···6677The main classes are:
88- Repo: Manages dataset storage in S3 with Redis indexing
99-- Index: Redis-backed index for tracking dataset metadata
1010-- BasicIndexEntry: Index entry representing a stored dataset
99+- LocalIndex: Redis-backed index for tracking dataset metadata
1010+- LocalDatasetEntry: Index entry representing a stored dataset
11111212This is intended for development and small-scale deployment before
1313-migrating to the full atproto PDS infrastructure.
1313+migrating to the full atproto PDS infrastructure. The implementation
1414+uses ATProto-compatible CIDs for content addressing, enabling seamless
1515+promotion from local storage to the atmosphere (ATProto network).
1416"""
15171618##
···2022 PackableSample,
2123 Dataset,
2224)
2525+from atdata._cid import generate_cid
2626+from atdata._protocols import IndexEntry
2727+from atdata._type_utils import numpy_dtype_to_string, PRIMITIVE_TYPE_MAP
23282424-import os
2529from pathlib import Path
2630from uuid import uuid4
2731from tempfile import TemporaryDirectory
···4448from typing import (
4549 Any,
4650 Optional,
4747- Dict,
4851 Type,
4952 TypeVar,
5053 Generator,
5454+ Iterator,
5155 BinaryIO,
5656+ Union,
5257 cast,
5858+ get_type_hints,
5959+ get_origin,
6060+ get_args,
5361)
6262+import types
6363+from dataclasses import fields, is_dataclass
6464+from datetime import datetime, timezone
6565+import json
54665567T = TypeVar( 'T', bound = PackableSample )
56686969+# Redis key prefixes for index entries and schemas
7070+REDIS_KEY_DATASET_ENTRY = "LocalDatasetEntry"
7171+REDIS_KEY_SCHEMA = "LocalSchema"
7272+57735874##
5975# Helpers
60766177def _kind_str_for_sample_type( st: Type[PackableSample] ) -> str:
6262- """Convert a sample type to a fully-qualified string identifier.
7878+ """Return fully-qualified 'module.name' string for a sample type."""
7979+ return f'{st.__module__}.{st.__name__}'
8080+8181+8282+def _create_s3_write_callbacks(
8383+ credentials: dict[str, Any],
8484+ temp_dir: str,
8585+ written_shards: list[str],
8686+ fs: S3FileSystem | None,
8787+ cache_local: bool,
8888+ add_s3_prefix: bool = False,
8989+) -> tuple:
9090+ """Create opener and post callbacks for ShardWriter with S3 upload.
63916492 Args:
6565- st: The sample type class.
9393+ credentials: S3 credentials dict.
9494+ temp_dir: Temporary directory for local caching.
9595+ written_shards: List to append written shard paths to.
9696+ fs: S3FileSystem for direct writes (used when cache_local=False).
9797+ cache_local: If True, write locally then copy to S3.
9898+ add_s3_prefix: If True, prepend 's3://' to shard paths.
669967100 Returns:
6868- A string in the format 'module.name' identifying the sample type.
101101+ Tuple of (writer_opener, writer_post) callbacks.
69102 """
7070- return f'{st.__module__}.{st.__name__}'
103103+ if cache_local:
104104+ import boto3
711057272-def _decode_bytes_dict( d: dict[bytes, bytes] ) -> dict[str, str]:
7373- """Decode a dictionary with byte keys and values to strings.
106106+ s3_client_kwargs = {
107107+ 'aws_access_key_id': credentials['AWS_ACCESS_KEY_ID'],
108108+ 'aws_secret_access_key': credentials['AWS_SECRET_ACCESS_KEY']
109109+ }
110110+ if 'AWS_ENDPOINT' in credentials:
111111+ s3_client_kwargs['endpoint_url'] = credentials['AWS_ENDPOINT']
112112+ s3_client = boto3.client('s3', **s3_client_kwargs)
741137575- Redis returns dictionaries with bytes keys/values, this converts them to strings.
114114+ def _writer_opener(p: str):
115115+ local_path = Path(temp_dir) / p
116116+ local_path.parent.mkdir(parents=True, exist_ok=True)
117117+ return open(local_path, 'wb')
118118+119119+ def _writer_post(p: str):
120120+ local_path = Path(temp_dir) / p
121121+ path_parts = Path(p).parts
122122+ bucket = path_parts[0]
123123+ key = str(Path(*path_parts[1:]))
124124+125125+ with open(local_path, 'rb') as f_in:
126126+ s3_client.put_object(Bucket=bucket, Key=key, Body=f_in.read())
127127+128128+ local_path.unlink()
129129+ if add_s3_prefix:
130130+ written_shards.append(f"s3://{p}")
131131+ else:
132132+ written_shards.append(p)
133133+134134+ return _writer_opener, _writer_post
135135+ else:
136136+ assert fs is not None, "S3FileSystem required when cache_local=False"
137137+138138+ def _direct_opener(s: str):
139139+ return cast(BinaryIO, fs.open(f's3://{s}', 'wb'))
140140+141141+ def _direct_post(s: str):
142142+ if add_s3_prefix:
143143+ written_shards.append(f"s3://{s}")
144144+ else:
145145+ written_shards.append(s)
146146+147147+ return _direct_opener, _direct_post
148148+149149+##
150150+# Schema helpers
151151+152152+def _schema_ref_from_type(sample_type: Type[PackableSample], version: str = "1.0.0") -> str:
153153+ """Generate 'local://schemas/{module.Class}@{version}' reference."""
154154+ kind_str = _kind_str_for_sample_type(sample_type)
155155+ return f"local://schemas/{kind_str}@{version}"
156156+157157+158158+def _parse_schema_ref(ref: str) -> tuple[str, str]:
159159+ """Parse 'local://schemas/{module.Class}@{version}' into (module.Class, version)."""
160160+ if not ref.startswith("local://schemas/"):
161161+ raise ValueError(f"Invalid local schema reference: {ref}")
162162+163163+ path = ref[len("local://schemas/"):]
164164+ if "@" not in path:
165165+ raise ValueError(f"Schema reference must include version (@version): {ref}")
166166+167167+ kind_str, version = path.rsplit("@", 1)
168168+ return kind_str, version
169169+170170+171171+def _python_type_to_field_type(python_type: Any) -> dict:
172172+ """Convert Python type annotation to schema field type dict."""
173173+ # Handle primitives
174174+ if python_type in PRIMITIVE_TYPE_MAP:
175175+ return {"$type": "local#primitive", "primitive": PRIMITIVE_TYPE_MAP[python_type]}
176176+177177+ # Check for NDArray
178178+ type_str = str(python_type)
179179+ if "NDArray" in type_str or "ndarray" in type_str.lower():
180180+ dtype = "float32" # Default
181181+ args = get_args(python_type)
182182+ if args:
183183+ dtype_arg = args[-1] if args else None
184184+ if dtype_arg is not None:
185185+ dtype = numpy_dtype_to_string(dtype_arg)
186186+ return {"$type": "local#ndarray", "dtype": dtype}
187187+188188+ # Check for list/array types
189189+ origin = get_origin(python_type)
190190+ if origin is list:
191191+ args = get_args(python_type)
192192+ if args:
193193+ items = _python_type_to_field_type(args[0])
194194+ return {"$type": "local#array", "items": items}
195195+ else:
196196+ return {"$type": "local#array", "items": {"$type": "local#primitive", "primitive": "str"}}
197197+198198+ # Check for nested dataclass (not yet supported)
199199+ if is_dataclass(python_type):
200200+ raise TypeError(
201201+ f"Nested dataclass types not yet supported: {python_type.__name__}. "
202202+ "Publish nested types separately and use references."
203203+ )
204204+205205+ raise TypeError(f"Unsupported type for schema field: {python_type}")
206206+207207+208208+def _build_schema_record(
209209+ sample_type: Type[PackableSample],
210210+ *,
211211+ version: str = "1.0.0",
212212+ description: str | None = None,
213213+) -> dict:
214214+ """Build a schema record dict from a PackableSample type.
7621577216 Args:
7878- d: Dictionary with bytes keys and values.
217217+ sample_type: The PackableSample subclass to introspect.
218218+ version: Semantic version string.
219219+ description: Optional human-readable description.
7922080221 Returns:
8181- Dictionary with UTF-8 decoded string keys and values.
222222+ Schema record dict suitable for Redis storage.
223223+224224+ Raises:
225225+ ValueError: If sample_type is not a dataclass.
226226+ TypeError: If a field type is not supported.
82227 """
228228+ if not is_dataclass(sample_type):
229229+ raise ValueError(f"{sample_type.__name__} must be a dataclass (use @packable)")
230230+231231+ field_defs = []
232232+ type_hints = get_type_hints(sample_type)
233233+234234+ for f in fields(sample_type):
235235+ field_type = type_hints.get(f.name, f.type)
236236+237237+ # Check for Optional types (Union with None)
238238+ is_optional = False
239239+ origin = get_origin(field_type)
240240+241241+ if origin is Union or isinstance(field_type, types.UnionType):
242242+ args = get_args(field_type)
243243+ non_none_args = [a for a in args if a is not type(None)]
244244+ if type(None) in args or len(non_none_args) < len(args):
245245+ is_optional = True
246246+ if len(non_none_args) == 1:
247247+ field_type = non_none_args[0]
248248+ elif len(non_none_args) > 1:
249249+ raise TypeError(f"Complex union types not supported: {field_type}")
250250+251251+ field_type_dict = _python_type_to_field_type(field_type)
252252+253253+ field_defs.append({
254254+ "name": f.name,
255255+ "fieldType": field_type_dict,
256256+ "optional": is_optional,
257257+ })
258258+83259 return {
8484- k.decode('utf-8'): v.decode('utf-8')
8585- for k, v in d.items()
260260+ "name": sample_type.__name__,
261261+ "version": version,
262262+ "fields": field_defs,
263263+ "description": description,
264264+ "createdAt": datetime.now(timezone.utc).isoformat(),
86265 }
8726688267···90269# Redis object model
9127092271@dataclass
9393-class BasicIndexEntry:
9494- """Index entry for a dataset stored in the repository.
272272+class LocalDatasetEntry:
273273+ """Index entry for a dataset stored in the local repository.
274274+275275+ Implements the IndexEntry protocol for compatibility with AbstractIndex.
276276+ Uses dual identity: a content-addressable CID (ATProto-compatible) and
277277+ a human-readable name.
952789696- Tracks metadata about a dataset stored in S3, including its location,
9797- type, and unique identifier.
279279+ The CID is generated from the entry's content (schema_ref + data_urls),
280280+ ensuring the same data produces the same CID whether stored locally or
281281+ in the atmosphere. This enables seamless promotion from local to ATProto.
98282 """
99283 ##
100284101101- wds_url: str
102102- """WebDataset URL for the dataset tar files, for use with atdata.Dataset."""
285285+ _name: str
286286+ """Human-readable name for this dataset."""
287287+288288+ _schema_ref: str
289289+ """Reference to the schema for this dataset (local:// path)."""
290290+291291+ _data_urls: list[str]
292292+ """WebDataset URLs for the data."""
293293+294294+ _metadata: dict | None = None
295295+ """Arbitrary metadata dictionary, or None if not set."""
296296+297297+ _cid: str | None = field(default=None, repr=False)
298298+ """Content identifier (ATProto-compatible CID). Generated from content if not provided."""
299299+300300+ # Legacy field for backwards compatibility during migration
301301+ _legacy_uuid: str | None = field(default=None, repr=False)
302302+ """Legacy UUID for backwards compatibility with existing Redis entries."""
303303+304304+ def __post_init__(self):
305305+ """Generate CID from content if not provided."""
306306+ if self._cid is None:
307307+ self._cid = self._generate_cid()
308308+309309+ def _generate_cid(self) -> str:
310310+ """Generate ATProto-compatible CID from entry content."""
311311+ # CID is based on schema_ref and data_urls - the identity of the dataset
312312+ content = {
313313+ "schema_ref": self._schema_ref,
314314+ "data_urls": self._data_urls,
315315+ }
316316+ return generate_cid(content)
103317104104- sample_kind: str
105105- """Fully-qualified sample type name (e.g., 'module.ClassName')."""
318318+ # IndexEntry protocol properties
106319107107- metadata_url: str | None
108108- """S3 URL to the dataset's metadata msgpack file, if any."""
320320+ @property
321321+ def name(self) -> str:
322322+ """Human-readable dataset name."""
323323+ return self._name
109324110110- uuid: str = field( default_factory = lambda: str( uuid4() ) )
111111- """Unique identifier for this dataset entry. Defaults to a new UUID if not provided."""
325325+ @property
326326+ def schema_ref(self) -> str:
327327+ """Reference to the schema for this dataset."""
328328+ return self._schema_ref
112329113113- def write_to( self, redis: Redis ):
330330+ @property
331331+ def data_urls(self) -> list[str]:
332332+ """WebDataset URLs for the data."""
333333+ return self._data_urls
334334+335335+ @property
336336+ def metadata(self) -> dict | None:
337337+ """Arbitrary metadata dictionary, or None if not set."""
338338+ return self._metadata
339339+340340+ # Additional properties
341341+342342+ @property
343343+ def cid(self) -> str:
344344+ """Content identifier (ATProto-compatible CID)."""
345345+ assert self._cid is not None
346346+ return self._cid
347347+348348+ # Legacy compatibility
349349+350350+ @property
351351+ def wds_url(self) -> str:
352352+ """Legacy property: returns first data URL for backwards compatibility."""
353353+ return self._data_urls[0] if self._data_urls else ""
354354+355355+ @property
356356+ def sample_kind(self) -> str:
357357+ """Legacy property: returns schema_ref for backwards compatibility."""
358358+ return self._schema_ref
359359+360360+ def write_to(self, redis: Redis):
114361 """Persist this index entry to Redis.
115362116116- Stores the entry as a Redis hash with key 'BasicIndexEntry:{uuid}'.
363363+ Stores the entry as a Redis hash with key '{REDIS_KEY_DATASET_ENTRY}:{cid}'.
117364118365 Args:
119366 redis: Redis connection to write to.
120367 """
121121- save_key = f'BasicIndexEntry:{self.uuid}'
122122- # Filter out None values - Redis doesn't accept None
123123- data = {k: v for k, v in asdict(self).items() if v is not None}
124124- # redis-py typing uses untyped dict, so type checker complains about dict[str, Any]
125125- redis.hset( save_key, mapping = data ) # type: ignore[arg-type]
368368+ save_key = f'{REDIS_KEY_DATASET_ENTRY}:{self.cid}'
369369+ data = {
370370+ 'name': self._name,
371371+ 'schema_ref': self._schema_ref,
372372+ 'data_urls': msgpack.packb(self._data_urls), # Serialize list
373373+ 'cid': self.cid,
374374+ }
375375+ if self._metadata is not None:
376376+ data['metadata'] = msgpack.packb(self._metadata)
377377+ if self._legacy_uuid is not None:
378378+ data['legacy_uuid'] = self._legacy_uuid
379379+380380+ redis.hset(save_key, mapping=data) # type: ignore[arg-type]
381381+382382+ @classmethod
383383+ def from_redis(cls, redis: Redis, cid: str) -> "LocalDatasetEntry":
384384+ """Load an entry from Redis by CID.
385385+386386+ Args:
387387+ redis: Redis connection to read from.
388388+ cid: Content identifier of the entry to load.
389389+390390+ Returns:
391391+ LocalDatasetEntry loaded from Redis.
392392+393393+ Raises:
394394+ KeyError: If entry not found.
395395+ """
396396+ save_key = f'{REDIS_KEY_DATASET_ENTRY}:{cid}'
397397+ raw_data = redis.hgetall(save_key)
398398+ if not raw_data:
399399+ raise KeyError(f"{REDIS_KEY_DATASET_ENTRY} not found: {cid}")
126400127127-def _s3_env( credentials_path: str | Path ) -> dict[str, Any]:
128128- """Load S3 credentials from a .env file.
401401+ # Decode string fields, keep binary fields as bytes for msgpack
402402+ raw_data_typed = cast(dict[bytes, bytes], raw_data)
403403+ name = raw_data_typed[b'name'].decode('utf-8')
404404+ schema_ref = raw_data_typed[b'schema_ref'].decode('utf-8')
405405+ cid_value = raw_data_typed.get(b'cid', b'').decode('utf-8') or None
406406+ legacy_uuid = raw_data_typed.get(b'legacy_uuid', b'').decode('utf-8') or None
407407+408408+ # Deserialize msgpack fields (stored as raw bytes)
409409+ data_urls = msgpack.unpackb(raw_data_typed[b'data_urls'])
410410+ metadata = None
411411+ if b'metadata' in raw_data_typed:
412412+ metadata = msgpack.unpackb(raw_data_typed[b'metadata'])
413413+414414+ return cls(
415415+ _name=name,
416416+ _schema_ref=schema_ref,
417417+ _data_urls=data_urls,
418418+ _metadata=metadata,
419419+ _cid=cid_value,
420420+ _legacy_uuid=legacy_uuid,
421421+ )
129422130130- Args:
131131- credentials_path: Path to .env file containing S3 credentials.
132423133133- Returns:
134134- Dictionary with AWS_ENDPOINT, AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY.
424424+# Backwards compatibility alias
425425+BasicIndexEntry = LocalDatasetEntry
135426136136- Raises:
137137- AssertionError: If required credentials are missing from the file.
138138- """
139139- ##
427427+def _s3_env( credentials_path: str | Path ) -> dict[str, Any]:
428428+ """Load S3 credentials (AWS_ENDPOINT, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) from .env file."""
140429 credentials_path = Path( credentials_path )
141430 env_values = dotenv_values( credentials_path )
142431 assert 'AWS_ENDPOINT' in env_values
···153442 }
154443155444def _s3_from_credentials( creds: str | Path | dict ) -> S3FileSystem:
156156- """Create an S3FileSystem from credentials.
157157-158158- Args:
159159- creds: Either a path to a .env file with credentials, or a dict
160160- containing AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and optionally
161161- AWS_ENDPOINT.
162162-163163- Returns:
164164- Configured S3FileSystem instance.
165165- """
166166- ##
445445+ """Create S3FileSystem from credentials dict or .env file path."""
167446 if not isinstance( creds, dict ):
168447 creds = _s3_env( creds )
169448···248527249528 ##
250529251251- def insert( self, ds: Dataset[T],
252252- #
530530+ def insert(self,
531531+ ds: Dataset[T],
532532+ *,
533533+ name: str,
253534 cache_local: bool = False,
254254- #
255255- **kwargs
256256- ) -> tuple[BasicIndexEntry, Dataset[T]]:
535535+ schema_ref: str | None = None,
536536+ **kwargs
537537+ ) -> tuple[LocalDatasetEntry, Dataset[T]]:
257538 """Insert a dataset into the repository.
258539259540 Writes the dataset to S3 as WebDataset tar files, stores metadata,
···261542262543 Args:
263544 ds: The dataset to insert.
545545+ name: Human-readable name for the dataset.
264546 cache_local: If True, write to local temporary storage first, then
265547 copy to S3. This can be faster for some workloads.
548548+ schema_ref: Optional schema reference. If None, generates from sample type.
266549 **kwargs: Additional arguments passed to wds.ShardWriter.
267550268551 Returns:
269552 A tuple of (index_entry, new_dataset) where:
270270- - index_entry: BasicIndexEntry for the stored dataset
553553+ - index_entry: LocalDatasetEntry for the stored dataset
271554 - new_dataset: Dataset object pointing to the stored copy
272555273556 Raises:
274274- AssertionError: If S3 credentials or hive_path are not configured.
557557+ ValueError: If S3 credentials or hive_path are not configured.
275558 RuntimeError: If no shards were written.
276559 """
277277-278278- assert self.s3_credentials is not None
279279- assert self.hive_bucket is not None
280280- assert self.hive_path is not None
560560+ if self.s3_credentials is None:
561561+ raise ValueError("S3 credentials required for insert(). Initialize Repo with s3_credentials.")
562562+ if self.hive_bucket is None or self.hive_path is None:
563563+ raise ValueError("hive_path required for insert(). Initialize Repo with hive_path.")
281564282565 new_uuid = str( uuid4() )
283566···305588 / f'atdata--{new_uuid}--%06d.tar'
306589 ).as_posix()
307590591591+ written_shards: list[str] = []
308592 with TemporaryDirectory() as temp_dir:
309309-310310- if cache_local:
311311- # For cache_local, we need to use boto3 directly to avoid s3fs async issues with moto
312312- import boto3
313313-314314- # Create boto3 client from credentials
315315- s3_client_kwargs = {
316316- 'aws_access_key_id': self.s3_credentials['AWS_ACCESS_KEY_ID'],
317317- 'aws_secret_access_key': self.s3_credentials['AWS_SECRET_ACCESS_KEY']
318318- }
319319- if 'AWS_ENDPOINT' in self.s3_credentials:
320320- s3_client_kwargs['endpoint_url'] = self.s3_credentials['AWS_ENDPOINT']
321321- s3_client = boto3.client('s3', **s3_client_kwargs)
322322-323323- def _writer_opener( p: str ):
324324- local_cache_path = Path( temp_dir ) / p
325325- local_cache_path.parent.mkdir( parents = True, exist_ok = True )
326326- return open( local_cache_path, 'wb' )
327327- writer_opener = _writer_opener
328328-329329- def _writer_post( p: str ):
330330- local_cache_path = Path( temp_dir ) / p
331331-332332- # Copy to S3 using boto3 client (avoids s3fs async issues)
333333- path_parts = Path( p ).parts
334334- bucket = path_parts[0]
335335- key = str( Path( *path_parts[1:] ) )
336336-337337- with open( local_cache_path, 'rb' ) as f_in:
338338- s3_client.put_object( Bucket=bucket, Key=key, Body=f_in.read() )
339339-340340- # Delete local cache file
341341- local_cache_path.unlink()
342342-343343- written_shards.append( p )
344344- writer_post = _writer_post
345345-346346- else:
347347- # Use s3:// prefix to ensure s3fs treats paths as S3 paths
348348- writer_opener = lambda s: cast( BinaryIO, hive_fs.open( f's3://{s}', 'wb' ) )
349349- writer_post = lambda s: written_shards.append( s )
593593+ writer_opener, writer_post = _create_s3_write_callbacks(
594594+ credentials=self.s3_credentials,
595595+ temp_dir=temp_dir,
596596+ written_shards=written_shards,
597597+ fs=hive_fs,
598598+ cache_local=cache_local,
599599+ add_s3_prefix=False,
600600+ )
350601351351- written_shards = []
352602 with wds.writer.ShardWriter(
353603 shard_pattern,
354354- opener = writer_opener,
355355- post = writer_post,
604604+ opener=writer_opener,
605605+ post=writer_post,
356606 **kwargs,
357607 ) as sink:
358358- for sample in ds.ordered( batch_size = None ):
359359- sink.write( sample.as_wds )
608608+ for sample in ds.ordered(batch_size=None):
609609+ sink.write(sample.as_wds)
360610361611 # Make a new Dataset object for the written dataset copy
362612 if len( written_shards ) == 0:
···379629 new_dataset_url = shard_s3_format.format( shard_id = shard_id_braced )
380630381631 new_dataset = Dataset[ds.sample_type](
382382- url = new_dataset_url,
383383- metadata_url = metadata_path.as_posix(),
632632+ url=new_dataset_url,
633633+ metadata_url=metadata_path.as_posix(),
384634 )
385635386386- # Add to index
387387- new_entry = self.index.add_entry( new_dataset, uuid = new_uuid )
636636+ # Add to index (use ds._metadata to avoid network requests)
637637+ new_entry = self.index.add_entry(
638638+ new_dataset,
639639+ name=name,
640640+ schema_ref=schema_ref,
641641+ metadata=ds._metadata,
642642+ )
388643389644 return new_entry, new_dataset
390645···392647class Index:
393648 """Redis-backed index for tracking datasets in a repository.
394649395395- Maintains a registry of BasicIndexEntry objects in Redis, allowing
650650+ Maintains a registry of LocalDatasetEntry objects in Redis, allowing
396651 enumeration and lookup of stored datasets.
397652398653 Attributes:
···401656402657 ##
403658404404- def __init__( self,
405405- redis: Redis | None = None,
406406- **kwargs
407407- ) -> None:
659659+ def __init__(self,
660660+ redis: Redis | None = None,
661661+ **kwargs
662662+ ) -> None:
408663 """Initialize an index.
409664410665 Args:
···418673 if redis is not None:
419674 self._redis = redis
420675 else:
421421- self._redis: Redis = Redis( **kwargs )
676676+ self._redis: Redis = Redis(**kwargs)
422677423678 @property
424424- def all_entries( self ) -> list[BasicIndexEntry]:
679679+ def all_entries(self) -> list[LocalDatasetEntry]:
425680 """Get all index entries as a list.
426681427682 Returns:
428428- List of all BasicIndexEntry objects in the index.
683683+ List of all LocalDatasetEntry objects in the index.
429684 """
430430- return list( self.entries )
685685+ return list(self.entries)
431686432687 @property
433433- def entries( self ) -> Generator[BasicIndexEntry, None, None]:
688688+ def entries(self) -> Generator[LocalDatasetEntry, None, None]:
434689 """Iterate over all index entries.
435690436436- Scans Redis for all BasicIndexEntry keys and yields them one at a time.
691691+ Scans Redis for LocalDatasetEntry keys and yields them one at a time.
437692438693 Yields:
439439- BasicIndexEntry objects from the index.
694694+ LocalDatasetEntry objects from the index.
440695 """
441441- ##
442442- for key in self._redis.scan_iter( match = 'BasicIndexEntry:*' ):
443443- # hgetall returns dict[bytes, bytes] which we decode to dict[str, str]
444444- cur_entry_data = _decode_bytes_dict( cast(dict[bytes, bytes], self._redis.hgetall( key )) )
445445-446446- # Provide default None for optional fields that may be missing
447447- # Type checker complains about None in dict[str, str], but BasicIndexEntry accepts it
448448- cur_entry_data: dict[str, Any] = dict( **cur_entry_data )
449449- cur_entry_data.setdefault('metadata_url', None)
450450-451451- cur_entry = BasicIndexEntry( **cur_entry_data )
452452- yield cur_entry
696696+ prefix = f'{REDIS_KEY_DATASET_ENTRY}:'
697697+ for key in self._redis.scan_iter(match=f'{prefix}*'):
698698+ key_str = key.decode('utf-8') if isinstance(key, bytes) else key
699699+ cid = key_str[len(prefix):]
700700+ yield LocalDatasetEntry.from_redis(self._redis, cid)
453701454454- return
455455-456456- def add_entry( self, ds: Dataset,
457457- uuid: str | None = None,
458458- ) -> BasicIndexEntry:
702702+ def add_entry(self,
703703+ ds: Dataset,
704704+ *,
705705+ name: str,
706706+ schema_ref: str | None = None,
707707+ metadata: dict | None = None,
708708+ ) -> LocalDatasetEntry:
459709 """Add a dataset to the index.
460710461461- Creates a BasicIndexEntry for the dataset and persists it to Redis.
711711+ Creates a LocalDatasetEntry for the dataset and persists it to Redis.
462712463713 Args:
464714 ds: The dataset to add to the index.
465465- uuid: Optional UUID for the entry. If None, a new UUID is generated.
715715+ name: Human-readable name for the dataset.
716716+ schema_ref: Optional schema reference. If None, generates from sample type.
717717+ metadata: Optional metadata dictionary. If None, uses ds._metadata if available.
466718467719 Returns:
468468- The created BasicIndexEntry object.
720720+ The created LocalDatasetEntry object.
469721 """
470722 ##
471471- temp_sample_kind = _kind_str_for_sample_type( ds.sample_type )
723723+ if schema_ref is None:
724724+ schema_ref = f"local://schemas/{_kind_str_for_sample_type(ds.sample_type)}@1.0.0"
725725+726726+ # Normalize URL to list
727727+ data_urls = [ds.url]
728728+729729+ # Use provided metadata, or fall back to dataset's cached metadata
730730+ # (avoid triggering network requests via ds.metadata property)
731731+ entry_metadata = metadata if metadata is not None else ds._metadata
732732+733733+ entry = LocalDatasetEntry(
734734+ _name=name,
735735+ _schema_ref=schema_ref,
736736+ _data_urls=data_urls,
737737+ _metadata=entry_metadata,
738738+ )
739739+740740+ entry.write_to(self._redis)
741741+742742+ return entry
743743+744744+ def get_entry(self, cid: str) -> LocalDatasetEntry:
745745+ """Get an entry by its CID.
746746+747747+ Args:
748748+ cid: Content identifier of the entry.
749749+750750+ Returns:
751751+ LocalDatasetEntry for the given CID.
752752+753753+ Raises:
754754+ KeyError: If entry not found.
755755+ """
756756+ return LocalDatasetEntry.from_redis(self._redis, cid)
757757+758758+ def get_entry_by_name(self, name: str) -> LocalDatasetEntry:
759759+ """Get an entry by its human-readable name.
760760+761761+ Args:
762762+ name: Human-readable name of the entry.
763763+764764+ Returns:
765765+ LocalDatasetEntry with the given name.
766766+767767+ Raises:
768768+ KeyError: If no entry with that name exists.
769769+ """
770770+ for entry in self.entries:
771771+ if entry.name == name:
772772+ return entry
773773+ raise KeyError(f"No entry with name: {name}")
774774+775775+ # AbstractIndex protocol methods
776776+777777+ def insert_dataset(
778778+ self,
779779+ ds: Dataset,
780780+ *,
781781+ name: str,
782782+ schema_ref: str | None = None,
783783+ **kwargs,
784784+ ) -> LocalDatasetEntry:
785785+ """Insert a dataset into the index (AbstractIndex protocol).
786786+787787+ Args:
788788+ ds: The Dataset to register.
789789+ name: Human-readable name for the dataset.
790790+ schema_ref: Optional schema reference.
791791+ **kwargs: Additional options (metadata supported).
792792+793793+ Returns:
794794+ IndexEntry for the inserted dataset.
795795+ """
796796+ metadata = kwargs.get('metadata')
797797+ return self.add_entry(ds, name=name, schema_ref=schema_ref, metadata=metadata)
798798+799799+ def get_dataset(self, ref: str) -> LocalDatasetEntry:
800800+ """Get a dataset entry by name (AbstractIndex protocol).
801801+802802+ Args:
803803+ ref: Dataset name.
804804+805805+ Returns:
806806+ IndexEntry for the dataset.
807807+808808+ Raises:
809809+ KeyError: If dataset not found.
810810+ """
811811+ return self.get_entry_by_name(ref)
812812+813813+ def list_datasets(self) -> Iterator[LocalDatasetEntry]:
814814+ """List all dataset entries (AbstractIndex protocol).
815815+816816+ Yields:
817817+ IndexEntry for each dataset.
818818+ """
819819+ return self.entries
820820+821821+ # Schema operations
822822+823823+ def publish_schema(
824824+ self,
825825+ sample_type: Type[PackableSample],
826826+ *,
827827+ version: str = "1.0.0",
828828+ description: str | None = None,
829829+ ) -> str:
830830+ """Publish a schema for a sample type to Redis.
831831+832832+ Args:
833833+ sample_type: The PackableSample subclass to publish.
834834+ version: Semantic version string (e.g., '1.0.0').
835835+ description: Optional human-readable description.
836836+837837+ Returns:
838838+ Schema reference string: 'local://schemas/{module.Class}@{version}'.
839839+840840+ Raises:
841841+ ValueError: If sample_type is not a dataclass.
842842+ TypeError: If a field type is not supported.
843843+ """
844844+ schema_record = _build_schema_record(
845845+ sample_type,
846846+ version=version,
847847+ description=description,
848848+ )
849849+850850+ schema_ref = _schema_ref_from_type(sample_type, version)
851851+ kind_str, _ = _parse_schema_ref(schema_ref)
852852+853853+ # Store in Redis
854854+ redis_key = f"{REDIS_KEY_SCHEMA}:{kind_str}@{version}"
855855+ schema_json = json.dumps(schema_record)
856856+ self._redis.set(redis_key, schema_json)
857857+858858+ return schema_ref
859859+860860+ def get_schema(self, ref: str) -> dict:
861861+ """Get a schema record by reference.
862862+863863+ Args:
864864+ ref: Schema reference string (local://schemas/...).
865865+866866+ Returns:
867867+ Schema record as a dictionary.
868868+869869+ Raises:
870870+ KeyError: If schema not found.
871871+ ValueError: If reference format is invalid.
872872+ """
873873+ kind_str, version = _parse_schema_ref(ref)
874874+ redis_key = f"{REDIS_KEY_SCHEMA}:{kind_str}@{version}"
875875+876876+ schema_json = self._redis.get(redis_key)
877877+ if schema_json is None:
878878+ raise KeyError(f"Schema not found: {ref}")
879879+880880+ if isinstance(schema_json, bytes):
881881+ schema_json = schema_json.decode('utf-8')
882882+883883+ schema = json.loads(schema_json)
884884+ # Add $ref for decode_schema compatibility
885885+ schema['$ref'] = ref
886886+ return schema
887887+888888+ def list_schemas(self) -> Generator[dict, None, None]:
889889+ """List all schema records in this index.
890890+891891+ Yields:
892892+ Schema records as dictionaries.
893893+ """
894894+ prefix = f'{REDIS_KEY_SCHEMA}:'
895895+ for key in self._redis.scan_iter(match=f'{prefix}*'):
896896+ key_str = key.decode('utf-8') if isinstance(key, bytes) else key
897897+ # Extract kind_str@version from key
898898+ schema_id = key_str[len(prefix):]
899899+900900+ schema_json = self._redis.get(key)
901901+ if schema_json is None:
902902+ continue
903903+904904+ if isinstance(schema_json, bytes):
905905+ schema_json = schema_json.decode('utf-8')
906906+907907+ schema = json.loads(schema_json)
908908+ schema['$ref'] = f"local://schemas/{schema_id}"
909909+ yield schema
910910+911911+ def decode_schema(self, ref: str) -> Type[PackableSample]:
912912+ """Reconstruct a Python PackableSample type from a stored schema.
913913+914914+ This method enables loading datasets without knowing the sample type
915915+ ahead of time. The index retrieves the schema record and dynamically
916916+ generates a PackableSample subclass matching the schema definition.
917917+918918+ Args:
919919+ ref: Schema reference string (local://schemas/...).
920920+921921+ Returns:
922922+ A dynamically generated PackableSample subclass.
923923+924924+ Raises:
925925+ KeyError: If schema not found.
926926+ ValueError: If schema cannot be decoded.
927927+ """
928928+ from atdata._schema_codec import schema_to_type
929929+930930+ schema = self.get_schema(ref)
931931+ return schema_to_type(schema)
932932+933933+934934+# Backwards compatibility alias
935935+LocalIndex = Index
936936+937937+938938+class S3DataStore:
939939+ """S3-compatible data store implementing AbstractDataStore protocol.
940940+941941+ Handles writing dataset shards to S3-compatible object storage and
942942+ resolving URLs for reading.
472943473473- if uuid is None:
474474- ret_data = BasicIndexEntry(
475475- wds_url = ds.url,
476476- sample_kind = temp_sample_kind,
477477- metadata_url = ds.metadata_url,
478478- )
944944+ Attributes:
945945+ credentials: S3 credentials dictionary.
946946+ bucket: Target bucket name.
947947+ _fs: S3FileSystem instance.
948948+ """
949949+950950+ def __init__(
951951+ self,
952952+ credentials: str | Path | dict[str, Any],
953953+ *,
954954+ bucket: str,
955955+ ) -> None:
956956+ """Initialize an S3 data store.
957957+958958+ Args:
959959+ credentials: Path to .env file or dict with AWS_ACCESS_KEY_ID,
960960+ AWS_SECRET_ACCESS_KEY, and optionally AWS_ENDPOINT.
961961+ bucket: Name of the S3 bucket for storage.
962962+ """
963963+ if isinstance(credentials, dict):
964964+ self.credentials = credentials
479965 else:
480480- ret_data = BasicIndexEntry(
481481- wds_url = ds.url,
482482- sample_kind = temp_sample_kind,
483483- metadata_url = ds.metadata_url,
484484- uuid = uuid,
966966+ self.credentials = _s3_env(credentials)
967967+968968+ self.bucket = bucket
969969+ self._fs = _s3_from_credentials(self.credentials)
970970+971971+ def write_shards(
972972+ self,
973973+ ds: Dataset,
974974+ *,
975975+ prefix: str,
976976+ cache_local: bool = False,
977977+ **kwargs,
978978+ ) -> list[str]:
979979+ """Write dataset shards to S3.
980980+981981+ Args:
982982+ ds: The Dataset to write.
983983+ prefix: Path prefix within bucket (e.g., 'datasets/mnist/v1').
984984+ cache_local: If True, write locally first then copy to S3.
985985+ **kwargs: Additional args passed to wds.ShardWriter (e.g., maxcount).
986986+987987+ Returns:
988988+ List of S3 URLs for the written shards.
989989+990990+ Raises:
991991+ RuntimeError: If no shards were written.
992992+ """
993993+ new_uuid = str(uuid4())
994994+ shard_pattern = f"{self.bucket}/{prefix}/data--{new_uuid}--%06d.tar"
995995+996996+ written_shards: list[str] = []
997997+998998+ with TemporaryDirectory() as temp_dir:
999999+ writer_opener, writer_post = _create_s3_write_callbacks(
10001000+ credentials=self.credentials,
10011001+ temp_dir=temp_dir,
10021002+ written_shards=written_shards,
10031003+ fs=self._fs,
10041004+ cache_local=cache_local,
10051005+ add_s3_prefix=True,
4851006 )
4861007487487- ret_data.write_to( self._redis )
10081008+ with wds.writer.ShardWriter(
10091009+ shard_pattern,
10101010+ opener=writer_opener,
10111011+ post=writer_post,
10121012+ **kwargs,
10131013+ ) as sink:
10141014+ for sample in ds.ordered(batch_size=None):
10151015+ sink.write(sample.as_wds)
10161016+10171017+ if len(written_shards) == 0:
10181018+ raise RuntimeError("No shards written")
10191019+10201020+ return written_shards
10211021+10221022+ def read_url(self, url: str) -> str:
10231023+ """Resolve an S3 URL for reading.
10241024+10251025+ For S3, URLs are returned as-is (WebDataset handles s3:// directly).
10261026+10271027+ Args:
10281028+ url: S3 URL to resolve.
4881029489489- return ret_data
10301030+ Returns:
10311031+ The URL unchanged.
10321032+ """
10331033+ return url
10341034+10351035+ def supports_streaming(self) -> bool:
10361036+ """S3 supports streaming reads.
10371037+10381038+ Returns:
10391039+ True.
10401040+ """
10411041+ return True
490104249110434921044#
+197
src/atdata/promote.py
···11+"""Promotion workflow for migrating datasets from local to atmosphere.
22+33+This module provides functionality to promote locally-indexed datasets to the
44+ATProto atmosphere network. This enables sharing datasets with the broader
55+federation while maintaining schema consistency.
66+77+Example:
88+ >>> from atdata.local import LocalIndex, Repo
99+ >>> from atdata.atmosphere import AtmosphereClient, AtmosphereIndex
1010+ >>> from atdata.promote import promote_to_atmosphere
1111+ >>>
1212+ >>> # Setup
1313+ >>> local_index = LocalIndex()
1414+ >>> client = AtmosphereClient()
1515+ >>> client.login("handle.bsky.social", "app-password")
1616+ >>>
1717+ >>> # Promote a dataset
1818+ >>> entry = local_index.get_dataset("my-dataset")
1919+ >>> at_uri = promote_to_atmosphere(entry, local_index, client)
2020+"""
2121+2222+from typing import TYPE_CHECKING, Type
2323+2424+if TYPE_CHECKING:
2525+ from .local import LocalDatasetEntry, Index as LocalIndex
2626+ from .atmosphere import AtmosphereClient, AtUri
2727+ from .atmosphere._types import AtUri as AtUriType
2828+ from .dataset import PackableSample
2929+ from ._protocols import AbstractDataStore
3030+3131+3232+def _find_existing_schema(
3333+ client: "AtmosphereClient",
3434+ name: str,
3535+ version: str,
3636+) -> str | None:
3737+ """Check if a schema with the given name and version already exists.
3838+3939+ Args:
4040+ client: Authenticated atmosphere client.
4141+ name: Schema name to search for.
4242+ version: Schema version to match.
4343+4444+ Returns:
4545+ AT URI of existing schema if found, None otherwise.
4646+ """
4747+ from .atmosphere import SchemaLoader
4848+4949+ loader = SchemaLoader(client)
5050+ for record in loader.list_all():
5151+ rec_value = record.get("value", record)
5252+ if rec_value.get("name") == name and rec_value.get("version") == version:
5353+ return record.get("uri", "")
5454+ return None
5555+5656+5757+def _find_or_publish_schema(
5858+ sample_type: "Type[PackableSample]",
5959+ version: str,
6060+ client: "AtmosphereClient",
6161+ description: str | None = None,
6262+) -> str:
6363+ """Find existing schema or publish a new one.
6464+6565+ Checks if a schema with the same name and version already exists on the
6666+ user's atmosphere repository. If found, returns the existing URI to avoid
6767+ duplicates. Otherwise, publishes a new schema record.
6868+6969+ Args:
7070+ sample_type: The PackableSample subclass to publish.
7171+ version: Semantic version string.
7272+ client: Authenticated atmosphere client.
7373+ description: Optional schema description.
7474+7575+ Returns:
7676+ AT URI of the schema (existing or newly published).
7777+ """
7878+ from .atmosphere import SchemaPublisher
7979+8080+ schema_name = f"{sample_type.__module__}.{sample_type.__name__}"
8181+8282+ # Check for existing schema
8383+ existing = _find_existing_schema(client, schema_name, version)
8484+ if existing:
8585+ return existing
8686+8787+ # Publish new schema
8888+ publisher = SchemaPublisher(client)
8989+ uri = publisher.publish(
9090+ sample_type,
9191+ version=version,
9292+ description=description,
9393+ )
9494+ return str(uri)
9595+9696+9797+def promote_to_atmosphere(
9898+ local_entry: "LocalDatasetEntry",
9999+ local_index: "LocalIndex",
100100+ atmosphere_client: "AtmosphereClient",
101101+ *,
102102+ data_store: "AbstractDataStore | None" = None,
103103+ name: str | None = None,
104104+ description: str | None = None,
105105+ tags: list[str] | None = None,
106106+ license: str | None = None,
107107+) -> str:
108108+ """Promote a local dataset to the atmosphere network.
109109+110110+ This function takes a locally-indexed dataset and publishes it to ATProto,
111111+ making it discoverable on the federated atmosphere network.
112112+113113+ Args:
114114+ local_entry: The LocalDatasetEntry to promote.
115115+ local_index: Local index containing the schema for this entry.
116116+ atmosphere_client: Authenticated AtmosphereClient.
117117+ data_store: Optional data store for copying data to new location.
118118+ If None, the existing data_urls are used as-is.
119119+ name: Override name for the atmosphere record. Defaults to local name.
120120+ description: Optional description for the dataset.
121121+ tags: Optional tags for discovery.
122122+ license: Optional license identifier.
123123+124124+ Returns:
125125+ AT URI of the created atmosphere dataset record.
126126+127127+ Raises:
128128+ KeyError: If schema not found in local index.
129129+ ValueError: If local entry has no data URLs.
130130+131131+ Example:
132132+ >>> entry = local_index.get_dataset("mnist-train")
133133+ >>> uri = promote_to_atmosphere(entry, local_index, client)
134134+ >>> print(uri)
135135+ at://did:plc:abc123/ac.foundation.dataset.datasetIndex/...
136136+ """
137137+ from .atmosphere import DatasetPublisher
138138+ from ._schema_codec import schema_to_type
139139+140140+ # Validate entry has data
141141+ if not local_entry.data_urls:
142142+ raise ValueError(f"Local entry '{local_entry.name}' has no data URLs")
143143+144144+ # Get schema from local index
145145+ schema_ref = local_entry.schema_ref
146146+ schema_record = local_index.get_schema(schema_ref)
147147+148148+ # Reconstruct sample type from schema
149149+ sample_type = schema_to_type(schema_record)
150150+ schema_version = schema_record.get("version", "1.0.0")
151151+152152+ # Find or publish schema on atmosphere (deduplication)
153153+ atmosphere_schema_uri = _find_or_publish_schema(
154154+ sample_type,
155155+ schema_version,
156156+ atmosphere_client,
157157+ description=schema_record.get("description"),
158158+ )
159159+160160+ # Determine data URLs
161161+ if data_store is not None:
162162+ # Copy data to new storage location
163163+ # Create a temporary Dataset to write through the data store
164164+ from .dataset import Dataset
165165+166166+ # Build WDS URL from data_urls
167167+ if len(local_entry.data_urls) == 1:
168168+ wds_url = local_entry.data_urls[0]
169169+ else:
170170+ # Use brace notation for multiple URLs
171171+ wds_url = " ".join(local_entry.data_urls)
172172+173173+ ds = Dataset[sample_type](wds_url)
174174+ prefix = f"promoted/{local_entry.name}"
175175+ data_urls = data_store.write_shards(ds, prefix=prefix)
176176+ else:
177177+ # Use existing URLs as-is
178178+ data_urls = local_entry.data_urls
179179+180180+ # Publish dataset record to atmosphere
181181+ publisher = DatasetPublisher(atmosphere_client)
182182+ uri = publisher.publish_with_urls(
183183+ urls=data_urls,
184184+ schema_uri=atmosphere_schema_uri,
185185+ name=name or local_entry.name,
186186+ description=description,
187187+ tags=tags,
188188+ license=license,
189189+ metadata=local_entry.metadata,
190190+ )
191191+192192+ return str(uri)
193193+194194+195195+__all__ = [
196196+ "promote_to_atmosphere",
197197+]
+31
tests/conftest.py
···11"""Pytest configuration for atdata tests."""
22+33+import pytest
44+from redis import Redis
55+66+77+@pytest.fixture
88+def redis_connection():
99+ """Provide a Redis connection, skip test if Redis is not available."""
1010+ try:
1111+ redis = Redis()
1212+ redis.ping()
1313+ yield redis
1414+ except Exception:
1515+ pytest.skip("Redis server not available")
1616+1717+1818+@pytest.fixture
1919+def clean_redis(redis_connection):
2020+ """Provide a Redis connection with automatic cleanup of test keys.
2121+2222+ Clears LocalDatasetEntry, BasicIndexEntry (legacy), and LocalSchema keys
2323+ before and after each test to ensure test isolation.
2424+ """
2525+ def _clear_all():
2626+ for pattern in ('LocalDatasetEntry:*', 'BasicIndexEntry:*', 'LocalSchema:*'):
2727+ for key in redis_connection.scan_iter(match=pattern):
2828+ redis_connection.delete(key)
2929+3030+ _clear_all()
3131+ yield redis_connection
3232+ _clear_all()