A loose federation of distributed, typed datasets
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

Merge branch 'feature/atmosphere-local-harmonization' into dev/maxine-at-forecast

+7139 -520
+86
CHANGELOG.md
··· 11 11 ### Fixed 12 12 13 13 ### Changed 14 + - Adversarial review: Test suite and codebase comprehensive assessment (#181) 15 + - Consolidate test sample type definitions into conftest.py (#184) 16 + - Trim verbose docstrings that restate function signatures (#189) 17 + - Replace assertions with explicit ValueError in Repo.insert (#187) 18 + - Add Redis key prefix constants to eliminate magic strings (#186) 19 + - Convert TODO comments to tracked issues or design notes (#185) 20 + - Add tests for schema_to_type with malformed/edge-case schemas (#183) 21 + - Remove duplicate shard writing logic between Repo.insert and S3DataStore (#182) 22 + - Remove unused Lens import from dataset.py (#188) 23 + - Build comprehensive markdown documentation for atdata (#171) 24 + - Write docs/protocols.md - Abstract protocols reference (#180) 25 + - Write docs/load-dataset.md - HuggingFace-style API (#179) 26 + - Write docs/promotion.md - Local to atmosphere workflow (#178) 27 + - Write docs/atmosphere.md - ATProto publishing and loading (#177) 28 + - Write docs/local-storage.md - LocalIndex, Repo, S3DataStore (#176) 29 + - Write docs/lenses.md - Lens transformations (#175) 30 + - Write docs/datasets.md - Dataset iteration and batching (#174) 31 + - Write docs/packable-samples.md - PackableSample and @packable (#173) 32 + - Write docs/index.md - overview and quick start guide (#172) 33 + - Adversarial review: Post Local-ATProto Reconciliation (#165) 34 + - Add error path tests for Dataset with invalid tar files (#170) 35 + - Convert TODO comment to design note in dataset.py (#169) 36 + - Replace O(n²) string prefix extraction with os.path.commonprefix (#168) 37 + - Remove unused Lens import from dataset.py (#167) 38 + - Extract shared dtype/type conversion to _type_utils.py (#166) 39 + - Local-ATProto Reconciliation Refactor (#111) 40 + - Phase 7: Documentation and Examples (#118) 41 + - Review and update docstrings for new public API (#164) 42 + - Create examples/promote_workflow.py demonstration (#163) 43 + - Create examples/local_workflow.py demonstration (#162) 44 + - Phase 6: Testing (protocols, integration, property tests) (#117) 45 + - Add schema deduplication integration test (#161) 46 + - Add integration test for local to atmosphere round-trip (#160) 47 + - Create test_protocols.py for protocol compliance tests (#159) 48 + - Phase 5: Local to Atmosphere Promotion Workflow (#116) 49 + - Add tests for promote.py (#158) 50 + - Implement schema deduplication helper (#157) 51 + - Create promote.py module with promote_to_atmosphere function (#156) 52 + - Adversarial review: Phase 4 code contraction (#148) 53 + - Simplify _python_type_to_field_type in local.py (#155) 54 + - Clean up unused imports and type: ignore comments (#153) 55 + - Precompile split detection regex patterns (#151) 56 + - Add missing error path tests for invalid msgpack data (#154) 57 + - Remove duplicate S3 write logic between Repo.insert and S3DataStore (#152) 58 + - Remove verbose docstrings that restate function signatures (#150) 59 + - Remove redundant _ensure_good() call in PackableSample.from_data() (#149) 60 + - Phase 4: Integrate with load_dataset() (@handle/dataset resolution) (#115) 61 + - Support auto-type resolution from index schema (#147) 62 + - Add @handle/dataset path resolution (#146) 63 + - Extend load_dataset signature with index parameter (#145) 64 + - Adversarial review: Phase 1 and Phase 2 implementation (#133) 65 + - Phase 3: Implement Concrete Classes (LocalIndex, AtmosphereIndex, S3DataStore) (#114) 66 + - Implement AtmosphereIndex wrapper (#144) 67 + - Implement S3DataStore class (#143) 68 + - Add AbstractIndex protocol methods to Index class (#142) 69 + - Adversarial review: Phase 1 & 2 implementation (#135) 70 + - Add test coverage for _schema_codec utility functions (#141) 71 + - Add missing test for parse_cid with malformed input (#140) 72 + - Merge clean_redis and clean_redis_schemas fixtures (#139) 73 + - Remove redundant return statement in Index.entries (#137) 74 + - DRY: Consolidate Redis deserialization in Index.entries vs LocalDatasetEntry.from_redis (#138) 75 + - Remove unused _decode_bytes_dict function from local.py (#136) 76 + - Adversarial review: Phase 1 & 2 implementation (#134) 77 + - Phase 2: Align Local with ATProto Record Formats (#113) 78 + - Add schema storage to local (LocalSchemaRecord in Redis) (#128) 79 + - Rename BasicIndexEntry to LocalDatasetEntry and implement IndexEntry protocol (#127) 80 + - Rename Index to LocalIndex and implement AbstractIndex protocol (#129) 81 + - Update Repo to use new LocalIndex API (#130) 82 + - Update test_local.py for renamed classes and new API (#131) 83 + - Add libipld dependency for CID generation (#132) 84 + - Phase 2: Refactor local.py to use new protocols (#113) 85 + - Add CID utilities module (_cid.py) with ATProto-compatible CID generation (#132) 86 + - Rename BasicIndexEntry to LocalDatasetEntry with CID + name dual identity (#127) 87 + - Add LocalIndex alias for Index class (#129) 88 + - Update Repo.insert() to require name parameter (#130) 89 + - Update test_local.py for new LocalDatasetEntry API (#131) 90 + - Revise AbstractIndex: Remove single-type generic, add schema decoding (#123) 91 + - Implement dynamic PackableSample class generation from schema (#126) 92 + - Add decode_schema() method to AbstractIndex (#125) 93 + - Remove generic type parameter from AbstractIndex (#124) 94 + - Phase 1: Define Abstract Protocols (_protocols.py) (#112) 95 + - Export protocols from __init__.py (#122) 96 + - Define AbstractDataStore protocol (#121) 97 + - Define AbstractIndex protocol (#120) 98 + - Define IndexEntry protocol (#119) 99 + - Review ATProto vs Local integration architecture convergence (#110) 14 100 - Add HuggingFace Datasets-style API to atdata (#103) 15 101 - Support streaming mode parameter (#108) 16 102 - Add split parameter handling (train/test/validation) (#107)
+374
docs/atmosphere.md
··· 1 + # Atmosphere (ATProto Integration) 2 + 3 + The atmosphere module enables publishing and discovering datasets on the ATProto network, creating a federated ecosystem for typed datasets. 4 + 5 + ## Installation 6 + 7 + ```bash 8 + pip install atdata[atmosphere] 9 + # or 10 + pip install atproto 11 + ``` 12 + 13 + ## Overview 14 + 15 + ATProto integration publishes datasets, schemas, and lenses as records in the `ac.foundation.dataset.*` namespace. This enables: 16 + 17 + - **Discovery** through the ATProto network 18 + - **Federation** across different hosts 19 + - **Verifiability** through content-addressable records 20 + 21 + ## AtmosphereClient 22 + 23 + The client handles authentication and record operations: 24 + 25 + ```python 26 + from atdata.atmosphere import AtmosphereClient 27 + 28 + client = AtmosphereClient() 29 + 30 + # Login with app-specific password (not your main password!) 31 + client.login("alice.bsky.social", "app-password") 32 + 33 + print(client.did) # 'did:plc:...' 34 + print(client.handle) # 'alice.bsky.social' 35 + ``` 36 + 37 + ### Session Management 38 + 39 + Save and restore sessions to avoid re-authentication: 40 + 41 + ```python 42 + # Export session for later 43 + session_string = client.export_session() 44 + 45 + # Later: restore session 46 + new_client = AtmosphereClient() 47 + new_client.login_with_session(session_string) 48 + ``` 49 + 50 + ### Custom PDS 51 + 52 + Connect to a custom PDS instead of bsky.social: 53 + 54 + ```python 55 + client = AtmosphereClient(base_url="https://pds.example.com") 56 + ``` 57 + 58 + ## AtmosphereIndex 59 + 60 + The unified interface for ATProto operations, implementing the AbstractIndex protocol: 61 + 62 + ```python 63 + from atdata.atmosphere import AtmosphereClient, AtmosphereIndex 64 + 65 + client = AtmosphereClient() 66 + client.login("handle.bsky.social", "app-password") 67 + 68 + index = AtmosphereIndex(client) 69 + ``` 70 + 71 + ### Publishing Schemas 72 + 73 + ```python 74 + @atdata.packable 75 + class ImageSample: 76 + image: NDArray 77 + label: str 78 + confidence: float 79 + 80 + # Publish schema 81 + schema_uri = index.publish_schema( 82 + ImageSample, 83 + version="1.0.0", 84 + description="Image classification sample", 85 + ) 86 + # Returns: "at://did:plc:.../ac.foundation.dataset.sampleSchema/..." 87 + ``` 88 + 89 + ### Publishing Datasets 90 + 91 + ```python 92 + dataset = atdata.Dataset[ImageSample]("data-{000000..000009}.tar") 93 + 94 + entry = index.insert_dataset( 95 + dataset, 96 + name="imagenet-subset", 97 + schema_ref=schema_uri, # Optional - auto-publishes if omitted 98 + description="ImageNet subset", 99 + tags=["images", "classification"], 100 + license="MIT", 101 + ) 102 + 103 + print(entry.uri) # AT URI of the record 104 + print(entry.data_urls) # WebDataset URLs 105 + ``` 106 + 107 + ### Listing and Retrieving 108 + 109 + ```python 110 + # List your datasets 111 + for entry in index.list_datasets(): 112 + print(f"{entry.name}: {entry.schema_ref}") 113 + 114 + # List from another user 115 + for entry in index.list_datasets(repo="did:plc:other-user"): 116 + print(entry.name) 117 + 118 + # Get specific dataset 119 + entry = index.get_dataset("at://did:plc:.../ac.foundation.dataset.record/...") 120 + 121 + # List schemas 122 + for schema in index.list_schemas(): 123 + print(f"{schema['name']} v{schema['version']}") 124 + 125 + # Decode schema to Python type 126 + SampleType = index.decode_schema(schema_uri) 127 + ``` 128 + 129 + ## Lower-Level Publishers 130 + 131 + For more control, use the individual publisher classes: 132 + 133 + ### SchemaPublisher 134 + 135 + ```python 136 + from atdata.atmosphere import SchemaPublisher 137 + 138 + publisher = SchemaPublisher(client) 139 + 140 + uri = publisher.publish( 141 + ImageSample, 142 + name="ImageSample", 143 + version="1.0.0", 144 + description="Image with label", 145 + metadata={"source": "training"}, 146 + ) 147 + ``` 148 + 149 + ### DatasetPublisher 150 + 151 + ```python 152 + from atdata.atmosphere import DatasetPublisher 153 + 154 + publisher = DatasetPublisher(client) 155 + 156 + uri = publisher.publish( 157 + dataset, 158 + name="training-images", 159 + schema_uri=schema_uri, # Required if auto_publish_schema=False 160 + auto_publish_schema=True, # Publish schema automatically 161 + description="Training images", 162 + tags=["training", "images"], 163 + license="MIT", 164 + ) 165 + ``` 166 + 167 + ### LensPublisher 168 + 169 + ```python 170 + from atdata.atmosphere import LensPublisher 171 + 172 + publisher = LensPublisher(client) 173 + 174 + # With code references 175 + uri = publisher.publish( 176 + name="simplify", 177 + source_schema=full_schema_uri, 178 + target_schema=simple_schema_uri, 179 + description="Extract label only", 180 + getter_code={ 181 + "repository": "https://github.com/org/repo", 182 + "commit": "abc123def...", 183 + "path": "transforms/simplify.py:simplify_getter", 184 + }, 185 + putter_code={ 186 + "repository": "https://github.com/org/repo", 187 + "commit": "abc123def...", 188 + "path": "transforms/simplify.py:simplify_putter", 189 + }, 190 + ) 191 + 192 + # Or publish from a Lens object 193 + from atdata.lens import lens 194 + 195 + @lens 196 + def simplify(src: FullSample) -> SimpleSample: 197 + return SimpleSample(label=src.label) 198 + 199 + uri = publisher.publish_from_lens( 200 + simplify, 201 + source_schema=full_schema_uri, 202 + target_schema=simple_schema_uri, 203 + ) 204 + ``` 205 + 206 + ## AT URIs 207 + 208 + ATProto records are identified by AT URIs: 209 + 210 + ```python 211 + from atdata.atmosphere import AtUri 212 + 213 + # Parse an AT URI 214 + uri = AtUri.parse("at://did:plc:abc123/ac.foundation.dataset.sampleSchema/xyz") 215 + 216 + print(uri.authority) # 'did:plc:abc123' 217 + print(uri.collection) # 'ac.foundation.dataset.sampleSchema' 218 + print(uri.rkey) # 'xyz' 219 + 220 + # Format back to string 221 + print(str(uri)) # 'at://did:plc:abc123/ac.foundation.dataset.sampleSchema/xyz' 222 + ``` 223 + 224 + ## Record Types 225 + 226 + ### SchemaRecord 227 + 228 + ```python 229 + from atdata.atmosphere import SchemaRecord, FieldDef, FieldType 230 + 231 + schema = SchemaRecord( 232 + name="ImageSample", 233 + version="1.0.0", 234 + fields=[ 235 + FieldDef( 236 + name="image", 237 + field_type=FieldType(kind="ndarray", dtype="float32"), 238 + ), 239 + FieldDef( 240 + name="label", 241 + field_type=FieldType(kind="primitive", primitive="str"), 242 + ), 243 + ], 244 + description="Image with label", 245 + ) 246 + 247 + record_dict = schema.to_record() 248 + ``` 249 + 250 + ### DatasetRecord 251 + 252 + ```python 253 + from atdata.atmosphere import DatasetRecord, StorageLocation 254 + 255 + dataset_record = DatasetRecord( 256 + name="training-images", 257 + schema_ref="at://did:plc:.../...", 258 + storage=StorageLocation( 259 + kind="external", 260 + urls=["s3://bucket/data-{000000..000009}.tar"], 261 + ), 262 + tags=["training"], 263 + license="MIT", 264 + ) 265 + ``` 266 + 267 + ### LensRecord 268 + 269 + ```python 270 + from atdata.atmosphere import LensRecord, CodeReference 271 + 272 + lens_record = LensRecord( 273 + name="simplify", 274 + source_schema="at://did:plc:.../.../source", 275 + target_schema="at://did:plc:.../.../target", 276 + description="Simplify sample", 277 + getter_code=CodeReference( 278 + repository="https://github.com/org/repo", 279 + commit="abc123", 280 + path="transforms.py:simplify", 281 + ), 282 + ) 283 + ``` 284 + 285 + ## Supported Field Types 286 + 287 + Schemas support these field types: 288 + 289 + | Python Type | ATProto Type | 290 + |-------------|--------------| 291 + | `str` | `primitive/str` | 292 + | `int` | `primitive/int` | 293 + | `float` | `primitive/float` | 294 + | `bool` | `primitive/bool` | 295 + | `bytes` | `primitive/bytes` | 296 + | `NDArray` | `ndarray` (default dtype: float32) | 297 + | `NDArray[np.float64]` | `ndarray` (dtype: float64) | 298 + | `list[str]` | `array` with items | 299 + | `T \| None` | Optional field | 300 + 301 + ## Complete Example 302 + 303 + ```python 304 + import numpy as np 305 + from numpy.typing import NDArray 306 + import atdata 307 + from atdata.atmosphere import AtmosphereClient, AtmosphereIndex 308 + import webdataset as wds 309 + 310 + # 1. Define and create samples 311 + @atdata.packable 312 + class FeatureSample: 313 + features: NDArray 314 + label: int 315 + source: str 316 + 317 + samples = [ 318 + FeatureSample( 319 + features=np.random.randn(128).astype(np.float32), 320 + label=i % 10, 321 + source="synthetic", 322 + ) 323 + for i in range(1000) 324 + ] 325 + 326 + # 2. Write to tar 327 + with wds.writer.TarWriter("features.tar") as sink: 328 + for i, s in enumerate(samples): 329 + sink.write({**s.as_wds, "__key__": f"{i:06d}"}) 330 + 331 + # 3. Authenticate 332 + client = AtmosphereClient() 333 + client.login("myhandle.bsky.social", "app-password") 334 + 335 + index = AtmosphereIndex(client) 336 + 337 + # 4. Publish schema 338 + schema_uri = index.publish_schema( 339 + FeatureSample, 340 + version="1.0.0", 341 + description="Feature vectors with labels", 342 + ) 343 + 344 + # 5. Publish dataset 345 + dataset = atdata.Dataset[FeatureSample]("features.tar") 346 + entry = index.insert_dataset( 347 + dataset, 348 + name="synthetic-features-v1", 349 + schema_ref=schema_uri, 350 + tags=["features", "synthetic"], 351 + ) 352 + 353 + print(f"Published: {entry.uri}") 354 + 355 + # 6. Later: discover and load 356 + for dataset_entry in index.list_datasets(): 357 + print(f"Found: {dataset_entry.name}") 358 + 359 + # Reconstruct type from schema 360 + SampleType = index.decode_schema(dataset_entry.schema_ref) 361 + 362 + # Load dataset 363 + ds = atdata.Dataset[SampleType](dataset_entry.data_urls[0]) 364 + for batch in ds.ordered(batch_size=32): 365 + print(batch.features.shape) 366 + break 367 + ``` 368 + 369 + ## Related 370 + 371 + - [Local Storage](local-storage.md) - Redis + S3 backend 372 + - [Promotion](promotion.md) - Promoting local datasets to ATProto 373 + - [Protocols](protocols.md) - AbstractIndex interface 374 + - [Packable Samples](packable-samples.md) - Defining sample types
+193
docs/datasets.md
··· 1 + # Datasets 2 + 3 + The `Dataset` class provides typed iteration over WebDataset tar files with automatic batching and lens transformations. 4 + 5 + ## Creating a Dataset 6 + 7 + ```python 8 + import atdata 9 + 10 + @atdata.packable 11 + class ImageSample: 12 + image: NDArray 13 + label: str 14 + 15 + # Single shard 16 + dataset = atdata.Dataset[ImageSample]("data-000000.tar") 17 + 18 + # Multiple shards with brace notation 19 + dataset = atdata.Dataset[ImageSample]("data-{000000..000009}.tar") 20 + ``` 21 + 22 + The type parameter `[ImageSample]` specifies what sample type the dataset contains. This enables type-safe iteration and automatic deserialization. 23 + 24 + ## Iteration Modes 25 + 26 + ### Ordered Iteration 27 + 28 + Iterate through samples in their original order: 29 + 30 + ```python 31 + # With batching (default batch_size=1) 32 + for batch in dataset.ordered(batch_size=32): 33 + images = batch.image # numpy array (32, H, W, C) 34 + labels = batch.label # list of 32 strings 35 + 36 + # Without batching (raw samples) 37 + for sample in dataset.ordered(batch_size=None): 38 + print(sample.label) 39 + ``` 40 + 41 + ### Shuffled Iteration 42 + 43 + Iterate with randomized order at both shard and sample levels: 44 + 45 + ```python 46 + for batch in dataset.shuffled(batch_size=32): 47 + # Samples are shuffled 48 + process(batch) 49 + 50 + # Control shuffle buffer sizes 51 + for batch in dataset.shuffled( 52 + buffer_shards=100, # Shards to buffer (default: 100) 53 + buffer_samples=10000, # Samples to buffer (default: 10,000) 54 + batch_size=32, 55 + ): 56 + process(batch) 57 + ``` 58 + 59 + Larger buffer sizes increase randomness but use more memory. 60 + 61 + ## SampleBatch 62 + 63 + When iterating with a `batch_size`, each iteration yields a `SampleBatch` with automatic attribute aggregation. 64 + 65 + ```python 66 + @atdata.packable 67 + class Sample: 68 + features: NDArray # shape (256,) 69 + label: str 70 + score: float 71 + 72 + for batch in dataset.ordered(batch_size=16): 73 + # NDArray fields are stacked with a batch dimension 74 + features = batch.features # numpy array (16, 256) 75 + 76 + # Other fields become lists 77 + labels = batch.label # list of 16 strings 78 + scores = batch.score # list of 16 floats 79 + ``` 80 + 81 + Results are cached, so accessing the same attribute multiple times is efficient. 82 + 83 + ## Type Transformations with Lenses 84 + 85 + View a dataset through a different sample type using registered lenses: 86 + 87 + ```python 88 + @atdata.packable 89 + class SimplifiedSample: 90 + label: str 91 + 92 + @atdata.lens 93 + def simplify(src: ImageSample) -> SimplifiedSample: 94 + return SimplifiedSample(label=src.label) 95 + 96 + # Transform dataset to different type 97 + simple_ds = dataset.as_type(SimplifiedSample) 98 + 99 + for batch in simple_ds.ordered(batch_size=16): 100 + print(batch.label) # Only label field available 101 + ``` 102 + 103 + See [Lenses](lenses.md) for details on defining transformations. 104 + 105 + ## Dataset Properties 106 + 107 + ### Shard List 108 + 109 + Get the list of individual tar files: 110 + 111 + ```python 112 + dataset = atdata.Dataset[Sample]("data-{000000..000009}.tar") 113 + shards = dataset.shard_list 114 + # ['data-000000.tar', 'data-000001.tar', ..., 'data-000009.tar'] 115 + ``` 116 + 117 + ### Metadata 118 + 119 + Datasets can have associated metadata from a URL: 120 + 121 + ```python 122 + dataset = atdata.Dataset[Sample]( 123 + "data-{000000..000009}.tar", 124 + metadata_url="https://example.com/metadata.msgpack" 125 + ) 126 + 127 + # Fetched and cached on first access 128 + metadata = dataset.metadata # dict or None 129 + ``` 130 + 131 + ## Writing Datasets 132 + 133 + Use WebDataset's `TarWriter` or `ShardWriter` to create datasets: 134 + 135 + ```python 136 + import webdataset as wds 137 + 138 + samples = [ 139 + ImageSample(image=np.random.rand(224, 224, 3).astype(np.float32), label="cat") 140 + for _ in range(100) 141 + ] 142 + 143 + # Single tar file 144 + with wds.writer.TarWriter("data-000000.tar") as sink: 145 + for i, sample in enumerate(samples): 146 + sink.write({**sample.as_wds, "__key__": f"sample_{i:06d}"}) 147 + 148 + # Multiple shards with automatic splitting 149 + with wds.writer.ShardWriter("data-%06d.tar", maxcount=1000) as sink: 150 + for i, sample in enumerate(samples): 151 + sink.write({**sample.as_wds, "__key__": f"sample_{i:06d}"}) 152 + ``` 153 + 154 + ## Parquet Export 155 + 156 + Export dataset contents to parquet format: 157 + 158 + ```python 159 + # Export entire dataset 160 + dataset.to_parquet("output.parquet") 161 + 162 + # Export with custom field mapping 163 + def extract_fields(sample): 164 + return {"label": sample.label, "score": sample.confidence} 165 + 166 + dataset.to_parquet("output.parquet", sample_map=extract_fields) 167 + 168 + # Export in segments 169 + dataset.to_parquet("output.parquet", maxcount=10000) 170 + # Creates output-000000.parquet, output-000001.parquet, etc. 171 + ``` 172 + 173 + ## URL Formats 174 + 175 + WebDataset supports various URL formats: 176 + 177 + ```python 178 + # Local files 179 + dataset = atdata.Dataset[Sample]("./data/file.tar") 180 + dataset = atdata.Dataset[Sample]("/absolute/path/file-{000000..000009}.tar") 181 + 182 + # S3 (requires s3fs) 183 + dataset = atdata.Dataset[Sample]("s3://bucket/path/file-{000000..000009}.tar") 184 + 185 + # HTTP/HTTPS 186 + dataset = atdata.Dataset[Sample]("https://example.com/data-{000000..000009}.tar") 187 + ``` 188 + 189 + ## Related 190 + 191 + - [Packable Samples](packable-samples.md) - Defining typed samples 192 + - [Lenses](lenses.md) - Type transformations 193 + - [load_dataset](load-dataset.md) - HuggingFace-style loading API
+150
docs/index.md
··· 1 + # atdata 2 + 3 + A loose federation of distributed, typed datasets built on WebDataset. 4 + 5 + ## What is atdata? 6 + 7 + atdata provides a typed dataset abstraction for machine learning workflows with: 8 + 9 + - **Typed samples** with automatic msgpack serialization 10 + - **NDArray handling** with transparent numpy array conversion 11 + - **Lens transformations** for viewing datasets through different schemas 12 + - **Batch aggregation** with automatic numpy stacking 13 + - **WebDataset integration** for efficient large-scale storage 14 + - **ATProto federation** for publishing and discovering datasets 15 + 16 + ## Installation 17 + 18 + ```bash 19 + pip install atdata 20 + 21 + # With ATProto support 22 + pip install atdata[atmosphere] 23 + ``` 24 + 25 + ## Quick Start 26 + 27 + ### Define a Sample Type 28 + 29 + ```python 30 + import numpy as np 31 + from numpy.typing import NDArray 32 + import atdata 33 + 34 + @atdata.packable 35 + class ImageSample: 36 + image: NDArray 37 + label: str 38 + confidence: float 39 + ``` 40 + 41 + ### Create and Write Samples 42 + 43 + ```python 44 + import webdataset as wds 45 + 46 + samples = [ 47 + ImageSample( 48 + image=np.random.rand(224, 224, 3).astype(np.float32), 49 + label="cat", 50 + confidence=0.95, 51 + ) 52 + for _ in range(100) 53 + ] 54 + 55 + with wds.writer.TarWriter("data-000000.tar") as sink: 56 + for i, sample in enumerate(samples): 57 + sink.write({**sample.as_wds, "__key__": f"sample_{i:06d}"}) 58 + ``` 59 + 60 + ### Load and Iterate 61 + 62 + ```python 63 + dataset = atdata.Dataset[ImageSample]("data-000000.tar") 64 + 65 + # Iterate with batching 66 + for batch in dataset.shuffled(batch_size=32): 67 + images = batch.image # numpy array (32, 224, 224, 3) 68 + labels = batch.label # list of 32 strings 69 + confs = batch.confidence # list of 32 floats 70 + ``` 71 + 72 + ### Use Lenses for Type Transformations 73 + 74 + ```python 75 + @atdata.packable 76 + class SimplifiedSample: 77 + label: str 78 + 79 + @atdata.lens 80 + def simplify(src: ImageSample) -> SimplifiedSample: 81 + return SimplifiedSample(label=src.label) 82 + 83 + # View dataset through a different type 84 + simple_ds = dataset.as_type(SimplifiedSample) 85 + for batch in simple_ds.ordered(batch_size=16): 86 + print(batch.label) 87 + ``` 88 + 89 + ## HuggingFace-Style Loading 90 + 91 + ```python 92 + # Load from local path 93 + ds = atdata.load_dataset("path/to/data-{000000..000009}.tar", split="train") 94 + 95 + # Load with split detection 96 + ds_dict = atdata.load_dataset("path/to/data/") 97 + train_ds = ds_dict["train"] 98 + test_ds = ds_dict["test"] 99 + ``` 100 + 101 + ## Local Storage with Redis + S3 102 + 103 + ```python 104 + from atdata.local import LocalIndex, Repo 105 + 106 + # Set up local index 107 + index = LocalIndex() # Connects to Redis 108 + 109 + # Create repo with S3 storage 110 + repo = Repo( 111 + s3_credentials={"AWS_ENDPOINT": "http://localhost:9000", ...}, 112 + bucket="my-bucket", 113 + index=index, 114 + ) 115 + 116 + # Insert dataset 117 + entry = repo.insert(samples, name="my-dataset") 118 + print(f"Stored at: {entry.data_urls}") 119 + ``` 120 + 121 + ## Publish to ATProto Federation 122 + 123 + ```python 124 + from atdata.atmosphere import AtmosphereClient 125 + from atdata.promote import promote_to_atmosphere 126 + 127 + # Authenticate 128 + client = AtmosphereClient() 129 + client.login("handle.bsky.social", "app-password") 130 + 131 + # Promote local dataset to federation 132 + entry = index.get_dataset("my-dataset") 133 + at_uri = promote_to_atmosphere(entry, index, client) 134 + print(f"Published at: {at_uri}") 135 + ``` 136 + 137 + ## Documentation 138 + 139 + - [Packable Samples](packable-samples.md) - Defining typed samples 140 + - [Datasets](datasets.md) - Loading and iterating datasets 141 + - [Lenses](lenses.md) - Type transformations 142 + - [Local Storage](local-storage.md) - Redis + S3 backend 143 + - [Atmosphere](atmosphere.md) - ATProto federation 144 + - [Promotion](promotion.md) - Local to atmosphere workflow 145 + - [load_dataset](load-dataset.md) - HuggingFace-style API 146 + - [Protocols](protocols.md) - Abstract interfaces 147 + 148 + ## License 149 + 150 + MIT
+195
docs/lenses.md
··· 1 + # Lenses 2 + 3 + Lenses provide bidirectional transformations between sample types, enabling datasets to be viewed through different schemas without duplicating data. 4 + 5 + ## Overview 6 + 7 + A lens consists of: 8 + - **Getter**: Transforms source type `S` to view type `V` 9 + - **Putter**: Updates source based on a modified view (optional) 10 + 11 + ## Creating a Lens 12 + 13 + Use the `@lens` decorator to define a getter: 14 + 15 + ```python 16 + import atdata 17 + from numpy.typing import NDArray 18 + 19 + @atdata.packable 20 + class FullSample: 21 + image: NDArray 22 + label: str 23 + confidence: float 24 + metadata: dict 25 + 26 + @atdata.packable 27 + class SimpleSample: 28 + label: str 29 + confidence: float 30 + 31 + @atdata.lens 32 + def simplify(src: FullSample) -> SimpleSample: 33 + return SimpleSample(label=src.label, confidence=src.confidence) 34 + ``` 35 + 36 + The decorator: 37 + 1. Creates a `Lens` object from the getter function 38 + 2. Registers it in the global `LensNetwork` registry 39 + 3. Extracts source/view types from annotations 40 + 41 + ## Adding a Putter 42 + 43 + To enable bidirectional updates, add a putter: 44 + 45 + ```python 46 + @simplify.putter 47 + def simplify_put(view: SimpleSample, source: FullSample) -> FullSample: 48 + return FullSample( 49 + image=source.image, 50 + label=view.label, 51 + confidence=view.confidence, 52 + metadata=source.metadata, 53 + ) 54 + ``` 55 + 56 + The putter receives: 57 + - `view`: The modified view value 58 + - `source`: The original source value 59 + 60 + It returns an updated source that reflects changes from the view. 61 + 62 + ## Using Lenses with Datasets 63 + 64 + Lenses integrate with `Dataset.as_type()`: 65 + 66 + ```python 67 + dataset = atdata.Dataset[FullSample]("data-{000000..000009}.tar") 68 + 69 + # View through a different type 70 + simple_ds = dataset.as_type(SimpleSample) 71 + 72 + for batch in simple_ds.ordered(batch_size=32): 73 + # Only SimpleSample fields available 74 + labels = batch.label 75 + scores = batch.confidence 76 + ``` 77 + 78 + ## Direct Lens Usage 79 + 80 + Lenses can also be called directly: 81 + 82 + ```python 83 + full = FullSample( 84 + image=np.zeros((224, 224, 3)), 85 + label="cat", 86 + confidence=0.95, 87 + metadata={"source": "training"} 88 + ) 89 + 90 + # Apply getter 91 + simple = simplify(full) 92 + # Or: simple = simplify.get(full) 93 + 94 + # Apply putter 95 + modified_simple = SimpleSample(label="dog", confidence=0.87) 96 + updated_full = simplify.put(modified_simple, full) 97 + # updated_full has label="dog", confidence=0.87, but retains 98 + # original image and metadata 99 + ``` 100 + 101 + ## Lens Laws 102 + 103 + Well-behaved lenses should satisfy these properties: 104 + 105 + ### GetPut Law 106 + If you get a view and immediately put it back, the source is unchanged: 107 + ```python 108 + view = lens.get(source) 109 + assert lens.put(view, source) == source 110 + ``` 111 + 112 + ### PutGet Law 113 + If you put a view, getting it back yields that view: 114 + ```python 115 + updated = lens.put(view, source) 116 + assert lens.get(updated) == view 117 + ``` 118 + 119 + ### PutPut Law 120 + Putting twice is equivalent to putting once with the final value: 121 + ```python 122 + result1 = lens.put(v2, lens.put(v1, source)) 123 + result2 = lens.put(v2, source) 124 + assert result1 == result2 125 + ``` 126 + 127 + ## Trivial Putter 128 + 129 + If no putter is defined, a trivial putter is used that ignores view updates: 130 + 131 + ```python 132 + @atdata.lens 133 + def extract_label(src: FullSample) -> SimpleSample: 134 + return SimpleSample(label=src.label, confidence=src.confidence) 135 + 136 + # Without a putter, put() returns the original source unchanged 137 + view = SimpleSample(label="modified", confidence=0.5) 138 + updated = extract_label.put(view, original) 139 + assert updated == original # No changes applied 140 + ``` 141 + 142 + ## LensNetwork Registry 143 + 144 + The `LensNetwork` is a singleton that stores all registered lenses: 145 + 146 + ```python 147 + from atdata.lens import LensNetwork 148 + 149 + network = LensNetwork() 150 + 151 + # Look up a specific lens 152 + lens = network.transform(FullSample, SimpleSample) 153 + 154 + # Raises ValueError if no lens exists 155 + try: 156 + lens = network.transform(TypeA, TypeB) 157 + except ValueError: 158 + print("No lens registered for TypeA -> TypeB") 159 + ``` 160 + 161 + ## Example: Feature Extraction 162 + 163 + ```python 164 + @atdata.packable 165 + class RawSample: 166 + audio: NDArray 167 + text: str 168 + speaker_id: int 169 + 170 + @atdata.packable 171 + class TextFeatures: 172 + text: str 173 + word_count: int 174 + 175 + @atdata.lens 176 + def extract_text(src: RawSample) -> TextFeatures: 177 + return TextFeatures( 178 + text=src.text, 179 + word_count=len(src.text.split()) 180 + ) 181 + 182 + @extract_text.putter 183 + def extract_text_put(view: TextFeatures, source: RawSample) -> RawSample: 184 + return RawSample( 185 + audio=source.audio, 186 + text=view.text, 187 + speaker_id=source.speaker_id 188 + ) 189 + ``` 190 + 191 + ## Related 192 + 193 + - [Datasets](datasets.md) - Using lenses with Dataset.as_type() 194 + - [Packable Samples](packable-samples.md) - Defining sample types 195 + - [Atmosphere](atmosphere.md) - Publishing lenses to ATProto federation
+248
docs/load-dataset.md
··· 1 + # load_dataset API 2 + 3 + The `load_dataset()` function provides a HuggingFace Datasets-style interface for loading typed datasets. 4 + 5 + ## Overview 6 + 7 + Key differences from HuggingFace Datasets: 8 + - Requires explicit `sample_type` parameter (typed dataclass) unless using index 9 + - Returns `atdata.Dataset[ST]` instead of HF Dataset 10 + - Built on WebDataset for efficient streaming 11 + - No Arrow caching layer 12 + 13 + ## Basic Usage 14 + 15 + ```python 16 + import atdata 17 + from atdata import load_dataset 18 + 19 + @atdata.packable 20 + class TextSample: 21 + text: str 22 + label: int 23 + 24 + # Load a specific split 25 + train_ds = load_dataset("path/to/data.tar", TextSample, split="train") 26 + 27 + # Load all splits (returns DatasetDict) 28 + ds_dict = load_dataset("path/to/data/", TextSample) 29 + train_ds = ds_dict["train"] 30 + test_ds = ds_dict["test"] 31 + ``` 32 + 33 + ## Path Formats 34 + 35 + ### WebDataset Brace Notation 36 + 37 + ```python 38 + # Range notation 39 + ds = load_dataset("data-{000000..000099}.tar", MySample, split="train") 40 + 41 + # List notation 42 + ds = load_dataset("data-{train,test,val}.tar", MySample, split="train") 43 + ``` 44 + 45 + ### Glob Patterns 46 + 47 + ```python 48 + # Match all tar files 49 + ds = load_dataset("path/to/*.tar", MySample) 50 + 51 + # Match pattern 52 + ds = load_dataset("path/to/train-*.tar", MySample, split="train") 53 + ``` 54 + 55 + ### Local Directory 56 + 57 + ```python 58 + # Scans for .tar files 59 + ds = load_dataset("./my-dataset/", MySample) 60 + ``` 61 + 62 + ### Remote URLs 63 + 64 + ```python 65 + # S3 66 + ds = load_dataset("s3://bucket/data-{000..099}.tar", MySample, split="train") 67 + 68 + # HTTP/HTTPS 69 + ds = load_dataset("https://example.com/data.tar", MySample, split="train") 70 + 71 + # Google Cloud Storage 72 + ds = load_dataset("gs://bucket/data.tar", MySample, split="train") 73 + ``` 74 + 75 + ### Index Lookup 76 + 77 + ```python 78 + from atdata.local import LocalIndex 79 + 80 + index = LocalIndex() 81 + 82 + # Load from local index (auto-resolves type from schema) 83 + ds = load_dataset("@local/my-dataset", index=index, split="train") 84 + 85 + # With explicit type 86 + ds = load_dataset("@local/my-dataset", MySample, index=index, split="train") 87 + ``` 88 + 89 + ## Split Detection 90 + 91 + Splits are automatically detected from filenames and directories: 92 + 93 + | Pattern | Detected Split | 94 + |---------|---------------| 95 + | `train-*.tar`, `training-*.tar` | train | 96 + | `test-*.tar`, `testing-*.tar` | test | 97 + | `val-*.tar`, `valid-*.tar`, `validation-*.tar` | validation | 98 + | `dev-*.tar`, `development-*.tar` | validation | 99 + | `train/*.tar` (directory) | train | 100 + | `test/*.tar` (directory) | test | 101 + 102 + Files without a detected split default to "train". 103 + 104 + ## DatasetDict 105 + 106 + When loading without `split=`, returns a `DatasetDict`: 107 + 108 + ```python 109 + ds_dict = load_dataset("path/to/data/", MySample) 110 + 111 + # Access splits 112 + train_ds = ds_dict["train"] 113 + test_ds = ds_dict["test"] 114 + 115 + # Iterate splits 116 + for name, dataset in ds_dict.items(): 117 + print(f"{name}: {len(dataset.shard_list)} shards") 118 + 119 + # Properties 120 + print(ds_dict.num_shards) # {'train': 10, 'test': 2} 121 + print(ds_dict.sample_type) # <class 'MySample'> 122 + print(ds_dict.streaming) # False 123 + ``` 124 + 125 + ## Explicit Data Files 126 + 127 + Override automatic detection with `data_files`: 128 + 129 + ```python 130 + # Single pattern 131 + ds = load_dataset( 132 + "path/to/", 133 + MySample, 134 + data_files="custom-*.tar", 135 + ) 136 + 137 + # List of patterns 138 + ds = load_dataset( 139 + "path/to/", 140 + MySample, 141 + data_files=["shard-000.tar", "shard-001.tar"], 142 + ) 143 + 144 + # Explicit split mapping 145 + ds = load_dataset( 146 + "path/to/", 147 + MySample, 148 + data_files={ 149 + "train": "training-shards-*.tar", 150 + "test": "eval-data.tar", 151 + }, 152 + ) 153 + ``` 154 + 155 + ## Streaming Mode 156 + 157 + The `streaming` parameter signals intent for streaming mode: 158 + 159 + ```python 160 + # Mark as streaming 161 + ds_dict = load_dataset("path/to/data.tar", MySample, streaming=True) 162 + 163 + # Check streaming status 164 + if ds_dict.streaming: 165 + print("Streaming mode") 166 + ``` 167 + 168 + Note: atdata datasets are always lazy/streaming via WebDataset pipelines. This parameter primarily signals intent. 169 + 170 + ## Auto Type Resolution 171 + 172 + When using index lookup, the sample type can be resolved automatically: 173 + 174 + ```python 175 + from atdata.local import LocalIndex 176 + 177 + index = LocalIndex() 178 + 179 + # No sample_type needed - resolved from schema 180 + ds = load_dataset("@local/my-dataset", index=index, split="train") 181 + 182 + # Type is inferred from the stored schema 183 + sample_type = ds.sample_type 184 + ``` 185 + 186 + ## Error Handling 187 + 188 + ```python 189 + try: 190 + ds = load_dataset("path/to/data.tar", MySample, split="train") 191 + except FileNotFoundError: 192 + print("No data files found") 193 + except ValueError as e: 194 + if "Split" in str(e): 195 + print("Requested split not found") 196 + else: 197 + print(f"Invalid configuration: {e}") 198 + except KeyError: 199 + print("Dataset not found in index") 200 + ``` 201 + 202 + ## Complete Example 203 + 204 + ```python 205 + import numpy as np 206 + from numpy.typing import NDArray 207 + import atdata 208 + from atdata import load_dataset 209 + import webdataset as wds 210 + 211 + # 1. Define sample type 212 + @atdata.packable 213 + class ImageSample: 214 + image: NDArray 215 + label: str 216 + 217 + # 2. Create dataset files 218 + for split in ["train", "test"]: 219 + with wds.writer.TarWriter(f"{split}-000.tar") as sink: 220 + for i in range(100): 221 + sample = ImageSample( 222 + image=np.random.rand(64, 64, 3).astype(np.float32), 223 + label=f"sample_{i}", 224 + ) 225 + sink.write({**sample.as_wds, "__key__": f"{i:06d}"}) 226 + 227 + # 3. Load with split detection 228 + ds_dict = load_dataset("./", ImageSample) 229 + print(ds_dict.keys()) # dict_keys(['train', 'test']) 230 + 231 + # 4. Iterate 232 + for batch in ds_dict["train"].ordered(batch_size=16): 233 + print(batch.image.shape) # (16, 64, 64, 3) 234 + print(batch.label) # ['sample_0', 'sample_1', ...] 235 + break 236 + 237 + # 5. Load specific split 238 + train_ds = load_dataset("./", ImageSample, split="train") 239 + for batch in train_ds.ordered(batch_size=32): 240 + process(batch) 241 + ``` 242 + 243 + ## Related 244 + 245 + - [Datasets](datasets.md) - Dataset iteration and batching 246 + - [Packable Samples](packable-samples.md) - Defining sample types 247 + - [Local Storage](local-storage.md) - LocalIndex for index lookup 248 + - [Protocols](protocols.md) - AbstractIndex interface
+279
docs/local-storage.md
··· 1 + # Local Storage 2 + 3 + The local storage module provides a Redis + S3 backend for storing and managing datasets before publishing to the ATProto federation. 4 + 5 + ## Overview 6 + 7 + Local storage uses: 8 + - **Redis** for indexing and tracking dataset metadata 9 + - **S3-compatible storage** for dataset tar files 10 + 11 + This enables development and small-scale deployment before promoting to the full ATProto infrastructure. 12 + 13 + ## LocalIndex 14 + 15 + The index tracks datasets in Redis: 16 + 17 + ```python 18 + from atdata.local import LocalIndex 19 + 20 + # Default connection (localhost:6379) 21 + index = LocalIndex() 22 + 23 + # Custom Redis connection 24 + import redis 25 + r = redis.Redis(host='custom-host', port=6379) 26 + index = LocalIndex(redis=r) 27 + 28 + # With connection kwargs 29 + index = LocalIndex(host='custom-host', port=6379, db=1) 30 + ``` 31 + 32 + ### Adding Entries 33 + 34 + ```python 35 + dataset = atdata.Dataset[ImageSample]("data-{000000..000009}.tar") 36 + 37 + entry = index.add_entry( 38 + dataset, 39 + name="my-dataset", 40 + schema_ref="local://schemas/mymodule.ImageSample@1.0.0", # optional 41 + metadata={"description": "Training images"}, # optional 42 + ) 43 + 44 + print(entry.cid) # Content identifier 45 + print(entry.name) # "my-dataset" 46 + print(entry.data_urls) # ["data-{000000..000009}.tar"] 47 + ``` 48 + 49 + ### Listing and Retrieving 50 + 51 + ```python 52 + # Iterate all entries 53 + for entry in index.entries: 54 + print(f"{entry.name}: {entry.cid}") 55 + 56 + # Get as list 57 + all_entries = index.all_entries 58 + 59 + # Get by name 60 + entry = index.get_entry_by_name("my-dataset") 61 + 62 + # Get by CID 63 + entry = index.get_entry("bafyrei...") 64 + ``` 65 + 66 + ## Repo 67 + 68 + The Repo class combines S3 storage with Redis indexing: 69 + 70 + ```python 71 + from atdata.local import Repo 72 + 73 + # From credentials file 74 + repo = Repo( 75 + s3_credentials="path/to/.env", 76 + hive_path="my-bucket/datasets", 77 + ) 78 + 79 + # From credentials dict 80 + repo = Repo( 81 + s3_credentials={ 82 + "AWS_ENDPOINT": "http://localhost:9000", 83 + "AWS_ACCESS_KEY_ID": "minioadmin", 84 + "AWS_SECRET_ACCESS_KEY": "minioadmin", 85 + }, 86 + hive_path="my-bucket/datasets", 87 + ) 88 + ``` 89 + 90 + ### Credentials File Format 91 + 92 + The `.env` file should contain: 93 + 94 + ``` 95 + AWS_ENDPOINT=http://localhost:9000 96 + AWS_ACCESS_KEY_ID=your-access-key 97 + AWS_SECRET_ACCESS_KEY=your-secret-key 98 + ``` 99 + 100 + For AWS S3, omit `AWS_ENDPOINT` to use the default endpoint. 101 + 102 + ### Inserting Datasets 103 + 104 + ```python 105 + @atdata.packable 106 + class ImageSample: 107 + image: NDArray 108 + label: str 109 + 110 + # Create dataset from samples 111 + samples = [ImageSample(...) for _ in range(1000)] 112 + with wds.writer.TarWriter("temp.tar") as sink: 113 + for i, s in enumerate(samples): 114 + sink.write({**s.as_wds, "__key__": f"{i:06d}"}) 115 + 116 + dataset = atdata.Dataset[ImageSample]("temp.tar") 117 + 118 + # Insert into repo (writes to S3 + indexes in Redis) 119 + entry, stored_dataset = repo.insert( 120 + dataset, 121 + name="training-images-v1", 122 + cache_local=False, # Stream directly to S3 123 + ) 124 + 125 + print(entry.cid) # Content identifier 126 + print(stored_dataset.url) # S3 URL for the stored data 127 + print(stored_dataset.shard_list) # Individual shard URLs 128 + ``` 129 + 130 + ### Insert Options 131 + 132 + ```python 133 + entry, ds = repo.insert( 134 + dataset, 135 + name="my-dataset", 136 + cache_local=True, # Write locally first, then copy (faster for some workloads) 137 + maxcount=10000, # Samples per shard 138 + maxsize=100_000_000, # Max shard size in bytes 139 + ) 140 + ``` 141 + 142 + ## LocalDatasetEntry 143 + 144 + Index entries provide content-addressable identification: 145 + 146 + ```python 147 + entry = index.get_entry_by_name("my-dataset") 148 + 149 + # Core properties (IndexEntry protocol) 150 + entry.name # Human-readable name 151 + entry.schema_ref # Schema reference 152 + entry.data_urls # WebDataset URLs 153 + entry.metadata # Arbitrary metadata dict or None 154 + 155 + # Content addressing 156 + entry.cid # ATProto-compatible CID (content identifier) 157 + 158 + # Legacy compatibility 159 + entry.wds_url # First data URL 160 + entry.sample_kind # Same as schema_ref 161 + ``` 162 + 163 + The CID is generated from the entry's content (schema_ref + data_urls), ensuring identical data produces identical CIDs whether stored locally or in the atmosphere. 164 + 165 + ## Schema Storage 166 + 167 + Schemas can be stored and retrieved from the index: 168 + 169 + ```python 170 + # Publish a schema 171 + schema_ref = index.publish_schema( 172 + ImageSample, 173 + version="1.0.0", 174 + description="Image with label annotation", 175 + ) 176 + # Returns: "local://schemas/mymodule.ImageSample@1.0.0" 177 + 178 + # Retrieve schema record 179 + schema = index.get_schema(schema_ref) 180 + # { 181 + # "name": "ImageSample", 182 + # "version": "1.0.0", 183 + # "fields": [...], 184 + # "description": "...", 185 + # "createdAt": "...", 186 + # } 187 + 188 + # List all schemas 189 + for schema in index.list_schemas(): 190 + print(f"{schema['name']}@{schema['version']}") 191 + 192 + # Reconstruct sample type from schema 193 + SampleType = index.decode_schema(schema_ref) 194 + dataset = atdata.Dataset[SampleType](entry.data_urls[0]) 195 + ``` 196 + 197 + ## S3DataStore 198 + 199 + For direct S3 operations without Redis indexing: 200 + 201 + ```python 202 + from atdata.local import S3DataStore 203 + 204 + store = S3DataStore( 205 + credentials="path/to/.env", 206 + bucket="my-bucket", 207 + ) 208 + 209 + # Write dataset shards 210 + urls = store.write_shards( 211 + dataset, 212 + prefix="datasets/v1", 213 + maxcount=10000, 214 + ) 215 + # Returns: ["s3://my-bucket/datasets/v1/data--uuid--000000.tar", ...] 216 + 217 + # Check capabilities 218 + store.supports_streaming() # True 219 + ``` 220 + 221 + ## Complete Workflow Example 222 + 223 + ```python 224 + import numpy as np 225 + from numpy.typing import NDArray 226 + import atdata 227 + from atdata.local import Repo, LocalIndex 228 + import webdataset as wds 229 + 230 + # 1. Define sample type 231 + @atdata.packable 232 + class TrainingSample: 233 + features: NDArray 234 + label: int 235 + source: str 236 + 237 + # 2. Create samples 238 + samples = [ 239 + TrainingSample( 240 + features=np.random.randn(128).astype(np.float32), 241 + label=i % 10, 242 + source="synthetic", 243 + ) 244 + for i in range(10000) 245 + ] 246 + 247 + # 3. Write to local tar 248 + with wds.writer.TarWriter("local-data.tar") as sink: 249 + for i, sample in enumerate(samples): 250 + sink.write({**sample.as_wds, "__key__": f"{i:06d}"}) 251 + 252 + # 4. Create repo and insert 253 + repo = Repo( 254 + s3_credentials={ 255 + "AWS_ENDPOINT": "http://localhost:9000", 256 + "AWS_ACCESS_KEY_ID": "minioadmin", 257 + "AWS_SECRET_ACCESS_KEY": "minioadmin", 258 + }, 259 + hive_path="datasets-bucket/training", 260 + ) 261 + 262 + local_ds = atdata.Dataset[TrainingSample]("local-data.tar") 263 + entry, stored_ds = repo.insert(local_ds, name="training-v1") 264 + 265 + # 5. Retrieve later 266 + index = LocalIndex() 267 + entry = index.get_entry_by_name("training-v1") 268 + dataset = atdata.Dataset[TrainingSample](entry.data_urls[0]) 269 + 270 + for batch in dataset.ordered(batch_size=32): 271 + print(batch.features.shape) # (32, 128) 272 + ``` 273 + 274 + ## Related 275 + 276 + - [Datasets](datasets.md) - Dataset iteration and batching 277 + - [Protocols](protocols.md) - AbstractIndex and IndexEntry interfaces 278 + - [Promotion](promotion.md) - Promoting local datasets to ATProto 279 + - [Atmosphere](atmosphere.md) - ATProto federation
+183
docs/packable-samples.md
··· 1 + # Packable Samples 2 + 3 + Packable samples are typed dataclasses that can be serialized with msgpack for storage in WebDataset tar files. 4 + 5 + ## The `@packable` Decorator 6 + 7 + The recommended way to define a sample type is with the `@packable` decorator: 8 + 9 + ```python 10 + import numpy as np 11 + from numpy.typing import NDArray 12 + import atdata 13 + 14 + @atdata.packable 15 + class ImageSample: 16 + image: NDArray 17 + label: str 18 + confidence: float 19 + ``` 20 + 21 + This creates a dataclass that: 22 + - Inherits from `PackableSample` 23 + - Has automatic msgpack serialization 24 + - Handles NDArray conversion to/from bytes 25 + 26 + ## Supported Field Types 27 + 28 + ### Primitives 29 + 30 + ```python 31 + @atdata.packable 32 + class PrimitiveSample: 33 + name: str 34 + count: int 35 + score: float 36 + active: bool 37 + data: bytes 38 + ``` 39 + 40 + ### NumPy Arrays 41 + 42 + Fields annotated as `NDArray` are automatically converted: 43 + 44 + ```python 45 + @atdata.packable 46 + class ArraySample: 47 + features: NDArray # Required array 48 + embeddings: NDArray | None # Optional array 49 + ``` 50 + 51 + **Note**: Bytes in NDArray-typed fields are always interpreted as serialized arrays. Don't use `NDArray` for raw binary data. 52 + 53 + ### Lists 54 + 55 + ```python 56 + @atdata.packable 57 + class ListSample: 58 + tags: list[str] 59 + scores: list[float] 60 + ``` 61 + 62 + ## Serialization 63 + 64 + ### Packing to Bytes 65 + 66 + ```python 67 + sample = ImageSample( 68 + image=np.random.rand(224, 224, 3).astype(np.float32), 69 + label="cat", 70 + confidence=0.95, 71 + ) 72 + 73 + # Serialize to msgpack bytes 74 + packed_bytes = sample.packed 75 + print(f"Size: {len(packed_bytes)} bytes") 76 + ``` 77 + 78 + ### Unpacking from Bytes 79 + 80 + ```python 81 + # Deserialize from bytes 82 + restored = ImageSample.from_bytes(packed_bytes) 83 + 84 + # Arrays are automatically restored 85 + assert np.array_equal(sample.image, restored.image) 86 + assert sample.label == restored.label 87 + ``` 88 + 89 + ### WebDataset Format 90 + 91 + The `as_wds` property returns a dict ready for WebDataset: 92 + 93 + ```python 94 + wds_dict = sample.as_wds 95 + # {'__key__': '1234...', 'msgpack': b'...'} 96 + ``` 97 + 98 + Write samples to a tar file: 99 + 100 + ```python 101 + import webdataset as wds 102 + 103 + with wds.writer.TarWriter("data-000000.tar") as sink: 104 + for i, sample in enumerate(samples): 105 + # Use custom key or let as_wds generate one 106 + sink.write({**sample.as_wds, "__key__": f"sample_{i:06d}"}) 107 + ``` 108 + 109 + ## Direct Inheritance (Alternative) 110 + 111 + You can also inherit directly from `PackableSample`: 112 + 113 + ```python 114 + from dataclasses import dataclass 115 + 116 + @dataclass 117 + class DirectSample(atdata.PackableSample): 118 + name: str 119 + values: NDArray 120 + ``` 121 + 122 + This is equivalent to using `@packable` but more verbose. 123 + 124 + ## How It Works 125 + 126 + ### Serialization Flow 127 + 128 + 1. **Packing** (`sample.packed`): 129 + - NDArray fields → converted to bytes via `array_to_bytes()` 130 + - Other fields → passed through 131 + - All fields → packed with msgpack 132 + 133 + 2. **Unpacking** (`Sample.from_bytes()`): 134 + - Bytes → unpacked with ormsgpack 135 + - Dict → passed to `__init__` 136 + - `__post_init__` → calls `_ensure_good()` 137 + - NDArray fields → bytes converted back to arrays 138 + 139 + ### The `_ensure_good()` Method 140 + 141 + This method runs automatically after construction and handles NDArray conversion: 142 + 143 + ```python 144 + def _ensure_good(self): 145 + for field in dataclasses.fields(self): 146 + if _is_possibly_ndarray_type(field.type): 147 + value = getattr(self, field.name) 148 + if isinstance(value, bytes): 149 + setattr(self, field.name, bytes_to_array(value)) 150 + ``` 151 + 152 + ## Best Practices 153 + 154 + ### Do 155 + 156 + ```python 157 + @atdata.packable 158 + class GoodSample: 159 + features: NDArray # Clear type annotation 160 + label: str # Simple primitives 161 + metadata: dict # Msgpack-compatible dicts 162 + scores: list[float] # Typed lists 163 + ``` 164 + 165 + ### Don't 166 + 167 + ```python 168 + @atdata.packable 169 + class BadSample: 170 + # DON'T: Nested dataclasses not supported 171 + nested: OtherSample 172 + 173 + # DON'T: Complex objects that aren't msgpack-serializable 174 + callback: Callable 175 + 176 + # DON'T: Use NDArray for raw bytes 177 + raw_data: NDArray # Use 'bytes' type instead 178 + ``` 179 + 180 + ## Related 181 + 182 + - [Datasets](datasets.md) - Loading and iterating samples 183 + - [Lenses](lenses.md) - Transforming between sample types
+191
docs/promotion.md
··· 1 + # Promotion Workflow 2 + 3 + The promotion workflow migrates datasets from local storage (Redis + S3) to the ATProto atmosphere network, enabling federation and discovery. 4 + 5 + ## Overview 6 + 7 + Promotion handles: 8 + - **Schema deduplication**: Avoids publishing duplicate schemas 9 + - **Data URL preservation**: Keeps existing S3 URLs or copies to new storage 10 + - **Metadata transfer**: Preserves tags, descriptions, and custom metadata 11 + 12 + ## Basic Usage 13 + 14 + ```python 15 + from atdata.local import LocalIndex 16 + from atdata.atmosphere import AtmosphereClient 17 + from atdata.promote import promote_to_atmosphere 18 + 19 + # Setup 20 + local_index = LocalIndex() 21 + client = AtmosphereClient() 22 + client.login("handle.bsky.social", "app-password") 23 + 24 + # Get local entry 25 + entry = local_index.get_entry_by_name("my-dataset") 26 + 27 + # Promote to atmosphere 28 + at_uri = promote_to_atmosphere(entry, local_index, client) 29 + print(f"Published: {at_uri}") 30 + ``` 31 + 32 + ## With Metadata 33 + 34 + ```python 35 + at_uri = promote_to_atmosphere( 36 + entry, 37 + local_index, 38 + client, 39 + name="my-dataset-v2", # Override name 40 + description="Training images", # Add description 41 + tags=["images", "training"], # Add discovery tags 42 + license="MIT", # Specify license 43 + ) 44 + ``` 45 + 46 + ## Schema Deduplication 47 + 48 + The promotion workflow automatically checks for existing schemas: 49 + 50 + ```python 51 + # First promotion: publishes schema 52 + uri1 = promote_to_atmosphere(entry1, local_index, client) 53 + 54 + # Second promotion with same schema type + version: reuses existing schema 55 + uri2 = promote_to_atmosphere(entry2, local_index, client) 56 + ``` 57 + 58 + Schema matching is based on: 59 + - `{module}.{class_name}` (e.g., `mymodule.ImageSample`) 60 + - Version string (e.g., `1.0.0`) 61 + 62 + ## Data Storage Options 63 + 64 + ### Use Existing URLs (Default) 65 + 66 + By default, promotion keeps the original data URLs: 67 + 68 + ```python 69 + # Data stays in original S3 location 70 + at_uri = promote_to_atmosphere(entry, local_index, client) 71 + ``` 72 + 73 + ### Copy to New Storage 74 + 75 + To copy data to a different storage location: 76 + 77 + ```python 78 + from atdata.local import S3DataStore 79 + 80 + # Create new data store 81 + new_store = S3DataStore( 82 + credentials="new-s3-creds.env", 83 + bucket="public-datasets", 84 + ) 85 + 86 + # Promote with data copy 87 + at_uri = promote_to_atmosphere( 88 + entry, 89 + local_index, 90 + client, 91 + data_store=new_store, # Copy data to new storage 92 + ) 93 + ``` 94 + 95 + ## Complete Workflow Example 96 + 97 + ```python 98 + import numpy as np 99 + from numpy.typing import NDArray 100 + import atdata 101 + from atdata.local import LocalIndex, Repo 102 + from atdata.atmosphere import AtmosphereClient 103 + from atdata.promote import promote_to_atmosphere 104 + import webdataset as wds 105 + 106 + # 1. Define sample type 107 + @atdata.packable 108 + class FeatureSample: 109 + features: NDArray 110 + label: int 111 + 112 + # 2. Create local dataset 113 + samples = [ 114 + FeatureSample( 115 + features=np.random.randn(128).astype(np.float32), 116 + label=i % 10, 117 + ) 118 + for i in range(1000) 119 + ] 120 + 121 + with wds.writer.TarWriter("features.tar") as sink: 122 + for i, s in enumerate(samples): 123 + sink.write({**s.as_wds, "__key__": f"{i:06d}"}) 124 + 125 + # 3. Store in local repo 126 + repo = Repo( 127 + s3_credentials={ 128 + "AWS_ENDPOINT": "http://localhost:9000", 129 + "AWS_ACCESS_KEY_ID": "minioadmin", 130 + "AWS_SECRET_ACCESS_KEY": "minioadmin", 131 + }, 132 + hive_path="datasets-bucket/features", 133 + ) 134 + 135 + dataset = atdata.Dataset[FeatureSample]("features.tar") 136 + local_entry, _ = repo.insert(dataset, name="feature-vectors-v1") 137 + 138 + # 4. Publish schema to local index 139 + local_index = LocalIndex() 140 + local_index.publish_schema(FeatureSample, version="1.0.0") 141 + 142 + # 5. Promote to atmosphere 143 + client = AtmosphereClient() 144 + client.login("myhandle.bsky.social", "app-password") 145 + 146 + at_uri = promote_to_atmosphere( 147 + local_entry, 148 + local_index, 149 + client, 150 + description="Feature vectors for classification", 151 + tags=["features", "embeddings"], 152 + license="MIT", 153 + ) 154 + 155 + print(f"Dataset published: {at_uri}") 156 + 157 + # 6. Verify on atmosphere 158 + from atdata.atmosphere import AtmosphereIndex 159 + 160 + atm_index = AtmosphereIndex(client) 161 + entry = atm_index.get_dataset(at_uri) 162 + print(f"Name: {entry.name}") 163 + print(f"Schema: {entry.schema_ref}") 164 + print(f"URLs: {entry.data_urls}") 165 + ``` 166 + 167 + ## Error Handling 168 + 169 + ```python 170 + try: 171 + at_uri = promote_to_atmosphere(entry, local_index, client) 172 + except KeyError as e: 173 + # Schema not found in local index 174 + print(f"Missing schema: {e}") 175 + except ValueError as e: 176 + # Entry has no data URLs 177 + print(f"Invalid entry: {e}") 178 + ``` 179 + 180 + ## Requirements 181 + 182 + Before promotion: 183 + 1. Dataset must be in local index (via `Repo.insert()` or `Index.add_entry()`) 184 + 2. Schema must be published to local index (via `Index.publish_schema()`) 185 + 3. AtmosphereClient must be authenticated 186 + 187 + ## Related 188 + 189 + - [Local Storage](local-storage.md) - Setting up local datasets 190 + - [Atmosphere](atmosphere.md) - ATProto integration 191 + - [Protocols](protocols.md) - AbstractIndex and AbstractDataStore
+243
docs/protocols.md
··· 1 + # Protocols 2 + 3 + The protocols module defines abstract interfaces that enable interchangeable index backends (local Redis vs ATProto) and data stores (S3 vs PDS blobs). 4 + 5 + ## Overview 6 + 7 + Both local and atmosphere implementations solve the same problem: indexed dataset storage with external data URLs. These protocols formalize that common interface: 8 + 9 + - **IndexEntry**: Common interface for dataset index entries 10 + - **AbstractIndex**: Protocol for index operations 11 + - **AbstractDataStore**: Protocol for data storage operations 12 + 13 + ## IndexEntry Protocol 14 + 15 + Represents a dataset entry in any index: 16 + 17 + ```python 18 + from atdata._protocols import IndexEntry 19 + 20 + def process_entry(entry: IndexEntry) -> None: 21 + print(f"Name: {entry.name}") 22 + print(f"Schema: {entry.schema_ref}") 23 + print(f"URLs: {entry.data_urls}") 24 + print(f"Metadata: {entry.metadata}") 25 + ``` 26 + 27 + ### Properties 28 + 29 + | Property | Type | Description | 30 + |----------|------|-------------| 31 + | `name` | `str` | Human-readable dataset name | 32 + | `schema_ref` | `str` | Schema reference (local:// or at://) | 33 + | `data_urls` | `list[str]` | WebDataset URLs for the data | 34 + | `metadata` | `dict \| None` | Arbitrary metadata dictionary | 35 + 36 + ### Implementations 37 + 38 + - `LocalDatasetEntry` (from `atdata.local`) 39 + - `AtmosphereIndexEntry` (from `atdata.atmosphere`) 40 + 41 + ## AbstractIndex Protocol 42 + 43 + Defines operations for managing schemas and datasets: 44 + 45 + ```python 46 + from atdata._protocols import AbstractIndex 47 + 48 + def list_all_datasets(index: AbstractIndex) -> None: 49 + """Works with LocalIndex or AtmosphereIndex.""" 50 + for entry in index.list_datasets(): 51 + print(f"{entry.name}: {entry.schema_ref}") 52 + ``` 53 + 54 + ### Dataset Operations 55 + 56 + ```python 57 + # Insert a dataset 58 + entry = index.insert_dataset( 59 + dataset, 60 + name="my-dataset", 61 + schema_ref="local://schemas/MySample@1.0.0", # optional 62 + ) 63 + 64 + # Get by name/reference 65 + entry = index.get_dataset("my-dataset") 66 + 67 + # List all datasets 68 + for entry in index.list_datasets(): 69 + print(entry.name) 70 + ``` 71 + 72 + ### Schema Operations 73 + 74 + ```python 75 + # Publish a schema 76 + schema_ref = index.publish_schema( 77 + MySample, 78 + version="1.0.0", 79 + ) 80 + 81 + # Get schema record 82 + schema = index.get_schema(schema_ref) 83 + print(schema["name"], schema["version"]) 84 + 85 + # List all schemas 86 + for schema in index.list_schemas(): 87 + print(f"{schema['name']}@{schema['version']}") 88 + 89 + # Decode schema to Python type 90 + SampleType = index.decode_schema(schema_ref) 91 + dataset = atdata.Dataset[SampleType](entry.data_urls[0]) 92 + ``` 93 + 94 + ### Implementations 95 + 96 + - `LocalIndex` / `Index` (from `atdata.local`) 97 + - `AtmosphereIndex` (from `atdata.atmosphere`) 98 + 99 + ## AbstractDataStore Protocol 100 + 101 + Abstracts over different storage backends: 102 + 103 + ```python 104 + from atdata._protocols import AbstractDataStore 105 + 106 + def write_dataset(store: AbstractDataStore, dataset) -> list[str]: 107 + """Works with S3DataStore or future PDS blob store.""" 108 + urls = store.write_shards(dataset, prefix="datasets/v1") 109 + return urls 110 + ``` 111 + 112 + ### Methods 113 + 114 + ```python 115 + # Write dataset shards 116 + urls = store.write_shards( 117 + dataset, 118 + prefix="datasets/mnist/v1", 119 + maxcount=10000, # samples per shard 120 + ) 121 + 122 + # Resolve URL for reading 123 + readable_url = store.read_url("s3://bucket/path.tar") 124 + 125 + # Check streaming support 126 + if store.supports_streaming(): 127 + # Can stream directly 128 + pass 129 + ``` 130 + 131 + ### Implementations 132 + 133 + - `S3DataStore` (from `atdata.local`) 134 + 135 + ## Using Protocols for Polymorphism 136 + 137 + Write code that works with any backend: 138 + 139 + ```python 140 + from atdata._protocols import AbstractIndex, IndexEntry 141 + from atdata import Dataset 142 + 143 + def backup_all_datasets( 144 + source: AbstractIndex, 145 + target: AbstractIndex, 146 + ) -> None: 147 + """Copy all datasets from source index to target.""" 148 + for entry in source.list_datasets(): 149 + # Decode schema from source 150 + SampleType = source.decode_schema(entry.schema_ref) 151 + 152 + # Publish schema to target 153 + target_schema = target.publish_schema(SampleType) 154 + 155 + # Load and re-insert dataset 156 + ds = Dataset[SampleType](entry.data_urls[0]) 157 + target.insert_dataset( 158 + ds, 159 + name=entry.name, 160 + schema_ref=target_schema, 161 + ) 162 + ``` 163 + 164 + ## Schema Reference Formats 165 + 166 + Schema references vary by backend: 167 + 168 + | Backend | Format | Example | 169 + |---------|--------|---------| 170 + | Local | `local://schemas/{module.Class}@{version}` | `local://schemas/myapp.ImageSample@1.0.0` | 171 + | Atmosphere | `at://{did}/{collection}/{rkey}` | `at://did:plc:abc123/ac.foundation.dataset.sampleSchema/xyz` | 172 + 173 + ## Type Checking 174 + 175 + Protocols are runtime-checkable: 176 + 177 + ```python 178 + from atdata._protocols import IndexEntry, AbstractIndex 179 + 180 + # Check if object implements protocol 181 + entry = index.get_dataset("test") 182 + assert isinstance(entry, IndexEntry) 183 + 184 + # Type hints work with protocols 185 + def process(index: AbstractIndex) -> None: 186 + ... # IDE provides autocomplete 187 + ``` 188 + 189 + ## Complete Example 190 + 191 + ```python 192 + import atdata 193 + from atdata.local import LocalIndex, S3DataStore 194 + from atdata.atmosphere import AtmosphereClient, AtmosphereIndex 195 + from atdata._protocols import AbstractIndex 196 + import numpy as np 197 + from numpy.typing import NDArray 198 + 199 + # Define sample type 200 + @atdata.packable 201 + class FeatureSample: 202 + features: NDArray 203 + label: int 204 + 205 + # Function works with any index 206 + def count_datasets(index: AbstractIndex) -> int: 207 + return sum(1 for _ in index.list_datasets()) 208 + 209 + # Use with local index 210 + local_index = LocalIndex() 211 + print(f"Local datasets: {count_datasets(local_index)}") 212 + 213 + # Use with atmosphere index 214 + client = AtmosphereClient() 215 + client.login("handle.bsky.social", "app-password") 216 + atm_index = AtmosphereIndex(client) 217 + print(f"Atmosphere datasets: {count_datasets(atm_index)}") 218 + 219 + # Migrate from local to atmosphere 220 + def migrate_dataset( 221 + name: str, 222 + source: AbstractIndex, 223 + target: AbstractIndex, 224 + ) -> None: 225 + entry = source.get_dataset(name) 226 + SampleType = source.decode_schema(entry.schema_ref) 227 + 228 + # Publish schema 229 + schema_ref = target.publish_schema(SampleType) 230 + 231 + # Create dataset and insert 232 + ds = atdata.Dataset[SampleType](entry.data_urls[0]) 233 + target.insert_dataset(ds, name=name, schema_ref=schema_ref) 234 + 235 + migrate_dataset("my-features", local_index, atm_index) 236 + ``` 237 + 238 + ## Related 239 + 240 + - [Local Storage](local-storage.md) - LocalIndex and S3DataStore 241 + - [Atmosphere](atmosphere.md) - AtmosphereIndex 242 + - [Promotion](promotion.md) - Local to atmosphere migration 243 + - [load_dataset](load-dataset.md) - Using indexes with load_dataset()
+312
examples/local_workflow.py
··· 1 + #!/usr/bin/env python3 2 + """Demonstration of atdata local storage workflow. 3 + 4 + This script demonstrates how to use the local module to store and index 5 + datasets using Redis and S3-compatible storage. 6 + 7 + Usage: 8 + # Dry run with mocks (no Redis/S3 required): 9 + python local_workflow.py 10 + 11 + # With actual Redis (requires redis-server running): 12 + python local_workflow.py --redis 13 + 14 + # With Redis and S3 (requires MinIO or AWS): 15 + python local_workflow.py --redis --s3-endpoint http://localhost:9000 16 + 17 + Requirements: 18 + pip install atdata redis 19 + 20 + Note: 21 + For S3 storage, you can use MinIO for local development: 22 + docker run -p 9000:9000 minio/minio server /data 23 + """ 24 + 25 + import argparse 26 + import tempfile 27 + from datetime import datetime 28 + from pathlib import Path 29 + 30 + import numpy as np 31 + from numpy.typing import NDArray 32 + 33 + import atdata 34 + from atdata.local import LocalIndex, LocalDatasetEntry, Repo, S3DataStore 35 + 36 + 37 + # ============================================================================= 38 + # Define sample types 39 + # ============================================================================= 40 + 41 + @atdata.packable 42 + class TrainingSample: 43 + """A sample containing features and label for training.""" 44 + features: NDArray 45 + label: int 46 + 47 + 48 + @atdata.packable 49 + class TextSample: 50 + """A sample containing text data.""" 51 + text: str 52 + category: str 53 + 54 + 55 + # ============================================================================= 56 + # Demo functions 57 + # ============================================================================= 58 + 59 + def demo_local_dataset_entry(): 60 + """Demonstrate LocalDatasetEntry creation and CID generation.""" 61 + print("\n" + "=" * 60) 62 + print("LocalDatasetEntry Demo") 63 + print("=" * 60) 64 + 65 + # Create an entry 66 + entry = LocalDatasetEntry( 67 + _name="my-dataset", 68 + _schema_ref="local://schemas/examples.TrainingSample@1.0.0", 69 + _data_urls=["s3://bucket/data-000000.tar", "s3://bucket/data-000001.tar"], 70 + _metadata={"source": "example", "samples": 10000}, 71 + ) 72 + 73 + print(f"\nEntry name: {entry.name}") 74 + print(f"Schema ref: {entry.schema_ref}") 75 + print(f"Data URLs: {entry.data_urls}") 76 + print(f"Metadata: {entry.metadata}") 77 + print(f"CID: {entry.cid}") 78 + 79 + # Demonstrate CID determinism 80 + entry2 = LocalDatasetEntry( 81 + _name="different-name", # Name doesn't affect CID 82 + _schema_ref="local://schemas/examples.TrainingSample@1.0.0", 83 + _data_urls=["s3://bucket/data-000000.tar", "s3://bucket/data-000001.tar"], 84 + ) 85 + 86 + print(f"\nCID comparison (same content, different name):") 87 + print(f" Entry 1 CID: {entry.cid}") 88 + print(f" Entry 2 CID: {entry2.cid}") 89 + print(f" Match: {entry.cid == entry2.cid}") 90 + 91 + 92 + def demo_local_index_mock(): 93 + """Demonstrate LocalIndex operations with mock data.""" 94 + print("\n" + "=" * 60) 95 + print("LocalIndex Demo (mock)") 96 + print("=" * 60) 97 + 98 + # LocalIndex without Redis connection works for read operations 99 + index = LocalIndex() 100 + 101 + print("\nLocalIndex created (no Redis connection)") 102 + print("Methods available:") 103 + print(" - index.insert_dataset(dataset, name='...')") 104 + print(" - index.get_dataset(name_or_cid)") 105 + print(" - index.list_datasets()") 106 + print(" - index.publish_schema(sample_type, version='1.0.0')") 107 + print(" - index.get_schema(ref)") 108 + print(" - index.list_schemas()") 109 + print(" - index.decode_schema(ref) # Returns PackableSample class") 110 + 111 + 112 + def demo_local_index_redis(redis_host: str = "localhost", redis_port: int = 6379): 113 + """Demonstrate LocalIndex with actual Redis.""" 114 + print("\n" + "=" * 60) 115 + print("LocalIndex Demo (Redis)") 116 + print("=" * 60) 117 + 118 + from redis import Redis 119 + 120 + # Connect to Redis 121 + try: 122 + redis = Redis(host=redis_host, port=redis_port) 123 + redis.ping() 124 + except Exception as e: 125 + print(f"Could not connect to Redis: {e}") 126 + print("Skipping Redis demo.") 127 + return 128 + 129 + # Create index with Redis 130 + index = LocalIndex(redis=redis) 131 + print(f"\nConnected to Redis at {redis_host}:{redis_port}") 132 + 133 + # Publish a schema 134 + print("\nPublishing TrainingSample schema...") 135 + schema_ref = index.publish_schema(TrainingSample, version="1.0.0") 136 + print(f" Schema ref: {schema_ref}") 137 + 138 + # List schemas 139 + print("\nListing schemas:") 140 + for schema in index.list_schemas(): 141 + print(f" - {schema.get('name', 'Unknown')} v{schema.get('version', '?')}") 142 + 143 + # Get schema and decode to type 144 + schema_record = index.get_schema(schema_ref) 145 + print(f"\nSchema record: {schema_record.get('name')}") 146 + print(f" Fields: {[f['name'] for f in schema_record.get('fields', [])]}") 147 + 148 + # Decode schema back to a PackableSample class 149 + decoded_type = index.decode_schema(schema_ref) 150 + print(f"\nDecoded type: {decoded_type.__name__}") 151 + 152 + # Clean up test data 153 + for key in redis.scan_iter(match="LocalSchema:*"): 154 + redis.delete(key) 155 + print("\nCleaned up test schemas") 156 + 157 + 158 + def demo_s3_datastore(): 159 + """Demonstrate S3DataStore interface.""" 160 + print("\n" + "=" * 60) 161 + print("S3DataStore Demo") 162 + print("=" * 60) 163 + 164 + # S3DataStore with mock credentials (won't actually connect) 165 + creds = { 166 + "AWS_ENDPOINT": "http://localhost:9000", 167 + "AWS_ACCESS_KEY_ID": "minioadmin", 168 + "AWS_SECRET_ACCESS_KEY": "minioadmin", 169 + } 170 + 171 + store = S3DataStore(creds, bucket="my-bucket") 172 + 173 + print(f"\nS3DataStore created:") 174 + print(f" Bucket: {store.bucket}") 175 + print(f" Supports streaming: {store.supports_streaming()}") 176 + 177 + # read_url returns the URL unchanged (passthrough for WDS) 178 + url = "s3://my-bucket/data.tar" 179 + print(f"\nread_url passthrough: {store.read_url(url)}") 180 + 181 + 182 + def demo_repo_workflow(tmp_path: Path): 183 + """Demonstrate full Repo workflow with local files.""" 184 + import webdataset as wds 185 + 186 + print("\n" + "=" * 60) 187 + print("Repo Workflow Demo (local files)") 188 + print("=" * 60) 189 + 190 + # Create sample data 191 + samples = [ 192 + TrainingSample(features=np.random.randn(10).astype(np.float32), label=i % 3) 193 + for i in range(100) 194 + ] 195 + 196 + print(f"\nCreated {len(samples)} training samples") 197 + 198 + # Create a Dataset and write to local tar file 199 + tar_path = tmp_path / "local-data-000000.tar" 200 + with wds.writer.TarWriter(str(tar_path)) as sink: 201 + for i, sample in enumerate(samples): 202 + sink.write({**sample.as_wds, "__key__": f"sample_{i:06d}"}) 203 + 204 + print(f"Wrote samples to: {tar_path}") 205 + 206 + # Load the dataset back 207 + ds = atdata.Dataset[TrainingSample](str(tar_path)) 208 + loaded = list(ds.ordered(batch_size=None)) 209 + print(f"Loaded {len(loaded)} samples back") 210 + 211 + # Verify round-trip 212 + assert len(loaded) == len(samples) 213 + assert np.allclose(loaded[0].features, samples[0].features) 214 + print("Round-trip verification: PASSED") 215 + 216 + 217 + def demo_load_dataset_with_index(): 218 + """Demonstrate load_dataset with index parameter.""" 219 + print("\n" + "=" * 60) 220 + print("load_dataset with Index Demo") 221 + print("=" * 60) 222 + 223 + print(""" 224 + The load_dataset() function supports an index parameter for both local 225 + and atmosphere backends: 226 + 227 + # Local index lookup 228 + from atdata import load_dataset 229 + from atdata.local import LocalIndex 230 + 231 + index = LocalIndex() 232 + ds = load_dataset('@local/my-dataset', index=index, split='train') 233 + 234 + # The index resolves the dataset name to URLs and schema 235 + for batch in ds.shuffled(batch_size=32): 236 + process(batch) 237 + 238 + # Atmosphere lookup (via @handle/dataset syntax) 239 + ds = load_dataset('@alice.science/mnist', split='train') 240 + 241 + # This automatically: 242 + # 1. Resolves the handle to a DID 243 + # 2. Fetches the dataset record from the user's repository 244 + # 3. Gets the data URLs from the record 245 + # 4. Resolves the schema for type information 246 + """) 247 + 248 + 249 + # ============================================================================= 250 + # Main 251 + # ============================================================================= 252 + 253 + def main(): 254 + parser = argparse.ArgumentParser( 255 + description="Demonstrate atdata local storage workflow", 256 + formatter_class=argparse.RawDescriptionHelpFormatter, 257 + epilog=__doc__, 258 + ) 259 + parser.add_argument( 260 + "--redis", 261 + action="store_true", 262 + help="Run demos that require Redis", 263 + ) 264 + parser.add_argument( 265 + "--redis-host", 266 + default="localhost", 267 + help="Redis host (default: localhost)", 268 + ) 269 + parser.add_argument( 270 + "--redis-port", 271 + type=int, 272 + default=6379, 273 + help="Redis port (default: 6379)", 274 + ) 275 + parser.add_argument( 276 + "--s3-endpoint", 277 + help="S3 endpoint URL for live S3 demo", 278 + ) 279 + 280 + args = parser.parse_args() 281 + 282 + print("=" * 60) 283 + print("atdata.local Demo") 284 + print("=" * 60) 285 + print(f"\nTime: {datetime.now().isoformat()}") 286 + 287 + # Always run these demos (no external services required) 288 + demo_local_dataset_entry() 289 + demo_local_index_mock() 290 + demo_s3_datastore() 291 + demo_load_dataset_with_index() 292 + 293 + # Run with temp directory for file-based demos 294 + with tempfile.TemporaryDirectory() as tmp: 295 + demo_repo_workflow(Path(tmp)) 296 + 297 + # Run Redis demo if requested 298 + if args.redis: 299 + demo_local_index_redis(args.redis_host, args.redis_port) 300 + else: 301 + print("\n" + "=" * 60) 302 + print("Redis Demo Skipped") 303 + print("=" * 60) 304 + print("\nTo run with Redis: python local_workflow.py --redis") 305 + 306 + print("\n" + "=" * 60) 307 + print("Demo Complete!") 308 + print("=" * 60) 309 + 310 + 311 + if __name__ == "__main__": 312 + main()
+394
examples/promote_workflow.py
··· 1 + #!/usr/bin/env python3 2 + """Demonstration of promoting local datasets to the atmosphere network. 3 + 4 + This script demonstrates the workflow for migrating datasets from local 5 + Redis/S3 storage to the federated ATProto atmosphere network. 6 + 7 + Usage: 8 + # Dry run with mocks (no external services required): 9 + python promote_workflow.py 10 + 11 + # With actual ATProto connection: 12 + python promote_workflow.py --handle your.handle --password your-app-password 13 + 14 + Requirements: 15 + pip install atdata[atmosphere] 16 + 17 + Note: 18 + Use an app-specific password, not your main Bluesky password. 19 + Create app passwords at: https://bsky.app/settings/app-passwords 20 + """ 21 + 22 + import argparse 23 + from datetime import datetime 24 + from unittest.mock import Mock, MagicMock 25 + 26 + import numpy as np 27 + from numpy.typing import NDArray 28 + 29 + import atdata 30 + from atdata.promote import promote_to_atmosphere 31 + 32 + 33 + # ============================================================================= 34 + # Define sample types 35 + # ============================================================================= 36 + 37 + @atdata.packable 38 + class ExperimentSample: 39 + """A sample from a scientific experiment.""" 40 + measurement: NDArray 41 + timestamp: float 42 + sensor_id: str 43 + 44 + 45 + # ============================================================================= 46 + # Demo functions 47 + # ============================================================================= 48 + 49 + def demo_promotion_concept(): 50 + """Explain the promotion workflow concept.""" 51 + print("\n" + "=" * 60) 52 + print("Promotion Workflow Overview") 53 + print("=" * 60) 54 + 55 + print(""" 56 + The promotion workflow moves datasets from local storage to the atmosphere: 57 + 58 + LOCAL ATMOSPHERE 59 + ----- ---------- 60 + Redis Index ATProto PDS 61 + S3 Storage --> (same S3 or new location) 62 + local://schemas/... at://did:plc:.../schema/... 63 + 64 + Steps: 65 + 1. Retrieve dataset entry from LocalIndex 66 + 2. Get schema from local index 67 + 3. Find or publish schema on atmosphere (deduplication) 68 + 4. Optionally copy data to new storage location 69 + 5. Create dataset record on atmosphere 70 + 6. Return AT URI for the published dataset 71 + 72 + Key features: 73 + - Schema deduplication: Won't republish identical schemas 74 + - Flexible data handling: Keep existing URLs or copy to new storage 75 + - Metadata preservation: Local metadata carries over to atmosphere 76 + """) 77 + 78 + 79 + def demo_mock_promotion(): 80 + """Demonstrate promotion with mocked services.""" 81 + print("\n" + "=" * 60) 82 + print("Mock Promotion Demo") 83 + print("=" * 60) 84 + 85 + from atdata.local import LocalDatasetEntry 86 + 87 + # Create a mock local entry 88 + local_entry = LocalDatasetEntry( 89 + _name="experiment-2024-001", 90 + _schema_ref="local://schemas/__main__.ExperimentSample@1.0.0", 91 + _data_urls=[ 92 + "s3://research-bucket/experiments/exp-2024-001/shard-000000.tar", 93 + "s3://research-bucket/experiments/exp-2024-001/shard-000001.tar", 94 + ], 95 + _metadata={ 96 + "experiment_date": "2024-01-15", 97 + "lab": "Physics Building Room 302", 98 + "principal_investigator": "Dr. Smith", 99 + }, 100 + ) 101 + 102 + print(f"\nLocal entry to promote:") 103 + print(f" Name: {local_entry.name}") 104 + print(f" Schema: {local_entry.schema_ref}") 105 + print(f" URLs: {len(local_entry.data_urls)} shards") 106 + print(f" Metadata: {local_entry.metadata}") 107 + 108 + # Create mock local index 109 + mock_index = Mock() 110 + mock_index.get_schema.return_value = { 111 + "name": "__main__.ExperimentSample", 112 + "version": "1.0.0", 113 + "description": "A sample from a scientific experiment", 114 + "fields": [ 115 + {"name": "measurement", "fieldType": {"$type": "local#ndarray", "dtype": "float32"}, "optional": False}, 116 + {"name": "timestamp", "fieldType": {"$type": "local#primitive", "primitive": "float"}, "optional": False}, 117 + {"name": "sensor_id", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 118 + ], 119 + } 120 + 121 + # Create mock atmosphere client 122 + mock_client = Mock() 123 + mock_client.did = "did:plc:demo123456789" 124 + 125 + # Mock the atmosphere modules 126 + from unittest.mock import patch 127 + 128 + with patch("atdata.promote._find_existing_schema") as mock_find: 129 + mock_find.return_value = None # No existing schema 130 + 131 + with patch("atdata.atmosphere.SchemaPublisher") as MockSchemaPublisher: 132 + mock_schema_pub = MockSchemaPublisher.return_value 133 + mock_schema_uri = Mock(__str__=lambda s: "at://did:plc:demo123456789/ac.foundation.dataset.sampleSchema/exp001") 134 + mock_schema_pub.publish.return_value = mock_schema_uri 135 + 136 + with patch("atdata.atmosphere.DatasetPublisher") as MockDatasetPublisher: 137 + mock_ds_pub = MockDatasetPublisher.return_value 138 + mock_ds_uri = Mock(__str__=lambda s: "at://did:plc:demo123456789/ac.foundation.dataset.datasetIndex/exp2024001") 139 + mock_ds_pub.publish_with_urls.return_value = mock_ds_uri 140 + 141 + # Perform the promotion 142 + result = promote_to_atmosphere( 143 + local_entry, 144 + mock_index, 145 + mock_client, 146 + tags=["experiment", "physics", "2024"], 147 + license="CC-BY-4.0", 148 + ) 149 + 150 + print(f"\nPromotion result:") 151 + print(f" AT URI: {result}") 152 + print(f"\nPublished:") 153 + print(f" Schema: at://did:plc:demo123456789/.../exp001") 154 + print(f" Dataset: at://did:plc:demo123456789/.../exp2024001") 155 + 156 + 157 + def demo_schema_deduplication(): 158 + """Demonstrate schema deduplication during promotion.""" 159 + print("\n" + "=" * 60) 160 + print("Schema Deduplication Demo") 161 + print("=" * 60) 162 + 163 + from atdata.promote import _find_existing_schema 164 + from unittest.mock import patch 165 + 166 + mock_client = Mock() 167 + 168 + # Scenario 1: Schema already exists 169 + print("\nScenario 1: Schema already exists on atmosphere") 170 + with patch("atdata.atmosphere.SchemaLoader") as MockLoader: 171 + mock_loader = MockLoader.return_value 172 + mock_loader.list_all.return_value = [ 173 + { 174 + "uri": "at://did:plc:abc/schema/existing", 175 + "value": { 176 + "name": "mymodule.MySample", 177 + "version": "1.0.0", 178 + } 179 + } 180 + ] 181 + 182 + result = _find_existing_schema(mock_client, "mymodule.MySample", "1.0.0") 183 + print(f" Looking for: mymodule.MySample@1.0.0") 184 + print(f" Found: {result}") 185 + print(f" Action: Reuse existing schema (no republish)") 186 + 187 + # Scenario 2: Different version 188 + print("\nScenario 2: Same name but different version") 189 + with patch("atdata.atmosphere.SchemaLoader") as MockLoader: 190 + mock_loader = MockLoader.return_value 191 + mock_loader.list_all.return_value = [ 192 + { 193 + "uri": "at://did:plc:abc/schema/v1", 194 + "value": { 195 + "name": "mymodule.MySample", 196 + "version": "1.0.0", # v1.0.0 exists 197 + } 198 + } 199 + ] 200 + 201 + result = _find_existing_schema(mock_client, "mymodule.MySample", "2.0.0") # Looking for v2.0.0 202 + print(f" Looking for: mymodule.MySample@2.0.0") 203 + print(f" Found: {result}") 204 + print(f" Action: Publish new schema record") 205 + 206 + 207 + def demo_data_migration_options(): 208 + """Explain data migration options during promotion.""" 209 + print("\n" + "=" * 60) 210 + print("Data Migration Options") 211 + print("=" * 60) 212 + 213 + print(""" 214 + When promoting, you can choose how to handle the data files: 215 + 216 + Option A: Keep existing URLs (default) 217 + ----------------------------------------- 218 + promote_to_atmosphere(entry, index, client) 219 + 220 + - Data stays in original S3 location 221 + - Dataset record points to existing URLs 222 + - Fastest option, no data copying 223 + - Requires original storage to remain accessible 224 + 225 + Option B: Copy to new S3 location 226 + ----------------------------------------- 227 + new_store = S3DataStore(creds, bucket='public-bucket') 228 + promote_to_atmosphere(entry, index, client, data_store=new_store) 229 + 230 + - Data is copied to new bucket 231 + - Dataset record points to new URLs 232 + - Good for moving from private to public storage 233 + 234 + Option C: Use ATProto blobs (future) 235 + ----------------------------------------- 236 + # Not yet implemented 237 + promote_to_atmosphere(entry, index, client, data_store='pds-blobs') 238 + 239 + - Data uploaded as ATProto blobs 240 + - Self-contained in the PDS 241 + - Size limits apply (ATProto blob limits) 242 + """) 243 + 244 + 245 + def demo_live_promotion(handle: str, password: str): 246 + """Demonstrate actual promotion to ATProto.""" 247 + print("\n" + "=" * 60) 248 + print("Live Promotion Demo") 249 + print("=" * 60) 250 + 251 + from atdata.atmosphere import AtmosphereClient 252 + from atdata.local import LocalDatasetEntry 253 + 254 + # Connect to atmosphere 255 + print(f"\nConnecting as {handle}...") 256 + client = AtmosphereClient() 257 + client.login(handle, password) 258 + print(f"Authenticated! DID: {client.did}") 259 + 260 + # Create a demo local entry (simulating a real local dataset) 261 + local_entry = LocalDatasetEntry( 262 + _name="demo-promoted-dataset", 263 + _schema_ref="local://schemas/__main__.ExperimentSample@1.0.0", 264 + _data_urls=["s3://example-bucket/demo-data-{000000..000004}.tar"], 265 + _metadata={"promoted_from": "local_demo", "demo": True}, 266 + ) 267 + 268 + # Create a mock local index with our schema 269 + mock_index = Mock() 270 + mock_index.get_schema.return_value = { 271 + "name": "__main__.ExperimentSample", 272 + "version": "1.0.0", 273 + "fields": [ 274 + {"name": "measurement", "fieldType": {"$type": "local#ndarray", "dtype": "float32"}, "optional": False}, 275 + {"name": "timestamp", "fieldType": {"$type": "local#primitive", "primitive": "float"}, "optional": False}, 276 + {"name": "sensor_id", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 277 + ], 278 + } 279 + 280 + print("\nPromoting dataset to atmosphere...") 281 + result = promote_to_atmosphere( 282 + local_entry, 283 + mock_index, 284 + client, 285 + tags=["demo", "atdata"], 286 + license="MIT", 287 + ) 288 + 289 + print(f"\nPromotion successful!") 290 + print(f" AT URI: {result}") 291 + print(f"\nYou can now discover this dataset via:") 292 + print(f" atdata.load_dataset('@{handle}/demo-promoted-dataset')") 293 + 294 + 295 + def demo_full_workflow(): 296 + """Show the complete local-to-atmosphere workflow.""" 297 + print("\n" + "=" * 60) 298 + print("Complete Workflow Example") 299 + print("=" * 60) 300 + 301 + print(""" 302 + Here's a complete example of the local-to-atmosphere workflow: 303 + 304 + import atdata 305 + from atdata.local import LocalIndex, Repo 306 + from atdata.atmosphere import AtmosphereClient 307 + from atdata.promote import promote_to_atmosphere 308 + 309 + # 1. Define your sample type 310 + @atdata.packable 311 + class MySample: 312 + features: NDArray 313 + label: str 314 + 315 + # 2. Create and index local dataset 316 + local_index = LocalIndex() # Connects to Redis 317 + repo = Repo(s3_creds, bucket='my-bucket', index=local_index) 318 + 319 + # Insert dataset (writes to S3, indexes in Redis) 320 + samples = [MySample(features=..., label=...) for ...] 321 + entry = repo.insert(samples, name='my-dataset') 322 + 323 + print(f"Local CID: {entry.cid}") 324 + print(f"Local URLs: {entry.data_urls}") 325 + 326 + # 3. When ready to share, promote to atmosphere 327 + client = AtmosphereClient() 328 + client.login('myhandle.bsky.social', 'app-password') 329 + 330 + at_uri = promote_to_atmosphere( 331 + entry, 332 + local_index, 333 + client, 334 + tags=['ml', 'vision'], 335 + license='MIT', 336 + ) 337 + 338 + print(f"Published at: {at_uri}") 339 + 340 + # 4. Others can now discover and load your dataset 341 + # ds = atdata.load_dataset('@myhandle.bsky.social/my-dataset') 342 + """) 343 + 344 + 345 + # ============================================================================= 346 + # Main 347 + # ============================================================================= 348 + 349 + def main(): 350 + parser = argparse.ArgumentParser( 351 + description="Demonstrate local to atmosphere promotion workflow", 352 + formatter_class=argparse.RawDescriptionHelpFormatter, 353 + epilog=__doc__, 354 + ) 355 + parser.add_argument( 356 + "--handle", 357 + help="Bluesky handle for live demo", 358 + ) 359 + parser.add_argument( 360 + "--password", 361 + help="App-specific password for live demo", 362 + ) 363 + 364 + args = parser.parse_args() 365 + 366 + print("=" * 60) 367 + print("atdata Promotion Workflow Demo") 368 + print("=" * 60) 369 + print(f"\nTime: {datetime.now().isoformat()}") 370 + 371 + # Always run these demos (no external services required) 372 + demo_promotion_concept() 373 + demo_mock_promotion() 374 + demo_schema_deduplication() 375 + demo_data_migration_options() 376 + demo_full_workflow() 377 + 378 + # Run live demo if credentials provided 379 + if args.handle and args.password: 380 + demo_live_promotion(args.handle, args.password) 381 + else: 382 + print("\n" + "=" * 60) 383 + print("Live Demo Skipped") 384 + print("=" * 60) 385 + print("\nTo run with actual ATProto connection:") 386 + print(" python promote_workflow.py --handle your.handle --password your-app-password") 387 + 388 + print("\n" + "=" * 60) 389 + print("Demo Complete!") 390 + print("=" * 60) 391 + 392 + 393 + if __name__ == "__main__": 394 + main()
+1
pyproject.toml
··· 10 10 dependencies = [ 11 11 "atproto>=0.0.65", 12 12 "fastparquet>=2024.11.0", 13 + "libipld>=3.3.2", 13 14 "msgpack>=1.1.2", 14 15 "numpy>=2.3.4", 15 16 "ormsgpack>=1.11.0",
+19
src/atdata/__init__.py
··· 56 56 DatasetDict, 57 57 ) 58 58 59 + from ._protocols import ( 60 + IndexEntry, 61 + AbstractIndex, 62 + AbstractDataStore, 63 + ) 64 + 65 + from ._schema_codec import ( 66 + schema_to_type, 67 + ) 68 + 69 + from ._cid import ( 70 + generate_cid, 71 + verify_cid, 72 + ) 73 + 74 + from .promote import ( 75 + promote_to_atmosphere, 76 + ) 77 + 59 78 # ATProto integration (lazy import to avoid requiring atproto package) 60 79 from . import atmosphere 61 80
+140
src/atdata/_cid.py
··· 1 + """CID (Content Identifier) utilities for atdata. 2 + 3 + This module provides utilities for generating ATProto-compatible CIDs from 4 + data. CIDs are content-addressable identifiers that can be used to uniquely 5 + identify schemas, datasets, and other records. 6 + 7 + The CIDs generated here use: 8 + - CIDv1 format 9 + - dag-cbor codec (0x71) 10 + - SHA-256 hash (0x12) 11 + 12 + This ensures compatibility with ATProto's CID requirements and enables 13 + seamless promotion from local storage to atmosphere (ATProto network). 14 + 15 + Example: 16 + >>> schema = {"name": "ImageSample", "version": "1.0.0", "fields": [...]} 17 + >>> cid = generate_cid(schema) 18 + >>> print(cid) 19 + bafyreihffx5a2e7k6r5zqgp5iwpjqr2gfyheqhzqtlxagvqjqyxzqpzqaa 20 + """ 21 + 22 + import hashlib 23 + from typing import Any 24 + 25 + import libipld 26 + 27 + 28 + # CID constants 29 + CID_VERSION_1 = 0x01 30 + CODEC_DAG_CBOR = 0x71 31 + HASH_SHA256 = 0x12 32 + SHA256_SIZE = 0x20 33 + 34 + 35 + def generate_cid(data: Any) -> str: 36 + """Generate an ATProto-compatible CID from arbitrary data. 37 + 38 + The data is first encoded as DAG-CBOR, then hashed with SHA-256, 39 + and finally formatted as a CIDv1 string (base32 multibase). 40 + 41 + Args: 42 + data: Any data structure that can be encoded as DAG-CBOR. 43 + This includes dicts, lists, strings, numbers, bytes, etc. 44 + 45 + Returns: 46 + CIDv1 string in base32 multibase format (starts with 'bafy'). 47 + 48 + Raises: 49 + ValueError: If the data cannot be encoded as DAG-CBOR. 50 + 51 + Example: 52 + >>> generate_cid({"name": "test", "value": 42}) 53 + 'bafyrei...' 54 + """ 55 + # Encode data as DAG-CBOR 56 + try: 57 + cbor_bytes = libipld.encode_dag_cbor(data) 58 + except Exception as e: 59 + raise ValueError(f"Failed to encode data as DAG-CBOR: {e}") from e 60 + 61 + # Hash with SHA-256 62 + sha256_hash = hashlib.sha256(cbor_bytes).digest() 63 + 64 + # Build raw CID bytes: 65 + # CIDv1 = version(1) + codec(dag-cbor) + multihash 66 + # Multihash = code(sha256) + size(32) + digest 67 + raw_cid_bytes = bytes([CID_VERSION_1, CODEC_DAG_CBOR, HASH_SHA256, SHA256_SIZE]) + sha256_hash 68 + 69 + # Encode to base32 multibase string 70 + return libipld.encode_cid(raw_cid_bytes) 71 + 72 + 73 + def generate_cid_from_bytes(data_bytes: bytes) -> str: 74 + """Generate a CID from raw bytes (already encoded data). 75 + 76 + Use this when you have pre-encoded data (e.g., DAG-CBOR bytes from 77 + another source) and want to generate its CID without re-encoding. 78 + 79 + Args: 80 + data_bytes: Raw bytes to hash (treated as opaque blob). 81 + 82 + Returns: 83 + CIDv1 string in base32 multibase format. 84 + 85 + Example: 86 + >>> cbor_bytes = libipld.encode_dag_cbor({"key": "value"}) 87 + >>> cid = generate_cid_from_bytes(cbor_bytes) 88 + """ 89 + sha256_hash = hashlib.sha256(data_bytes).digest() 90 + raw_cid_bytes = bytes([CID_VERSION_1, CODEC_DAG_CBOR, HASH_SHA256, SHA256_SIZE]) + sha256_hash 91 + return libipld.encode_cid(raw_cid_bytes) 92 + 93 + 94 + def verify_cid(cid: str, data: Any) -> bool: 95 + """Verify that a CID matches the given data. 96 + 97 + Args: 98 + cid: CID string to verify. 99 + data: Data that should correspond to the CID. 100 + 101 + Returns: 102 + True if the CID matches the data, False otherwise. 103 + 104 + Example: 105 + >>> cid = generate_cid({"name": "test"}) 106 + >>> verify_cid(cid, {"name": "test"}) 107 + True 108 + >>> verify_cid(cid, {"name": "different"}) 109 + False 110 + """ 111 + expected_cid = generate_cid(data) 112 + return cid == expected_cid 113 + 114 + 115 + def parse_cid(cid: str) -> dict: 116 + """Parse a CID string into its components. 117 + 118 + Args: 119 + cid: CID string to parse. 120 + 121 + Returns: 122 + Dictionary with 'version', 'codec', and 'hash' keys. 123 + The 'hash' value is itself a dict with 'code', 'size', and 'digest'. 124 + 125 + Example: 126 + >>> info = parse_cid('bafyrei...') 127 + >>> info['version'] 128 + 1 129 + >>> info['codec'] 130 + 113 # 0x71 = dag-cbor 131 + """ 132 + return libipld.decode_cid(cid) 133 + 134 + 135 + __all__ = [ 136 + "generate_cid", 137 + "generate_cid_from_bytes", 138 + "verify_cid", 139 + "parse_cid", 140 + ]
+192 -99
src/atdata/_hf_api.py
··· 31 31 import re 32 32 from pathlib import Path 33 33 from typing import ( 34 + TYPE_CHECKING, 34 35 Any, 35 36 Generic, 36 37 Iterator, 37 38 Mapping, 39 + Optional, 38 40 Type, 39 41 TypeVar, 40 42 Union, ··· 42 44 ) 43 45 44 46 from .dataset import Dataset, PackableSample 47 + 48 + if TYPE_CHECKING: 49 + from ._protocols import AbstractIndex 45 50 46 51 ## 47 52 # Type variables ··· 134 139 135 140 136 141 def _is_brace_pattern(path: str) -> bool: 137 - """Check if path contains WebDataset brace expansion notation. 138 - 139 - Examples: 140 - >>> _is_brace_pattern("data-{000000..000099}.tar") 141 - True 142 - >>> _is_brace_pattern("data-{train,test}.tar") 143 - True 144 - >>> _is_brace_pattern("data-000000.tar") 145 - False 146 - """ 142 + """Check if path contains WebDataset brace expansion notation like {000..099}.""" 147 143 return bool(re.search(r"\{[^}]+\}", path)) 148 144 149 145 150 146 def _is_glob_pattern(path: str) -> bool: 151 - """Check if path contains glob wildcards. 152 - 153 - Examples: 154 - >>> _is_glob_pattern("data-*.tar") 155 - True 156 - >>> _is_glob_pattern("data-000000.tar") 157 - False 158 - """ 147 + """Check if path contains glob wildcards (* or ?).""" 159 148 return "*" in path or "?" in path 160 149 161 150 162 151 def _is_remote_url(path: str) -> bool: 163 - """Check if path is a remote URL (s3, http, etc.). 164 - 165 - Examples: 166 - >>> _is_remote_url("s3://bucket/path") 167 - True 168 - >>> _is_remote_url("https://example.com/data.tar") 169 - True 170 - >>> _is_remote_url("/local/path/data.tar") 171 - False 172 - """ 152 + """Check if path is a remote URL (s3://, gs://, http://, https://, az://).""" 173 153 return path.startswith(("s3://", "gs://", "http://", "https://", "az://")) 174 154 175 155 176 156 def _expand_local_glob(pattern: str) -> list[str]: 177 - """Expand a local glob pattern to list of paths. 178 - 179 - Args: 180 - pattern: Glob pattern like "path/to/*.tar" 181 - 182 - Returns: 183 - Sorted list of matching file paths. 184 - """ 157 + """Expand local glob pattern to sorted list of matching file paths.""" 185 158 base_path = Path(pattern).parent 186 159 glob_part = Path(pattern).name 187 160 ··· 192 165 return [str(p) for p in matches if p.is_file()] 193 166 194 167 195 - # Common split name patterns in filenames 196 - _SPLIT_PATTERNS = [ 168 + # Pre-compiled split name patterns (pattern, split_name) 169 + _SPLIT_PATTERNS: list[tuple[re.Pattern[str], str]] = [ 197 170 # Patterns like "dataset-train-000000.tar" (split in middle with delimiters) 198 - (r"[_-](train|training)[_-]", "train"), 199 - (r"[_-](test|testing)[_-]", "test"), 200 - (r"[_-](val|valid|validation)[_-]", "validation"), 201 - (r"[_-](dev|development)[_-]", "validation"), 171 + (re.compile(r"[_-](train|training)[_-]"), "train"), 172 + (re.compile(r"[_-](test|testing)[_-]"), "test"), 173 + (re.compile(r"[_-](val|valid|validation)[_-]"), "validation"), 174 + (re.compile(r"[_-](dev|development)[_-]"), "validation"), 202 175 # Patterns at start of filename like "train-000.tar" or "test_data.tar" 203 - (r"^(train|training)[_-]", "train"), 204 - (r"^(test|testing)[_-]", "test"), 205 - (r"^(val|valid|validation)[_-]", "validation"), 206 - (r"^(dev|development)[_-]", "validation"), 176 + (re.compile(r"^(train|training)[_-]"), "train"), 177 + (re.compile(r"^(test|testing)[_-]"), "test"), 178 + (re.compile(r"^(val|valid|validation)[_-]"), "validation"), 179 + (re.compile(r"^(dev|development)[_-]"), "validation"), 207 180 # Patterns in directory path like "/path/train/shard-000.tar" 208 - (r"[/\\](train|training)[/\\]", "train"), 209 - (r"[/\\](test|testing)[/\\]", "test"), 210 - (r"[/\\](val|valid|validation)[/\\]", "validation"), 211 - (r"[/\\](dev|development)[/\\]", "validation"), 181 + (re.compile(r"[/\\](train|training)[/\\]"), "train"), 182 + (re.compile(r"[/\\](test|testing)[/\\]"), "test"), 183 + (re.compile(r"[/\\](val|valid|validation)[/\\]"), "validation"), 184 + (re.compile(r"[/\\](dev|development)[/\\]"), "validation"), 212 185 # Patterns at start of path like "train/shard-000.tar" 213 - (r"^(train|training)[/\\]", "train"), 214 - (r"^(test|testing)[/\\]", "test"), 215 - (r"^(val|valid|validation)[/\\]", "validation"), 216 - (r"^(dev|development)[/\\]", "validation"), 186 + (re.compile(r"^(train|training)[/\\]"), "train"), 187 + (re.compile(r"^(test|testing)[/\\]"), "test"), 188 + (re.compile(r"^(val|valid|validation)[/\\]"), "validation"), 189 + (re.compile(r"^(dev|development)[/\\]"), "validation"), 217 190 ] 218 191 219 192 220 193 def _detect_split_from_path(path: str) -> str | None: 221 - """Attempt to detect split name from a file path. 222 - 223 - Args: 224 - path: File path to analyze. 225 - 226 - Returns: 227 - Detected split name ("train", "test", "validation") or None. 228 - """ 229 - # Extract just the filename for pattern matching on full paths 194 + """Detect split name (train/test/validation) from file path.""" 230 195 filename = Path(path).name 231 196 path_lower = path.lower() 232 197 filename_lower = filename.lower() 233 198 234 199 # Check filename first (more specific) 235 200 for pattern, split_name in _SPLIT_PATTERNS: 236 - if re.search(pattern, filename_lower): 201 + if pattern.search(filename_lower): 237 202 return split_name 238 203 239 - # Fall back to full path (catches directory patterns like "train/...") 204 + # Fall back to full path (catches directory patterns) 240 205 for pattern, split_name in _SPLIT_PATTERNS: 241 - if re.search(pattern, path_lower): 206 + if pattern.search(path_lower): 242 207 return split_name 243 208 244 209 return None ··· 356 321 >>> _shards_to_wds_url(["train.tar"]) 357 322 "train.tar" 358 323 """ 324 + import os.path 325 + 359 326 if len(shards) == 0: 360 327 raise ValueError("Cannot create URL from empty shard list") 361 328 362 329 if len(shards) == 1: 363 330 return shards[0] 364 331 365 - # Find common prefix across ALL shards 366 - prefix = shards[0] 367 - for s in shards[1:]: 368 - # Shorten prefix until it matches 369 - while not s.startswith(prefix) and prefix: 370 - prefix = prefix[:-1] 332 + # Find common prefix using os.path.commonprefix (O(n) vs O(n²)) 333 + prefix = os.path.commonprefix(shards) 371 334 372 - # Find common suffix across ALL shards 373 - suffix = shards[0] 374 - for s in shards[1:]: 375 - # Shorten suffix until it matches 376 - while not s.endswith(suffix) and suffix: 377 - suffix = suffix[1:] 335 + # Find common suffix by reversing strings 336 + reversed_shards = [s[::-1] for s in shards] 337 + suffix = os.path.commonprefix(reversed_shards)[::-1] 378 338 379 339 prefix_len = len(prefix) 380 340 suffix_len = len(suffix) ··· 427 387 428 388 429 389 ## 390 + # Index-based path resolution 391 + 392 + 393 + def _is_indexed_path(path: str) -> bool: 394 + """Check if path uses @handle/dataset notation for index lookup. 395 + 396 + Examples: 397 + >>> _is_indexed_path("@maxine.science/mnist") 398 + True 399 + >>> _is_indexed_path("@did:plc:abc123/my-dataset") 400 + True 401 + >>> _is_indexed_path("s3://bucket/data.tar") 402 + False 403 + """ 404 + return path.startswith("@") 405 + 406 + 407 + def _parse_indexed_path(path: str) -> tuple[str, str]: 408 + """Parse @handle/dataset path into (handle_or_did, dataset_name). 409 + 410 + Args: 411 + path: Path in format "@handle/dataset" or "@did:plc:xxx/dataset" 412 + 413 + Returns: 414 + Tuple of (handle_or_did, dataset_name) 415 + 416 + Raises: 417 + ValueError: If path format is invalid. 418 + """ 419 + if not path.startswith("@"): 420 + raise ValueError(f"Not an indexed path: {path}") 421 + 422 + # Remove leading @ 423 + rest = path[1:] 424 + 425 + # Split on first / (handle can contain . but dataset name is after /) 426 + if "/" not in rest: 427 + raise ValueError( 428 + f"Invalid indexed path format: {path}. " 429 + "Expected @handle/dataset or @did:plc:xxx/dataset" 430 + ) 431 + 432 + # Find the split point - for DIDs, the format is did:plc:xxx/dataset 433 + # For handles, it's handle.domain/dataset 434 + parts = rest.split("/", 1) 435 + if len(parts) != 2 or not parts[0] or not parts[1]: 436 + raise ValueError(f"Invalid indexed path: {path}") 437 + 438 + return parts[0], parts[1] 439 + 440 + 441 + def _resolve_indexed_path( 442 + path: str, 443 + index: "AbstractIndex", 444 + ) -> tuple[list[str], str]: 445 + """Resolve @handle/dataset path to URLs and schema_ref via index lookup. 446 + 447 + Args: 448 + path: Path in @handle/dataset format. 449 + index: Index to use for lookup. 450 + 451 + Returns: 452 + Tuple of (data_urls, schema_ref). 453 + 454 + Raises: 455 + KeyError: If dataset not found in index. 456 + """ 457 + handle_or_did, dataset_name = _parse_indexed_path(path) 458 + 459 + # For AtmosphereIndex, we need to resolve handle to DID first 460 + # For LocalIndex, the handle is ignored and we just look up by name 461 + entry = index.get_dataset(dataset_name) 462 + 463 + return entry.data_urls, entry.schema_ref 464 + 465 + 466 + ## 430 467 # Main load_dataset function 431 468 432 469 ··· 438 475 split: str, 439 476 data_files: str | list[str] | dict[str, str | list[str]] | None = None, 440 477 streaming: bool = False, 478 + index: Optional["AbstractIndex"] = None, 441 479 ) -> Dataset[ST]: ... 442 480 443 481 ··· 449 487 split: None = None, 450 488 data_files: str | list[str] | dict[str, str | list[str]] | None = None, 451 489 streaming: bool = False, 490 + index: Optional["AbstractIndex"] = None, 452 491 ) -> DatasetDict[ST]: ... 492 + 493 + 494 + @overload 495 + def load_dataset( 496 + path: str, 497 + sample_type: None = None, 498 + *, 499 + split: str, 500 + data_files: str | list[str] | dict[str, str | list[str]] | None = None, 501 + streaming: bool = False, 502 + index: "AbstractIndex", 503 + ) -> Dataset[PackableSample]: ... 504 + 505 + 506 + @overload 507 + def load_dataset( 508 + path: str, 509 + sample_type: None = None, 510 + *, 511 + split: None = None, 512 + data_files: str | list[str] | dict[str, str | list[str]] | None = None, 513 + streaming: bool = False, 514 + index: "AbstractIndex", 515 + ) -> DatasetDict[PackableSample]: ... 453 516 454 517 455 518 def load_dataset( 456 519 path: str, 457 - sample_type: Type[ST], 520 + sample_type: Type[ST] | None = None, 458 521 *, 459 522 split: str | None = None, 460 523 data_files: str | list[str] | dict[str, str | list[str]] | None = None, 461 524 streaming: bool = False, 525 + index: Optional["AbstractIndex"] = None, 462 526 ) -> Dataset[ST] | DatasetDict[ST]: 463 - """Load a dataset from local files or remote URLs. 527 + """Load a dataset from local files, remote URLs, or an index. 464 528 465 529 This function provides a HuggingFace Datasets-style interface for loading 466 530 atdata typed datasets. It handles path resolution, split detection, and ··· 469 533 470 534 Args: 471 535 path: Path to dataset. Can be: 536 + - Index lookup: "@handle/dataset-name" or "@local/dataset-name" 472 537 - WebDataset brace notation: "path/to/{train,test}-{000..099}.tar" 473 538 - Local directory: "./data/" (scans for .tar files) 474 539 - Glob pattern: "path/to/*.tar" 475 540 - Remote URL: "s3://bucket/path/data-*.tar" 476 541 - Single file: "path/to/data.tar" 477 542 478 - sample_type: The PackableSample subclass defining the schema for 479 - samples in this dataset. This is required (unlike HF Datasets) 480 - because atdata uses typed dataclasses. 543 + sample_type: The PackableSample subclass defining the schema. Can be 544 + None if index is provided - the type will be resolved from the 545 + schema stored in the index. 481 546 482 547 split: Which split to load. If None, returns a DatasetDict with all 483 548 detected splits. If specified (e.g., "train", "test"), returns ··· 490 555 491 556 streaming: If True, explicitly marks the dataset for streaming mode. 492 557 Note: atdata Datasets are already lazy/streaming via WebDataset 493 - pipelines, so this parameter primarily signals intent. When True, 494 - shard list precomputation is skipped. Default False. 558 + pipelines, so this parameter primarily signals intent. 559 + 560 + index: Optional AbstractIndex for dataset lookup. Required when using 561 + @handle/dataset syntax or when sample_type is None. Can be a 562 + LocalIndex or AtmosphereIndex. 495 563 496 564 Returns: 497 565 If split is None: DatasetDict[ST] with all detected splits. 498 566 If split is specified: Dataset[ST] for that split. 499 567 500 568 Raises: 501 - ValueError: If the specified split is not found. 569 + ValueError: If the specified split is not found, or if sample_type 570 + is None without an index. 502 571 FileNotFoundError: If no data files are found at the path. 572 + KeyError: If dataset not found in index. 503 573 504 574 Example: 505 - >>> @atdata.packable 506 - ... class TextData: 507 - ... text: str 508 - ... label: int 575 + >>> # Load from local path with explicit type 576 + >>> train_ds = load_dataset("./data/train-*.tar", TextData, split="train") 509 577 >>> 510 - >>> # Load single split 511 - >>> train_ds = load_dataset("./data/train-*.tar", TextData, split="train") 578 + >>> # Load from index with auto-type resolution 579 + >>> index = LocalIndex() 580 + >>> ds = load_dataset("@local/my-dataset", index=index, split="train") 512 581 >>> 513 582 >>> # Load all splits 514 583 >>> ds_dict = load_dataset("./data/", TextData) 515 584 >>> train_ds = ds_dict["train"] 516 - >>> test_ds = ds_dict["test"] 517 - >>> 518 - >>> # Explicit data files 519 - >>> ds_dict = load_dataset("./data/", TextData, data_files={ 520 - ... "train": "train-*.tar", 521 - ... "test": "test-*.tar", 522 - ... }) 523 585 """ 586 + # Handle @handle/dataset indexed path resolution 587 + if _is_indexed_path(path): 588 + if index is None: 589 + raise ValueError( 590 + f"Index required for indexed path: {path}. " 591 + "Pass index=LocalIndex() or index=AtmosphereIndex(client)." 592 + ) 593 + 594 + data_urls, schema_ref = _resolve_indexed_path(path, index) 595 + 596 + # Resolve sample_type from schema if not provided 597 + if sample_type is None: 598 + sample_type = index.decode_schema(schema_ref) 599 + 600 + # For indexed datasets, we treat all URLs as a single "train" split 601 + url = _shards_to_wds_url(data_urls) 602 + ds = Dataset[sample_type](url) 603 + 604 + if split is not None: 605 + # Indexed datasets are single-split by default 606 + return ds 607 + 608 + return DatasetDict({"train": ds}, sample_type=sample_type, streaming=streaming) 609 + 610 + # Validate sample_type for non-indexed paths 611 + if sample_type is None: 612 + raise ValueError( 613 + "sample_type is required for non-indexed paths. " 614 + "Use @handle/dataset with an index for auto-type resolution." 615 + ) 616 + 524 617 # Resolve path to split -> shard URL mapping 525 618 splits_shards = _resolve_shards(path, data_files) 526 619
+319
src/atdata/_protocols.py
··· 1 + """Protocol definitions for atdata index and storage abstractions. 2 + 3 + This module defines the abstract protocols that enable interchangeable 4 + index backends (local Redis vs ATProto PDS) and data stores (S3 vs PDS blobs). 5 + 6 + The key insight is that both local and atmosphere implementations solve the 7 + same problem: indexed dataset storage with external data URLs. These protocols 8 + formalize that common interface. 9 + 10 + Note: 11 + Protocol methods use ``...`` (Ellipsis) as the body per PEP 544. This is 12 + the standard Python syntax for Protocol definitions - these are interface 13 + specifications, not stub implementations. Concrete classes (LocalIndex, 14 + AtmosphereIndex, etc.) provide the actual implementations. 15 + 16 + Protocols: 17 + IndexEntry: Common interface for dataset index entries 18 + AbstractIndex: Protocol for index operations (schemas, datasets, lenses) 19 + AbstractDataStore: Protocol for data storage operations 20 + 21 + Example: 22 + >>> def process_datasets(index: AbstractIndex) -> None: 23 + ... for entry in index.list_datasets(): 24 + ... print(f"{entry.name}: {entry.data_urls}") 25 + ... 26 + >>> # Works with either LocalIndex or AtmosphereIndex 27 + >>> process_datasets(local_index) 28 + >>> process_datasets(atmosphere_index) 29 + """ 30 + 31 + from typing import ( 32 + Iterator, 33 + Optional, 34 + Protocol, 35 + Type, 36 + TYPE_CHECKING, 37 + runtime_checkable, 38 + ) 39 + 40 + if TYPE_CHECKING: 41 + from .dataset import PackableSample, Dataset 42 + 43 + 44 + ## 45 + # IndexEntry Protocol 46 + 47 + 48 + @runtime_checkable 49 + class IndexEntry(Protocol): 50 + """Common interface for index entries (local or atmosphere). 51 + 52 + Both LocalDatasetEntry and atmosphere DatasetRecord-based entries 53 + should satisfy this protocol, enabling code that works with either. 54 + 55 + Properties: 56 + name: Human-readable dataset name 57 + schema_ref: Reference to schema (local:// path or AT URI) 58 + data_urls: WebDataset URLs for the data 59 + metadata: Arbitrary metadata dict, or None 60 + """ 61 + 62 + @property 63 + def name(self) -> str: 64 + """Human-readable dataset name.""" 65 + ... 66 + 67 + @property 68 + def schema_ref(self) -> str: 69 + """Reference to the schema for this dataset. 70 + 71 + For local: 'local://schemas/{module.Class}@{version}' 72 + For atmosphere: 'at://did:plc:.../ac.foundation.dataset.sampleSchema/...' 73 + """ 74 + ... 75 + 76 + @property 77 + def data_urls(self) -> list[str]: 78 + """WebDataset URLs for the data. 79 + 80 + These are the URLs that can be passed to atdata.Dataset() or 81 + used with WebDataset directly. May use brace notation for shards. 82 + """ 83 + ... 84 + 85 + @property 86 + def metadata(self) -> Optional[dict]: 87 + """Arbitrary metadata dictionary, or None if not set.""" 88 + ... 89 + 90 + 91 + ## 92 + # AbstractIndex Protocol 93 + 94 + 95 + class AbstractIndex(Protocol): 96 + """Protocol for index operations - implemented by LocalIndex and AtmosphereIndex. 97 + 98 + This protocol defines the common interface for managing dataset metadata: 99 + - Publishing and retrieving schemas 100 + - Inserting and listing datasets 101 + - (Future) Publishing and retrieving lenses 102 + 103 + A single index can hold datasets of many different sample types. The sample 104 + type is tracked via schema references, not as a generic parameter on the index. 105 + 106 + Example: 107 + >>> def publish_and_list(index: AbstractIndex) -> None: 108 + ... # Publish schemas for different types 109 + ... schema1 = index.publish_schema(ImageSample, version="1.0.0") 110 + ... schema2 = index.publish_schema(TextSample, version="1.0.0") 111 + ... 112 + ... # Insert datasets of different types 113 + ... index.insert_dataset(image_ds, name="images") 114 + ... index.insert_dataset(text_ds, name="texts") 115 + ... 116 + ... # List all datasets (mixed types) 117 + ... for entry in index.list_datasets(): 118 + ... print(f"{entry.name} -> {entry.schema_ref}") 119 + """ 120 + 121 + # Dataset operations 122 + 123 + def insert_dataset( 124 + self, 125 + ds: "Dataset", 126 + *, 127 + name: str, 128 + schema_ref: Optional[str] = None, 129 + **kwargs, 130 + ) -> IndexEntry: 131 + """Insert a dataset into the index. 132 + 133 + The sample type is inferred from ``ds.sample_type``. If schema_ref is not 134 + provided, the schema may be auto-published based on the sample type. 135 + 136 + Args: 137 + ds: The Dataset to register in the index (any sample type). 138 + name: Human-readable name for the dataset. 139 + schema_ref: Optional explicit schema reference. If not provided, 140 + the schema may be auto-published or inferred from ds.sample_type. 141 + **kwargs: Additional backend-specific options. 142 + 143 + Returns: 144 + IndexEntry for the inserted dataset. 145 + """ 146 + ... 147 + 148 + def get_dataset(self, ref: str) -> IndexEntry: 149 + """Get a dataset entry by name or reference. 150 + 151 + Args: 152 + ref: Dataset name, path, or full reference string. 153 + 154 + Returns: 155 + IndexEntry for the dataset. 156 + 157 + Raises: 158 + KeyError: If dataset not found. 159 + """ 160 + ... 161 + 162 + def list_datasets(self) -> Iterator[IndexEntry]: 163 + """List all dataset entries in this index. 164 + 165 + Yields: 166 + IndexEntry for each dataset (may be of different sample types). 167 + """ 168 + ... 169 + 170 + # Schema operations 171 + 172 + def publish_schema( 173 + self, 174 + sample_type: "Type[PackableSample]", 175 + *, 176 + version: str = "1.0.0", 177 + **kwargs, 178 + ) -> str: 179 + """Publish a schema for a sample type. 180 + 181 + Args: 182 + sample_type: The PackableSample subclass to publish. 183 + version: Semantic version string for the schema. 184 + **kwargs: Additional backend-specific options. 185 + 186 + Returns: 187 + Schema reference string: 188 + - Local: 'local://schemas/{module.Class}@{version}' 189 + - Atmosphere: 'at://did:plc:.../ac.foundation.dataset.sampleSchema/...' 190 + """ 191 + ... 192 + 193 + def get_schema(self, ref: str) -> dict: 194 + """Get a schema record by reference. 195 + 196 + Args: 197 + ref: Schema reference string (local:// or at://). 198 + 199 + Returns: 200 + Schema record as a dictionary with fields like 'name', 'version', 201 + 'fields', etc. 202 + 203 + Raises: 204 + KeyError: If schema not found. 205 + """ 206 + ... 207 + 208 + def list_schemas(self) -> Iterator[dict]: 209 + """List all schema records in this index. 210 + 211 + Yields: 212 + Schema records as dictionaries. 213 + """ 214 + ... 215 + 216 + def decode_schema(self, ref: str) -> "Type[PackableSample]": 217 + """Reconstruct a Python PackableSample type from a stored schema. 218 + 219 + This method enables loading datasets without knowing the sample type 220 + ahead of time. The index retrieves the schema record and dynamically 221 + generates a PackableSample subclass matching the schema definition. 222 + 223 + Args: 224 + ref: Schema reference string (local:// or at://). 225 + 226 + Returns: 227 + A dynamically generated PackableSample subclass with fields 228 + matching the schema definition. The class can be used with 229 + ``Dataset[T]`` to load and iterate over samples. 230 + 231 + Raises: 232 + KeyError: If schema not found. 233 + ValueError: If schema cannot be decoded (unsupported field types). 234 + 235 + Example: 236 + >>> entry = index.get_dataset("my-dataset") 237 + >>> SampleType = index.decode_schema(entry.schema_ref) 238 + >>> ds = Dataset[SampleType](entry.data_urls[0]) 239 + >>> for sample in ds.ordered(): 240 + ... print(sample) # sample is instance of SampleType 241 + """ 242 + ... 243 + 244 + 245 + ## 246 + # AbstractDataStore Protocol 247 + 248 + 249 + class AbstractDataStore(Protocol): 250 + """Protocol for data storage operations. 251 + 252 + This protocol abstracts over different storage backends for dataset data: 253 + - S3DataStore: S3-compatible object storage 254 + - PDSBlobStore: ATProto PDS blob storage (future) 255 + 256 + The separation of index (metadata) from data store (actual files) allows 257 + flexible deployment: local index with S3 storage, atmosphere index with 258 + S3 storage, or atmosphere index with PDS blobs. 259 + 260 + Example: 261 + >>> store = S3DataStore(credentials, bucket="my-bucket") 262 + >>> urls = store.write_shards(dataset, prefix="training/v1") 263 + >>> print(urls) 264 + ['s3://my-bucket/training/v1/shard-000000.tar', ...] 265 + """ 266 + 267 + def write_shards( 268 + self, 269 + ds: "Dataset", 270 + *, 271 + prefix: str, 272 + **kwargs, 273 + ) -> list[str]: 274 + """Write dataset shards to storage. 275 + 276 + Args: 277 + ds: The Dataset to write. 278 + prefix: Path prefix for the shards (e.g., 'datasets/mnist/v1'). 279 + **kwargs: Backend-specific options (e.g., maxcount for shard size). 280 + 281 + Returns: 282 + List of URLs for the written shards, suitable for use with 283 + WebDataset or atdata.Dataset(). 284 + """ 285 + ... 286 + 287 + def read_url(self, url: str) -> str: 288 + """Resolve a storage URL for reading. 289 + 290 + Some storage backends may need to transform URLs (e.g., signing S3 URLs 291 + or resolving blob references). This method returns a URL that can be 292 + used directly with WebDataset. 293 + 294 + Args: 295 + url: Storage URL to resolve. 296 + 297 + Returns: 298 + WebDataset-compatible URL for reading. 299 + """ 300 + ... 301 + 302 + def supports_streaming(self) -> bool: 303 + """Whether this store supports streaming reads. 304 + 305 + Returns: 306 + True if the store supports efficient streaming (like S3), 307 + False if data must be fully downloaded first. 308 + """ 309 + ... 310 + 311 + 312 + ## 313 + # Module exports 314 + 315 + __all__ = [ 316 + "IndexEntry", 317 + "AbstractIndex", 318 + "AbstractDataStore", 319 + ]
+237
src/atdata/_schema_codec.py
··· 1 + """Schema codec for dynamic PackableSample type generation. 2 + 3 + This module provides functionality to reconstruct Python PackableSample types 4 + from schema records. This enables loading datasets without knowing the sample 5 + type ahead of time - the type can be dynamically generated from stored schema 6 + metadata. 7 + 8 + The schema format follows the ATProto record structure defined in 9 + ``atmosphere/_types.py``, with field types supporting primitives, ndarrays, 10 + arrays, and schema references. 11 + 12 + Example: 13 + >>> schema = { 14 + ... "name": "ImageSample", 15 + ... "version": "1.0.0", 16 + ... "fields": [ 17 + ... {"name": "image", "fieldType": {"$type": "...#ndarray", "dtype": "float32"}, "optional": False}, 18 + ... {"name": "label", "fieldType": {"$type": "...#primitive", "primitive": "str"}, "optional": False}, 19 + ... ] 20 + ... } 21 + >>> ImageSample = schema_to_type(schema) 22 + >>> sample = ImageSample(image=np.zeros((64, 64)), label="cat") 23 + """ 24 + 25 + from dataclasses import dataclass, field, make_dataclass 26 + from typing import Any, Optional, Type, Union, get_origin 27 + import hashlib 28 + 29 + from numpy.typing import NDArray 30 + 31 + # Import PackableSample for inheritance 32 + from .dataset import PackableSample 33 + 34 + 35 + # Type cache to avoid regenerating identical types 36 + _type_cache: dict[str, Type[PackableSample]] = {} 37 + 38 + 39 + def _schema_cache_key(schema: dict) -> str: 40 + """Generate a cache key for a schema. 41 + 42 + Uses name + version + field signature to identify unique schemas. 43 + """ 44 + name = schema.get("name", "Unknown") 45 + version = schema.get("version", "0.0.0") 46 + fields = schema.get("fields", []) 47 + 48 + # Create a stable string representation of fields 49 + field_sig = ";".join( 50 + f"{f['name']}:{f['fieldType'].get('$type', '')}:{f.get('optional', False)}" 51 + for f in fields 52 + ) 53 + 54 + # Hash for compactness 55 + sig_hash = hashlib.md5(field_sig.encode()).hexdigest()[:8] 56 + return f"{name}@{version}#{sig_hash}" 57 + 58 + 59 + def _field_type_to_python(field_type: dict, optional: bool = False) -> Any: 60 + """Convert a schema field type to a Python type annotation. 61 + 62 + Args: 63 + field_type: Field type dict with '$type' and type-specific fields. 64 + optional: Whether this field is optional (can be None). 65 + 66 + Returns: 67 + Python type annotation suitable for dataclass field. 68 + 69 + Raises: 70 + ValueError: If field type is not supported. 71 + """ 72 + type_str = field_type.get("$type", "") 73 + 74 + # Extract kind from $type (e.g., "ac.foundation.dataset.schemaType#primitive" -> "primitive") 75 + if "#" in type_str: 76 + kind = type_str.split("#")[-1] 77 + else: 78 + # Fallback for simplified format 79 + kind = field_type.get("kind", "") 80 + 81 + python_type: Any 82 + 83 + if kind == "primitive": 84 + primitive = field_type.get("primitive", "str") 85 + primitive_map = { 86 + "str": str, 87 + "int": int, 88 + "float": float, 89 + "bool": bool, 90 + "bytes": bytes, 91 + } 92 + python_type = primitive_map.get(primitive) 93 + if python_type is None: 94 + raise ValueError(f"Unknown primitive type: {primitive}") 95 + 96 + elif kind == "ndarray": 97 + # NDArray type - dtype info is available but we use generic NDArray 98 + # The dtype is handled at runtime by PackableSample serialization 99 + python_type = NDArray 100 + 101 + elif kind == "array": 102 + # List type - recursively resolve item type 103 + items = field_type.get("items") 104 + if items: 105 + item_type = _field_type_to_python(items, optional=False) 106 + python_type = list[item_type] 107 + else: 108 + python_type = list 109 + 110 + elif kind == "ref": 111 + # Reference to another schema - not yet supported for dynamic generation 112 + raise ValueError( 113 + f"Schema references ('ref') are not yet supported for dynamic type generation. " 114 + f"Referenced schema: {field_type.get('ref')}" 115 + ) 116 + 117 + else: 118 + raise ValueError(f"Unknown field type kind: {kind}") 119 + 120 + # Wrap in Optional if needed 121 + if optional: 122 + python_type = Optional[python_type] 123 + 124 + return python_type 125 + 126 + 127 + def schema_to_type( 128 + schema: dict, 129 + *, 130 + use_cache: bool = True, 131 + ) -> Type[PackableSample]: 132 + """Generate a PackableSample subclass from a schema record. 133 + 134 + This function dynamically creates a dataclass that inherits from PackableSample, 135 + with fields matching the schema definition. The generated class can be used 136 + with ``Dataset[T]`` to load and process samples. 137 + 138 + Args: 139 + schema: Schema record dict with 'name', 'version', 'fields', etc. 140 + Fields should have 'name', 'fieldType', and 'optional' keys. 141 + use_cache: If True, cache and reuse generated types for identical schemas. 142 + Defaults to True. 143 + 144 + Returns: 145 + A dynamically generated PackableSample subclass. 146 + 147 + Raises: 148 + ValueError: If schema is malformed or contains unsupported types. 149 + 150 + Example: 151 + >>> schema = index.get_schema("local://schemas/MySample@1.0.0") 152 + >>> MySample = schema_to_type(schema) 153 + >>> ds = Dataset[MySample]("data.tar") 154 + >>> for sample in ds.ordered(): 155 + ... print(sample) 156 + """ 157 + # Check cache first 158 + if use_cache: 159 + cache_key = _schema_cache_key(schema) 160 + if cache_key in _type_cache: 161 + return _type_cache[cache_key] 162 + 163 + # Extract schema metadata 164 + name = schema.get("name") 165 + if not name: 166 + raise ValueError("Schema must have a 'name' field") 167 + 168 + version = schema.get("version", "1.0.0") 169 + fields_data = schema.get("fields", []) 170 + 171 + if not fields_data: 172 + raise ValueError("Schema must have at least one field") 173 + 174 + # Build field definitions for make_dataclass 175 + # Format: (name, type) or (name, type, field()) 176 + dataclass_fields: list[tuple[str, Any] | tuple[str, Any, Any]] = [] 177 + 178 + for field_def in fields_data: 179 + field_name = field_def.get("name") 180 + if not field_name: 181 + raise ValueError("Each field must have a 'name'") 182 + 183 + field_type_dict = field_def.get("fieldType", {}) 184 + is_optional = field_def.get("optional", False) 185 + 186 + # Convert to Python type 187 + python_type = _field_type_to_python(field_type_dict, optional=is_optional) 188 + 189 + # Optional fields need a default value of None 190 + if is_optional: 191 + dataclass_fields.append((field_name, python_type, field(default=None))) 192 + else: 193 + dataclass_fields.append((field_name, python_type)) 194 + 195 + # Create the dataclass dynamically 196 + # We need to make it inherit from PackableSample and call __post_init__ 197 + generated_class = make_dataclass( 198 + name, 199 + dataclass_fields, 200 + bases=(PackableSample,), 201 + namespace={ 202 + "__post_init__": lambda self: PackableSample.__post_init__(self), 203 + "__schema_version__": version, 204 + "__schema_ref__": schema.get("$ref", None), # Store original ref if available 205 + }, 206 + ) 207 + 208 + # Cache the generated type 209 + if use_cache: 210 + cache_key = _schema_cache_key(schema) 211 + _type_cache[cache_key] = generated_class 212 + 213 + return generated_class 214 + 215 + 216 + def clear_type_cache() -> None: 217 + """Clear the cached generated types. 218 + 219 + Useful for testing or when schema definitions change. 220 + """ 221 + _type_cache.clear() 222 + 223 + 224 + def get_cached_types() -> dict[str, Type[PackableSample]]: 225 + """Get a copy of the current type cache. 226 + 227 + Returns: 228 + Dictionary mapping cache keys to generated types. 229 + """ 230 + return dict(_type_cache) 231 + 232 + 233 + __all__ = [ 234 + "schema_to_type", 235 + "clear_type_cache", 236 + "get_cached_types", 237 + ]
+36
src/atdata/_type_utils.py
··· 1 + """Shared type conversion utilities for schema handling. 2 + 3 + This module provides common type mapping functions used by both local.py 4 + and atmosphere/schema.py to avoid code duplication. 5 + """ 6 + 7 + from typing import Any 8 + 9 + # Mapping from numpy dtype strings to schema dtype names 10 + NUMPY_DTYPE_MAP = { 11 + "float16": "float16", "float32": "float32", "float64": "float64", 12 + "int8": "int8", "int16": "int16", "int32": "int32", "int64": "int64", 13 + "uint8": "uint8", "uint16": "uint16", "uint32": "uint32", "uint64": "uint64", 14 + "bool": "bool", "complex64": "complex64", "complex128": "complex128", 15 + } 16 + 17 + # Mapping from Python primitive types to schema type names 18 + PRIMITIVE_TYPE_MAP = { 19 + str: "str", int: "int", float: "float", bool: "bool", bytes: "bytes", 20 + } 21 + 22 + 23 + def numpy_dtype_to_string(dtype: Any) -> str: 24 + """Convert a numpy dtype annotation to a schema dtype string. 25 + 26 + Args: 27 + dtype: A numpy dtype or type annotation containing dtype info. 28 + 29 + Returns: 30 + Schema dtype string (e.g., "float32", "int64"). Defaults to "float32". 31 + """ 32 + dtype_str = str(dtype) 33 + for key, value in NUMPY_DTYPE_MAP.items(): 34 + if key in dtype_str: 35 + return value 36 + return "float32"
+216
src/atdata/atmosphere/__init__.py
··· 30 30 pip install atproto 31 31 """ 32 32 33 + from typing import Iterator, Optional, Type, TYPE_CHECKING 34 + 33 35 from .client import AtmosphereClient 34 36 from .schema import SchemaPublisher, SchemaLoader 35 37 from .records import DatasetPublisher, DatasetLoader ··· 41 43 LensRecord, 42 44 ) 43 45 46 + if TYPE_CHECKING: 47 + from ..dataset import PackableSample, Dataset 48 + 49 + 50 + class AtmosphereIndexEntry: 51 + """Entry wrapper for ATProto dataset records implementing IndexEntry protocol. 52 + 53 + Attributes: 54 + _uri: AT URI of the record. 55 + _record: Raw record dictionary. 56 + """ 57 + 58 + def __init__(self, uri: str, record: dict): 59 + self._uri = uri 60 + self._record = record 61 + 62 + @property 63 + def name(self) -> str: 64 + """Human-readable dataset name.""" 65 + return self._record.get("name", "") 66 + 67 + @property 68 + def schema_ref(self) -> str: 69 + """AT URI of the schema record.""" 70 + return self._record.get("schemaRef", "") 71 + 72 + @property 73 + def data_urls(self) -> list[str]: 74 + """WebDataset URLs from external storage.""" 75 + storage = self._record.get("storage", {}) 76 + storage_type = storage.get("$type", "") 77 + if "storageExternal" in storage_type: 78 + return storage.get("urls", []) 79 + return [] 80 + 81 + @property 82 + def metadata(self) -> Optional[dict]: 83 + """Metadata from the record, if any.""" 84 + import msgpack 85 + metadata_bytes = self._record.get("metadata") 86 + if metadata_bytes is None: 87 + return None 88 + return msgpack.unpackb(metadata_bytes, raw=False) 89 + 90 + @property 91 + def uri(self) -> str: 92 + """AT URI of this record.""" 93 + return self._uri 94 + 95 + 96 + class AtmosphereIndex: 97 + """ATProto index implementing AbstractIndex protocol. 98 + 99 + Wraps SchemaPublisher/Loader and DatasetPublisher/Loader to provide 100 + a unified interface compatible with LocalIndex. 101 + 102 + Example: 103 + >>> client = AtmosphereClient() 104 + >>> client.login("handle.bsky.social", "app-password") 105 + >>> 106 + >>> index = AtmosphereIndex(client) 107 + >>> schema_ref = index.publish_schema(MySample, version="1.0.0") 108 + >>> entry = index.insert_dataset(dataset, name="my-data") 109 + """ 110 + 111 + def __init__(self, client: AtmosphereClient): 112 + """Initialize the atmosphere index. 113 + 114 + Args: 115 + client: Authenticated AtmosphereClient instance. 116 + """ 117 + self.client = client 118 + self._schema_publisher = SchemaPublisher(client) 119 + self._schema_loader = SchemaLoader(client) 120 + self._dataset_publisher = DatasetPublisher(client) 121 + self._dataset_loader = DatasetLoader(client) 122 + 123 + # Dataset operations 124 + 125 + def insert_dataset( 126 + self, 127 + ds: "Dataset", 128 + *, 129 + name: str, 130 + schema_ref: Optional[str] = None, 131 + **kwargs, 132 + ) -> AtmosphereIndexEntry: 133 + """Insert a dataset into ATProto. 134 + 135 + Args: 136 + ds: The Dataset to publish. 137 + name: Human-readable name. 138 + schema_ref: Optional schema AT URI. If None, auto-publishes schema. 139 + **kwargs: Additional options (description, tags, license). 140 + 141 + Returns: 142 + AtmosphereIndexEntry for the inserted dataset. 143 + """ 144 + uri = self._dataset_publisher.publish( 145 + ds, 146 + name=name, 147 + schema_uri=schema_ref, 148 + description=kwargs.get("description"), 149 + tags=kwargs.get("tags"), 150 + license=kwargs.get("license"), 151 + auto_publish_schema=(schema_ref is None), 152 + ) 153 + record = self._dataset_loader.get(uri) 154 + return AtmosphereIndexEntry(str(uri), record) 155 + 156 + def get_dataset(self, ref: str) -> AtmosphereIndexEntry: 157 + """Get a dataset by AT URI. 158 + 159 + Args: 160 + ref: AT URI of the dataset record. 161 + 162 + Returns: 163 + AtmosphereIndexEntry for the dataset. 164 + 165 + Raises: 166 + ValueError: If record is not a dataset. 167 + """ 168 + record = self._dataset_loader.get(ref) 169 + return AtmosphereIndexEntry(ref, record) 170 + 171 + def list_datasets(self, repo: Optional[str] = None) -> Iterator[AtmosphereIndexEntry]: 172 + """List dataset entries from a repository. 173 + 174 + Args: 175 + repo: DID of repository. Defaults to authenticated user. 176 + 177 + Yields: 178 + AtmosphereIndexEntry for each dataset. 179 + """ 180 + records = self._dataset_loader.list_all(repo=repo) 181 + for rec in records: 182 + uri = rec.get("uri", "") 183 + yield AtmosphereIndexEntry(uri, rec.get("value", rec)) 184 + 185 + # Schema operations 186 + 187 + def publish_schema( 188 + self, 189 + sample_type: "Type[PackableSample]", 190 + *, 191 + version: str = "1.0.0", 192 + **kwargs, 193 + ) -> str: 194 + """Publish a schema to ATProto. 195 + 196 + Args: 197 + sample_type: The PackableSample subclass to publish. 198 + version: Semantic version string. 199 + **kwargs: Additional options (description, metadata). 200 + 201 + Returns: 202 + AT URI of the schema record. 203 + """ 204 + uri = self._schema_publisher.publish( 205 + sample_type, 206 + version=version, 207 + description=kwargs.get("description"), 208 + metadata=kwargs.get("metadata"), 209 + ) 210 + return str(uri) 211 + 212 + def get_schema(self, ref: str) -> dict: 213 + """Get a schema record by AT URI. 214 + 215 + Args: 216 + ref: AT URI of the schema record. 217 + 218 + Returns: 219 + Schema record dictionary. 220 + 221 + Raises: 222 + ValueError: If record is not a schema. 223 + """ 224 + return self._schema_loader.get(ref) 225 + 226 + def list_schemas(self, repo: Optional[str] = None) -> Iterator[dict]: 227 + """List schema records from a repository. 228 + 229 + Args: 230 + repo: DID of repository. Defaults to authenticated user. 231 + 232 + Yields: 233 + Schema records. 234 + """ 235 + records = self._schema_loader.list_all(repo=repo) 236 + for rec in records: 237 + yield rec.get("value", rec) 238 + 239 + def decode_schema(self, ref: str) -> "Type[PackableSample]": 240 + """Reconstruct a Python type from a schema record. 241 + 242 + Args: 243 + ref: AT URI of the schema record. 244 + 245 + Returns: 246 + Dynamically generated PackableSample subclass. 247 + 248 + Raises: 249 + ValueError: If schema cannot be decoded. 250 + """ 251 + from .._schema_codec import schema_to_type 252 + 253 + schema = self.get_schema(ref) 254 + return schema_to_type(schema) 255 + 256 + 44 257 __all__ = [ 45 258 # Client 46 259 "AtmosphereClient", 260 + # Unified index (AbstractIndex protocol) 261 + "AtmosphereIndex", 262 + "AtmosphereIndexEntry", 47 263 # Schema operations 48 264 "SchemaPublisher", 49 265 "SchemaLoader",
+2 -24
src/atdata/atmosphere/schema.py
··· 17 17 FieldType, 18 18 LEXICON_NAMESPACE, 19 19 ) 20 + from .._type_utils import numpy_dtype_to_string 20 21 21 22 # Import for type checking only to avoid circular imports 22 23 from typing import TYPE_CHECKING ··· 205 206 206 207 def _numpy_dtype_to_string(self, dtype) -> str: 207 208 """Convert a numpy dtype annotation to a string.""" 208 - dtype_str = str(dtype) 209 - # Handle common numpy dtypes 210 - dtype_map = { 211 - "float16": "float16", 212 - "float32": "float32", 213 - "float64": "float64", 214 - "int8": "int8", 215 - "int16": "int16", 216 - "int32": "int32", 217 - "int64": "int64", 218 - "uint8": "uint8", 219 - "uint16": "uint16", 220 - "uint32": "uint32", 221 - "uint64": "uint64", 222 - "bool": "bool", 223 - "complex64": "complex64", 224 - "complex128": "complex128", 225 - } 226 - 227 - for key, value in dtype_map.items(): 228 - if key in dtype_str: 229 - return value 230 - 231 - return "float32" # Default fallback 209 + return numpy_dtype_to_string(dtype) 232 210 233 211 234 212 class SchemaLoader:
+11 -17
src/atdata/dataset.py
··· 54 54 Sequence, 55 55 Iterable, 56 56 Callable, 57 - Union, 58 - # 59 57 Self, 60 58 Generic, 61 59 Type, ··· 187 185 continue 188 186 189 187 elif isinstance( var_cur_value, bytes ): 190 - # TODO This does create a constraint that serialized bytes 191 - # in a field that might be an NDArray are always interpreted 192 - # as being the NDArray interpretation 188 + # Design note: bytes in NDArray-typed fields are always interpreted 189 + # as serialized arrays. This means raw bytes fields must not be 190 + # annotated as NDArray. 193 191 setattr( self, var_name, eh.bytes_to_array( var_cur_value ) ) 194 192 195 193 def __post_init__( self ): ··· 202 200 """Create a sample instance from unpacked msgpack data. 203 201 204 202 Args: 205 - data: A dictionary of unpacked msgpack data with keys matching 206 - the sample's field names. 203 + data: Dictionary with keys matching the sample's field names. 207 204 208 205 Returns: 209 - A new instance of this sample class with fields populated from 210 - the data dictionary and NDArray fields auto-converted from bytes. 206 + New instance with NDArray fields auto-converted from bytes. 211 207 """ 212 - ret = cls( **data ) 213 - ret._ensure_good() 214 - return ret 208 + return cls( **data ) 215 209 216 210 @classmethod 217 211 def from_bytes( cls, bs: bytes ) -> Self: ··· 253 247 254 248 return ret 255 249 256 - # TODO Expand to allow for specifying explicit __key__ 257 250 @property 258 251 def as_wds( self ) -> WDSRawSample: 259 252 """Pack this sample's data for writing to WebDataset. ··· 263 256 ``msgpack`` (packed sample data) fields suitable for WebDataset. 264 257 265 258 Note: 266 - TODO: Expand to allow specifying explicit ``__key__`` values. 259 + Keys are auto-generated as UUID v1 for time-sortable ordering. 260 + Custom key specification is not currently supported. 267 261 """ 268 262 return { 269 263 # Generates a UUID that is timelike-sortable ··· 575 569 wds.filters.map( self.wrap_batch ), 576 570 ) 577 571 578 - # TODO Rewrite to eliminate `pandas` dependency directly calling 579 - # `fastparquet` 572 + # Design note: Uses pandas for parquet export. Could be replaced with 573 + # direct fastparquet calls to reduce dependencies if needed. 580 574 def to_parquet( self, path: Pathlike, 581 575 sample_map: Optional[SampleExportMap] = None, 582 576 maxcount: Optional[int] = None, ··· 721 715 def __post_init__( self ): 722 716 return PackableSample.__post_init__( self ) 723 717 724 - # TODO This doesn't properly carry over the original 718 + # Restore original class identity for better repr/debugging 725 719 as_packable.__name__ = class_name 726 720 as_packable.__annotations__ = class_annotations 727 721
+715 -163
src/atdata/local.py
··· 6 6 7 7 The main classes are: 8 8 - Repo: Manages dataset storage in S3 with Redis indexing 9 - - Index: Redis-backed index for tracking dataset metadata 10 - - BasicIndexEntry: Index entry representing a stored dataset 9 + - LocalIndex: Redis-backed index for tracking dataset metadata 10 + - LocalDatasetEntry: Index entry representing a stored dataset 11 11 12 12 This is intended for development and small-scale deployment before 13 - migrating to the full atproto PDS infrastructure. 13 + migrating to the full atproto PDS infrastructure. The implementation 14 + uses ATProto-compatible CIDs for content addressing, enabling seamless 15 + promotion from local storage to the atmosphere (ATProto network). 14 16 """ 15 17 16 18 ## ··· 20 22 PackableSample, 21 23 Dataset, 22 24 ) 25 + from atdata._cid import generate_cid 26 + from atdata._protocols import IndexEntry 27 + from atdata._type_utils import numpy_dtype_to_string, PRIMITIVE_TYPE_MAP 23 28 24 - import os 25 29 from pathlib import Path 26 30 from uuid import uuid4 27 31 from tempfile import TemporaryDirectory ··· 44 48 from typing import ( 45 49 Any, 46 50 Optional, 47 - Dict, 48 51 Type, 49 52 TypeVar, 50 53 Generator, 54 + Iterator, 51 55 BinaryIO, 56 + Union, 52 57 cast, 58 + get_type_hints, 59 + get_origin, 60 + get_args, 53 61 ) 62 + import types 63 + from dataclasses import fields, is_dataclass 64 + from datetime import datetime, timezone 65 + import json 54 66 55 67 T = TypeVar( 'T', bound = PackableSample ) 56 68 69 + # Redis key prefixes for index entries and schemas 70 + REDIS_KEY_DATASET_ENTRY = "LocalDatasetEntry" 71 + REDIS_KEY_SCHEMA = "LocalSchema" 72 + 57 73 58 74 ## 59 75 # Helpers 60 76 61 77 def _kind_str_for_sample_type( st: Type[PackableSample] ) -> str: 62 - """Convert a sample type to a fully-qualified string identifier. 78 + """Return fully-qualified 'module.name' string for a sample type.""" 79 + return f'{st.__module__}.{st.__name__}' 80 + 81 + 82 + def _create_s3_write_callbacks( 83 + credentials: dict[str, Any], 84 + temp_dir: str, 85 + written_shards: list[str], 86 + fs: S3FileSystem | None, 87 + cache_local: bool, 88 + add_s3_prefix: bool = False, 89 + ) -> tuple: 90 + """Create opener and post callbacks for ShardWriter with S3 upload. 63 91 64 92 Args: 65 - st: The sample type class. 93 + credentials: S3 credentials dict. 94 + temp_dir: Temporary directory for local caching. 95 + written_shards: List to append written shard paths to. 96 + fs: S3FileSystem for direct writes (used when cache_local=False). 97 + cache_local: If True, write locally then copy to S3. 98 + add_s3_prefix: If True, prepend 's3://' to shard paths. 66 99 67 100 Returns: 68 - A string in the format 'module.name' identifying the sample type. 101 + Tuple of (writer_opener, writer_post) callbacks. 69 102 """ 70 - return f'{st.__module__}.{st.__name__}' 103 + if cache_local: 104 + import boto3 71 105 72 - def _decode_bytes_dict( d: dict[bytes, bytes] ) -> dict[str, str]: 73 - """Decode a dictionary with byte keys and values to strings. 106 + s3_client_kwargs = { 107 + 'aws_access_key_id': credentials['AWS_ACCESS_KEY_ID'], 108 + 'aws_secret_access_key': credentials['AWS_SECRET_ACCESS_KEY'] 109 + } 110 + if 'AWS_ENDPOINT' in credentials: 111 + s3_client_kwargs['endpoint_url'] = credentials['AWS_ENDPOINT'] 112 + s3_client = boto3.client('s3', **s3_client_kwargs) 74 113 75 - Redis returns dictionaries with bytes keys/values, this converts them to strings. 114 + def _writer_opener(p: str): 115 + local_path = Path(temp_dir) / p 116 + local_path.parent.mkdir(parents=True, exist_ok=True) 117 + return open(local_path, 'wb') 118 + 119 + def _writer_post(p: str): 120 + local_path = Path(temp_dir) / p 121 + path_parts = Path(p).parts 122 + bucket = path_parts[0] 123 + key = str(Path(*path_parts[1:])) 124 + 125 + with open(local_path, 'rb') as f_in: 126 + s3_client.put_object(Bucket=bucket, Key=key, Body=f_in.read()) 127 + 128 + local_path.unlink() 129 + if add_s3_prefix: 130 + written_shards.append(f"s3://{p}") 131 + else: 132 + written_shards.append(p) 133 + 134 + return _writer_opener, _writer_post 135 + else: 136 + assert fs is not None, "S3FileSystem required when cache_local=False" 137 + 138 + def _direct_opener(s: str): 139 + return cast(BinaryIO, fs.open(f's3://{s}', 'wb')) 140 + 141 + def _direct_post(s: str): 142 + if add_s3_prefix: 143 + written_shards.append(f"s3://{s}") 144 + else: 145 + written_shards.append(s) 146 + 147 + return _direct_opener, _direct_post 148 + 149 + ## 150 + # Schema helpers 151 + 152 + def _schema_ref_from_type(sample_type: Type[PackableSample], version: str = "1.0.0") -> str: 153 + """Generate 'local://schemas/{module.Class}@{version}' reference.""" 154 + kind_str = _kind_str_for_sample_type(sample_type) 155 + return f"local://schemas/{kind_str}@{version}" 156 + 157 + 158 + def _parse_schema_ref(ref: str) -> tuple[str, str]: 159 + """Parse 'local://schemas/{module.Class}@{version}' into (module.Class, version).""" 160 + if not ref.startswith("local://schemas/"): 161 + raise ValueError(f"Invalid local schema reference: {ref}") 162 + 163 + path = ref[len("local://schemas/"):] 164 + if "@" not in path: 165 + raise ValueError(f"Schema reference must include version (@version): {ref}") 166 + 167 + kind_str, version = path.rsplit("@", 1) 168 + return kind_str, version 169 + 170 + 171 + def _python_type_to_field_type(python_type: Any) -> dict: 172 + """Convert Python type annotation to schema field type dict.""" 173 + # Handle primitives 174 + if python_type in PRIMITIVE_TYPE_MAP: 175 + return {"$type": "local#primitive", "primitive": PRIMITIVE_TYPE_MAP[python_type]} 176 + 177 + # Check for NDArray 178 + type_str = str(python_type) 179 + if "NDArray" in type_str or "ndarray" in type_str.lower(): 180 + dtype = "float32" # Default 181 + args = get_args(python_type) 182 + if args: 183 + dtype_arg = args[-1] if args else None 184 + if dtype_arg is not None: 185 + dtype = numpy_dtype_to_string(dtype_arg) 186 + return {"$type": "local#ndarray", "dtype": dtype} 187 + 188 + # Check for list/array types 189 + origin = get_origin(python_type) 190 + if origin is list: 191 + args = get_args(python_type) 192 + if args: 193 + items = _python_type_to_field_type(args[0]) 194 + return {"$type": "local#array", "items": items} 195 + else: 196 + return {"$type": "local#array", "items": {"$type": "local#primitive", "primitive": "str"}} 197 + 198 + # Check for nested dataclass (not yet supported) 199 + if is_dataclass(python_type): 200 + raise TypeError( 201 + f"Nested dataclass types not yet supported: {python_type.__name__}. " 202 + "Publish nested types separately and use references." 203 + ) 204 + 205 + raise TypeError(f"Unsupported type for schema field: {python_type}") 206 + 207 + 208 + def _build_schema_record( 209 + sample_type: Type[PackableSample], 210 + *, 211 + version: str = "1.0.0", 212 + description: str | None = None, 213 + ) -> dict: 214 + """Build a schema record dict from a PackableSample type. 76 215 77 216 Args: 78 - d: Dictionary with bytes keys and values. 217 + sample_type: The PackableSample subclass to introspect. 218 + version: Semantic version string. 219 + description: Optional human-readable description. 79 220 80 221 Returns: 81 - Dictionary with UTF-8 decoded string keys and values. 222 + Schema record dict suitable for Redis storage. 223 + 224 + Raises: 225 + ValueError: If sample_type is not a dataclass. 226 + TypeError: If a field type is not supported. 82 227 """ 228 + if not is_dataclass(sample_type): 229 + raise ValueError(f"{sample_type.__name__} must be a dataclass (use @packable)") 230 + 231 + field_defs = [] 232 + type_hints = get_type_hints(sample_type) 233 + 234 + for f in fields(sample_type): 235 + field_type = type_hints.get(f.name, f.type) 236 + 237 + # Check for Optional types (Union with None) 238 + is_optional = False 239 + origin = get_origin(field_type) 240 + 241 + if origin is Union or isinstance(field_type, types.UnionType): 242 + args = get_args(field_type) 243 + non_none_args = [a for a in args if a is not type(None)] 244 + if type(None) in args or len(non_none_args) < len(args): 245 + is_optional = True 246 + if len(non_none_args) == 1: 247 + field_type = non_none_args[0] 248 + elif len(non_none_args) > 1: 249 + raise TypeError(f"Complex union types not supported: {field_type}") 250 + 251 + field_type_dict = _python_type_to_field_type(field_type) 252 + 253 + field_defs.append({ 254 + "name": f.name, 255 + "fieldType": field_type_dict, 256 + "optional": is_optional, 257 + }) 258 + 83 259 return { 84 - k.decode('utf-8'): v.decode('utf-8') 85 - for k, v in d.items() 260 + "name": sample_type.__name__, 261 + "version": version, 262 + "fields": field_defs, 263 + "description": description, 264 + "createdAt": datetime.now(timezone.utc).isoformat(), 86 265 } 87 266 88 267 ··· 90 269 # Redis object model 91 270 92 271 @dataclass 93 - class BasicIndexEntry: 94 - """Index entry for a dataset stored in the repository. 272 + class LocalDatasetEntry: 273 + """Index entry for a dataset stored in the local repository. 274 + 275 + Implements the IndexEntry protocol for compatibility with AbstractIndex. 276 + Uses dual identity: a content-addressable CID (ATProto-compatible) and 277 + a human-readable name. 95 278 96 - Tracks metadata about a dataset stored in S3, including its location, 97 - type, and unique identifier. 279 + The CID is generated from the entry's content (schema_ref + data_urls), 280 + ensuring the same data produces the same CID whether stored locally or 281 + in the atmosphere. This enables seamless promotion from local to ATProto. 98 282 """ 99 283 ## 100 284 101 - wds_url: str 102 - """WebDataset URL for the dataset tar files, for use with atdata.Dataset.""" 285 + _name: str 286 + """Human-readable name for this dataset.""" 287 + 288 + _schema_ref: str 289 + """Reference to the schema for this dataset (local:// path).""" 290 + 291 + _data_urls: list[str] 292 + """WebDataset URLs for the data.""" 293 + 294 + _metadata: dict | None = None 295 + """Arbitrary metadata dictionary, or None if not set.""" 296 + 297 + _cid: str | None = field(default=None, repr=False) 298 + """Content identifier (ATProto-compatible CID). Generated from content if not provided.""" 299 + 300 + # Legacy field for backwards compatibility during migration 301 + _legacy_uuid: str | None = field(default=None, repr=False) 302 + """Legacy UUID for backwards compatibility with existing Redis entries.""" 303 + 304 + def __post_init__(self): 305 + """Generate CID from content if not provided.""" 306 + if self._cid is None: 307 + self._cid = self._generate_cid() 308 + 309 + def _generate_cid(self) -> str: 310 + """Generate ATProto-compatible CID from entry content.""" 311 + # CID is based on schema_ref and data_urls - the identity of the dataset 312 + content = { 313 + "schema_ref": self._schema_ref, 314 + "data_urls": self._data_urls, 315 + } 316 + return generate_cid(content) 103 317 104 - sample_kind: str 105 - """Fully-qualified sample type name (e.g., 'module.ClassName').""" 318 + # IndexEntry protocol properties 106 319 107 - metadata_url: str | None 108 - """S3 URL to the dataset's metadata msgpack file, if any.""" 320 + @property 321 + def name(self) -> str: 322 + """Human-readable dataset name.""" 323 + return self._name 109 324 110 - uuid: str = field( default_factory = lambda: str( uuid4() ) ) 111 - """Unique identifier for this dataset entry. Defaults to a new UUID if not provided.""" 325 + @property 326 + def schema_ref(self) -> str: 327 + """Reference to the schema for this dataset.""" 328 + return self._schema_ref 112 329 113 - def write_to( self, redis: Redis ): 330 + @property 331 + def data_urls(self) -> list[str]: 332 + """WebDataset URLs for the data.""" 333 + return self._data_urls 334 + 335 + @property 336 + def metadata(self) -> dict | None: 337 + """Arbitrary metadata dictionary, or None if not set.""" 338 + return self._metadata 339 + 340 + # Additional properties 341 + 342 + @property 343 + def cid(self) -> str: 344 + """Content identifier (ATProto-compatible CID).""" 345 + assert self._cid is not None 346 + return self._cid 347 + 348 + # Legacy compatibility 349 + 350 + @property 351 + def wds_url(self) -> str: 352 + """Legacy property: returns first data URL for backwards compatibility.""" 353 + return self._data_urls[0] if self._data_urls else "" 354 + 355 + @property 356 + def sample_kind(self) -> str: 357 + """Legacy property: returns schema_ref for backwards compatibility.""" 358 + return self._schema_ref 359 + 360 + def write_to(self, redis: Redis): 114 361 """Persist this index entry to Redis. 115 362 116 - Stores the entry as a Redis hash with key 'BasicIndexEntry:{uuid}'. 363 + Stores the entry as a Redis hash with key '{REDIS_KEY_DATASET_ENTRY}:{cid}'. 117 364 118 365 Args: 119 366 redis: Redis connection to write to. 120 367 """ 121 - save_key = f'BasicIndexEntry:{self.uuid}' 122 - # Filter out None values - Redis doesn't accept None 123 - data = {k: v for k, v in asdict(self).items() if v is not None} 124 - # redis-py typing uses untyped dict, so type checker complains about dict[str, Any] 125 - redis.hset( save_key, mapping = data ) # type: ignore[arg-type] 368 + save_key = f'{REDIS_KEY_DATASET_ENTRY}:{self.cid}' 369 + data = { 370 + 'name': self._name, 371 + 'schema_ref': self._schema_ref, 372 + 'data_urls': msgpack.packb(self._data_urls), # Serialize list 373 + 'cid': self.cid, 374 + } 375 + if self._metadata is not None: 376 + data['metadata'] = msgpack.packb(self._metadata) 377 + if self._legacy_uuid is not None: 378 + data['legacy_uuid'] = self._legacy_uuid 379 + 380 + redis.hset(save_key, mapping=data) # type: ignore[arg-type] 381 + 382 + @classmethod 383 + def from_redis(cls, redis: Redis, cid: str) -> "LocalDatasetEntry": 384 + """Load an entry from Redis by CID. 385 + 386 + Args: 387 + redis: Redis connection to read from. 388 + cid: Content identifier of the entry to load. 389 + 390 + Returns: 391 + LocalDatasetEntry loaded from Redis. 392 + 393 + Raises: 394 + KeyError: If entry not found. 395 + """ 396 + save_key = f'{REDIS_KEY_DATASET_ENTRY}:{cid}' 397 + raw_data = redis.hgetall(save_key) 398 + if not raw_data: 399 + raise KeyError(f"{REDIS_KEY_DATASET_ENTRY} not found: {cid}") 126 400 127 - def _s3_env( credentials_path: str | Path ) -> dict[str, Any]: 128 - """Load S3 credentials from a .env file. 401 + # Decode string fields, keep binary fields as bytes for msgpack 402 + raw_data_typed = cast(dict[bytes, bytes], raw_data) 403 + name = raw_data_typed[b'name'].decode('utf-8') 404 + schema_ref = raw_data_typed[b'schema_ref'].decode('utf-8') 405 + cid_value = raw_data_typed.get(b'cid', b'').decode('utf-8') or None 406 + legacy_uuid = raw_data_typed.get(b'legacy_uuid', b'').decode('utf-8') or None 407 + 408 + # Deserialize msgpack fields (stored as raw bytes) 409 + data_urls = msgpack.unpackb(raw_data_typed[b'data_urls']) 410 + metadata = None 411 + if b'metadata' in raw_data_typed: 412 + metadata = msgpack.unpackb(raw_data_typed[b'metadata']) 413 + 414 + return cls( 415 + _name=name, 416 + _schema_ref=schema_ref, 417 + _data_urls=data_urls, 418 + _metadata=metadata, 419 + _cid=cid_value, 420 + _legacy_uuid=legacy_uuid, 421 + ) 129 422 130 - Args: 131 - credentials_path: Path to .env file containing S3 credentials. 132 423 133 - Returns: 134 - Dictionary with AWS_ENDPOINT, AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY. 424 + # Backwards compatibility alias 425 + BasicIndexEntry = LocalDatasetEntry 135 426 136 - Raises: 137 - AssertionError: If required credentials are missing from the file. 138 - """ 139 - ## 427 + def _s3_env( credentials_path: str | Path ) -> dict[str, Any]: 428 + """Load S3 credentials (AWS_ENDPOINT, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) from .env file.""" 140 429 credentials_path = Path( credentials_path ) 141 430 env_values = dotenv_values( credentials_path ) 142 431 assert 'AWS_ENDPOINT' in env_values ··· 153 442 } 154 443 155 444 def _s3_from_credentials( creds: str | Path | dict ) -> S3FileSystem: 156 - """Create an S3FileSystem from credentials. 157 - 158 - Args: 159 - creds: Either a path to a .env file with credentials, or a dict 160 - containing AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and optionally 161 - AWS_ENDPOINT. 162 - 163 - Returns: 164 - Configured S3FileSystem instance. 165 - """ 166 - ## 445 + """Create S3FileSystem from credentials dict or .env file path.""" 167 446 if not isinstance( creds, dict ): 168 447 creds = _s3_env( creds ) 169 448 ··· 248 527 249 528 ## 250 529 251 - def insert( self, ds: Dataset[T], 252 - # 530 + def insert(self, 531 + ds: Dataset[T], 532 + *, 533 + name: str, 253 534 cache_local: bool = False, 254 - # 255 - **kwargs 256 - ) -> tuple[BasicIndexEntry, Dataset[T]]: 535 + schema_ref: str | None = None, 536 + **kwargs 537 + ) -> tuple[LocalDatasetEntry, Dataset[T]]: 257 538 """Insert a dataset into the repository. 258 539 259 540 Writes the dataset to S3 as WebDataset tar files, stores metadata, ··· 261 542 262 543 Args: 263 544 ds: The dataset to insert. 545 + name: Human-readable name for the dataset. 264 546 cache_local: If True, write to local temporary storage first, then 265 547 copy to S3. This can be faster for some workloads. 548 + schema_ref: Optional schema reference. If None, generates from sample type. 266 549 **kwargs: Additional arguments passed to wds.ShardWriter. 267 550 268 551 Returns: 269 552 A tuple of (index_entry, new_dataset) where: 270 - - index_entry: BasicIndexEntry for the stored dataset 553 + - index_entry: LocalDatasetEntry for the stored dataset 271 554 - new_dataset: Dataset object pointing to the stored copy 272 555 273 556 Raises: 274 - AssertionError: If S3 credentials or hive_path are not configured. 557 + ValueError: If S3 credentials or hive_path are not configured. 275 558 RuntimeError: If no shards were written. 276 559 """ 277 - 278 - assert self.s3_credentials is not None 279 - assert self.hive_bucket is not None 280 - assert self.hive_path is not None 560 + if self.s3_credentials is None: 561 + raise ValueError("S3 credentials required for insert(). Initialize Repo with s3_credentials.") 562 + if self.hive_bucket is None or self.hive_path is None: 563 + raise ValueError("hive_path required for insert(). Initialize Repo with hive_path.") 281 564 282 565 new_uuid = str( uuid4() ) 283 566 ··· 305 588 / f'atdata--{new_uuid}--%06d.tar' 306 589 ).as_posix() 307 590 591 + written_shards: list[str] = [] 308 592 with TemporaryDirectory() as temp_dir: 309 - 310 - if cache_local: 311 - # For cache_local, we need to use boto3 directly to avoid s3fs async issues with moto 312 - import boto3 313 - 314 - # Create boto3 client from credentials 315 - s3_client_kwargs = { 316 - 'aws_access_key_id': self.s3_credentials['AWS_ACCESS_KEY_ID'], 317 - 'aws_secret_access_key': self.s3_credentials['AWS_SECRET_ACCESS_KEY'] 318 - } 319 - if 'AWS_ENDPOINT' in self.s3_credentials: 320 - s3_client_kwargs['endpoint_url'] = self.s3_credentials['AWS_ENDPOINT'] 321 - s3_client = boto3.client('s3', **s3_client_kwargs) 322 - 323 - def _writer_opener( p: str ): 324 - local_cache_path = Path( temp_dir ) / p 325 - local_cache_path.parent.mkdir( parents = True, exist_ok = True ) 326 - return open( local_cache_path, 'wb' ) 327 - writer_opener = _writer_opener 328 - 329 - def _writer_post( p: str ): 330 - local_cache_path = Path( temp_dir ) / p 331 - 332 - # Copy to S3 using boto3 client (avoids s3fs async issues) 333 - path_parts = Path( p ).parts 334 - bucket = path_parts[0] 335 - key = str( Path( *path_parts[1:] ) ) 336 - 337 - with open( local_cache_path, 'rb' ) as f_in: 338 - s3_client.put_object( Bucket=bucket, Key=key, Body=f_in.read() ) 339 - 340 - # Delete local cache file 341 - local_cache_path.unlink() 342 - 343 - written_shards.append( p ) 344 - writer_post = _writer_post 345 - 346 - else: 347 - # Use s3:// prefix to ensure s3fs treats paths as S3 paths 348 - writer_opener = lambda s: cast( BinaryIO, hive_fs.open( f's3://{s}', 'wb' ) ) 349 - writer_post = lambda s: written_shards.append( s ) 593 + writer_opener, writer_post = _create_s3_write_callbacks( 594 + credentials=self.s3_credentials, 595 + temp_dir=temp_dir, 596 + written_shards=written_shards, 597 + fs=hive_fs, 598 + cache_local=cache_local, 599 + add_s3_prefix=False, 600 + ) 350 601 351 - written_shards = [] 352 602 with wds.writer.ShardWriter( 353 603 shard_pattern, 354 - opener = writer_opener, 355 - post = writer_post, 604 + opener=writer_opener, 605 + post=writer_post, 356 606 **kwargs, 357 607 ) as sink: 358 - for sample in ds.ordered( batch_size = None ): 359 - sink.write( sample.as_wds ) 608 + for sample in ds.ordered(batch_size=None): 609 + sink.write(sample.as_wds) 360 610 361 611 # Make a new Dataset object for the written dataset copy 362 612 if len( written_shards ) == 0: ··· 379 629 new_dataset_url = shard_s3_format.format( shard_id = shard_id_braced ) 380 630 381 631 new_dataset = Dataset[ds.sample_type]( 382 - url = new_dataset_url, 383 - metadata_url = metadata_path.as_posix(), 632 + url=new_dataset_url, 633 + metadata_url=metadata_path.as_posix(), 384 634 ) 385 635 386 - # Add to index 387 - new_entry = self.index.add_entry( new_dataset, uuid = new_uuid ) 636 + # Add to index (use ds._metadata to avoid network requests) 637 + new_entry = self.index.add_entry( 638 + new_dataset, 639 + name=name, 640 + schema_ref=schema_ref, 641 + metadata=ds._metadata, 642 + ) 388 643 389 644 return new_entry, new_dataset 390 645 ··· 392 647 class Index: 393 648 """Redis-backed index for tracking datasets in a repository. 394 649 395 - Maintains a registry of BasicIndexEntry objects in Redis, allowing 650 + Maintains a registry of LocalDatasetEntry objects in Redis, allowing 396 651 enumeration and lookup of stored datasets. 397 652 398 653 Attributes: ··· 401 656 402 657 ## 403 658 404 - def __init__( self, 405 - redis: Redis | None = None, 406 - **kwargs 407 - ) -> None: 659 + def __init__(self, 660 + redis: Redis | None = None, 661 + **kwargs 662 + ) -> None: 408 663 """Initialize an index. 409 664 410 665 Args: ··· 418 673 if redis is not None: 419 674 self._redis = redis 420 675 else: 421 - self._redis: Redis = Redis( **kwargs ) 676 + self._redis: Redis = Redis(**kwargs) 422 677 423 678 @property 424 - def all_entries( self ) -> list[BasicIndexEntry]: 679 + def all_entries(self) -> list[LocalDatasetEntry]: 425 680 """Get all index entries as a list. 426 681 427 682 Returns: 428 - List of all BasicIndexEntry objects in the index. 683 + List of all LocalDatasetEntry objects in the index. 429 684 """ 430 - return list( self.entries ) 685 + return list(self.entries) 431 686 432 687 @property 433 - def entries( self ) -> Generator[BasicIndexEntry, None, None]: 688 + def entries(self) -> Generator[LocalDatasetEntry, None, None]: 434 689 """Iterate over all index entries. 435 690 436 - Scans Redis for all BasicIndexEntry keys and yields them one at a time. 691 + Scans Redis for LocalDatasetEntry keys and yields them one at a time. 437 692 438 693 Yields: 439 - BasicIndexEntry objects from the index. 694 + LocalDatasetEntry objects from the index. 440 695 """ 441 - ## 442 - for key in self._redis.scan_iter( match = 'BasicIndexEntry:*' ): 443 - # hgetall returns dict[bytes, bytes] which we decode to dict[str, str] 444 - cur_entry_data = _decode_bytes_dict( cast(dict[bytes, bytes], self._redis.hgetall( key )) ) 445 - 446 - # Provide default None for optional fields that may be missing 447 - # Type checker complains about None in dict[str, str], but BasicIndexEntry accepts it 448 - cur_entry_data: dict[str, Any] = dict( **cur_entry_data ) 449 - cur_entry_data.setdefault('metadata_url', None) 450 - 451 - cur_entry = BasicIndexEntry( **cur_entry_data ) 452 - yield cur_entry 696 + prefix = f'{REDIS_KEY_DATASET_ENTRY}:' 697 + for key in self._redis.scan_iter(match=f'{prefix}*'): 698 + key_str = key.decode('utf-8') if isinstance(key, bytes) else key 699 + cid = key_str[len(prefix):] 700 + yield LocalDatasetEntry.from_redis(self._redis, cid) 453 701 454 - return 455 - 456 - def add_entry( self, ds: Dataset, 457 - uuid: str | None = None, 458 - ) -> BasicIndexEntry: 702 + def add_entry(self, 703 + ds: Dataset, 704 + *, 705 + name: str, 706 + schema_ref: str | None = None, 707 + metadata: dict | None = None, 708 + ) -> LocalDatasetEntry: 459 709 """Add a dataset to the index. 460 710 461 - Creates a BasicIndexEntry for the dataset and persists it to Redis. 711 + Creates a LocalDatasetEntry for the dataset and persists it to Redis. 462 712 463 713 Args: 464 714 ds: The dataset to add to the index. 465 - uuid: Optional UUID for the entry. If None, a new UUID is generated. 715 + name: Human-readable name for the dataset. 716 + schema_ref: Optional schema reference. If None, generates from sample type. 717 + metadata: Optional metadata dictionary. If None, uses ds._metadata if available. 466 718 467 719 Returns: 468 - The created BasicIndexEntry object. 720 + The created LocalDatasetEntry object. 469 721 """ 470 722 ## 471 - temp_sample_kind = _kind_str_for_sample_type( ds.sample_type ) 723 + if schema_ref is None: 724 + schema_ref = f"local://schemas/{_kind_str_for_sample_type(ds.sample_type)}@1.0.0" 725 + 726 + # Normalize URL to list 727 + data_urls = [ds.url] 728 + 729 + # Use provided metadata, or fall back to dataset's cached metadata 730 + # (avoid triggering network requests via ds.metadata property) 731 + entry_metadata = metadata if metadata is not None else ds._metadata 732 + 733 + entry = LocalDatasetEntry( 734 + _name=name, 735 + _schema_ref=schema_ref, 736 + _data_urls=data_urls, 737 + _metadata=entry_metadata, 738 + ) 739 + 740 + entry.write_to(self._redis) 741 + 742 + return entry 743 + 744 + def get_entry(self, cid: str) -> LocalDatasetEntry: 745 + """Get an entry by its CID. 746 + 747 + Args: 748 + cid: Content identifier of the entry. 749 + 750 + Returns: 751 + LocalDatasetEntry for the given CID. 752 + 753 + Raises: 754 + KeyError: If entry not found. 755 + """ 756 + return LocalDatasetEntry.from_redis(self._redis, cid) 757 + 758 + def get_entry_by_name(self, name: str) -> LocalDatasetEntry: 759 + """Get an entry by its human-readable name. 760 + 761 + Args: 762 + name: Human-readable name of the entry. 763 + 764 + Returns: 765 + LocalDatasetEntry with the given name. 766 + 767 + Raises: 768 + KeyError: If no entry with that name exists. 769 + """ 770 + for entry in self.entries: 771 + if entry.name == name: 772 + return entry 773 + raise KeyError(f"No entry with name: {name}") 774 + 775 + # AbstractIndex protocol methods 776 + 777 + def insert_dataset( 778 + self, 779 + ds: Dataset, 780 + *, 781 + name: str, 782 + schema_ref: str | None = None, 783 + **kwargs, 784 + ) -> LocalDatasetEntry: 785 + """Insert a dataset into the index (AbstractIndex protocol). 786 + 787 + Args: 788 + ds: The Dataset to register. 789 + name: Human-readable name for the dataset. 790 + schema_ref: Optional schema reference. 791 + **kwargs: Additional options (metadata supported). 792 + 793 + Returns: 794 + IndexEntry for the inserted dataset. 795 + """ 796 + metadata = kwargs.get('metadata') 797 + return self.add_entry(ds, name=name, schema_ref=schema_ref, metadata=metadata) 798 + 799 + def get_dataset(self, ref: str) -> LocalDatasetEntry: 800 + """Get a dataset entry by name (AbstractIndex protocol). 801 + 802 + Args: 803 + ref: Dataset name. 804 + 805 + Returns: 806 + IndexEntry for the dataset. 807 + 808 + Raises: 809 + KeyError: If dataset not found. 810 + """ 811 + return self.get_entry_by_name(ref) 812 + 813 + def list_datasets(self) -> Iterator[LocalDatasetEntry]: 814 + """List all dataset entries (AbstractIndex protocol). 815 + 816 + Yields: 817 + IndexEntry for each dataset. 818 + """ 819 + return self.entries 820 + 821 + # Schema operations 822 + 823 + def publish_schema( 824 + self, 825 + sample_type: Type[PackableSample], 826 + *, 827 + version: str = "1.0.0", 828 + description: str | None = None, 829 + ) -> str: 830 + """Publish a schema for a sample type to Redis. 831 + 832 + Args: 833 + sample_type: The PackableSample subclass to publish. 834 + version: Semantic version string (e.g., '1.0.0'). 835 + description: Optional human-readable description. 836 + 837 + Returns: 838 + Schema reference string: 'local://schemas/{module.Class}@{version}'. 839 + 840 + Raises: 841 + ValueError: If sample_type is not a dataclass. 842 + TypeError: If a field type is not supported. 843 + """ 844 + schema_record = _build_schema_record( 845 + sample_type, 846 + version=version, 847 + description=description, 848 + ) 849 + 850 + schema_ref = _schema_ref_from_type(sample_type, version) 851 + kind_str, _ = _parse_schema_ref(schema_ref) 852 + 853 + # Store in Redis 854 + redis_key = f"{REDIS_KEY_SCHEMA}:{kind_str}@{version}" 855 + schema_json = json.dumps(schema_record) 856 + self._redis.set(redis_key, schema_json) 857 + 858 + return schema_ref 859 + 860 + def get_schema(self, ref: str) -> dict: 861 + """Get a schema record by reference. 862 + 863 + Args: 864 + ref: Schema reference string (local://schemas/...). 865 + 866 + Returns: 867 + Schema record as a dictionary. 868 + 869 + Raises: 870 + KeyError: If schema not found. 871 + ValueError: If reference format is invalid. 872 + """ 873 + kind_str, version = _parse_schema_ref(ref) 874 + redis_key = f"{REDIS_KEY_SCHEMA}:{kind_str}@{version}" 875 + 876 + schema_json = self._redis.get(redis_key) 877 + if schema_json is None: 878 + raise KeyError(f"Schema not found: {ref}") 879 + 880 + if isinstance(schema_json, bytes): 881 + schema_json = schema_json.decode('utf-8') 882 + 883 + schema = json.loads(schema_json) 884 + # Add $ref for decode_schema compatibility 885 + schema['$ref'] = ref 886 + return schema 887 + 888 + def list_schemas(self) -> Generator[dict, None, None]: 889 + """List all schema records in this index. 890 + 891 + Yields: 892 + Schema records as dictionaries. 893 + """ 894 + prefix = f'{REDIS_KEY_SCHEMA}:' 895 + for key in self._redis.scan_iter(match=f'{prefix}*'): 896 + key_str = key.decode('utf-8') if isinstance(key, bytes) else key 897 + # Extract kind_str@version from key 898 + schema_id = key_str[len(prefix):] 899 + 900 + schema_json = self._redis.get(key) 901 + if schema_json is None: 902 + continue 903 + 904 + if isinstance(schema_json, bytes): 905 + schema_json = schema_json.decode('utf-8') 906 + 907 + schema = json.loads(schema_json) 908 + schema['$ref'] = f"local://schemas/{schema_id}" 909 + yield schema 910 + 911 + def decode_schema(self, ref: str) -> Type[PackableSample]: 912 + """Reconstruct a Python PackableSample type from a stored schema. 913 + 914 + This method enables loading datasets without knowing the sample type 915 + ahead of time. The index retrieves the schema record and dynamically 916 + generates a PackableSample subclass matching the schema definition. 917 + 918 + Args: 919 + ref: Schema reference string (local://schemas/...). 920 + 921 + Returns: 922 + A dynamically generated PackableSample subclass. 923 + 924 + Raises: 925 + KeyError: If schema not found. 926 + ValueError: If schema cannot be decoded. 927 + """ 928 + from atdata._schema_codec import schema_to_type 929 + 930 + schema = self.get_schema(ref) 931 + return schema_to_type(schema) 932 + 933 + 934 + # Backwards compatibility alias 935 + LocalIndex = Index 936 + 937 + 938 + class S3DataStore: 939 + """S3-compatible data store implementing AbstractDataStore protocol. 940 + 941 + Handles writing dataset shards to S3-compatible object storage and 942 + resolving URLs for reading. 472 943 473 - if uuid is None: 474 - ret_data = BasicIndexEntry( 475 - wds_url = ds.url, 476 - sample_kind = temp_sample_kind, 477 - metadata_url = ds.metadata_url, 478 - ) 944 + Attributes: 945 + credentials: S3 credentials dictionary. 946 + bucket: Target bucket name. 947 + _fs: S3FileSystem instance. 948 + """ 949 + 950 + def __init__( 951 + self, 952 + credentials: str | Path | dict[str, Any], 953 + *, 954 + bucket: str, 955 + ) -> None: 956 + """Initialize an S3 data store. 957 + 958 + Args: 959 + credentials: Path to .env file or dict with AWS_ACCESS_KEY_ID, 960 + AWS_SECRET_ACCESS_KEY, and optionally AWS_ENDPOINT. 961 + bucket: Name of the S3 bucket for storage. 962 + """ 963 + if isinstance(credentials, dict): 964 + self.credentials = credentials 479 965 else: 480 - ret_data = BasicIndexEntry( 481 - wds_url = ds.url, 482 - sample_kind = temp_sample_kind, 483 - metadata_url = ds.metadata_url, 484 - uuid = uuid, 966 + self.credentials = _s3_env(credentials) 967 + 968 + self.bucket = bucket 969 + self._fs = _s3_from_credentials(self.credentials) 970 + 971 + def write_shards( 972 + self, 973 + ds: Dataset, 974 + *, 975 + prefix: str, 976 + cache_local: bool = False, 977 + **kwargs, 978 + ) -> list[str]: 979 + """Write dataset shards to S3. 980 + 981 + Args: 982 + ds: The Dataset to write. 983 + prefix: Path prefix within bucket (e.g., 'datasets/mnist/v1'). 984 + cache_local: If True, write locally first then copy to S3. 985 + **kwargs: Additional args passed to wds.ShardWriter (e.g., maxcount). 986 + 987 + Returns: 988 + List of S3 URLs for the written shards. 989 + 990 + Raises: 991 + RuntimeError: If no shards were written. 992 + """ 993 + new_uuid = str(uuid4()) 994 + shard_pattern = f"{self.bucket}/{prefix}/data--{new_uuid}--%06d.tar" 995 + 996 + written_shards: list[str] = [] 997 + 998 + with TemporaryDirectory() as temp_dir: 999 + writer_opener, writer_post = _create_s3_write_callbacks( 1000 + credentials=self.credentials, 1001 + temp_dir=temp_dir, 1002 + written_shards=written_shards, 1003 + fs=self._fs, 1004 + cache_local=cache_local, 1005 + add_s3_prefix=True, 485 1006 ) 486 1007 487 - ret_data.write_to( self._redis ) 1008 + with wds.writer.ShardWriter( 1009 + shard_pattern, 1010 + opener=writer_opener, 1011 + post=writer_post, 1012 + **kwargs, 1013 + ) as sink: 1014 + for sample in ds.ordered(batch_size=None): 1015 + sink.write(sample.as_wds) 1016 + 1017 + if len(written_shards) == 0: 1018 + raise RuntimeError("No shards written") 1019 + 1020 + return written_shards 1021 + 1022 + def read_url(self, url: str) -> str: 1023 + """Resolve an S3 URL for reading. 1024 + 1025 + For S3, URLs are returned as-is (WebDataset handles s3:// directly). 1026 + 1027 + Args: 1028 + url: S3 URL to resolve. 488 1029 489 - return ret_data 1030 + Returns: 1031 + The URL unchanged. 1032 + """ 1033 + return url 1034 + 1035 + def supports_streaming(self) -> bool: 1036 + """S3 supports streaming reads. 1037 + 1038 + Returns: 1039 + True. 1040 + """ 1041 + return True 490 1042 491 1043 492 1044 #
+197
src/atdata/promote.py
··· 1 + """Promotion workflow for migrating datasets from local to atmosphere. 2 + 3 + This module provides functionality to promote locally-indexed datasets to the 4 + ATProto atmosphere network. This enables sharing datasets with the broader 5 + federation while maintaining schema consistency. 6 + 7 + Example: 8 + >>> from atdata.local import LocalIndex, Repo 9 + >>> from atdata.atmosphere import AtmosphereClient, AtmosphereIndex 10 + >>> from atdata.promote import promote_to_atmosphere 11 + >>> 12 + >>> # Setup 13 + >>> local_index = LocalIndex() 14 + >>> client = AtmosphereClient() 15 + >>> client.login("handle.bsky.social", "app-password") 16 + >>> 17 + >>> # Promote a dataset 18 + >>> entry = local_index.get_dataset("my-dataset") 19 + >>> at_uri = promote_to_atmosphere(entry, local_index, client) 20 + """ 21 + 22 + from typing import TYPE_CHECKING, Type 23 + 24 + if TYPE_CHECKING: 25 + from .local import LocalDatasetEntry, Index as LocalIndex 26 + from .atmosphere import AtmosphereClient, AtUri 27 + from .atmosphere._types import AtUri as AtUriType 28 + from .dataset import PackableSample 29 + from ._protocols import AbstractDataStore 30 + 31 + 32 + def _find_existing_schema( 33 + client: "AtmosphereClient", 34 + name: str, 35 + version: str, 36 + ) -> str | None: 37 + """Check if a schema with the given name and version already exists. 38 + 39 + Args: 40 + client: Authenticated atmosphere client. 41 + name: Schema name to search for. 42 + version: Schema version to match. 43 + 44 + Returns: 45 + AT URI of existing schema if found, None otherwise. 46 + """ 47 + from .atmosphere import SchemaLoader 48 + 49 + loader = SchemaLoader(client) 50 + for record in loader.list_all(): 51 + rec_value = record.get("value", record) 52 + if rec_value.get("name") == name and rec_value.get("version") == version: 53 + return record.get("uri", "") 54 + return None 55 + 56 + 57 + def _find_or_publish_schema( 58 + sample_type: "Type[PackableSample]", 59 + version: str, 60 + client: "AtmosphereClient", 61 + description: str | None = None, 62 + ) -> str: 63 + """Find existing schema or publish a new one. 64 + 65 + Checks if a schema with the same name and version already exists on the 66 + user's atmosphere repository. If found, returns the existing URI to avoid 67 + duplicates. Otherwise, publishes a new schema record. 68 + 69 + Args: 70 + sample_type: The PackableSample subclass to publish. 71 + version: Semantic version string. 72 + client: Authenticated atmosphere client. 73 + description: Optional schema description. 74 + 75 + Returns: 76 + AT URI of the schema (existing or newly published). 77 + """ 78 + from .atmosphere import SchemaPublisher 79 + 80 + schema_name = f"{sample_type.__module__}.{sample_type.__name__}" 81 + 82 + # Check for existing schema 83 + existing = _find_existing_schema(client, schema_name, version) 84 + if existing: 85 + return existing 86 + 87 + # Publish new schema 88 + publisher = SchemaPublisher(client) 89 + uri = publisher.publish( 90 + sample_type, 91 + version=version, 92 + description=description, 93 + ) 94 + return str(uri) 95 + 96 + 97 + def promote_to_atmosphere( 98 + local_entry: "LocalDatasetEntry", 99 + local_index: "LocalIndex", 100 + atmosphere_client: "AtmosphereClient", 101 + *, 102 + data_store: "AbstractDataStore | None" = None, 103 + name: str | None = None, 104 + description: str | None = None, 105 + tags: list[str] | None = None, 106 + license: str | None = None, 107 + ) -> str: 108 + """Promote a local dataset to the atmosphere network. 109 + 110 + This function takes a locally-indexed dataset and publishes it to ATProto, 111 + making it discoverable on the federated atmosphere network. 112 + 113 + Args: 114 + local_entry: The LocalDatasetEntry to promote. 115 + local_index: Local index containing the schema for this entry. 116 + atmosphere_client: Authenticated AtmosphereClient. 117 + data_store: Optional data store for copying data to new location. 118 + If None, the existing data_urls are used as-is. 119 + name: Override name for the atmosphere record. Defaults to local name. 120 + description: Optional description for the dataset. 121 + tags: Optional tags for discovery. 122 + license: Optional license identifier. 123 + 124 + Returns: 125 + AT URI of the created atmosphere dataset record. 126 + 127 + Raises: 128 + KeyError: If schema not found in local index. 129 + ValueError: If local entry has no data URLs. 130 + 131 + Example: 132 + >>> entry = local_index.get_dataset("mnist-train") 133 + >>> uri = promote_to_atmosphere(entry, local_index, client) 134 + >>> print(uri) 135 + at://did:plc:abc123/ac.foundation.dataset.datasetIndex/... 136 + """ 137 + from .atmosphere import DatasetPublisher 138 + from ._schema_codec import schema_to_type 139 + 140 + # Validate entry has data 141 + if not local_entry.data_urls: 142 + raise ValueError(f"Local entry '{local_entry.name}' has no data URLs") 143 + 144 + # Get schema from local index 145 + schema_ref = local_entry.schema_ref 146 + schema_record = local_index.get_schema(schema_ref) 147 + 148 + # Reconstruct sample type from schema 149 + sample_type = schema_to_type(schema_record) 150 + schema_version = schema_record.get("version", "1.0.0") 151 + 152 + # Find or publish schema on atmosphere (deduplication) 153 + atmosphere_schema_uri = _find_or_publish_schema( 154 + sample_type, 155 + schema_version, 156 + atmosphere_client, 157 + description=schema_record.get("description"), 158 + ) 159 + 160 + # Determine data URLs 161 + if data_store is not None: 162 + # Copy data to new storage location 163 + # Create a temporary Dataset to write through the data store 164 + from .dataset import Dataset 165 + 166 + # Build WDS URL from data_urls 167 + if len(local_entry.data_urls) == 1: 168 + wds_url = local_entry.data_urls[0] 169 + else: 170 + # Use brace notation for multiple URLs 171 + wds_url = " ".join(local_entry.data_urls) 172 + 173 + ds = Dataset[sample_type](wds_url) 174 + prefix = f"promoted/{local_entry.name}" 175 + data_urls = data_store.write_shards(ds, prefix=prefix) 176 + else: 177 + # Use existing URLs as-is 178 + data_urls = local_entry.data_urls 179 + 180 + # Publish dataset record to atmosphere 181 + publisher = DatasetPublisher(atmosphere_client) 182 + uri = publisher.publish_with_urls( 183 + urls=data_urls, 184 + schema_uri=atmosphere_schema_uri, 185 + name=name or local_entry.name, 186 + description=description, 187 + tags=tags, 188 + license=license, 189 + metadata=local_entry.metadata, 190 + ) 191 + 192 + return str(uri) 193 + 194 + 195 + __all__ = [ 196 + "promote_to_atmosphere", 197 + ]
+31
tests/conftest.py
··· 1 1 """Pytest configuration for atdata tests.""" 2 + 3 + import pytest 4 + from redis import Redis 5 + 6 + 7 + @pytest.fixture 8 + def redis_connection(): 9 + """Provide a Redis connection, skip test if Redis is not available.""" 10 + try: 11 + redis = Redis() 12 + redis.ping() 13 + yield redis 14 + except Exception: 15 + pytest.skip("Redis server not available") 16 + 17 + 18 + @pytest.fixture 19 + def clean_redis(redis_connection): 20 + """Provide a Redis connection with automatic cleanup of test keys. 21 + 22 + Clears LocalDatasetEntry, BasicIndexEntry (legacy), and LocalSchema keys 23 + before and after each test to ensure test isolation. 24 + """ 25 + def _clear_all(): 26 + for pattern in ('LocalDatasetEntry:*', 'BasicIndexEntry:*', 'LocalSchema:*'): 27 + for key in redis_connection.scan_iter(match=pattern): 28 + redis_connection.delete(key) 29 + 30 + _clear_all() 31 + yield redis_connection 32 + _clear_all()
+90
tests/test_atmosphere.py
··· 19 19 import atdata 20 20 from atdata.atmosphere import ( 21 21 AtmosphereClient, 22 + AtmosphereIndex, 23 + AtmosphereIndexEntry, 22 24 SchemaPublisher, 23 25 SchemaLoader, 24 26 DatasetPublisher, ··· 1361 1363 1362 1364 with pytest.raises(TypeError, match="Unsupported type"): 1363 1365 publisher.publish(UnsupportedSample, version="1.0.0") 1366 + 1367 + 1368 + # ============================================================================= 1369 + # AtmosphereIndex Tests 1370 + # ============================================================================= 1371 + 1372 + class TestAtmosphereIndexEntry: 1373 + """Tests for AtmosphereIndexEntry wrapper.""" 1374 + 1375 + def test_entry_properties(self): 1376 + """Entry exposes record properties correctly.""" 1377 + record = { 1378 + "name": "test-dataset", 1379 + "schemaRef": "at://did:plc:abc/schema/xyz", 1380 + "storage": { 1381 + "$type": f"{LEXICON_NAMESPACE}.storageExternal", 1382 + "urls": ["s3://bucket/data.tar"], 1383 + }, 1384 + } 1385 + 1386 + entry = AtmosphereIndexEntry("at://did:plc:abc/record/123", record) 1387 + 1388 + assert entry.name == "test-dataset" 1389 + assert entry.schema_ref == "at://did:plc:abc/schema/xyz" 1390 + assert entry.data_urls == ["s3://bucket/data.tar"] 1391 + assert entry.uri == "at://did:plc:abc/record/123" 1392 + 1393 + def test_entry_empty_storage(self): 1394 + """Entry handles missing storage gracefully.""" 1395 + record = {"name": "no-storage"} 1396 + 1397 + entry = AtmosphereIndexEntry("at://uri", record) 1398 + 1399 + assert entry.data_urls == [] 1400 + 1401 + 1402 + class TestAtmosphereIndex: 1403 + """Tests for AtmosphereIndex unified interface.""" 1404 + 1405 + def test_init(self, authenticated_client): 1406 + """Index initializes with client and creates publishers/loaders.""" 1407 + index = AtmosphereIndex(authenticated_client) 1408 + 1409 + assert index.client is authenticated_client 1410 + assert index._schema_publisher is not None 1411 + assert index._schema_loader is not None 1412 + assert index._dataset_publisher is not None 1413 + assert index._dataset_loader is not None 1414 + 1415 + def test_has_protocol_methods(self, authenticated_client): 1416 + """Index has all AbstractIndex protocol methods.""" 1417 + index = AtmosphereIndex(authenticated_client) 1418 + 1419 + assert hasattr(index, 'insert_dataset') 1420 + assert hasattr(index, 'get_dataset') 1421 + assert hasattr(index, 'list_datasets') 1422 + assert hasattr(index, 'publish_schema') 1423 + assert hasattr(index, 'get_schema') 1424 + assert hasattr(index, 'list_schemas') 1425 + assert hasattr(index, 'decode_schema') 1426 + 1427 + def test_publish_schema(self, authenticated_client, mock_atproto_client): 1428 + """publish_schema delegates to SchemaPublisher.""" 1429 + mock_response = Mock() 1430 + mock_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/abc" 1431 + mock_atproto_client.com.atproto.repo.create_record.return_value = mock_response 1432 + 1433 + index = AtmosphereIndex(authenticated_client) 1434 + uri = index.publish_schema(BasicSample, version="2.0.0") 1435 + 1436 + assert uri == str(mock_response.uri) 1437 + mock_atproto_client.com.atproto.repo.create_record.assert_called_once() 1438 + 1439 + def test_get_schema(self, authenticated_client, mock_atproto_client): 1440 + """get_schema delegates to SchemaLoader.""" 1441 + mock_response = Mock() 1442 + mock_response.value = { 1443 + "$type": f"{LEXICON_NAMESPACE}.sampleSchema", 1444 + "name": "TestSchema", 1445 + "version": "1.0.0", 1446 + "fields": [], 1447 + } 1448 + mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 1449 + 1450 + index = AtmosphereIndex(authenticated_client) 1451 + schema = index.get_schema("at://did:plc:test/schema/abc") 1452 + 1453 + assert schema["name"] == "TestSchema"
+219
tests/test_cid.py
··· 1 + """Tests for CID generation utilities.""" 2 + 3 + import pytest 4 + import libipld 5 + 6 + from atdata._cid import ( 7 + generate_cid, 8 + generate_cid_from_bytes, 9 + verify_cid, 10 + parse_cid, 11 + ) 12 + 13 + 14 + class TestGenerateCid: 15 + """Tests for generate_cid function.""" 16 + 17 + def test_generates_valid_cid_from_dict(self): 18 + """CID is generated from a dictionary.""" 19 + data = {"name": "TestSample", "version": "1.0.0"} 20 + cid = generate_cid(data) 21 + 22 + # CIDv1 base32 starts with 'bafy' 23 + assert cid.startswith("bafy") 24 + assert len(cid) > 40 # CIDs are typically 59 chars 25 + 26 + def test_deterministic_output(self): 27 + """Same data always produces same CID.""" 28 + data = {"name": "TestSample", "version": "1.0.0", "fields": []} 29 + 30 + cid1 = generate_cid(data) 31 + cid2 = generate_cid(data) 32 + 33 + assert cid1 == cid2 34 + 35 + def test_different_data_different_cid(self): 36 + """Different data produces different CIDs.""" 37 + data1 = {"name": "Sample1", "version": "1.0.0"} 38 + data2 = {"name": "Sample2", "version": "1.0.0"} 39 + 40 + cid1 = generate_cid(data1) 41 + cid2 = generate_cid(data2) 42 + 43 + assert cid1 != cid2 44 + 45 + def test_key_order_matters_in_dag_cbor(self): 46 + """DAG-CBOR has deterministic key ordering, so key order in input doesn't matter.""" 47 + # DAG-CBOR sorts keys, so these should produce the same CID 48 + data1 = {"a": 1, "b": 2} 49 + data2 = {"b": 2, "a": 1} 50 + 51 + cid1 = generate_cid(data1) 52 + cid2 = generate_cid(data2) 53 + 54 + # DAG-CBOR canonicalizes key order 55 + assert cid1 == cid2 56 + 57 + def test_handles_nested_structures(self): 58 + """CID can be generated from nested data structures.""" 59 + data = { 60 + "name": "NestedSample", 61 + "fields": [ 62 + {"name": "field1", "type": "str"}, 63 + {"name": "field2", "type": "int"}, 64 + ], 65 + "metadata": {"author": "test", "tags": ["a", "b", "c"]}, 66 + } 67 + 68 + cid = generate_cid(data) 69 + assert cid.startswith("bafy") 70 + 71 + def test_handles_various_types(self): 72 + """CID handles various Python types.""" 73 + data = { 74 + "string": "hello", 75 + "integer": 42, 76 + "float": 3.14, 77 + "boolean": True, 78 + "null": None, 79 + "bytes": b"binary data", 80 + "list": [1, 2, 3], 81 + } 82 + 83 + cid = generate_cid(data) 84 + assert cid.startswith("bafy") 85 + 86 + def test_invalid_data_raises_error(self): 87 + """Non-CBOR-serializable data raises ValueError.""" 88 + # Functions can't be serialized to CBOR 89 + data = {"func": lambda x: x} 90 + 91 + with pytest.raises(ValueError, match="Failed to encode"): 92 + generate_cid(data) 93 + 94 + 95 + class TestGenerateCidFromBytes: 96 + """Tests for generate_cid_from_bytes function.""" 97 + 98 + def test_generates_cid_from_bytes(self): 99 + """CID is generated from raw bytes.""" 100 + data_bytes = b"some raw bytes" 101 + cid = generate_cid_from_bytes(data_bytes) 102 + 103 + assert cid.startswith("bafy") 104 + 105 + def test_matches_manual_encoding(self): 106 + """CID from bytes matches CID from pre-encoded data.""" 107 + data = {"key": "value"} 108 + cbor_bytes = libipld.encode_dag_cbor(data) 109 + 110 + cid_from_data = generate_cid(data) 111 + cid_from_bytes = generate_cid_from_bytes(cbor_bytes) 112 + 113 + assert cid_from_data == cid_from_bytes 114 + 115 + 116 + class TestVerifyCid: 117 + """Tests for verify_cid function.""" 118 + 119 + def test_verify_matching_data(self): 120 + """verify_cid returns True for matching data.""" 121 + data = {"name": "test", "value": 123} 122 + cid = generate_cid(data) 123 + 124 + assert verify_cid(cid, data) is True 125 + 126 + def test_verify_non_matching_data(self): 127 + """verify_cid returns False for non-matching data.""" 128 + data = {"name": "test", "value": 123} 129 + cid = generate_cid(data) 130 + 131 + different_data = {"name": "test", "value": 456} 132 + assert verify_cid(cid, different_data) is False 133 + 134 + def test_verify_with_complex_data(self): 135 + """verify_cid works with complex nested structures.""" 136 + data = { 137 + "schema": { 138 + "name": "ImageSample", 139 + "version": "1.0.0", 140 + "fields": [ 141 + {"name": "image", "type": "ndarray"}, 142 + {"name": "label", "type": "str"}, 143 + ], 144 + } 145 + } 146 + cid = generate_cid(data) 147 + 148 + assert verify_cid(cid, data) is True 149 + 150 + 151 + class TestParseCid: 152 + """Tests for parse_cid function.""" 153 + 154 + def test_parse_cid_components(self): 155 + """parse_cid extracts CID components.""" 156 + data = {"test": "data"} 157 + cid = generate_cid(data) 158 + 159 + parsed = parse_cid(cid) 160 + 161 + assert parsed["version"] == 1 162 + assert parsed["codec"] == 0x71 # dag-cbor 163 + assert parsed["hash"]["code"] == 0x12 # sha256 164 + assert parsed["hash"]["size"] == 32 165 + 166 + def test_parse_cid_digest_matches(self): 167 + """Parsed digest matches the SHA-256 of the data.""" 168 + import hashlib 169 + 170 + data = {"test": "data"} 171 + cid = generate_cid(data) 172 + 173 + cbor_bytes = libipld.encode_dag_cbor(data) 174 + expected_digest = hashlib.sha256(cbor_bytes).digest() 175 + 176 + parsed = parse_cid(cid) 177 + assert parsed["hash"]["digest"] == expected_digest 178 + 179 + @pytest.mark.parametrize("malformed_cid", [ 180 + "", # empty 181 + "invalid", # not a CID 182 + "bafy123", # truncated CID 183 + "Qm123", # v0 prefix but invalid 184 + ]) 185 + def test_parse_cid_malformed_raises_valueerror(self, malformed_cid): 186 + """Malformed CID strings raise ValueError.""" 187 + with pytest.raises(ValueError, match="Failed to decode CID"): 188 + parse_cid(malformed_cid) 189 + 190 + 191 + class TestAtprotoCompatibility: 192 + """Tests verifying ATProto SDK compatibility.""" 193 + 194 + def test_cid_decodable_by_atproto(self): 195 + """Generated CIDs can be decoded by atproto SDK.""" 196 + from atproto_core.cid.cid import CID 197 + 198 + data = {"name": "TestSchema", "version": "1.0.0"} 199 + cid_str = generate_cid(data) 200 + 201 + # Should not raise 202 + cid_obj = CID.decode(cid_str) 203 + 204 + assert cid_obj.version == 1 205 + assert cid_obj.codec == 0x71 206 + 207 + def test_hash_matches_atproto_decode(self): 208 + """Hash in generated CID matches when decoded by atproto.""" 209 + import hashlib 210 + from atproto_core.cid.cid import CID 211 + 212 + data = {"name": "TestSchema", "version": "1.0.0"} 213 + cid_str = generate_cid(data) 214 + 215 + cbor_bytes = libipld.encode_dag_cbor(data) 216 + expected_hash = hashlib.sha256(cbor_bytes).digest() 217 + 218 + cid_obj = CID.decode(cid_str) 219 + assert cid_obj.hash.digest == expected_hash
+95
tests/test_dataset.py
··· 585 585 assert batches_seen == n_samples // batch_size 586 586 587 587 588 + def test_from_bytes_invalid_msgpack(): 589 + """Test from_bytes raises on invalid msgpack data.""" 590 + @atdata.packable 591 + class SimpleSample: 592 + value: int 593 + 594 + with pytest.raises(Exception): # ormsgpack raises on invalid data 595 + SimpleSample.from_bytes(b"not valid msgpack data") 596 + 597 + 598 + def test_from_bytes_missing_field(): 599 + """Test from_bytes raises when required field is missing.""" 600 + @atdata.packable 601 + class RequiredFieldSample: 602 + name: str 603 + count: int 604 + 605 + import ormsgpack 606 + # Only provide 'name', missing 'count' 607 + incomplete_data = ormsgpack.packb({"name": "test"}) 608 + 609 + with pytest.raises(TypeError): # Missing required argument 610 + RequiredFieldSample.from_bytes(incomplete_data) 611 + 612 + 613 + def test_wrap_missing_msgpack_key(tmp_path): 614 + """Test wrap asserts on sample missing msgpack key.""" 615 + @atdata.packable 616 + class WrapTestSample: 617 + value: int 618 + 619 + wds_filename = (tmp_path / "wrap_test.tar").as_posix() 620 + with wds.writer.TarWriter(wds_filename) as sink: 621 + sample = WrapTestSample(value=42) 622 + sink.write(sample.as_wds) 623 + 624 + dataset = atdata.Dataset[WrapTestSample](wds_filename) 625 + 626 + # Directly call wrap with missing key 627 + with pytest.raises(AssertionError): 628 + dataset.wrap({"__key__": "test"}) # Missing 'msgpack' key 629 + 630 + 631 + def test_wrap_wrong_msgpack_type(tmp_path): 632 + """Test wrap asserts when msgpack value is not bytes.""" 633 + @atdata.packable 634 + class WrapTypeSample: 635 + value: int 636 + 637 + wds_filename = (tmp_path / "wrap_type_test.tar").as_posix() 638 + with wds.writer.TarWriter(wds_filename) as sink: 639 + sample = WrapTypeSample(value=42) 640 + sink.write(sample.as_wds) 641 + 642 + dataset = atdata.Dataset[WrapTypeSample](wds_filename) 643 + 644 + # Directly call wrap with wrong type 645 + with pytest.raises(AssertionError): 646 + dataset.wrap({"__key__": "test", "msgpack": "not bytes"}) 647 + 648 + 649 + def test_dataset_nonexistent_file(): 650 + """Test Dataset raises on nonexistent tar file during iteration.""" 651 + @atdata.packable 652 + class NonexistentSample: 653 + value: int 654 + 655 + dataset = atdata.Dataset[NonexistentSample]("/nonexistent/path/data.tar") 656 + 657 + # Dataset creation succeeds (lazy loading) 658 + assert dataset is not None 659 + 660 + # Iteration fails when file doesn't exist 661 + with pytest.raises(Exception): # FileNotFoundError or similar 662 + list(dataset.ordered(batch_size=None)) 663 + 664 + 665 + def test_dataset_invalid_batch_size(tmp_path): 666 + """Test Dataset raises on invalid batch_size values.""" 667 + @atdata.packable 668 + class BatchSizeSample: 669 + value: int 670 + 671 + wds_filename = (tmp_path / "batch_test.tar").as_posix() 672 + with wds.writer.TarWriter(wds_filename) as sink: 673 + sample = BatchSizeSample(value=42) 674 + sink.write(sample.as_wds) 675 + 676 + dataset = atdata.Dataset[BatchSizeSample](wds_filename) 677 + 678 + # batch_size=0 produces empty batches, causing IndexError in webdataset 679 + with pytest.raises((ValueError, AssertionError, IndexError)): 680 + list(dataset.ordered(batch_size=0)) 681 + 682 + 588 683 ##
+131
tests/test_hf_api.py
··· 23 23 _resolve_shards, 24 24 _resolve_data_files, 25 25 _group_shards_by_split, 26 + _is_indexed_path, 27 + _parse_indexed_path, 26 28 ) 29 + from unittest.mock import Mock, MagicMock 27 30 28 31 from numpy.typing import NDArray 29 32 ··· 643 646 # Aggregated attributes 644 647 labels = first_batch.label 645 648 assert len(labels) == 4 649 + 650 + 651 + ## 652 + # Indexed path tests 653 + 654 + 655 + class TestIsIndexedPath: 656 + """Tests for _is_indexed_path function.""" 657 + 658 + def test_at_handle_path(self): 659 + """@handle/dataset is indexed.""" 660 + assert _is_indexed_path("@maxine.science/mnist") is True 661 + 662 + def test_at_did_path(self): 663 + """@did:plc:abc/dataset is indexed.""" 664 + assert _is_indexed_path("@did:plc:abc123/my-dataset") is True 665 + 666 + def test_local_path(self): 667 + """Local paths are not indexed.""" 668 + assert _is_indexed_path("/path/to/data.tar") is False 669 + 670 + def test_s3_path(self): 671 + """S3 URLs are not indexed.""" 672 + assert _is_indexed_path("s3://bucket/data.tar") is False 673 + 674 + def test_relative_path(self): 675 + """Relative paths are not indexed.""" 676 + assert _is_indexed_path("./data/train.tar") is False 677 + 678 + 679 + class TestParseIndexedPath: 680 + """Tests for _parse_indexed_path function.""" 681 + 682 + def test_parse_handle_dataset(self): 683 + """Parse @handle/dataset format.""" 684 + handle, name = _parse_indexed_path("@maxine.science/mnist") 685 + assert handle == "maxine.science" 686 + assert name == "mnist" 687 + 688 + def test_parse_did_dataset(self): 689 + """Parse @did:plc:xxx/dataset format.""" 690 + handle, name = _parse_indexed_path("@did:plc:abc123/my-dataset") 691 + assert handle == "did:plc:abc123" 692 + assert name == "my-dataset" 693 + 694 + def test_parse_invalid_no_slash(self): 695 + """Invalid path without slash raises ValueError.""" 696 + with pytest.raises(ValueError, match="Invalid indexed path format"): 697 + _parse_indexed_path("@handle-only") 698 + 699 + def test_parse_invalid_no_at(self): 700 + """Path without @ raises ValueError.""" 701 + with pytest.raises(ValueError, match="Not an indexed path"): 702 + _parse_indexed_path("handle/dataset") 703 + 704 + def test_parse_invalid_empty_parts(self): 705 + """Empty handle or dataset raises ValueError.""" 706 + with pytest.raises(ValueError, match="Invalid indexed path"): 707 + _parse_indexed_path("@/dataset") 708 + 709 + 710 + class TestLoadDatasetWithIndex: 711 + """Tests for load_dataset with index parameter.""" 712 + 713 + def test_indexed_path_requires_index(self): 714 + """@handle/dataset without index raises ValueError.""" 715 + with pytest.raises(ValueError, match="Index required"): 716 + load_dataset("@handle/dataset", SimpleTestSample) 717 + 718 + def test_none_sample_type_requires_index(self): 719 + """sample_type=None without index raises ValueError.""" 720 + with pytest.raises(ValueError, match="sample_type is required"): 721 + load_dataset("/path/to/data.tar", None) 722 + 723 + def test_indexed_path_with_mock_index(self): 724 + """load_dataset with indexed path uses index lookup.""" 725 + mock_index = Mock() 726 + mock_entry = Mock() 727 + mock_entry.data_urls = ["s3://bucket/data.tar"] 728 + mock_entry.schema_ref = "local://schemas/test@1.0.0" 729 + mock_index.get_dataset.return_value = mock_entry 730 + 731 + # Need to mock decode_schema since sample_type is provided 732 + ds = load_dataset( 733 + "@local/my-dataset", 734 + SimpleTestSample, 735 + index=mock_index, 736 + split="train", 737 + ) 738 + 739 + mock_index.get_dataset.assert_called_once_with("my-dataset") 740 + assert ds.url == "s3://bucket/data.tar" 741 + 742 + def test_indexed_path_auto_type_resolution(self): 743 + """load_dataset with sample_type=None uses decode_schema.""" 744 + mock_index = Mock() 745 + mock_entry = Mock() 746 + mock_entry.data_urls = ["s3://bucket/data.tar"] 747 + mock_entry.schema_ref = "local://schemas/test@1.0.0" 748 + mock_index.get_dataset.return_value = mock_entry 749 + mock_index.decode_schema.return_value = SimpleTestSample 750 + 751 + ds = load_dataset( 752 + "@local/my-dataset", 753 + None, 754 + index=mock_index, 755 + split="train", 756 + ) 757 + 758 + mock_index.decode_schema.assert_called_once_with("local://schemas/test@1.0.0") 759 + assert ds.sample_type == SimpleTestSample 760 + 761 + def test_indexed_path_returns_datasetdict_without_split(self): 762 + """load_dataset with indexed path returns DatasetDict when split=None.""" 763 + mock_index = Mock() 764 + mock_entry = Mock() 765 + mock_entry.data_urls = ["s3://bucket/data.tar"] 766 + mock_entry.schema_ref = "local://schemas/test@1.0.0" 767 + mock_index.get_dataset.return_value = mock_entry 768 + 769 + result = load_dataset( 770 + "@local/my-dataset", 771 + SimpleTestSample, 772 + index=mock_index, 773 + ) 774 + 775 + assert isinstance(result, DatasetDict) 776 + assert "train" in result
+325
tests/test_integration.py
··· 1 + """Integration tests for atdata local-atmosphere workflows. 2 + 3 + These tests verify end-to-end workflows spanning local and atmosphere 4 + components, using mocks for external services (Redis, ATProto PDS). 5 + """ 6 + 7 + import pytest 8 + from unittest.mock import Mock, MagicMock, patch 9 + from dataclasses import dataclass 10 + import tempfile 11 + from pathlib import Path 12 + 13 + import numpy as np 14 + import webdataset as wds 15 + 16 + import atdata 17 + from atdata.local import LocalIndex, LocalDatasetEntry 18 + from atdata.atmosphere import AtmosphereIndex, AtmosphereIndexEntry 19 + from atdata.promote import promote_to_atmosphere 20 + 21 + 22 + @atdata.packable 23 + class IntegrationTestSample: 24 + """Sample type for integration tests.""" 25 + name: str 26 + value: int 27 + 28 + 29 + class TestLocalToAtmosphereRoundTrip: 30 + """Integration tests for local → atmosphere promotion workflow.""" 31 + 32 + def test_promote_preserves_data_urls(self, tmp_path): 33 + """Promote should preserve data URLs when no data_store provided.""" 34 + # Create a local dataset entry 35 + local_entry = LocalDatasetEntry( 36 + _name="test-dataset", 37 + _schema_ref="local://schemas/test_integration.IntegrationTestSample@1.0.0", 38 + _data_urls=["s3://bucket/data-000000.tar", "s3://bucket/data-000001.tar"], 39 + _metadata={"source": "test"}, 40 + ) 41 + 42 + # Mock local index with schema 43 + mock_local_index = Mock() 44 + mock_local_index.get_schema.return_value = { 45 + "name": "test_integration.IntegrationTestSample", 46 + "version": "1.0.0", 47 + "fields": [ 48 + {"name": "name", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 49 + {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 50 + ], 51 + } 52 + 53 + # Mock atmosphere client 54 + mock_client = Mock() 55 + 56 + with patch("atdata.promote._find_or_publish_schema") as mock_find_schema: 57 + mock_find_schema.return_value = "at://did:plc:test/schema/abc" 58 + 59 + with patch("atdata.atmosphere.DatasetPublisher") as MockPublisher: 60 + mock_publisher = MockPublisher.return_value 61 + mock_publisher.publish_with_urls.return_value = Mock( 62 + __str__=lambda s: "at://did:plc:test/record/xyz" 63 + ) 64 + 65 + result = promote_to_atmosphere( 66 + local_entry, 67 + mock_local_index, 68 + mock_client, 69 + ) 70 + 71 + # Verify data URLs were preserved 72 + call_kwargs = mock_publisher.publish_with_urls.call_args[1] 73 + assert call_kwargs["urls"] == [ 74 + "s3://bucket/data-000000.tar", 75 + "s3://bucket/data-000001.tar", 76 + ] 77 + 78 + # Verify metadata preserved 79 + assert call_kwargs["metadata"] == {"source": "test"} 80 + 81 + def test_promote_transfers_schema_metadata(self, tmp_path): 82 + """Promote should use schema version from local index.""" 83 + local_entry = LocalDatasetEntry( 84 + _name="versioned-dataset", 85 + _schema_ref="local://schemas/MySample@2.1.0", 86 + _data_urls=["s3://bucket/data.tar"], 87 + ) 88 + 89 + mock_local_index = Mock() 90 + mock_local_index.get_schema.return_value = { 91 + "name": "MySample", 92 + "version": "2.1.0", 93 + "description": "A sample with specific version", 94 + "fields": [ 95 + {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 96 + ], 97 + } 98 + 99 + mock_client = Mock() 100 + 101 + with patch("atdata.promote._find_or_publish_schema") as mock_find_schema: 102 + mock_find_schema.return_value = "at://schema" 103 + 104 + with patch("atdata.atmosphere.DatasetPublisher") as MockPublisher: 105 + mock_publisher = MockPublisher.return_value 106 + mock_publisher.publish_with_urls.return_value = Mock(__str__=lambda s: "at://result") 107 + 108 + promote_to_atmosphere(local_entry, mock_local_index, mock_client) 109 + 110 + # Verify _find_or_publish_schema was called with correct version 111 + call_args = mock_find_schema.call_args 112 + assert call_args[0][1] == "2.1.0" # version argument 113 + 114 + 115 + class TestSchemaDeduplication: 116 + """Tests for schema deduplication during promotion.""" 117 + 118 + def test_existing_schema_reused(self): 119 + """Promoting with existing schema should reuse it, not create duplicate.""" 120 + from atdata.promote import _find_or_publish_schema 121 + 122 + mock_client = Mock() 123 + 124 + # Mock finding an existing schema 125 + with patch("atdata.promote._find_existing_schema") as mock_find: 126 + mock_find.return_value = "at://did:plc:test/schema/existing" 127 + 128 + with patch("atdata.atmosphere.SchemaPublisher") as MockPublisher: 129 + result = _find_or_publish_schema( 130 + IntegrationTestSample, 131 + "1.0.0", 132 + mock_client, 133 + ) 134 + 135 + # Should return existing URI 136 + assert result == "at://did:plc:test/schema/existing" 137 + 138 + # Should NOT have called publish 139 + MockPublisher.return_value.publish.assert_not_called() 140 + 141 + def test_new_schema_published_when_not_found(self): 142 + """Promoting without existing schema should publish new one.""" 143 + from atdata.promote import _find_or_publish_schema 144 + 145 + mock_client = Mock() 146 + 147 + with patch("atdata.promote._find_existing_schema") as mock_find: 148 + mock_find.return_value = None # No existing schema 149 + 150 + with patch("atdata.atmosphere.SchemaPublisher") as MockPublisher: 151 + mock_publisher = MockPublisher.return_value 152 + mock_publisher.publish.return_value = Mock( 153 + __str__=lambda s: "at://did:plc:test/schema/new" 154 + ) 155 + 156 + result = _find_or_publish_schema( 157 + IntegrationTestSample, 158 + "1.0.0", 159 + mock_client, 160 + ) 161 + 162 + # Should return new URI 163 + assert result == "at://did:plc:test/schema/new" 164 + 165 + # Should have called publish 166 + mock_publisher.publish.assert_called_once() 167 + 168 + def test_version_mismatch_creates_new_schema(self): 169 + """Different version should create new schema even if name matches.""" 170 + from atdata.promote import _find_existing_schema 171 + 172 + mock_client = Mock() 173 + 174 + with patch("atdata.atmosphere.SchemaLoader") as MockLoader: 175 + mock_loader = MockLoader.return_value 176 + mock_loader.list_all.return_value = [ 177 + { 178 + "uri": "at://did:plc:test/schema/v1", 179 + "value": { 180 + "name": "test_integration.IntegrationTestSample", 181 + "version": "1.0.0", # Different version 182 + } 183 + } 184 + ] 185 + 186 + # Looking for version 2.0.0 187 + result = _find_existing_schema( 188 + mock_client, 189 + "test_integration.IntegrationTestSample", 190 + "2.0.0", 191 + ) 192 + 193 + # Should not find the v1 schema 194 + assert result is None 195 + 196 + 197 + class TestLoadDatasetWithIndex: 198 + """Integration tests for load_dataset with index parameter.""" 199 + 200 + def test_load_from_local_index(self, tmp_path): 201 + """load_dataset with LocalIndex should resolve dataset.""" 202 + # Create actual test data 203 + wds_file = tmp_path / "test-data.tar" 204 + with wds.writer.TarWriter(str(wds_file)) as sink: 205 + sample = IntegrationTestSample(name="test", value=42) 206 + sink.write(sample.as_wds) 207 + 208 + # Create local index entry 209 + local_entry = LocalDatasetEntry( 210 + _name="my-dataset", 211 + _schema_ref="local://schemas/IntegrationTestSample@1.0.0", 212 + _data_urls=[str(wds_file)], 213 + ) 214 + 215 + # Mock index 216 + mock_index = Mock() 217 + mock_index.get_dataset.return_value = local_entry 218 + mock_index.decode_schema.return_value = IntegrationTestSample 219 + 220 + # Load via index 221 + ds = atdata.load_dataset( 222 + "@local/my-dataset", 223 + index=mock_index, 224 + split="train", 225 + ) 226 + 227 + # Should return a Dataset 228 + assert isinstance(ds, atdata.Dataset) 229 + 230 + # Should be able to iterate (batch_size=None for individual samples) 231 + samples = list(ds.ordered(batch_size=None)) 232 + assert len(samples) == 1 233 + assert samples[0].name == "test" 234 + assert samples[0].value == 42 235 + 236 + def test_load_with_explicit_sample_type(self, tmp_path): 237 + """load_dataset with explicit sample_type should use it.""" 238 + wds_file = tmp_path / "typed-data.tar" 239 + with wds.writer.TarWriter(str(wds_file)) as sink: 240 + sample = IntegrationTestSample(name="explicit", value=100) 241 + sink.write(sample.as_wds) 242 + 243 + local_entry = LocalDatasetEntry( 244 + _name="typed-dataset", 245 + _schema_ref="local://schemas/IntegrationTestSample@1.0.0", 246 + _data_urls=[str(wds_file)], 247 + ) 248 + 249 + mock_index = Mock() 250 + mock_index.get_dataset.return_value = local_entry 251 + 252 + # Load with explicit type (should not call decode_schema) 253 + ds = atdata.load_dataset( 254 + "@local/typed-dataset", 255 + IntegrationTestSample, 256 + index=mock_index, 257 + split="train", 258 + ) 259 + 260 + # decode_schema should not be called when type is explicit 261 + mock_index.decode_schema.assert_not_called() 262 + 263 + samples = list(ds.ordered(batch_size=None)) 264 + assert len(samples) == 1 265 + assert isinstance(samples[0], IntegrationTestSample) 266 + 267 + 268 + class TestIndexEntryRoundTrip: 269 + """Tests for index entry serialization round-trips.""" 270 + 271 + def test_local_entry_redis_round_trip(self, clean_redis): 272 + """LocalDatasetEntry should round-trip through Redis correctly.""" 273 + original = LocalDatasetEntry( 274 + _name="roundtrip-test", 275 + _schema_ref="local://schemas/Test@1.0.0", 276 + _data_urls=["s3://bucket/shard-000.tar", "s3://bucket/shard-001.tar"], 277 + _metadata={"key": "value", "count": 42}, 278 + ) 279 + 280 + # Write to Redis 281 + original.write_to(clean_redis) 282 + 283 + # Read back 284 + loaded = LocalDatasetEntry.from_redis(clean_redis, original.cid) 285 + 286 + # Verify all fields match 287 + assert loaded.name == original.name 288 + assert loaded.schema_ref == original.schema_ref 289 + assert loaded.data_urls == original.data_urls 290 + assert loaded.metadata == original.metadata 291 + assert loaded.cid == original.cid 292 + 293 + def test_local_entry_cid_deterministic(self): 294 + """Same content should produce same CID.""" 295 + entry1 = LocalDatasetEntry( 296 + _name="deterministic", 297 + _schema_ref="local://schemas/Test@1.0.0", 298 + _data_urls=["s3://bucket/data.tar"], 299 + ) 300 + 301 + entry2 = LocalDatasetEntry( 302 + _name="deterministic", 303 + _schema_ref="local://schemas/Test@1.0.0", 304 + _data_urls=["s3://bucket/data.tar"], 305 + ) 306 + 307 + # CIDs should match (based on schema_ref and data_urls) 308 + assert entry1.cid == entry2.cid 309 + 310 + def test_local_entry_cid_differs_with_content(self): 311 + """Different content should produce different CID.""" 312 + entry1 = LocalDatasetEntry( 313 + _name="same-name", 314 + _schema_ref="local://schemas/Test@1.0.0", 315 + _data_urls=["s3://bucket/data-v1.tar"], 316 + ) 317 + 318 + entry2 = LocalDatasetEntry( 319 + _name="same-name", 320 + _schema_ref="local://schemas/Test@1.0.0", 321 + _data_urls=["s3://bucket/data-v2.tar"], # Different URL 322 + ) 323 + 324 + # CIDs should differ 325 + assert entry1.cid != entry2.cid
+796 -217
tests/test_local.py
··· 26 26 27 27 28 28 ## 29 - # Test fixtures 30 - 31 - @pytest.fixture 32 - def redis_connection(): 33 - """Provide a Redis connection, skip test if Redis is not available.""" 34 - try: 35 - redis = Redis() 36 - redis.ping() 37 - yield redis 38 - except Exception: 39 - pytest.skip("Redis server not available") 40 - 41 - 42 - @pytest.fixture 43 - def clean_redis(redis_connection): 44 - """Provide a Redis connection with automatic BasicIndexEntry cleanup. 45 - 46 - Clears all BasicIndexEntry keys before and after each test to ensure 47 - test isolation. 48 - """ 49 - def _clear_entries(): 50 - for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 51 - redis_connection.delete(key) 52 - 53 - _clear_entries() 54 - yield redis_connection 55 - _clear_entries() 56 - 29 + # Test fixtures (redis_connection and clean_redis are in conftest.py) 57 30 58 31 @pytest.fixture 59 32 def mock_s3(): ··· 155 128 assert result2 == f"{ArrayTestSample.__module__}.ArrayTestSample" 156 129 157 130 158 - def test_decode_bytes_dict(): 159 - """Test that byte dictionaries from Redis are correctly decoded to strings. 160 - 161 - Should handle UTF-8 decoding of both keys and values from Redis response format. 162 - """ 163 - bytes_dict = { 164 - b'wds_url': b's3://bucket/dataset.tar', 165 - b'sample_kind': b'module.Sample', 166 - b'metadata_url': b's3://bucket/metadata.msgpack', 167 - b'uuid': b'12345678-1234-1234-1234-123456789abc' 168 - } 169 - 170 - result = atlocal._decode_bytes_dict(bytes_dict) 171 - 172 - assert result == { 173 - 'wds_url': 's3://bucket/dataset.tar', 174 - 'sample_kind': 'module.Sample', 175 - 'metadata_url': 's3://bucket/metadata.msgpack', 176 - 'uuid': '12345678-1234-1234-1234-123456789abc' 177 - } 178 - assert all(isinstance(k, str) for k in result.keys()) 179 - assert all(isinstance(v, str) for v in result.values()) 180 - 181 - 182 131 def test_s3_env_valid_credentials(tmp_path): 183 132 """Test loading S3 credentials from a valid .env file. 184 133 ··· 259 208 260 209 261 210 ## 262 - # BasicIndexEntry tests 211 + # LocalDatasetEntry tests 263 212 264 - def test_basic_index_entry_creation(): 265 - """Test creating a BasicIndexEntry with explicit values. 213 + def test_local_dataset_entry_creation(): 214 + """Test creating a LocalDatasetEntry with explicit values. 266 215 267 - Should create an entry with provided wds_url, sample_kind, metadata_url, and uuid. 216 + Should create an entry with provided name, schema_ref, data_urls, and generate CID. 268 217 """ 269 - entry = atlocal.BasicIndexEntry( 270 - wds_url="s3://bucket/dataset.tar", 271 - sample_kind="test_module.TestSample", 272 - metadata_url="s3://bucket/metadata.msgpack", 273 - uuid="12345678-1234-1234-1234-123456789abc" 218 + entry = atlocal.LocalDatasetEntry( 219 + _name="test-dataset", 220 + _schema_ref="local://schemas/test_module.TestSample@1.0.0", 221 + _data_urls=["s3://bucket/dataset.tar"], 222 + _metadata={"description": "test"}, 274 223 ) 275 224 276 - assert entry.wds_url == "s3://bucket/dataset.tar" 277 - assert entry.sample_kind == "test_module.TestSample" 278 - assert entry.metadata_url == "s3://bucket/metadata.msgpack" 279 - assert entry.uuid == "12345678-1234-1234-1234-123456789abc" 225 + assert entry.name == "test-dataset" 226 + assert entry.schema_ref == "local://schemas/test_module.TestSample@1.0.0" 227 + assert entry.data_urls == ["s3://bucket/dataset.tar"] 228 + assert entry.metadata == {"description": "test"} 229 + # CID should be auto-generated 230 + assert entry.cid is not None 231 + assert entry.cid.startswith("bafy") 280 232 281 233 282 - def test_basic_index_entry_default_uuid(): 283 - """Test that BasicIndexEntry generates a valid UUID by default. 234 + def test_local_dataset_entry_cid_generation(): 235 + """Test that LocalDatasetEntry generates deterministic CIDs. 284 236 285 - Should auto-generate a unique UUID when none is provided, and it should be 286 - parsable as a valid UUID. 237 + Same content should produce the same CID. 287 238 """ 288 - entry = atlocal.BasicIndexEntry( 289 - wds_url="s3://bucket/dataset.tar", 290 - sample_kind="test_module.TestSample", 291 - metadata_url="s3://bucket/metadata.msgpack" 239 + entry1 = atlocal.LocalDatasetEntry( 240 + _name="test-dataset", 241 + _schema_ref="local://schemas/test_module.TestSample@1.0.0", 242 + _data_urls=["s3://bucket/dataset.tar"], 243 + ) 244 + entry2 = atlocal.LocalDatasetEntry( 245 + _name="test-dataset", # Name doesn't affect CID 246 + _schema_ref="local://schemas/test_module.TestSample@1.0.0", 247 + _data_urls=["s3://bucket/dataset.tar"], 292 248 ) 293 249 294 - assert entry.uuid is not None 295 - # Verify it's a valid UUID by parsing it 296 - parsed_uuid = UUID(entry.uuid) 297 - assert str(parsed_uuid) == entry.uuid 250 + # Same schema_ref and data_urls = same CID 251 + assert entry1.cid == entry2.cid 252 + 253 + 254 + def test_local_dataset_entry_different_content_different_cid(): 255 + """Test that different content produces different CIDs.""" 256 + entry1 = atlocal.LocalDatasetEntry( 257 + _name="dataset1", 258 + _schema_ref="local://schemas/test_module.TestSample@1.0.0", 259 + _data_urls=["s3://bucket/dataset1.tar"], 260 + ) 261 + entry2 = atlocal.LocalDatasetEntry( 262 + _name="dataset2", 263 + _schema_ref="local://schemas/test_module.TestSample@1.0.0", 264 + _data_urls=["s3://bucket/dataset2.tar"], # Different URL 265 + ) 298 266 267 + assert entry1.cid != entry2.cid 299 268 300 - def test_basic_index_entry_write_to_redis(clean_redis): 301 - """Test persisting a BasicIndexEntry to Redis. 302 269 303 - Should write the entry to Redis as a hash with key 'BasicIndexEntry:{uuid}' 270 + def test_local_dataset_entry_write_to_redis(clean_redis): 271 + """Test persisting a LocalDatasetEntry to Redis. 272 + 273 + Should write the entry to Redis as a hash with key 'LocalDatasetEntry:{cid}' 304 274 and all fields should be retrievable with correct values. 305 275 """ 306 - test_uuid = "12345678-1234-1234-1234-123456789abc" 307 - 308 - entry = atlocal.BasicIndexEntry( 309 - wds_url="s3://bucket/dataset.tar", 310 - sample_kind="test_module.TestSample", 311 - metadata_url="s3://bucket/metadata.msgpack", 312 - uuid=test_uuid 276 + entry = atlocal.LocalDatasetEntry( 277 + _name="test-dataset", 278 + _schema_ref="local://schemas/test_module.TestSample@1.0.0", 279 + _data_urls=["s3://bucket/dataset.tar"], 280 + _metadata={"version": "1.0"}, 313 281 ) 314 282 315 283 entry.write_to(clean_redis) 316 284 317 - # Retrieve and verify actual stored values 318 - stored_data = atlocal._decode_bytes_dict(clean_redis.hgetall(f"BasicIndexEntry:{test_uuid}")) 319 - assert stored_data['wds_url'] == "s3://bucket/dataset.tar" 320 - assert stored_data['sample_kind'] == "test_module.TestSample" 321 - assert stored_data['metadata_url'] == "s3://bucket/metadata.msgpack" 322 - assert stored_data['uuid'] == test_uuid 285 + # Verify key exists 286 + assert clean_redis.exists(f"LocalDatasetEntry:{entry.cid}") 323 287 288 + # Load back and verify 289 + loaded = atlocal.LocalDatasetEntry.from_redis(clean_redis, entry.cid) 290 + assert loaded.name == entry.name 291 + assert loaded.schema_ref == entry.schema_ref 292 + assert loaded.data_urls == entry.data_urls 293 + assert loaded.metadata == entry.metadata 294 + assert loaded.cid == entry.cid 324 295 325 - def test_basic_index_entry_round_trip_redis(clean_redis): 326 - """Test writing and reading a BasicIndexEntry from Redis. 296 + 297 + def test_local_dataset_entry_round_trip_redis(clean_redis): 298 + """Test writing and reading a LocalDatasetEntry from Redis. 327 299 328 300 Should be able to write an entry to Redis and read it back with all fields 329 301 intact and matching the original values. 330 302 """ 331 - test_uuid = "12345678-1234-1234-1234-123456789abc" 332 - 333 - original_entry = atlocal.BasicIndexEntry( 334 - wds_url="s3://bucket/dataset.tar", 335 - sample_kind="test_module.TestSample", 336 - metadata_url="s3://bucket/metadata.msgpack", 337 - uuid=test_uuid 303 + original_entry = atlocal.LocalDatasetEntry( 304 + _name="my-dataset", 305 + _schema_ref="local://schemas/module.Sample@2.0.0", 306 + _data_urls=["s3://bucket/data-{000000..000009}.tar"], 307 + _metadata={"author": "test", "tags": ["a", "b"]}, 338 308 ) 339 309 340 310 original_entry.write_to(clean_redis) 341 311 342 312 # Read back from Redis 343 - stored_data = atlocal._decode_bytes_dict(clean_redis.hgetall(f"BasicIndexEntry:{test_uuid}")) 344 - retrieved_entry = atlocal.BasicIndexEntry(**stored_data) 313 + retrieved_entry = atlocal.LocalDatasetEntry.from_redis(clean_redis, original_entry.cid) 314 + 315 + assert retrieved_entry.name == original_entry.name 316 + assert retrieved_entry.schema_ref == original_entry.schema_ref 317 + assert retrieved_entry.data_urls == original_entry.data_urls 318 + assert retrieved_entry.metadata == original_entry.metadata 319 + assert retrieved_entry.cid == original_entry.cid 320 + 321 + 322 + def test_local_dataset_entry_legacy_properties(): 323 + """Test that legacy properties work for backwards compatibility.""" 324 + entry = atlocal.LocalDatasetEntry( 325 + _name="test-dataset", 326 + _schema_ref="local://schemas/test_module.TestSample@1.0.0", 327 + _data_urls=["s3://bucket/dataset.tar"], 328 + ) 329 + 330 + # Legacy properties should work 331 + assert entry.wds_url == "s3://bucket/dataset.tar" 332 + assert entry.sample_kind == "local://schemas/test_module.TestSample@1.0.0" 333 + 334 + 335 + def test_local_dataset_entry_implements_index_entry_protocol(): 336 + """Test that LocalDatasetEntry implements the IndexEntry protocol.""" 337 + from atdata._protocols import IndexEntry 338 + 339 + entry = atlocal.LocalDatasetEntry( 340 + _name="test-dataset", 341 + _schema_ref="local://schemas/test_module.TestSample@1.0.0", 342 + _data_urls=["s3://bucket/dataset.tar"], 343 + ) 344 + 345 + # Should satisfy the protocol 346 + assert isinstance(entry, IndexEntry) 347 + 348 + 349 + def test_index_implements_abstract_index_protocol(): 350 + """Test that Index has all AbstractIndex protocol methods.""" 351 + index = atlocal.Index() 352 + 353 + # Check protocol methods exist 354 + assert hasattr(index, 'insert_dataset') 355 + assert hasattr(index, 'get_dataset') 356 + assert hasattr(index, 'list_datasets') 357 + assert hasattr(index, 'publish_schema') 358 + assert hasattr(index, 'get_schema') 359 + assert hasattr(index, 'list_schemas') 360 + assert hasattr(index, 'decode_schema') 345 361 346 - assert retrieved_entry.wds_url == original_entry.wds_url 347 - assert retrieved_entry.sample_kind == original_entry.sample_kind 348 - assert retrieved_entry.metadata_url == original_entry.metadata_url 349 - assert retrieved_entry.uuid == original_entry.uuid 362 + # Check they are callable 363 + assert callable(index.insert_dataset) 364 + assert callable(index.get_dataset) 365 + assert callable(index.list_datasets) 350 366 351 367 352 368 ## ··· 386 402 assert isinstance(index._redis, Redis) 387 403 388 404 389 - def test_index_add_entry_without_uuid(clean_redis): 390 - """Test adding a dataset entry to the index without specifying UUID. 405 + def test_index_add_entry(clean_redis): 406 + """Test adding a dataset entry to the index. 391 407 392 - Should create a BasicIndexEntry with auto-generated UUID and persist it to Redis. 408 + Should create a LocalDatasetEntry with auto-generated CID and persist it to Redis. 393 409 """ 394 410 index = atlocal.Index(redis=clean_redis) 395 411 ··· 398 414 metadata_url="s3://bucket/metadata.msgpack" 399 415 ) 400 416 401 - entry = index.add_entry(ds) 417 + entry = index.add_entry(ds, name="test-dataset") 402 418 403 - assert entry.uuid is not None 404 - assert entry.wds_url == ds.url 405 - assert entry.sample_kind == f"{SimpleTestSample.__module__}.SimpleTestSample" 406 - assert entry.metadata_url == ds.metadata_url 419 + assert entry.cid is not None 420 + assert entry.cid.startswith("bafy") 421 + assert entry.name == "test-dataset" 422 + assert entry.data_urls == ["s3://bucket/dataset.tar"] 423 + assert "SimpleTestSample" in entry.schema_ref 407 424 408 425 # Verify it was persisted to Redis 409 - stored_data = clean_redis.hgetall(f"BasicIndexEntry:{entry.uuid}") 426 + stored_data = clean_redis.hgetall(f"LocalDatasetEntry:{entry.cid}") 410 427 assert len(stored_data) > 0 411 428 412 429 413 - def test_index_add_entry_with_uuid(clean_redis): 414 - """Test adding a dataset entry to the index with a specified UUID. 430 + def test_index_add_entry_with_schema_ref(clean_redis): 431 + """Test adding a dataset entry with explicit schema_ref. 415 432 416 - Should create a BasicIndexEntry with the provided UUID and persist it to Redis. 433 + Should use the provided schema_ref instead of auto-generating. 417 434 """ 418 435 index = atlocal.Index(redis=clean_redis) 419 - test_uuid = "12345678-1234-1234-1234-123456789abc" 420 436 421 - ds = atdata.Dataset[SimpleTestSample]( 422 - url="s3://bucket/dataset.tar", 423 - metadata_url="s3://bucket/metadata.msgpack" 437 + ds = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset.tar") 438 + 439 + entry = index.add_entry( 440 + ds, 441 + name="test-dataset", 442 + schema_ref="local://schemas/custom.Schema@2.0.0" 424 443 ) 425 444 426 - entry = index.add_entry(ds, uuid=test_uuid) 445 + assert entry.schema_ref == "local://schemas/custom.Schema@2.0.0" 427 446 428 - assert entry.uuid == test_uuid 429 - assert entry.wds_url == ds.url 430 - assert entry.sample_kind == f"{SimpleTestSample.__module__}.SimpleTestSample" 431 - assert entry.metadata_url == ds.metadata_url 447 + 448 + def test_index_add_entry_with_metadata(clean_redis): 449 + """Test adding a dataset entry with metadata. 450 + 451 + Should store the provided metadata. 452 + """ 453 + index = atlocal.Index(redis=clean_redis) 454 + 455 + ds = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset.tar") 456 + 457 + entry = index.add_entry( 458 + ds, 459 + name="test-dataset", 460 + metadata={"version": "1.0", "author": "test"} 461 + ) 462 + 463 + assert entry.metadata == {"version": "1.0", "author": "test"} 432 464 433 465 434 466 def test_index_entries_generator_empty(clean_redis): ··· 445 477 def test_index_entries_generator_multiple(clean_redis): 446 478 """Test iterating over multiple entries in the index. 447 479 448 - Should yield all BasicIndexEntry objects that have been added to the index. 480 + Should yield all LocalDatasetEntry objects that have been added to the index. 449 481 """ 450 482 index = atlocal.Index(redis=clean_redis) 451 483 452 484 ds1 = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset1.tar") 453 485 ds2 = atdata.Dataset[ArrayTestSample](url="s3://bucket/dataset2.tar") 454 486 455 - entry1 = index.add_entry(ds1) 456 - entry2 = index.add_entry(ds2) 487 + entry1 = index.add_entry(ds1, name="dataset1") 488 + entry2 = index.add_entry(ds2, name="dataset2") 457 489 458 490 entries = list(index.entries) 459 491 assert len(entries) == 2 460 492 461 - uuids = {entry.uuid for entry in entries} 462 - assert entry1.uuid in uuids 463 - assert entry2.uuid in uuids 493 + cids = {entry.cid for entry in entries} 494 + assert entry1.cid in cids 495 + assert entry2.cid in cids 464 496 465 497 466 498 def test_index_all_entries_empty(clean_redis): ··· 478 510 def test_index_all_entries_multiple(clean_redis): 479 511 """Test getting all entries as a list with multiple entries. 480 512 481 - Should return a list containing all BasicIndexEntry objects in the index. 513 + Should return a list containing all LocalDatasetEntry objects in the index. 482 514 """ 483 515 index = atlocal.Index(redis=clean_redis) 484 516 485 517 ds1 = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset1.tar") 486 518 ds2 = atdata.Dataset[ArrayTestSample](url="s3://bucket/dataset2.tar") 487 519 488 - entry1 = index.add_entry(ds1) 489 - entry2 = index.add_entry(ds2) 520 + index.add_entry(ds1, name="dataset1") 521 + index.add_entry(ds2, name="dataset2") 490 522 491 523 entries = index.all_entries 492 524 assert isinstance(entries, list) ··· 494 526 495 527 496 528 def test_index_entries_filtering(clean_redis): 497 - """Test that index only returns BasicIndexEntry objects. 529 + """Test that index only returns LocalDatasetEntry objects. 498 530 499 - Should only iterate over keys matching 'BasicIndexEntry:*' pattern and 531 + Should only iterate over keys matching 'LocalDatasetEntry:*' pattern and 500 532 ignore any other Redis keys. 501 533 """ 502 534 index = atlocal.Index(redis=clean_redis) 503 535 504 - # Add a BasicIndexEntry 536 + # Add a LocalDatasetEntry 505 537 ds = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset.tar") 506 - entry = index.add_entry(ds) 538 + entry = index.add_entry(ds, name="test-dataset") 507 539 508 540 # Add some other Redis keys that should be ignored 509 541 clean_redis.set("other_key", "value") ··· 511 543 512 544 entries = list(index.entries) 513 545 assert len(entries) == 1 514 - assert entries[0].uuid == entry.uuid 546 + assert entries[0].cid == entry.cid 515 547 516 - # Clean up non-BasicIndexEntry keys (fixture only cleans BasicIndexEntry:*) 548 + # Clean up non-LocalDatasetEntry keys (fixture only cleans LocalDatasetEntry:*) 517 549 clean_redis.delete("other_key") 518 550 clean_redis.delete("other_hash") 519 551 520 552 553 + def test_index_get_entry_by_cid(clean_redis): 554 + """Test retrieving an entry by its CID.""" 555 + index = atlocal.Index(redis=clean_redis) 556 + 557 + ds = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset.tar") 558 + entry = index.add_entry(ds, name="test-dataset") 559 + 560 + retrieved = index.get_entry(entry.cid) 561 + 562 + assert retrieved.cid == entry.cid 563 + assert retrieved.name == entry.name 564 + assert retrieved.data_urls == entry.data_urls 565 + 566 + 567 + def test_index_get_entry_by_name(clean_redis): 568 + """Test retrieving an entry by its name.""" 569 + index = atlocal.Index(redis=clean_redis) 570 + 571 + ds = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset.tar") 572 + entry = index.add_entry(ds, name="my-special-dataset") 573 + 574 + retrieved = index.get_entry_by_name("my-special-dataset") 575 + 576 + assert retrieved.cid == entry.cid 577 + assert retrieved.name == "my-special-dataset" 578 + 579 + 580 + def test_index_get_entry_by_name_not_found(clean_redis): 581 + """Test that get_entry_by_name raises KeyError for unknown name.""" 582 + index = atlocal.Index(redis=clean_redis) 583 + 584 + with pytest.raises(KeyError, match="No entry with name"): 585 + index.get_entry_by_name("nonexistent") 586 + 587 + 588 + ## 589 + # AbstractIndex protocol method tests 590 + 591 + def test_index_insert_dataset(clean_redis): 592 + """Test insert_dataset protocol method.""" 593 + index = atlocal.Index(redis=clean_redis) 594 + ds = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset.tar") 595 + 596 + entry = index.insert_dataset(ds, name="protocol-test") 597 + 598 + assert entry.name == "protocol-test" 599 + assert entry.cid is not None 600 + 601 + 602 + def test_index_get_dataset(clean_redis): 603 + """Test get_dataset protocol method.""" 604 + index = atlocal.Index(redis=clean_redis) 605 + ds = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset.tar") 606 + index.insert_dataset(ds, name="my-dataset") 607 + 608 + entry = index.get_dataset("my-dataset") 609 + 610 + assert entry.name == "my-dataset" 611 + 612 + 613 + def test_index_get_dataset_not_found(clean_redis): 614 + """Test get_dataset raises KeyError for unknown name.""" 615 + index = atlocal.Index(redis=clean_redis) 616 + 617 + with pytest.raises(KeyError): 618 + index.get_dataset("nonexistent") 619 + 620 + 621 + def test_index_list_datasets(clean_redis): 622 + """Test list_datasets protocol method.""" 623 + index = atlocal.Index(redis=clean_redis) 624 + ds1 = atdata.Dataset[SimpleTestSample](url="s3://bucket/ds1.tar") 625 + ds2 = atdata.Dataset[SimpleTestSample](url="s3://bucket/ds2.tar") 626 + 627 + index.insert_dataset(ds1, name="dataset-1") 628 + index.insert_dataset(ds2, name="dataset-2") 629 + 630 + datasets = list(index.list_datasets()) 631 + 632 + assert len(datasets) == 2 633 + names = {d.name for d in datasets} 634 + assert names == {"dataset-1", "dataset-2"} 635 + 636 + 521 637 ## 522 638 # Repo tests - Initialization 523 639 ··· 624 740 # Repo tests - Insert functionality 625 741 626 742 def test_repo_insert_without_s3(): 627 - """Test that inserting a dataset without S3 configured raises AssertionError. 743 + """Test that inserting a dataset without S3 configured raises ValueError. 628 744 629 - Should fail with assertion error when trying to insert without S3 credentials. 745 + Should fail with ValueError when trying to insert without S3 credentials. 630 746 """ 631 747 repo = atlocal.Repo() 632 748 ds = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset.tar") 633 749 634 - with pytest.raises(AssertionError): 635 - repo.insert(ds) 750 + with pytest.raises(ValueError, match="S3 credentials required"): 751 + repo.insert(ds, name="test-dataset") 636 752 637 753 638 754 @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") ··· 649 765 redis=clean_redis 650 766 ) 651 767 652 - entry, new_ds = repo.insert(sample_dataset, maxcount=100) 768 + entry, new_ds = repo.insert(sample_dataset, name="single-shard-dataset", maxcount=100) 653 769 654 - assert entry.uuid is not None 655 - assert entry.wds_url is not None 656 - assert entry.sample_kind == f"{SimpleTestSample.__module__}.SimpleTestSample" 770 + assert entry.cid is not None 771 + assert entry.cid.startswith("bafy") 772 + assert entry.name == "single-shard-dataset" 773 + assert len(entry.data_urls) > 0 774 + assert "SimpleTestSample" in entry.schema_ref 657 775 assert len(repo.index.all_entries) == 1 658 776 assert '.tar' in new_ds.url 659 777 assert new_ds.url.startswith(mock_s3['hive_path']) ··· 674 792 redis=clean_redis 675 793 ) 676 794 677 - entry, new_ds = repo.insert(ds, maxcount=10) 795 + entry, new_ds = repo.insert(ds, name="multi-shard-dataset", maxcount=10) 678 796 679 - assert entry.uuid is not None 680 - assert entry.wds_url is not None 797 + assert entry.cid is not None 798 + assert len(entry.data_urls) > 0 681 799 assert '{' in new_ds.url and '}' in new_ds.url 682 800 683 801 ··· 686 804 def test_repo_insert_with_metadata(mock_s3, clean_redis, tmp_path): 687 805 """Test inserting a dataset with metadata. 688 806 689 - Should write metadata as msgpack to S3 and include metadata_url in the 690 - returned Dataset and BasicIndexEntry. 807 + Should write metadata as msgpack to S3 and store metadata in the entry. 691 808 """ 692 809 ds = make_simple_dataset(tmp_path, num_samples=5) 693 810 ds._metadata = {"description": "test dataset", "version": "1.0"} ··· 698 815 redis=clean_redis 699 816 ) 700 817 701 - entry, new_ds = repo.insert(ds, maxcount=100) 818 + entry, new_ds = repo.insert(ds, name="metadata-dataset", maxcount=100) 702 819 703 - assert entry.metadata_url is not None 820 + assert entry.metadata is not None 821 + assert entry.metadata.get("description") == "test dataset" 704 822 assert new_ds.metadata_url is not None 705 - assert 'metadata' in entry.metadata_url 706 823 707 824 708 825 @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") ··· 719 836 redis=clean_redis 720 837 ) 721 838 722 - entry, new_ds = repo.insert(ds, maxcount=100) 839 + entry, new_ds = repo.insert(ds, name="no-metadata-dataset", maxcount=100) 723 840 724 - assert entry.uuid is not None 841 + assert entry.cid is not None 725 842 assert len(repo.index.all_entries) == 1 726 843 727 844 ··· 738 855 redis=clean_redis 739 856 ) 740 857 741 - entry, new_ds = repo.insert(sample_dataset, cache_local=False, maxcount=100) 858 + entry, new_ds = repo.insert(sample_dataset, name="direct-write", cache_local=False, maxcount=100) 742 859 743 - assert entry.uuid is not None 744 - assert entry.wds_url is not None 860 + assert entry.cid is not None 861 + assert len(entry.data_urls) > 0 745 862 746 863 747 864 @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") ··· 758 875 redis=clean_redis 759 876 ) 760 877 761 - entry, new_ds = repo.insert(sample_dataset, cache_local=True, maxcount=100) 878 + entry, new_ds = repo.insert(sample_dataset, name="cached-write", cache_local=True, maxcount=100) 762 879 763 - assert entry.uuid is not None 764 - assert entry.wds_url is not None 880 + assert entry.cid is not None 881 + assert len(entry.data_urls) > 0 765 882 766 883 767 884 @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") ··· 769 886 def test_repo_insert_creates_index_entry(mock_s3, clean_redis, sample_dataset): 770 887 """Test that insert() creates a valid index entry. 771 888 772 - Should add a BasicIndexEntry to the index with correct wds_url, sample_kind, 773 - metadata_url, and UUID. 889 + Should add a LocalDatasetEntry to the index with correct data_urls, schema_ref, 890 + and CID. 774 891 """ 775 892 repo = atlocal.Repo( 776 893 s3_credentials=mock_s3['credentials'], ··· 778 895 redis=clean_redis 779 896 ) 780 897 781 - entry, new_ds = repo.insert(sample_dataset, maxcount=100) 898 + entry, new_ds = repo.insert(sample_dataset, name="indexed-dataset", maxcount=100) 782 899 783 - assert entry.uuid is not None 784 - assert entry.wds_url == new_ds.url 785 - assert entry.sample_kind == f"{SimpleTestSample.__module__}.SimpleTestSample" 900 + assert entry.cid is not None 901 + assert entry.data_urls == [new_ds.url] 902 + assert "SimpleTestSample" in entry.schema_ref 786 903 787 904 all_entries = repo.index.all_entries 788 905 assert len(all_entries) == 1 789 - assert all_entries[0].uuid == entry.uuid 906 + assert all_entries[0].cid == entry.cid 790 907 791 908 792 909 @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") 793 910 @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") 794 - def test_repo_insert_uuid_generation(mock_s3, clean_redis, sample_dataset): 795 - """Test that insert() generates a unique UUID for each dataset. 911 + def test_repo_insert_cid_generation(mock_s3, clean_redis, sample_dataset): 912 + """Test that insert() generates unique CIDs for each dataset. 796 913 797 - Should create a new UUID for the dataset and use it consistently in filenames, 798 - index entry, and returned Dataset. 914 + Should create different CIDs for datasets with different URLs. 799 915 """ 800 916 repo = atlocal.Repo( 801 917 s3_credentials=mock_s3['credentials'], ··· 803 919 redis=clean_redis 804 920 ) 805 921 806 - entry1, new_ds1 = repo.insert(sample_dataset, maxcount=100) 807 - entry2, new_ds2 = repo.insert(sample_dataset, maxcount=100) 922 + entry1, new_ds1 = repo.insert(sample_dataset, name="dataset1", maxcount=100) 923 + entry2, new_ds2 = repo.insert(sample_dataset, name="dataset2", maxcount=100) 808 924 809 - assert entry1.uuid != entry2.uuid 810 - assert entry1.uuid in new_ds1.url 811 - assert entry2.uuid in new_ds2.url 925 + # Different URLs should produce different CIDs 926 + assert entry1.cid != entry2.cid 812 927 assert len(repo.index.all_entries) == 2 813 928 814 929 ··· 833 948 ) 834 949 835 950 # Empty datasets succeed because WebDataset creates a shard file regardless 836 - entry, new_ds = repo.insert(ds, maxcount=100) 837 - assert entry.uuid is not None 951 + entry, new_ds = repo.insert(ds, name="empty-dataset", maxcount=100) 952 + assert entry.cid is not None 838 953 assert '.tar' in new_ds.url 839 954 840 955 ··· 851 966 redis=clean_redis 852 967 ) 853 968 854 - entry, new_ds = repo.insert(sample_dataset, maxcount=100) 969 + entry, new_ds = repo.insert(sample_dataset, name="typed-dataset", maxcount=100) 855 970 856 971 assert new_ds.sample_type == SimpleTestSample 857 - assert entry.sample_kind == f"{SimpleTestSample.__module__}.SimpleTestSample" 972 + assert "SimpleTestSample" in entry.schema_ref 858 973 859 974 860 975 @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") ··· 882 997 redis=clean_redis 883 998 ) 884 999 885 - entry, new_ds = repo.insert(ds, maxcount=5) 1000 + entry, new_ds = repo.insert(ds, name="sharded-dataset", maxcount=5) 886 1001 887 1002 assert '{' in new_ds.url and '}' in new_ds.url 888 1003 ··· 901 1016 redis=clean_redis 902 1017 ) 903 1018 904 - entry, new_ds = repo.insert(ds, maxcount=100) 1019 + entry, new_ds = repo.insert(ds, name="array-dataset", maxcount=100) 905 1020 906 - assert entry.uuid is not None 907 - assert entry.sample_kind == f"{ArrayTestSample.__module__}.ArrayTestSample" 1021 + assert entry.cid is not None 1022 + assert "ArrayTestSample" in entry.schema_ref 908 1023 909 1024 910 1025 ## ··· 924 1039 redis=clean_redis 925 1040 ) 926 1041 927 - entry, new_ds = repo.insert(sample_dataset, maxcount=100) 1042 + entry, new_ds = repo.insert(sample_dataset, name="integrated-dataset", maxcount=100) 928 1043 929 1044 all_entries = repo.index.all_entries 930 1045 assert len(all_entries) == 1 931 - assert all_entries[0].uuid == entry.uuid 932 - assert all_entries[0].wds_url == entry.wds_url 1046 + assert all_entries[0].cid == entry.cid 1047 + assert all_entries[0].data_urls == entry.data_urls 933 1048 934 1049 935 1050 @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") ··· 937 1052 def test_multiple_datasets_same_type(mock_s3, clean_redis, sample_dataset): 938 1053 """Test inserting multiple datasets of the same sample type. 939 1054 940 - Should create separate entries with different UUIDs and all should be 1055 + Should create separate entries with different CIDs and all should be 941 1056 retrievable from the index. 942 1057 """ 943 1058 repo = atlocal.Repo( ··· 946 1061 redis=clean_redis 947 1062 ) 948 1063 949 - entry1, _ = repo.insert(sample_dataset, maxcount=100) 950 - entry2, _ = repo.insert(sample_dataset, maxcount=100) 951 - entry3, _ = repo.insert(sample_dataset, maxcount=100) 1064 + entry1, _ = repo.insert(sample_dataset, name="dataset-a", maxcount=100) 1065 + entry2, _ = repo.insert(sample_dataset, name="dataset-b", maxcount=100) 1066 + entry3, _ = repo.insert(sample_dataset, name="dataset-c", maxcount=100) 952 1067 953 - uuids = {entry1.uuid, entry2.uuid, entry3.uuid} 954 - assert len(uuids) == 3 1068 + cids = {entry1.cid, entry2.cid, entry3.cid} 1069 + assert len(cids) == 3 955 1070 956 1071 all_entries = repo.index.all_entries 957 1072 assert len(all_entries) == 3 958 1073 959 1074 for entry in all_entries: 960 - assert entry.sample_kind == f"{SimpleTestSample.__module__}.SimpleTestSample" 1075 + assert "SimpleTestSample" in entry.schema_ref 961 1076 962 1077 963 1078 @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") ··· 965 1080 def test_multiple_datasets_different_types(mock_s3, clean_redis, tmp_path): 966 1081 """Test inserting datasets with different sample types. 967 1082 968 - Should correctly track sample_kind for each dataset and create distinct 1083 + Should correctly track schema_ref for each dataset and create distinct 969 1084 index entries. 970 1085 """ 971 1086 simple_ds = make_simple_dataset(tmp_path, num_samples=3, name="simple") ··· 977 1092 redis=clean_redis 978 1093 ) 979 1094 980 - entry1, _ = repo.insert(simple_ds, maxcount=100) 981 - entry2, _ = repo.insert(array_ds, maxcount=100) 1095 + entry1, _ = repo.insert(simple_ds, name="simple-dataset", maxcount=100) 1096 + entry2, _ = repo.insert(array_ds, name="array-dataset", maxcount=100) 982 1097 983 - assert entry1.sample_kind == f"{SimpleTestSample.__module__}.SimpleTestSample" 984 - assert entry2.sample_kind == f"{ArrayTestSample.__module__}.ArrayTestSample" 985 - assert entry1.sample_kind != entry2.sample_kind 1098 + assert "SimpleTestSample" in entry1.schema_ref 1099 + assert "ArrayTestSample" in entry2.schema_ref 1100 + assert entry1.schema_ref != entry2.schema_ref 986 1101 assert len(repo.index.all_entries) == 2 987 1102 988 1103 ··· 994 1109 """ 995 1110 index1 = atlocal.Index(redis=clean_redis) 996 1111 ds = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset.tar") 997 - entry1 = index1.add_entry(ds) 1112 + entry1 = index1.add_entry(ds, name="persistent-dataset") 998 1113 999 1114 index2 = atlocal.Index(redis=clean_redis) 1000 1115 entries = index2.all_entries 1001 1116 1002 1117 assert len(entries) == 1 1003 - assert entries[0].uuid == entry1.uuid 1004 - assert entries[0].wds_url == entry1.wds_url 1118 + assert entries[0].cid == entry1.cid 1119 + assert entries[0].data_urls == entry1.data_urls 1005 1120 1006 1121 1007 1122 def test_concurrent_index_access(clean_redis): ··· 1016 1131 ds1 = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset1.tar") 1017 1132 ds2 = atdata.Dataset[ArrayTestSample](url="s3://bucket/dataset2.tar") 1018 1133 1019 - entry1 = index1.add_entry(ds1) 1020 - entry2 = index2.add_entry(ds2) 1134 + entry1 = index1.add_entry(ds1, name="dataset1") 1135 + entry2 = index2.add_entry(ds2, name="dataset2") 1021 1136 1022 1137 entries1 = index1.all_entries 1023 1138 entries2 = index2.all_entries ··· 1025 1140 assert len(entries1) == 2 1026 1141 assert len(entries2) == 2 1027 1142 1028 - uuids1 = {e.uuid for e in entries1} 1029 - uuids2 = {e.uuid for e in entries2} 1143 + cids1 = {e.cid for e in entries1} 1144 + cids2 = {e.cid for e in entries2} 1145 + 1146 + assert entry1.cid in cids1 and entry2.cid in cids1 1147 + assert entry1.cid in cids2 and entry2.cid in cids2 1148 + 1149 + 1150 + ## 1151 + # S3DataStore tests 1152 + 1153 + def test_s3_datastore_init(): 1154 + """Test creating an S3DataStore.""" 1155 + creds = { 1156 + 'AWS_ENDPOINT': 'http://localhost:9000', 1157 + 'AWS_ACCESS_KEY_ID': 'minioadmin', 1158 + 'AWS_SECRET_ACCESS_KEY': 'minioadmin' 1159 + } 1160 + 1161 + store = atlocal.S3DataStore(credentials=creds, bucket="test-bucket") 1162 + 1163 + assert store.bucket == "test-bucket" 1164 + assert store.credentials == creds 1165 + assert store._fs is not None 1166 + 1167 + 1168 + def test_s3_datastore_supports_streaming(): 1169 + """Test that S3DataStore reports streaming support.""" 1170 + creds = { 1171 + 'AWS_ACCESS_KEY_ID': 'test', 1172 + 'AWS_SECRET_ACCESS_KEY': 'test' 1173 + } 1174 + 1175 + store = atlocal.S3DataStore(credentials=creds, bucket="test") 1176 + 1177 + assert store.supports_streaming() is True 1178 + 1179 + 1180 + def test_s3_datastore_read_url(): 1181 + """Test that read_url returns URL unchanged.""" 1182 + creds = { 1183 + 'AWS_ACCESS_KEY_ID': 'test', 1184 + 'AWS_SECRET_ACCESS_KEY': 'test' 1185 + } 1186 + 1187 + store = atlocal.S3DataStore(credentials=creds, bucket="test") 1188 + 1189 + url = "s3://bucket/path/to/data.tar" 1190 + assert store.read_url(url) == url 1191 + 1192 + 1193 + @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") 1194 + @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") 1195 + def test_s3_datastore_write_shards(mock_s3, tmp_path): 1196 + """Test writing shards with S3DataStore.""" 1197 + ds = make_simple_dataset(tmp_path, num_samples=5) 1198 + 1199 + store = atlocal.S3DataStore( 1200 + credentials=mock_s3['credentials'], 1201 + bucket=mock_s3['bucket'] 1202 + ) 1203 + 1204 + urls = store.write_shards(ds, prefix="test/data", maxcount=100) 1205 + 1206 + assert len(urls) >= 1 1207 + assert all(url.startswith("s3://") for url in urls) 1208 + assert all(mock_s3['bucket'] in url for url in urls) 1209 + 1210 + 1211 + @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") 1212 + @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") 1213 + def test_s3_datastore_write_shards_cache_local(mock_s3, tmp_path): 1214 + """Test writing shards with cache_local=True.""" 1215 + ds = make_simple_dataset(tmp_path, num_samples=5) 1216 + 1217 + store = atlocal.S3DataStore( 1218 + credentials=mock_s3['credentials'], 1219 + bucket=mock_s3['bucket'] 1220 + ) 1221 + 1222 + urls = store.write_shards(ds, prefix="cached/data", cache_local=True, maxcount=100) 1223 + 1224 + assert len(urls) >= 1 1225 + assert all(url.startswith("s3://") for url in urls) 1226 + 1227 + 1228 + ## 1229 + # Schema storage tests 1230 + 1231 + def test_publish_schema(clean_redis): 1232 + """Test publishing a schema to Redis.""" 1233 + index = atlocal.Index(redis=clean_redis) 1234 + 1235 + schema_ref = index.publish_schema(SimpleTestSample, version="1.0.0") 1236 + 1237 + assert schema_ref.startswith("local://schemas/") 1238 + assert "SimpleTestSample" in schema_ref 1239 + assert "@1.0.0" in schema_ref 1240 + 1241 + 1242 + def test_publish_schema_with_description(clean_redis): 1243 + """Test publishing a schema with a description.""" 1244 + index = atlocal.Index(redis=clean_redis) 1245 + 1246 + schema_ref = index.publish_schema( 1247 + SimpleTestSample, 1248 + version="2.0.0", 1249 + description="A simple test sample type" 1250 + ) 1251 + 1252 + schema = index.get_schema(schema_ref) 1253 + assert schema['description'] == "A simple test sample type" 1254 + 1255 + 1256 + def test_get_schema(clean_redis): 1257 + """Test retrieving a published schema.""" 1258 + index = atlocal.Index(redis=clean_redis) 1259 + 1260 + schema_ref = index.publish_schema(SimpleTestSample, version="1.0.0") 1261 + schema = index.get_schema(schema_ref) 1262 + 1263 + assert schema['name'] == 'SimpleTestSample' 1264 + assert schema['version'] == '1.0.0' 1265 + assert len(schema['fields']) == 2 # name and value fields 1266 + assert schema['$ref'] == schema_ref 1267 + 1268 + 1269 + def test_get_schema_not_found(clean_redis): 1270 + """Test that get_schema raises KeyError for missing schema.""" 1271 + index = atlocal.Index(redis=clean_redis) 1272 + 1273 + with pytest.raises(KeyError, match="Schema not found"): 1274 + index.get_schema("local://schemas/nonexistent.Sample@1.0.0") 1275 + 1276 + 1277 + def test_get_schema_invalid_ref(clean_redis): 1278 + """Test that get_schema raises ValueError for invalid reference.""" 1279 + index = atlocal.Index(redis=clean_redis) 1280 + 1281 + with pytest.raises(ValueError, match="Invalid local schema reference"): 1282 + index.get_schema("invalid://schemas/Sample@1.0.0") 1283 + 1284 + 1285 + def test_list_schemas_empty(clean_redis): 1286 + """Test listing schemas when none exist.""" 1287 + index = atlocal.Index(redis=clean_redis) 1288 + 1289 + schemas = list(index.list_schemas()) 1290 + assert len(schemas) == 0 1291 + 1030 1292 1031 - assert entry1.uuid in uuids1 and entry2.uuid in uuids1 1032 - assert entry1.uuid in uuids2 and entry2.uuid in uuids2 1293 + def test_list_schemas_multiple(clean_redis): 1294 + """Test listing multiple schemas.""" 1295 + index = atlocal.Index(redis=clean_redis) 1296 + 1297 + index.publish_schema(SimpleTestSample, version="1.0.0") 1298 + index.publish_schema(ArrayTestSample, version="1.0.0") 1299 + 1300 + schemas = list(index.list_schemas()) 1301 + assert len(schemas) == 2 1302 + 1303 + names = {s['name'] for s in schemas} 1304 + assert 'SimpleTestSample' in names 1305 + assert 'ArrayTestSample' in names 1306 + 1307 + 1308 + def test_schema_field_types(clean_redis): 1309 + """Test that schema correctly captures field types.""" 1310 + index = atlocal.Index(redis=clean_redis) 1311 + 1312 + schema_ref = index.publish_schema(SimpleTestSample, version="1.0.0") 1313 + schema = index.get_schema(schema_ref) 1314 + 1315 + # Find name field (should be str) 1316 + name_field = next(f for f in schema['fields'] if f['name'] == 'name') 1317 + assert name_field['fieldType']['primitive'] == 'str' 1318 + assert name_field['optional'] is False 1319 + 1320 + # Find value field (should be int) 1321 + value_field = next(f for f in schema['fields'] if f['name'] == 'value') 1322 + assert value_field['fieldType']['primitive'] == 'int' 1323 + 1324 + 1325 + def test_schema_ndarray_field(clean_redis): 1326 + """Test that schema correctly captures NDArray fields.""" 1327 + index = atlocal.Index(redis=clean_redis) 1328 + 1329 + schema_ref = index.publish_schema(ArrayTestSample, version="1.0.0") 1330 + schema = index.get_schema(schema_ref) 1331 + 1332 + # Find data field (should be ndarray) 1333 + data_field = next(f for f in schema['fields'] if f['name'] == 'data') 1334 + assert 'ndarray' in data_field['fieldType']['$type'] 1335 + assert data_field['fieldType']['dtype'] == 'float32' 1336 + 1337 + 1338 + def test_decode_schema(clean_redis): 1339 + """Test reconstructing a Python type from a schema.""" 1340 + index = atlocal.Index(redis=clean_redis) 1341 + 1342 + schema_ref = index.publish_schema(SimpleTestSample, version="1.0.0") 1343 + ReconstructedType = index.decode_schema(schema_ref) 1344 + 1345 + # Should be able to create instances 1346 + instance = ReconstructedType(name="test", value=42) 1347 + assert instance.name == "test" 1348 + assert instance.value == 42 1349 + 1350 + 1351 + def test_decode_schema_preserves_structure(clean_redis): 1352 + """Test that decoded schema matches original type structure.""" 1353 + index = atlocal.Index(redis=clean_redis) 1354 + 1355 + schema_ref = index.publish_schema(ArrayTestSample, version="1.0.0") 1356 + ReconstructedType = index.decode_schema(schema_ref) 1357 + 1358 + # Check fields exist 1359 + import numpy as np 1360 + instance = ReconstructedType(label="test", data=np.zeros((3, 3))) 1361 + assert instance.label == "test" 1362 + assert instance.data.shape == (3, 3) 1363 + 1364 + 1365 + def test_schema_version_handling(clean_redis): 1366 + """Test publishing multiple versions of the same schema.""" 1367 + index = atlocal.Index(redis=clean_redis) 1368 + 1369 + ref_v1 = index.publish_schema(SimpleTestSample, version="1.0.0") 1370 + ref_v2 = index.publish_schema(SimpleTestSample, version="2.0.0") 1371 + 1372 + assert ref_v1 != ref_v2 1373 + assert "@1.0.0" in ref_v1 1374 + assert "@2.0.0" in ref_v2 1375 + 1376 + # Both should be retrievable 1377 + schema_v1 = index.get_schema(ref_v1) 1378 + schema_v2 = index.get_schema(ref_v2) 1379 + 1380 + assert schema_v1['version'] == '1.0.0' 1381 + assert schema_v2['version'] == '2.0.0' 1382 + 1383 + 1384 + ## 1385 + # Schema codec tests 1386 + 1387 + def test_schema_codec_type_caching(): 1388 + """Test that schema_to_type caches generated types.""" 1389 + from atdata._schema_codec import schema_to_type, clear_type_cache, get_cached_types 1390 + 1391 + clear_type_cache() 1392 + assert len(get_cached_types()) == 0 1393 + 1394 + schema = { 1395 + "name": "CacheTestSample", 1396 + "version": "1.0.0", 1397 + "fields": [{"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}], 1398 + } 1399 + 1400 + # First call creates and caches type 1401 + Type1 = schema_to_type(schema) 1402 + cached = get_cached_types() 1403 + assert len(cached) == 1 1404 + 1405 + # Second call returns cached type 1406 + Type2 = schema_to_type(schema) 1407 + assert Type1 is Type2 1408 + 1409 + clear_type_cache() 1410 + assert len(get_cached_types()) == 0 1411 + 1412 + 1413 + def test_schema_to_type_missing_name(): 1414 + """Test schema_to_type raises on schema without name.""" 1415 + from atdata._schema_codec import schema_to_type, clear_type_cache 1416 + 1417 + clear_type_cache() 1418 + schema = { 1419 + "version": "1.0.0", 1420 + "fields": [{"name": "value", "fieldType": {"$type": "#primitive", "primitive": "int"}, "optional": False}], 1421 + } 1422 + 1423 + with pytest.raises(ValueError, match="must have a 'name' field"): 1424 + schema_to_type(schema) 1425 + 1426 + 1427 + def test_schema_to_type_empty_fields(): 1428 + """Test schema_to_type raises on schema with no fields.""" 1429 + from atdata._schema_codec import schema_to_type, clear_type_cache 1430 + 1431 + clear_type_cache() 1432 + schema = { 1433 + "name": "EmptySample", 1434 + "version": "1.0.0", 1435 + "fields": [], 1436 + } 1437 + 1438 + with pytest.raises(ValueError, match="must have at least one field"): 1439 + schema_to_type(schema) 1440 + 1441 + 1442 + def test_schema_to_type_field_missing_name(): 1443 + """Test schema_to_type raises on field without name.""" 1444 + from atdata._schema_codec import schema_to_type, clear_type_cache 1445 + 1446 + clear_type_cache() 1447 + schema = { 1448 + "name": "BadFieldSample", 1449 + "version": "1.0.0", 1450 + "fields": [{"fieldType": {"$type": "#primitive", "primitive": "int"}, "optional": False}], 1451 + } 1452 + 1453 + # Raises KeyError from cache key generation (accesses f['name']) or 1454 + # ValueError from validation - both indicate invalid schema is rejected 1455 + with pytest.raises((KeyError, ValueError)): 1456 + schema_to_type(schema) 1457 + 1458 + 1459 + def test_schema_to_type_unknown_primitive(): 1460 + """Test schema_to_type raises on unknown primitive type.""" 1461 + from atdata._schema_codec import schema_to_type, clear_type_cache 1462 + 1463 + clear_type_cache() 1464 + schema = { 1465 + "name": "UnknownPrimitiveSample", 1466 + "version": "1.0.0", 1467 + "fields": [{"name": "value", "fieldType": {"$type": "#primitive", "primitive": "unknown_type"}, "optional": False}], 1468 + } 1469 + 1470 + with pytest.raises(ValueError, match="Unknown primitive type"): 1471 + schema_to_type(schema) 1472 + 1473 + 1474 + def test_schema_to_type_unknown_field_kind(): 1475 + """Test schema_to_type raises on unknown field type kind.""" 1476 + from atdata._schema_codec import schema_to_type, clear_type_cache 1477 + 1478 + clear_type_cache() 1479 + schema = { 1480 + "name": "UnknownKindSample", 1481 + "version": "1.0.0", 1482 + "fields": [{"name": "value", "fieldType": {"$type": "#unknown_kind"}, "optional": False}], 1483 + } 1484 + 1485 + with pytest.raises(ValueError, match="Unknown field type kind"): 1486 + schema_to_type(schema) 1487 + 1488 + 1489 + def test_schema_to_type_ref_not_supported(): 1490 + """Test schema_to_type raises on ref field types (not yet supported).""" 1491 + from atdata._schema_codec import schema_to_type, clear_type_cache 1492 + 1493 + clear_type_cache() 1494 + schema = { 1495 + "name": "RefSample", 1496 + "version": "1.0.0", 1497 + "fields": [{"name": "other", "fieldType": {"$type": "#ref", "ref": "other.Schema"}, "optional": False}], 1498 + } 1499 + 1500 + with pytest.raises(ValueError, match="Schema references.*not yet supported"): 1501 + schema_to_type(schema) 1502 + 1503 + 1504 + def test_schema_to_type_all_primitives(): 1505 + """Test schema_to_type handles all primitive types correctly.""" 1506 + from atdata._schema_codec import schema_to_type, clear_type_cache 1507 + 1508 + clear_type_cache() 1509 + schema = { 1510 + "name": "AllPrimitivesSample", 1511 + "version": "1.0.0", 1512 + "fields": [ 1513 + {"name": "s", "fieldType": {"$type": "#primitive", "primitive": "str"}, "optional": False}, 1514 + {"name": "i", "fieldType": {"$type": "#primitive", "primitive": "int"}, "optional": False}, 1515 + {"name": "f", "fieldType": {"$type": "#primitive", "primitive": "float"}, "optional": False}, 1516 + {"name": "b", "fieldType": {"$type": "#primitive", "primitive": "bool"}, "optional": False}, 1517 + {"name": "by", "fieldType": {"$type": "#primitive", "primitive": "bytes"}, "optional": False}, 1518 + ], 1519 + } 1520 + 1521 + SampleType = schema_to_type(schema) 1522 + instance = SampleType(s="hello", i=42, f=3.14, b=True, by=b"data") 1523 + 1524 + assert instance.s == "hello" 1525 + assert instance.i == 42 1526 + assert instance.f == 3.14 1527 + assert instance.b is True 1528 + assert instance.by == b"data" 1529 + 1530 + 1531 + def test_schema_to_type_optional_fields(): 1532 + """Test schema_to_type handles optional fields with None defaults.""" 1533 + from atdata._schema_codec import schema_to_type, clear_type_cache 1534 + 1535 + clear_type_cache() 1536 + schema = { 1537 + "name": "OptionalSample", 1538 + "version": "1.0.0", 1539 + "fields": [ 1540 + {"name": "required", "fieldType": {"$type": "#primitive", "primitive": "str"}, "optional": False}, 1541 + {"name": "optional_str", "fieldType": {"$type": "#primitive", "primitive": "str"}, "optional": True}, 1542 + ], 1543 + } 1544 + 1545 + SampleType = schema_to_type(schema) 1546 + 1547 + # Can create with only required field 1548 + instance1 = SampleType(required="test") 1549 + assert instance1.required == "test" 1550 + assert instance1.optional_str is None 1551 + 1552 + # Can provide optional field 1553 + instance2 = SampleType(required="test", optional_str="value") 1554 + assert instance2.optional_str == "value" 1555 + 1556 + 1557 + def test_schema_to_type_ndarray_field(): 1558 + """Test schema_to_type handles NDArray fields.""" 1559 + from atdata._schema_codec import schema_to_type, clear_type_cache 1560 + 1561 + clear_type_cache() 1562 + schema = { 1563 + "name": "ArraySample", 1564 + "version": "1.0.0", 1565 + "fields": [ 1566 + {"name": "data", "fieldType": {"$type": "#ndarray", "dtype": "float32"}, "optional": False}, 1567 + ], 1568 + } 1569 + 1570 + SampleType = schema_to_type(schema) 1571 + arr = np.zeros((3, 3), dtype=np.float32) 1572 + instance = SampleType(data=arr) 1573 + 1574 + assert instance.data.shape == (3, 3) 1575 + 1576 + 1577 + def test_schema_to_type_array_field(): 1578 + """Test schema_to_type handles array (list) fields.""" 1579 + from atdata._schema_codec import schema_to_type, clear_type_cache 1580 + 1581 + clear_type_cache() 1582 + schema = { 1583 + "name": "ListSample", 1584 + "version": "1.0.0", 1585 + "fields": [ 1586 + {"name": "tags", "fieldType": {"$type": "#array", "items": {"$type": "#primitive", "primitive": "str"}}, "optional": False}, 1587 + ], 1588 + } 1589 + 1590 + SampleType = schema_to_type(schema) 1591 + instance = SampleType(tags=["a", "b", "c"]) 1592 + 1593 + assert instance.tags == ["a", "b", "c"] 1594 + 1595 + 1596 + def test_schema_to_type_use_cache_false(): 1597 + """Test schema_to_type with use_cache=False creates new types.""" 1598 + from atdata._schema_codec import schema_to_type, clear_type_cache 1599 + 1600 + clear_type_cache() 1601 + schema = { 1602 + "name": "NoCacheSample", 1603 + "version": "1.0.0", 1604 + "fields": [{"name": "value", "fieldType": {"$type": "#primitive", "primitive": "int"}, "optional": False}], 1605 + } 1606 + 1607 + Type1 = schema_to_type(schema, use_cache=False) 1608 + Type2 = schema_to_type(schema, use_cache=False) 1609 + 1610 + # Different instances since caching is disabled 1611 + assert Type1 is not Type2
+281
tests/test_promote.py
··· 1 + """Tests for the promote module.""" 2 + 3 + import pytest 4 + from unittest.mock import Mock, MagicMock, patch 5 + from dataclasses import dataclass 6 + 7 + import atdata 8 + from atdata.promote import ( 9 + promote_to_atmosphere, 10 + _find_existing_schema, 11 + _find_or_publish_schema, 12 + ) 13 + from atdata.local import LocalDatasetEntry 14 + 15 + 16 + @atdata.packable 17 + class PromoteTestSample: 18 + """Sample type for promotion tests.""" 19 + name: str 20 + value: int 21 + 22 + 23 + class TestFindExistingSchema: 24 + """Tests for _find_existing_schema helper.""" 25 + 26 + def test_finds_matching_schema(self): 27 + """Test finding an existing schema by name and version.""" 28 + mock_client = Mock() 29 + 30 + # Mock SchemaLoader.list_all to return a matching schema 31 + with patch("atdata.atmosphere.SchemaLoader") as MockLoader: 32 + mock_loader = MockLoader.return_value 33 + mock_loader.list_all.return_value = [ 34 + { 35 + "uri": "at://did:plc:test/ac.foundation.dataset.sampleSchema/abc", 36 + "value": { 37 + "name": "test_promote.PromoteTestSample", 38 + "version": "1.0.0", 39 + } 40 + } 41 + ] 42 + 43 + result = _find_existing_schema( 44 + mock_client, 45 + "test_promote.PromoteTestSample", 46 + "1.0.0" 47 + ) 48 + 49 + assert result == "at://did:plc:test/ac.foundation.dataset.sampleSchema/abc" 50 + 51 + def test_returns_none_when_not_found(self): 52 + """Test returns None when no matching schema exists.""" 53 + mock_client = Mock() 54 + 55 + with patch("atdata.atmosphere.SchemaLoader") as MockLoader: 56 + mock_loader = MockLoader.return_value 57 + mock_loader.list_all.return_value = [ 58 + { 59 + "uri": "at://did:plc:test/ac.foundation.dataset.sampleSchema/abc", 60 + "value": { 61 + "name": "other.OtherSample", 62 + "version": "1.0.0", 63 + } 64 + } 65 + ] 66 + 67 + result = _find_existing_schema( 68 + mock_client, 69 + "test_promote.PromoteTestSample", 70 + "1.0.0" 71 + ) 72 + 73 + assert result is None 74 + 75 + def test_returns_none_when_version_mismatch(self): 76 + """Test returns None when version doesn't match.""" 77 + mock_client = Mock() 78 + 79 + with patch("atdata.atmosphere.SchemaLoader") as MockLoader: 80 + mock_loader = MockLoader.return_value 81 + mock_loader.list_all.return_value = [ 82 + { 83 + "uri": "at://did:plc:test/ac.foundation.dataset.sampleSchema/abc", 84 + "value": { 85 + "name": "test_promote.PromoteTestSample", 86 + "version": "2.0.0", # Different version 87 + } 88 + } 89 + ] 90 + 91 + result = _find_existing_schema( 92 + mock_client, 93 + "test_promote.PromoteTestSample", 94 + "1.0.0" 95 + ) 96 + 97 + assert result is None 98 + 99 + 100 + class TestFindOrPublishSchema: 101 + """Tests for _find_or_publish_schema helper.""" 102 + 103 + def test_returns_existing_schema(self): 104 + """Test returns existing schema URI without publishing.""" 105 + mock_client = Mock() 106 + 107 + with patch("atdata.promote._find_existing_schema") as mock_find: 108 + mock_find.return_value = "at://existing/schema/uri" 109 + 110 + with patch("atdata.atmosphere.SchemaPublisher") as MockPublisher: 111 + result = _find_or_publish_schema( 112 + PromoteTestSample, 113 + "1.0.0", 114 + mock_client, 115 + ) 116 + 117 + assert result == "at://existing/schema/uri" 118 + MockPublisher.return_value.publish.assert_not_called() 119 + 120 + def test_publishes_new_schema_when_not_found(self): 121 + """Test publishes new schema when none exists.""" 122 + mock_client = Mock() 123 + 124 + with patch("atdata.promote._find_existing_schema") as mock_find: 125 + mock_find.return_value = None # No existing schema 126 + 127 + with patch("atdata.atmosphere.SchemaPublisher") as MockPublisher: 128 + mock_publisher = MockPublisher.return_value 129 + mock_publisher.publish.return_value = Mock(__str__=lambda s: "at://new/schema/uri") 130 + 131 + result = _find_or_publish_schema( 132 + PromoteTestSample, 133 + "1.0.0", 134 + mock_client, 135 + ) 136 + 137 + assert result == "at://new/schema/uri" 138 + mock_publisher.publish.assert_called_once() 139 + 140 + 141 + class TestPromoteToAtmosphere: 142 + """Tests for promote_to_atmosphere function.""" 143 + 144 + def test_raises_on_empty_data_urls(self): 145 + """Test raises ValueError when local entry has no data URLs.""" 146 + entry = LocalDatasetEntry( 147 + _name="test-dataset", 148 + _schema_ref="local://schemas/test@1.0.0", 149 + _data_urls=[], # Empty! 150 + ) 151 + mock_index = Mock() 152 + mock_client = Mock() 153 + 154 + with pytest.raises(ValueError, match="has no data URLs"): 155 + promote_to_atmosphere(entry, mock_index, mock_client) 156 + 157 + def test_promotes_with_existing_urls(self): 158 + """Test promotion using existing data URLs.""" 159 + entry = LocalDatasetEntry( 160 + _name="test-dataset", 161 + _schema_ref="local://schemas/test@1.0.0", 162 + _data_urls=["s3://bucket/data-000000.tar"], 163 + _metadata={"key": "value"}, 164 + ) 165 + 166 + mock_index = Mock() 167 + mock_index.get_schema.return_value = { 168 + "name": "test_promote.PromoteTestSample", 169 + "version": "1.0.0", 170 + "fields": [ 171 + {"name": "name", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 172 + {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 173 + ], 174 + } 175 + 176 + mock_client = Mock() 177 + 178 + with patch("atdata.promote._find_or_publish_schema") as mock_find_schema: 179 + mock_find_schema.return_value = "at://did:plc:test/schema/abc" 180 + 181 + with patch("atdata.atmosphere.DatasetPublisher") as MockPublisher: 182 + mock_publisher = MockPublisher.return_value 183 + mock_uri = Mock(__str__=lambda s: "at://did:plc:test/record/xyz") 184 + mock_publisher.publish_with_urls.return_value = mock_uri 185 + 186 + result = promote_to_atmosphere(entry, mock_index, mock_client) 187 + 188 + assert result == "at://did:plc:test/record/xyz" 189 + 190 + # Verify publish_with_urls was called with correct args 191 + mock_publisher.publish_with_urls.assert_called_once() 192 + call_kwargs = mock_publisher.publish_with_urls.call_args[1] 193 + assert call_kwargs["urls"] == ["s3://bucket/data-000000.tar"] 194 + assert call_kwargs["schema_uri"] == "at://did:plc:test/schema/abc" 195 + assert call_kwargs["name"] == "test-dataset" 196 + assert call_kwargs["metadata"] == {"key": "value"} 197 + 198 + def test_promotes_with_custom_name(self): 199 + """Test promotion with overridden name.""" 200 + entry = LocalDatasetEntry( 201 + _name="original-name", 202 + _schema_ref="local://schemas/test@1.0.0", 203 + _data_urls=["s3://bucket/data.tar"], 204 + ) 205 + 206 + mock_index = Mock() 207 + mock_index.get_schema.return_value = { 208 + "name": "TestSample", 209 + "version": "1.0.0", 210 + "fields": [ 211 + {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 212 + ], 213 + } 214 + 215 + mock_client = Mock() 216 + 217 + with patch("atdata.promote._find_or_publish_schema") as mock_find_schema: 218 + mock_find_schema.return_value = "at://schema" 219 + 220 + with patch("atdata.atmosphere.DatasetPublisher") as MockPublisher: 221 + mock_publisher = MockPublisher.return_value 222 + mock_publisher.publish_with_urls.return_value = Mock(__str__=lambda s: "at://result") 223 + 224 + result = promote_to_atmosphere( 225 + entry, 226 + mock_index, 227 + mock_client, 228 + name="custom-name", 229 + tags=["tag1", "tag2"], 230 + license="MIT", 231 + ) 232 + 233 + call_kwargs = mock_publisher.publish_with_urls.call_args[1] 234 + assert call_kwargs["name"] == "custom-name" 235 + assert call_kwargs["tags"] == ["tag1", "tag2"] 236 + assert call_kwargs["license"] == "MIT" 237 + 238 + def test_promotes_with_data_store(self): 239 + """Test promotion with data store for copying data.""" 240 + entry = LocalDatasetEntry( 241 + _name="test-dataset", 242 + _schema_ref="local://schemas/test@1.0.0", 243 + _data_urls=["s3://old-bucket/data.tar"], 244 + ) 245 + 246 + mock_index = Mock() 247 + mock_index.get_schema.return_value = { 248 + "name": "TestSample", 249 + "version": "1.0.0", 250 + "fields": [ 251 + {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 252 + ], 253 + } 254 + 255 + mock_client = Mock() 256 + mock_data_store = Mock() 257 + mock_data_store.write_shards.return_value = [ 258 + "s3://new-bucket/promoted/test-dataset/shard-000000.tar" 259 + ] 260 + 261 + with patch("atdata.promote._find_or_publish_schema") as mock_find_schema: 262 + mock_find_schema.return_value = "at://schema" 263 + 264 + with patch("atdata.atmosphere.DatasetPublisher") as MockPublisher: 265 + mock_publisher = MockPublisher.return_value 266 + mock_publisher.publish_with_urls.return_value = Mock(__str__=lambda s: "at://result") 267 + 268 + with patch("atdata.dataset.Dataset"): 269 + result = promote_to_atmosphere( 270 + entry, 271 + mock_index, 272 + mock_client, 273 + data_store=mock_data_store, 274 + ) 275 + 276 + # Verify data_store.write_shards was called 277 + mock_data_store.write_shards.assert_called_once() 278 + 279 + # Verify new URLs were used 280 + call_kwargs = mock_publisher.publish_with_urls.call_args[1] 281 + assert "s3://new-bucket/" in call_kwargs["urls"][0]
+238
tests/test_protocols.py
··· 1 + """Protocol compliance tests for atdata abstractions. 2 + 3 + These tests verify that concrete implementations satisfy their protocol 4 + definitions, ensuring interoperability between local and atmosphere backends. 5 + """ 6 + 7 + import pytest 8 + from unittest.mock import Mock, MagicMock 9 + from dataclasses import dataclass 10 + 11 + import atdata 12 + from atdata._protocols import ( 13 + IndexEntry, 14 + AbstractIndex, 15 + AbstractDataStore, 16 + ) 17 + from atdata.local import LocalDatasetEntry, Index as LocalIndex, S3DataStore 18 + from atdata.atmosphere import AtmosphereIndex, AtmosphereIndexEntry 19 + 20 + 21 + class TestIndexEntryProtocol: 22 + """Tests for IndexEntry protocol compliance.""" 23 + 24 + def test_local_dataset_entry_is_index_entry(self): 25 + """LocalDatasetEntry should satisfy IndexEntry protocol.""" 26 + entry = LocalDatasetEntry( 27 + _name="test-dataset", 28 + _schema_ref="local://schemas/test@1.0.0", 29 + _data_urls=["s3://bucket/data.tar"], 30 + _metadata={"key": "value"}, 31 + ) 32 + 33 + # Protocol compliance via isinstance (runtime_checkable) 34 + assert isinstance(entry, IndexEntry) 35 + 36 + # Verify required properties exist and work 37 + assert entry.name == "test-dataset" 38 + assert entry.schema_ref == "local://schemas/test@1.0.0" 39 + assert entry.data_urls == ["s3://bucket/data.tar"] 40 + assert entry.metadata == {"key": "value"} 41 + 42 + def test_atmosphere_index_entry_is_index_entry(self): 43 + """AtmosphereIndexEntry should satisfy IndexEntry protocol.""" 44 + record = { 45 + "name": "atmo-dataset", 46 + "schemaRef": "at://did:plc:test/schema/abc", 47 + "storage": { 48 + "$type": "ac.foundation.dataset.storageExternal", 49 + "urls": ["s3://bucket/data.tar"], 50 + }, 51 + } 52 + entry = AtmosphereIndexEntry("at://did:plc:test/record/xyz", record) 53 + 54 + # Protocol compliance 55 + assert isinstance(entry, IndexEntry) 56 + 57 + # Verify properties 58 + assert entry.name == "atmo-dataset" 59 + assert entry.schema_ref == "at://did:plc:test/schema/abc" 60 + assert entry.data_urls == ["s3://bucket/data.tar"] 61 + 62 + def test_index_entry_with_none_metadata(self): 63 + """IndexEntry should handle None metadata.""" 64 + entry = LocalDatasetEntry( 65 + _name="no-meta", 66 + _schema_ref="local://schemas/test@1.0.0", 67 + _data_urls=["s3://bucket/data.tar"], 68 + _metadata=None, 69 + ) 70 + 71 + assert entry.metadata is None 72 + 73 + 74 + class TestAbstractIndexProtocol: 75 + """Tests for AbstractIndex protocol compliance.""" 76 + 77 + def test_local_index_has_required_methods(self): 78 + """LocalIndex should have all AbstractIndex methods.""" 79 + # Can't use isinstance with non-runtime_checkable Protocol 80 + # So we verify methods exist 81 + index = LocalIndex() 82 + 83 + assert hasattr(index, "insert_dataset") 84 + assert hasattr(index, "get_dataset") 85 + assert hasattr(index, "list_datasets") 86 + assert hasattr(index, "publish_schema") 87 + assert hasattr(index, "get_schema") 88 + assert hasattr(index, "list_schemas") 89 + assert hasattr(index, "decode_schema") 90 + 91 + # Verify methods are callable 92 + assert callable(index.insert_dataset) 93 + assert callable(index.get_dataset) 94 + assert callable(index.list_datasets) 95 + assert callable(index.publish_schema) 96 + assert callable(index.get_schema) 97 + assert callable(index.list_schemas) 98 + assert callable(index.decode_schema) 99 + 100 + def test_atmosphere_index_has_required_methods(self): 101 + """AtmosphereIndex should have all AbstractIndex methods.""" 102 + mock_client = Mock() 103 + mock_client.did = "did:plc:test" 104 + index = AtmosphereIndex(mock_client) 105 + 106 + assert hasattr(index, "insert_dataset") 107 + assert hasattr(index, "get_dataset") 108 + assert hasattr(index, "list_datasets") 109 + assert hasattr(index, "publish_schema") 110 + assert hasattr(index, "get_schema") 111 + assert hasattr(index, "list_schemas") 112 + assert hasattr(index, "decode_schema") 113 + 114 + assert callable(index.insert_dataset) 115 + assert callable(index.get_dataset) 116 + assert callable(index.list_datasets) 117 + assert callable(index.publish_schema) 118 + assert callable(index.get_schema) 119 + assert callable(index.list_schemas) 120 + assert callable(index.decode_schema) 121 + 122 + 123 + class TestAbstractDataStoreProtocol: 124 + """Tests for AbstractDataStore protocol compliance.""" 125 + 126 + def test_s3_datastore_has_required_methods(self): 127 + """S3DataStore should have all AbstractDataStore methods.""" 128 + # Create with mock credentials 129 + mock_creds = { 130 + "AWS_ENDPOINT": "http://localhost:9000", 131 + "AWS_ACCESS_KEY_ID": "test", 132 + "AWS_SECRET_ACCESS_KEY": "test", 133 + } 134 + 135 + store = S3DataStore(mock_creds, bucket="test-bucket") 136 + 137 + assert hasattr(store, "write_shards") 138 + assert hasattr(store, "read_url") 139 + assert hasattr(store, "supports_streaming") 140 + 141 + assert callable(store.write_shards) 142 + assert callable(store.read_url) 143 + assert callable(store.supports_streaming) 144 + 145 + def test_s3_datastore_supports_streaming(self): 146 + """S3DataStore should report streaming support.""" 147 + mock_creds = { 148 + "AWS_ENDPOINT": "http://localhost:9000", 149 + "AWS_ACCESS_KEY_ID": "test", 150 + "AWS_SECRET_ACCESS_KEY": "test", 151 + } 152 + 153 + store = S3DataStore(mock_creds, bucket="test-bucket") 154 + assert store.supports_streaming() is True 155 + 156 + def test_s3_datastore_read_url_passthrough(self): 157 + """S3DataStore.read_url should return URL unchanged.""" 158 + mock_creds = { 159 + "AWS_ENDPOINT": "http://localhost:9000", 160 + "AWS_ACCESS_KEY_ID": "test", 161 + "AWS_SECRET_ACCESS_KEY": "test", 162 + } 163 + 164 + store = S3DataStore(mock_creds, bucket="test-bucket") 165 + url = "s3://bucket/path/data.tar" 166 + assert store.read_url(url) == url 167 + 168 + 169 + class TestProtocolInteroperability: 170 + """Tests verifying different implementations can be used interchangeably.""" 171 + 172 + def test_function_accepts_any_index_entry(self): 173 + """Functions typed with IndexEntry should accept any implementation.""" 174 + 175 + def get_dataset_name(entry: IndexEntry) -> str: 176 + return entry.name 177 + 178 + # LocalDatasetEntry 179 + local_entry = LocalDatasetEntry( 180 + _name="local-data", 181 + _schema_ref="local://schemas/test@1.0.0", 182 + _data_urls=["s3://bucket/data.tar"], 183 + ) 184 + assert get_dataset_name(local_entry) == "local-data" 185 + 186 + # AtmosphereIndexEntry 187 + atmo_entry = AtmosphereIndexEntry( 188 + "at://did:plc:test/record/xyz", 189 + {"name": "atmo-data", "schemaRef": "at://schema", "storage": {}}, 190 + ) 191 + assert get_dataset_name(atmo_entry) == "atmo-data" 192 + 193 + def test_function_accepts_any_index(self): 194 + """Functions typed with AbstractIndex should accept any implementation.""" 195 + 196 + def count_datasets(index) -> int: 197 + """Count datasets in an index.""" 198 + return sum(1 for _ in index.list_datasets()) 199 + 200 + # LocalIndex with mock redis 201 + local_index = LocalIndex() 202 + # Empty index returns 0 203 + assert count_datasets(local_index) == 0 204 + 205 + def test_index_entry_properties_consistent(self): 206 + """All IndexEntry implementations should have consistent property types.""" 207 + local_entry = LocalDatasetEntry( 208 + _name="test", 209 + _schema_ref="local://schemas/test@1.0.0", 210 + _data_urls=["url1", "url2"], 211 + _metadata={"k": "v"}, 212 + ) 213 + 214 + atmo_entry = AtmosphereIndexEntry( 215 + "at://test", 216 + { 217 + "name": "test", 218 + "schemaRef": "at://schema", 219 + "storage": { 220 + "$type": "ac.foundation.dataset.storageExternal", 221 + "urls": ["url1", "url2"], 222 + }, 223 + }, 224 + ) 225 + 226 + # Both should return str for name 227 + assert isinstance(local_entry.name, str) 228 + assert isinstance(atmo_entry.name, str) 229 + 230 + # Both should return str for schema_ref 231 + assert isinstance(local_entry.schema_ref, str) 232 + assert isinstance(atmo_entry.schema_ref, str) 233 + 234 + # Both should return list[str] for data_urls 235 + assert isinstance(local_entry.data_urls, list) 236 + assert isinstance(atmo_entry.data_urls, list) 237 + assert all(isinstance(u, str) for u in local_entry.data_urls) 238 + assert all(isinstance(u, str) for u in atmo_entry.data_urls)