A loose federation of distributed, typed datasets
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

feat(atmosphere): add ATProto integration module with examples and reference docs

Add new atmosphere module for ATProto SDK integration including client,
schema, records, and lens support. Include reference documentation for
atproto lexicons and Python SDK. Clean up dead code and TODO comments
across dataset.py, lens.py, and local.py. Update tests with PutPut law
verification and localized warning suppression convention.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

+4136 -323
.chainlink/issues.db

This is a binary file and will not be displayed.

+336
.reference/atproto_lexicon_guide.md
··· 1 + # AT Protocol Lexicon Guide 2 + 3 + > **Source**: [AT Protocol Lexicon Documentation](https://atproto.com/guides/lexicon) 4 + 5 + ## Overview 6 + 7 + Lexicon is a JSON-based schema language that defines RPC methods and record types for AT Protocol. It enables interoperability by establishing agreed-upon behaviors and semantics across the open network. 8 + 9 + ## Key Concepts 10 + 11 + ### NSIDs (Namespaced Identifiers) 12 + 13 + Schemas use reverse-DNS format identifiers indicating ownership: 14 + 15 + ``` 16 + com.atproto.repo.createRecord # Core ATProto API 17 + app.bsky.feed.post # Bluesky app record type 18 + ac.foundation.dataset.sampleSchema # Our custom namespace 19 + ``` 20 + 21 + Format: `authority.name` where authority is reverse-DNS 22 + 23 + ### Why Not RDF? 24 + 25 + Lexicon prioritizes: 26 + - Schema enforcement (not optional metadata) 27 + - Code generation with types and validation 28 + - Practical developer experience 29 + 30 + ## Schema Types 31 + 32 + ### 1. Record Types 33 + 34 + Define the structure of data stored in repositories: 35 + 36 + ```json 37 + { 38 + "lexicon": 1, 39 + "id": "com.example.follow", 40 + "defs": { 41 + "main": { 42 + "type": "record", 43 + "key": "tid", 44 + "record": { 45 + "type": "object", 46 + "required": ["subject", "createdAt"], 47 + "properties": { 48 + "subject": { "type": "string", "format": "did" }, 49 + "createdAt": { "type": "string", "format": "datetime" } 50 + } 51 + } 52 + } 53 + } 54 + } 55 + ``` 56 + 57 + Records stored in repos have a `$type` field mapping to their schema. 58 + 59 + ### 2. Query Methods 60 + 61 + Define HTTP GET endpoints: 62 + 63 + ```json 64 + { 65 + "lexicon": 1, 66 + "id": "com.example.getProfile", 67 + "defs": { 68 + "main": { 69 + "type": "query", 70 + "parameters": { 71 + "type": "params", 72 + "required": ["actor"], 73 + "properties": { 74 + "actor": { "type": "string", "format": "at-identifier" } 75 + } 76 + }, 77 + "output": { 78 + "encoding": "application/json", 79 + "schema": { "$ref": "#/defs/profileView" } 80 + } 81 + } 82 + } 83 + } 84 + ``` 85 + 86 + Maps to: `GET /xrpc/com.example.getProfile?actor=...` 87 + 88 + ### 3. Procedure Methods 89 + 90 + Define HTTP POST endpoints: 91 + 92 + ```json 93 + { 94 + "lexicon": 1, 95 + "id": "com.example.updateProfile", 96 + "defs": { 97 + "main": { 98 + "type": "procedure", 99 + "input": { 100 + "encoding": "application/json", 101 + "schema": { ... } 102 + }, 103 + "output": { 104 + "encoding": "application/json", 105 + "schema": { ... } 106 + } 107 + } 108 + } 109 + } 110 + ``` 111 + 112 + ### 4. Tokens 113 + 114 + Declare reusable global identifiers for extensible enums: 115 + 116 + ```json 117 + { 118 + "lexicon": 1, 119 + "id": "com.example.status.active", 120 + "defs": { 121 + "main": { 122 + "type": "token", 123 + "description": "User is active" 124 + } 125 + } 126 + } 127 + ``` 128 + 129 + Instead of hardcoding enum values, use tokens. Teams can add values without collisions. 130 + 131 + ## Field Types 132 + 133 + ### Primitives 134 + 135 + | Type | Description | 136 + |------|-------------| 137 + | `string` | Text, with optional format/length constraints | 138 + | `integer` | Whole numbers | 139 + | `boolean` | true/false | 140 + | `bytes` | Binary data (base64 encoded in JSON) | 141 + | `cid-link` | Content identifier reference | 142 + | `unknown` | Any JSON value | 143 + 144 + ### String Formats 145 + 146 + | Format | Description | 147 + |--------|-------------| 148 + | `at-uri` | AT Protocol URI | 149 + | `at-identifier` | Handle or DID | 150 + | `did` | Decentralized identifier | 151 + | `handle` | User handle | 152 + | `datetime` | ISO 8601 timestamp | 153 + | `uri` | Generic URI | 154 + | `language` | BCP 47 language tag | 155 + 156 + ### Complex Types 157 + 158 + ```json 159 + // Object 160 + { 161 + "type": "object", 162 + "required": ["field1"], 163 + "properties": { 164 + "field1": { "type": "string" }, 165 + "field2": { "type": "integer" } 166 + } 167 + } 168 + 169 + // Array 170 + { 171 + "type": "array", 172 + "items": { "type": "string" }, 173 + "maxLength": 100 174 + } 175 + 176 + // Union (discriminated) 177 + { 178 + "type": "union", 179 + "refs": [ 180 + "#defs/typeA", 181 + "#defs/typeB" 182 + ] 183 + } 184 + 185 + // Reference to another schema 186 + { 187 + "type": "ref", 188 + "ref": "com.example.otherSchema#defs/thing" 189 + } 190 + ``` 191 + 192 + ### Blob References 193 + 194 + For binary data stored separately: 195 + 196 + ```json 197 + { 198 + "type": "blob", 199 + "accept": ["image/jpeg", "image/png"], 200 + "maxSize": 1000000 201 + } 202 + ``` 203 + 204 + ## Versioning Rules 205 + 206 + **Published schemas are immutable regarding constraints.** 207 + 208 + - Loosening constraints breaks old software validation 209 + - Tightening constraints breaks new software validation 210 + - Only **optional** constraints may be added to previously unconstrained fields 211 + - Major changes require **new NSIDs** 212 + 213 + ## Schema Distribution 214 + 215 + Schemas should be published as machine-readable, network-accessible resources: 216 + 217 + 1. Host at well-known URL: `https://authority.com/.well-known/lexicons/` 218 + 2. Or embed in documentation 219 + 3. Ensure canonical representation exists for consumers 220 + 221 + ## Record Keys (rkeys) 222 + 223 + Records in collections are identified by keys: 224 + 225 + | Key Type | Description | 226 + |----------|-------------| 227 + | `tid` | Timestamp-based ID (sortable, unique) | 228 + | `literal:self` | Singleton record (e.g., profile) | 229 + | `any` | Any valid string | 230 + 231 + TID format: 13-character base32-sortable timestamp 232 + 233 + ## Example: Complete Lexicon 234 + 235 + ```json 236 + { 237 + "lexicon": 1, 238 + "id": "ac.foundation.dataset.sampleSchema", 239 + "revision": 1, 240 + "description": "Schema definition for a PackableSample type", 241 + "defs": { 242 + "main": { 243 + "type": "record", 244 + "key": "tid", 245 + "description": "A sample schema record", 246 + "record": { 247 + "type": "object", 248 + "required": ["name", "version", "fields"], 249 + "properties": { 250 + "name": { 251 + "type": "string", 252 + "description": "Human-readable schema name" 253 + }, 254 + "version": { 255 + "type": "string", 256 + "description": "Semantic version" 257 + }, 258 + "fields": { 259 + "type": "array", 260 + "items": { "type": "ref", "ref": "#defs/fieldDef" } 261 + }, 262 + "createdAt": { 263 + "type": "string", 264 + "format": "datetime" 265 + } 266 + } 267 + } 268 + }, 269 + "fieldDef": { 270 + "type": "object", 271 + "required": ["name", "fieldType"], 272 + "properties": { 273 + "name": { "type": "string" }, 274 + "fieldType": { "type": "ref", "ref": "#defs/fieldType" }, 275 + "optional": { "type": "boolean", "default": false } 276 + } 277 + }, 278 + "fieldType": { 279 + "type": "union", 280 + "refs": [ 281 + "#defs/primitiveType", 282 + "#defs/arrayType" 283 + ] 284 + }, 285 + "primitiveType": { 286 + "type": "object", 287 + "required": ["kind"], 288 + "properties": { 289 + "kind": { 290 + "type": "string", 291 + "knownValues": ["string", "int", "float", "bool", "bytes"] 292 + } 293 + } 294 + }, 295 + "arrayType": { 296 + "type": "object", 297 + "required": ["kind", "elementType"], 298 + "properties": { 299 + "kind": { "type": "string", "const": "ndarray" }, 300 + "elementType": { "type": "string" }, 301 + "shape": { 302 + "type": "array", 303 + "items": { "type": "integer" } 304 + } 305 + } 306 + } 307 + } 308 + } 309 + ``` 310 + 311 + ## XRPC (Cross-Server RPC) 312 + 313 + Lexicons map to HTTP endpoints: 314 + 315 + ``` 316 + com.example.getProfile() 317 + → GET /xrpc/com.example.getProfile 318 + 319 + com.example.createPost() 320 + → POST /xrpc/com.example.createPost 321 + ``` 322 + 323 + ## Validation Behavior 324 + 325 + The PDS can validate records against lexicons, but: 326 + 327 + 1. PDS is lexicon-agnostic by default 328 + 2. Validation can be disabled: `validate: false` 329 + 3. Unknown lexicons are stored without validation 330 + 4. Rate limits prevent abuse (not schema enforcement) 331 + 332 + ## Resources 333 + 334 + - **Lexicon Specification**: https://atproto.com/specs/lexicon 335 + - **Lexicon Guide**: https://atproto.com/guides/lexicon 336 + - **Bluesky Lexicons**: https://github.com/bluesky-social/atproto/tree/main/lexicons
+347
.reference/python_atproto_sdk.md
··· 1 + # Python ATProto SDK Reference 2 + 3 + > **Source**: [MarshalX/atproto](https://github.com/MarshalX/atproto) | [Documentation](https://atproto.blue/) | [PyPI](https://pypi.org/project/atproto/) 4 + 5 + ## Overview 6 + 7 + The `atproto` package is the community Python SDK for AT Protocol (Bluesky). It provides: 8 + 9 + - Autogenerated models from lexicons with full type hints 10 + - Synchronous and asynchronous XRPC clients 11 + - Firehose data streaming 12 + - Identity resolution (DID/Handle) 13 + - Cryptographic utilities 14 + - **Code generator for custom lexicon schemes** 15 + 16 + **Version**: 0.0.65 (Dec 2025) 17 + **Python**: 3.9 - 3.14 18 + **License**: MIT 19 + 20 + > Note: Until 1.0.0, compatibility between versions is not guaranteed. 21 + 22 + ## Installation 23 + 24 + ```bash 25 + pip install atproto 26 + ``` 27 + 28 + ## Package Structure 29 + 30 + | Package | Purpose | 31 + |---------|---------| 32 + | `atproto_client` | XRPC client, data models, utilities | 33 + | `atproto_core` | NSID, AT URIs, CID, CAR files, DID documents | 34 + | `atproto_crypto` | Multibase, signature verification, DID keys | 35 + | `atproto_firehose` | Real-time data streaming | 36 + | `atproto_identity` | DID and handle resolution | 37 + | `atproto_lexicon` | Lexicon parsing (parser, models) | 38 + | `atproto_codegen` | Code generator for models/clients from lexicons | 39 + | `atproto_cli` | CLI tool for code generation | 40 + | `atproto_server` | Server-side JWT utilities | 41 + 42 + ## Authentication 43 + 44 + ### Basic Login 45 + 46 + ```python 47 + from atproto import Client 48 + 49 + # Synchronous 50 + client = Client() 51 + client.login('handle.bsky.social', 'app-password') 52 + 53 + # Asynchronous 54 + from atproto import AsyncClient 55 + client = AsyncClient() 56 + await client.login('handle.bsky.social', 'app-password') 57 + ``` 58 + 59 + ### Session Persistence 60 + 61 + Sessions can be exported/imported to avoid repeated authentication: 62 + 63 + ```python 64 + # Export session 65 + session_string = client.export_session_string() 66 + 67 + # Import session later 68 + client = Client() 69 + client.login(session_string=session_string) 70 + ``` 71 + 72 + ### Custom PDS 73 + 74 + ```python 75 + client = Client(base_url='https://my-pds.example.com') 76 + ``` 77 + 78 + ## Namespaced API Access 79 + 80 + The SDK mirrors AT Protocol's NSID structure: 81 + 82 + ```python 83 + # Core atproto methods 84 + client.com.atproto.server.create_session(...) 85 + client.com.atproto.repo.create_record(...) 86 + client.com.atproto.repo.put_record(...) 87 + client.com.atproto.repo.get_record(...) 88 + client.com.atproto.repo.delete_record(...) 89 + 90 + # Bluesky app methods 91 + client.app.bsky.feed.get_timeline(...) 92 + client.app.bsky.actor.get_profile(...) 93 + 94 + # Chat methods 95 + client.chat.bsky.convo.send_message(...) 96 + ``` 97 + 98 + ## Creating Custom Records 99 + 100 + This is the key functionality for atdata's ATProto integration. 101 + 102 + ### Using com.atproto.repo.createRecord 103 + 104 + ```python 105 + from atproto import Client 106 + 107 + client = Client() 108 + client.login('handle', 'password') 109 + 110 + # Create a record with a custom collection 111 + response = client.com.atproto.repo.create_record( 112 + data={ 113 + 'repo': client.me.did, # Your DID 114 + 'collection': 'ac.foundation.dataset.sampleSchema', # Custom NSID 115 + 'record': { 116 + '$type': 'ac.foundation.dataset.sampleSchema', 117 + # ... your record fields 118 + }, 119 + 'validate': False # Skip lexicon validation for custom schemas 120 + } 121 + ) 122 + 123 + # Response contains: 124 + # - uri: AT URI for the record (at://did:plc:.../ac.foundation.dataset.sampleSchema/...) 125 + # - cid: Content hash of the record 126 + ``` 127 + 128 + ### Using com.atproto.repo.putRecord (Create or Update) 129 + 130 + ```python 131 + response = client.com.atproto.repo.put_record( 132 + data={ 133 + 'repo': client.me.did, 134 + 'collection': 'ac.foundation.dataset.sampleSchema', 135 + 'rkey': 'my-schema-key', # Explicit record key 136 + 'record': { 137 + '$type': 'ac.foundation.dataset.sampleSchema', 138 + # ... fields 139 + }, 140 + 'validate': False 141 + } 142 + ) 143 + ``` 144 + 145 + ### Getting a Record 146 + 147 + ```python 148 + response = client.com.atproto.repo.get_record( 149 + params={ 150 + 'repo': 'did:plc:...', 151 + 'collection': 'ac.foundation.dataset.sampleSchema', 152 + 'rkey': 'my-schema-key' 153 + } 154 + ) 155 + # response.value contains the record data 156 + ``` 157 + 158 + ### Listing Records in a Collection 159 + 160 + ```python 161 + response = client.com.atproto.repo.list_records( 162 + params={ 163 + 'repo': 'did:plc:...', 164 + 'collection': 'ac.foundation.dataset.sampleSchema', 165 + 'limit': 100 166 + } 167 + ) 168 + # response.records is a list of records 169 + ``` 170 + 171 + ### Deleting a Record 172 + 173 + ```python 174 + client.com.atproto.repo.delete_record( 175 + data={ 176 + 'repo': client.me.did, 177 + 'collection': 'ac.foundation.dataset.sampleSchema', 178 + 'rkey': 'my-schema-key' 179 + } 180 + ) 181 + ``` 182 + 183 + ## Key Insight: PDS is Lexicon-Agnostic 184 + 185 + From [GitHub Discussion #3116](https://github.com/bluesky-social/atproto/discussions/3116): 186 + 187 + > "You don't need the lexicon to parse a record, only to validate the schema. Validation can be disabled." 188 + 189 + The PDS stores any JSON data in any collection without requiring prior knowledge of the schema. This means: 190 + 191 + 1. We can publish `ac.foundation.dataset.*` records immediately 192 + 2. Set `validate: False` to bypass lexicon validation 193 + 3. Rate limits and account bans prevent abuse, not schema enforcement 194 + 195 + ## AT URIs 196 + 197 + Records are addressed using AT URIs: 198 + 199 + ``` 200 + at://did:plc:abcd1234/ac.foundation.dataset.sampleSchema/record-key 201 + └──────────────────────┘ └──────────────────────────────────┘ └────────┘ 202 + authority collection rkey 203 + ``` 204 + 205 + ### Parsing AT URIs 206 + 207 + ```python 208 + from atproto_core import AtUri 209 + 210 + uri = AtUri.from_str('at://did:plc:abc/com.example.record/key123') 211 + print(uri.hostname) # did:plc:abc 212 + print(uri.collection) # com.example.record 213 + print(uri.rkey) # key123 214 + ``` 215 + 216 + ## Core Utilities (atproto_core) 217 + 218 + ### NSID (Namespaced Identifier) 219 + 220 + ```python 221 + from atproto_core import NSID 222 + 223 + nsid = NSID.from_str('ac.foundation.dataset.sampleSchema') 224 + print(nsid.authority) # ac.foundation.dataset 225 + print(nsid.name) # sampleSchema 226 + ``` 227 + 228 + ### CID (Content Identifier) 229 + 230 + ```python 231 + from atproto_core import CID 232 + 233 + cid = CID.decode('bafyrei...') 234 + print(cid.version) 235 + print(cid.codec) 236 + ``` 237 + 238 + ### DID Document 239 + 240 + ```python 241 + from atproto_core import DidDocument 242 + 243 + doc = DidDocument(...) 244 + pds_endpoint = doc.get_pds_endpoint() 245 + handle = doc.get_handle() 246 + ``` 247 + 248 + ## Identity Resolution 249 + 250 + ```python 251 + from atproto_identity import IdentityResolver 252 + 253 + resolver = IdentityResolver() 254 + 255 + # Resolve handle to DID 256 + did = await resolver.resolve_handle('handle.bsky.social') 257 + 258 + # Resolve DID to DID document 259 + doc = await resolver.resolve_did('did:plc:...') 260 + ``` 261 + 262 + ## Firehose Streaming 263 + 264 + ```python 265 + from atproto import FirehoseSubscribeReposClient, parse_subscribe_repos_message 266 + 267 + client = FirehoseSubscribeReposClient() 268 + 269 + def on_message(message): 270 + commit = parse_subscribe_repos_message(message) 271 + # Process commits... 272 + 273 + client.start(on_message) 274 + ``` 275 + 276 + ## Blob Upload 277 + 278 + ```python 279 + # Upload binary data 280 + with open('image.jpg', 'rb') as f: 281 + upload = client.upload_blob(f.read()) 282 + 283 + # upload.blob can be used in record fields 284 + ``` 285 + 286 + ## Error Handling 287 + 288 + ```python 289 + from atproto import exceptions 290 + 291 + try: 292 + client.com.atproto.repo.get_record(...) 293 + except exceptions.AtProtocolError as e: 294 + print(f"AT Protocol error: {e}") 295 + except exceptions.NetworkError as e: 296 + print(f"Network error: {e}") 297 + ``` 298 + 299 + ## Code Generation for Custom Lexicons 300 + 301 + The SDK supports generating Python models from custom lexicon schemas: 302 + 303 + ```bash 304 + # Install CLI 305 + pip install atproto[cli] 306 + 307 + # Generate code from lexicons (exact CLI usage TBD) 308 + atproto codegen --lexicons ./my-lexicons --output ./generated 309 + ``` 310 + 311 + The `atproto_codegen` package can generate: 312 + - Data models for records 313 + - Client namespaces for queries/procedures 314 + - Validation functions 315 + 316 + ## Relevant API Endpoints for atdata 317 + 318 + | Endpoint | Purpose | 319 + |----------|---------| 320 + | `com.atproto.repo.createRecord` | Publish new schema/dataset/lens record | 321 + | `com.atproto.repo.putRecord` | Create or update by explicit rkey | 322 + | `com.atproto.repo.getRecord` | Fetch a specific record | 323 + | `com.atproto.repo.listRecords` | List all records in a collection | 324 + | `com.atproto.repo.deleteRecord` | Remove a record | 325 + | `com.atproto.sync.getRepo` | Download full repository (CAR file) | 326 + | `com.atproto.identity.resolveHandle` | Resolve handle to DID | 327 + 328 + ## Resources 329 + 330 + - **Documentation**: https://atproto.blue/ 331 + - **GitHub**: https://github.com/MarshalX/atproto 332 + - **Examples**: https://github.com/MarshalX/atproto/tree/main/examples 333 + - **PyPI**: https://pypi.org/project/atproto/ 334 + - **Discord**: https://discord.gg/PCyVJXU9jN 335 + 336 + ## AT Protocol Specification 337 + 338 + - **Lexicon Guide**: https://atproto.com/guides/lexicon 339 + - **Application Guide**: https://atproto.com/guides/applications 340 + - **SDK List**: https://atproto.com/sdks 341 + - **API Reference**: https://docs.bsky.app/docs/api/ 342 + 343 + ## Version History 344 + 345 + - 0.0.65 (Dec 8, 2025) - Latest 346 + - 0.0.64 (Dec 1, 2025) 347 + - 0.0.63 (Oct 22, 2025)
+23 -6
CLAUDE.md
··· 23 23 ### Testing 24 24 ```bash 25 25 # Run all tests with coverage 26 - pytest 26 + uv run pytest 27 27 28 28 # Run specific test file 29 - pytest tests/test_dataset.py 30 - pytest tests/test_lens.py 29 + uv run pytest tests/test_dataset.py 30 + uv run pytest tests/test_lens.py 31 31 32 32 # Run single test 33 - pytest tests/test_dataset.py::test_create_sample 34 - pytest tests/test_lens.py::test_lens 33 + uv run pytest tests/test_dataset.py::test_create_sample 34 + uv run pytest tests/test_lens.py::test_lens 35 35 ``` 36 36 37 37 ### Building ··· 148 148 - Test cases cover both decorator and inheritance syntax 149 149 - Temporary WebDataset tar files created in `tmp_path` fixture 150 150 - Tests verify both serialization and batch aggregation behavior 151 - - Lens tests verify well-behavedness (GetPut/PutGet laws) 151 + - Lens tests verify well-behavedness (GetPut/PutGet/PutPut laws) 152 + 153 + ### Warning Suppression Convention 154 + 155 + **Keep warning suppression local to individual tests, not global.** 156 + 157 + When tests generate expected warnings (e.g., from third-party library incompatibilities), suppress them using `@pytest.mark.filterwarnings` decorators on each affected test rather than global suppression in `conftest.py`. This: 158 + - Documents which specific tests have known warning behaviors 159 + - Makes it easier to track when warnings appear in unexpected places 160 + - Avoids masking genuine warnings from new code 161 + 162 + Example for s3fs/moto async incompatibility warnings: 163 + ```python 164 + @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") 165 + @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") 166 + def test_repo_insert_with_s3(mock_s3, clean_redis): 167 + ... 168 + ``` 152 169 153 170 ## Git Workflow 154 171
+368
examples/atmosphere_demo.py
··· 1 + #!/usr/bin/env python3 2 + """Demonstration of atdata.atmosphere ATProto integration. 3 + 4 + This script demonstrates how to use the atmosphere module to publish 5 + and discover datasets on the AT Protocol network. 6 + 7 + Usage: 8 + # Dry run (no actual ATProto connection): 9 + python atmosphere_demo.py 10 + 11 + # With actual ATProto connection: 12 + python atmosphere_demo.py --handle your.handle.social --password your-app-password 13 + 14 + Requirements: 15 + pip install atdata[atmosphere] 16 + 17 + Note: 18 + Use an app-specific password, not your main Bluesky password. 19 + Create app passwords at: https://bsky.app/settings/app-passwords 20 + """ 21 + 22 + import argparse 23 + import sys 24 + from dataclasses import asdict, fields, is_dataclass 25 + from datetime import datetime 26 + 27 + import numpy as np 28 + from numpy.typing import NDArray 29 + 30 + import atdata 31 + from atdata.atmosphere import ( 32 + AtmosphereClient, 33 + SchemaPublisher, 34 + SchemaLoader, 35 + DatasetPublisher, 36 + DatasetLoader, 37 + AtUri, 38 + ) 39 + 40 + 41 + # ============================================================================= 42 + # Define sample types using @packable decorator 43 + # ============================================================================= 44 + 45 + @atdata.packable 46 + class ImageSample: 47 + """A sample containing image data with metadata.""" 48 + image: NDArray 49 + label: str 50 + confidence: float 51 + 52 + 53 + @atdata.packable 54 + class TextEmbeddingSample: 55 + """A sample containing text with embedding vectors.""" 56 + text: str 57 + embedding: NDArray 58 + source: str 59 + 60 + 61 + # ============================================================================= 62 + # Demo functions 63 + # ============================================================================= 64 + 65 + def demo_type_introspection(): 66 + """Demonstrate how atmosphere introspects PackableSample types.""" 67 + print("\n" + "=" * 60) 68 + print("Type Introspection Demo") 69 + print("=" * 60) 70 + 71 + # Show what information is available from a PackableSample type 72 + print(f"\nSample type: {ImageSample.__name__}") 73 + print(f"Is dataclass: {is_dataclass(ImageSample)}") 74 + 75 + print("\nFields:") 76 + for field in fields(ImageSample): 77 + print(f" - {field.name}: {field.type}") 78 + 79 + # Create a sample instance 80 + sample = ImageSample( 81 + image=np.random.rand(224, 224, 3).astype(np.float32), 82 + label="cat", 83 + confidence=0.95, 84 + ) 85 + 86 + print(f"\nSample instance:") 87 + print(f" image shape: {sample.image.shape}") 88 + print(f" image dtype: {sample.image.dtype}") 89 + print(f" label: {sample.label}") 90 + print(f" confidence: {sample.confidence}") 91 + 92 + # Demonstrate serialization 93 + packed = sample.packed 94 + print(f"\nSerialized size: {len(packed):,} bytes") 95 + 96 + # Round-trip 97 + restored = ImageSample.from_bytes(packed) 98 + print(f"Round-trip successful: {np.allclose(sample.image, restored.image)}") 99 + 100 + 101 + def demo_at_uri_parsing(): 102 + """Demonstrate AT URI parsing.""" 103 + print("\n" + "=" * 60) 104 + print("AT URI Parsing Demo") 105 + print("=" * 60) 106 + 107 + # Example AT URIs 108 + uris = [ 109 + "at://did:plc:abc123/ac.foundation.dataset.sampleSchema/xyz789", 110 + "at://alice.bsky.social/ac.foundation.dataset.record/my-dataset", 111 + ] 112 + 113 + for uri_str in uris: 114 + print(f"\nParsing: {uri_str}") 115 + uri = AtUri.parse(uri_str) 116 + print(f" Authority: {uri.authority}") 117 + print(f" Collection: {uri.collection}") 118 + print(f" Rkey: {uri.rkey}") 119 + print(f" Roundtrip: {str(uri)}") 120 + 121 + 122 + def demo_schema_record_building(): 123 + """Demonstrate building schema records from PackableSample types.""" 124 + print("\n" + "=" * 60) 125 + print("Schema Record Building Demo") 126 + print("=" * 60) 127 + 128 + from atdata.atmosphere._types import SchemaRecord, FieldDef, FieldType 129 + 130 + # Build a schema record manually (what SchemaPublisher does internally) 131 + schema = SchemaRecord( 132 + name="ImageSample", 133 + version="1.0.0", 134 + description="A sample containing image data with metadata", 135 + fields=[ 136 + FieldDef( 137 + name="image", 138 + field_type=FieldType(kind="ndarray", dtype="float32", shape=[224, 224, 3]), 139 + optional=False, 140 + ), 141 + FieldDef( 142 + name="label", 143 + field_type=FieldType(kind="primitive", primitive="str"), 144 + optional=False, 145 + ), 146 + FieldDef( 147 + name="confidence", 148 + field_type=FieldType(kind="primitive", primitive="float"), 149 + optional=False, 150 + ), 151 + ], 152 + ) 153 + 154 + # Convert to ATProto record format 155 + record = schema.to_record() 156 + 157 + print("\nSchema record structure:") 158 + print(f" $type: {record['$type']}") 159 + print(f" name: {record['name']}") 160 + print(f" version: {record['version']}") 161 + print(f" description: {record.get('description', 'N/A')}") 162 + print(f" fields: {len(record['fields'])} fields") 163 + 164 + for field in record["fields"]: 165 + print(f" - {field['name']}: {field['fieldType']}") 166 + 167 + 168 + def demo_mock_client(): 169 + """Demonstrate the AtmosphereClient interface with a mock.""" 170 + print("\n" + "=" * 60) 171 + print("Mock Client Demo (no network)") 172 + print("=" * 60) 173 + 174 + from unittest.mock import Mock, MagicMock 175 + 176 + # Create a mock atproto client 177 + mock_atproto = Mock() 178 + mock_atproto.me = MagicMock() 179 + mock_atproto.me.did = "did:plc:demo123456789" 180 + mock_atproto.me.handle = "demo.bsky.social" 181 + 182 + # Mock the login response 183 + mock_profile = Mock() 184 + mock_profile.did = "did:plc:demo123456789" 185 + mock_profile.handle = "demo.bsky.social" 186 + mock_atproto.login.return_value = mock_profile 187 + 188 + # Mock create_record response 189 + mock_response = Mock() 190 + mock_response.uri = "at://did:plc:demo123456789/ac.foundation.dataset.sampleSchema/abc123" 191 + mock_atproto.com.atproto.repo.create_record.return_value = mock_response 192 + 193 + # Create our client with the mock 194 + client = AtmosphereClient(_client=mock_atproto) 195 + client.login("demo.bsky.social", "fake-password") 196 + 197 + print(f"\nAuthenticated as: {client.handle}") 198 + print(f"DID: {client.did}") 199 + 200 + # Demonstrate schema publishing with mock 201 + publisher = SchemaPublisher(client) 202 + uri = publisher.publish( 203 + ImageSample, 204 + name="ImageSample", 205 + version="1.0.0", 206 + description="Demo image sample type", 207 + ) 208 + 209 + print(f"\nPublished schema at: {uri}") 210 + print(f" Authority: {uri.authority}") 211 + print(f" Collection: {uri.collection}") 212 + print(f" Rkey: {uri.rkey}") 213 + 214 + 215 + def demo_live_connection(handle: str, password: str): 216 + """Demonstrate actual ATProto connection. 217 + 218 + Args: 219 + handle: Bluesky handle (e.g., 'alice.bsky.social') 220 + password: App-specific password 221 + """ 222 + print("\n" + "=" * 60) 223 + print("Live ATProto Connection Demo") 224 + print("=" * 60) 225 + 226 + # Create client and authenticate 227 + print(f"\nConnecting as {handle}...") 228 + client = AtmosphereClient() 229 + client.login(handle, password) 230 + 231 + print(f"Authenticated!") 232 + print(f" DID: {client.did}") 233 + print(f" Handle: {client.handle}") 234 + 235 + # Publish a schema 236 + print("\nPublishing ImageSample schema...") 237 + schema_publisher = SchemaPublisher(client) 238 + schema_uri = schema_publisher.publish( 239 + ImageSample, 240 + name="ImageSample", 241 + version="1.0.0", 242 + description="Demo: Image sample with label and confidence", 243 + ) 244 + print(f" Schema URI: {schema_uri}") 245 + 246 + # List schemas we've published 247 + print("\nListing your published schemas...") 248 + schema_loader = SchemaLoader(client) 249 + schemas = schema_loader.list_all(limit=10) 250 + print(f" Found {len(schemas)} schema(s)") 251 + for schema in schemas: 252 + print(f" - {schema.get('name', 'Unknown')}: v{schema.get('version', '?')}") 253 + 254 + # Publish a dataset record (pointing to example URLs) 255 + print("\nPublishing dataset record...") 256 + dataset_publisher = DatasetPublisher(client) 257 + dataset_uri = dataset_publisher.publish_with_urls( 258 + urls=["s3://example-bucket/demo-data-{000000..000009}.tar"], 259 + schema_uri=str(schema_uri), 260 + name="Demo Image Dataset", 261 + description="Example dataset demonstrating atmosphere publishing", 262 + tags=["demo", "images", "atdata"], 263 + license="MIT", 264 + ) 265 + print(f" Dataset URI: {dataset_uri}") 266 + 267 + # List datasets 268 + print("\nListing your published datasets...") 269 + dataset_loader = DatasetLoader(client) 270 + datasets = dataset_loader.list_all(limit=10) 271 + print(f" Found {len(datasets)} dataset(s)") 272 + for ds in datasets: 273 + print(f" - {ds.get('name', 'Unknown')}") 274 + print(f" Schema: {ds.get('schemaRef', 'N/A')}") 275 + tags = ds.get('tags', []) 276 + if tags: 277 + print(f" Tags: {', '.join(tags)}") 278 + 279 + 280 + def demo_dataset_loading(): 281 + """Demonstrate loading a dataset from an ATProto record.""" 282 + print("\n" + "=" * 60) 283 + print("Dataset Loading Demo (conceptual)") 284 + print("=" * 60) 285 + 286 + print(""" 287 + Once you have published a dataset, others can load it like this: 288 + 289 + from atdata.atmosphere import AtmosphereClient, DatasetLoader 290 + 291 + client = AtmosphereClient() 292 + # Note: reading public records doesn't require authentication 293 + 294 + loader = DatasetLoader(client) 295 + 296 + # Get the dataset record 297 + record = loader.get("at://did:plc:abc123/ac.foundation.dataset.record/xyz") 298 + 299 + # Get the WebDataset URLs 300 + urls = loader.get_urls("at://did:plc:abc123/ac.foundation.dataset.record/xyz") 301 + print(f"Dataset URLs: {urls}") 302 + 303 + # If you have the sample type class, create a Dataset directly 304 + dataset = loader.to_dataset( 305 + "at://did:plc:abc123/ac.foundation.dataset.record/xyz", 306 + sample_type=ImageSample, 307 + ) 308 + 309 + # Now iterate as usual 310 + for batch in dataset.shuffled(batch_size=32): 311 + images = batch.image # (32, 224, 224, 3) 312 + labels = batch.label # list of 32 strings 313 + process(images, labels) 314 + """) 315 + 316 + 317 + # ============================================================================= 318 + # Main 319 + # ============================================================================= 320 + 321 + def main(): 322 + parser = argparse.ArgumentParser( 323 + description="Demonstrate atdata.atmosphere ATProto integration", 324 + formatter_class=argparse.RawDescriptionHelpFormatter, 325 + epilog=__doc__, 326 + ) 327 + parser.add_argument( 328 + "--handle", 329 + help="Bluesky handle for live demo (e.g., alice.bsky.social)", 330 + ) 331 + parser.add_argument( 332 + "--password", 333 + help="App-specific password for live demo", 334 + ) 335 + 336 + args = parser.parse_args() 337 + 338 + print("=" * 60) 339 + print("atdata.atmosphere Demo") 340 + print("=" * 60) 341 + print(f"\nTime: {datetime.now().isoformat()}") 342 + print(f"atdata version: {atdata.__name__}") 343 + 344 + # Always run these demos (no network required) 345 + demo_type_introspection() 346 + demo_at_uri_parsing() 347 + demo_schema_record_building() 348 + demo_mock_client() 349 + demo_dataset_loading() 350 + 351 + # Run live demo if credentials provided 352 + if args.handle and args.password: 353 + demo_live_connection(args.handle, args.password) 354 + else: 355 + print("\n" + "=" * 60) 356 + print("Live Demo Skipped") 357 + print("=" * 60) 358 + print("\nTo run with actual ATProto connection:") 359 + print(" python atmosphere_demo.py --handle your.handle --password your-app-password") 360 + print("\nCreate app passwords at: https://bsky.app/settings/app-passwords") 361 + 362 + print("\n" + "=" * 60) 363 + print("Demo Complete!") 364 + print("=" * 60) 365 + 366 + 367 + if __name__ == "__main__": 368 + main()
+7
pyproject.toml
··· 8 8 ] 9 9 requires-python = ">=3.12" 10 10 dependencies = [ 11 + "atproto>=0.0.65", 11 12 "fastparquet>=2024.11.0", 12 13 "msgpack>=1.1.2", 13 14 "numpy>=2.3.4", ··· 16 17 "pydantic>=2.12.5", 17 18 "python-dotenv>=1.2.1", 18 19 "redis-om>=0.3.5", 20 + "requests>=2.32.5", 19 21 "s3fs>=2025.12.0", 20 22 "schemamodels>=0.9.1", 21 23 "tqdm>=4.67.1", 22 24 "webdataset>=1.0.2", 25 + ] 26 + 27 + [project.optional-dependencies] 28 + atmosphere = [ 29 + "atproto>=0.0.55", 23 30 ] 24 31 25 32 [project.scripts]
+3
src/atdata/__init__.py
··· 51 51 lens, 52 52 ) 53 53 54 + # ATProto integration (lazy import to avoid requiring atproto package) 55 + from . import atmosphere 56 + 54 57 55 58 #
+61
src/atdata/atmosphere/__init__.py
··· 1 + """ATProto integration for distributed dataset federation. 2 + 3 + This module provides ATProto publishing and discovery capabilities for atdata, 4 + enabling a loose federation of distributed, typed datasets on the AT Protocol 5 + network. 6 + 7 + Key components: 8 + 9 + - ``AtmosphereClient``: Authentication and session management for ATProto 10 + - ``SchemaPublisher``: Publish PackableSample schemas as ATProto records 11 + - ``DatasetPublisher``: Publish dataset index records with WebDataset URLs 12 + - ``LensPublisher``: Publish lens transformation records 13 + 14 + The ATProto integration is additive - existing atdata functionality continues 15 + to work unchanged. These features are opt-in for users who want to publish 16 + or discover datasets on the ATProto network. 17 + 18 + Example: 19 + >>> from atdata.atmosphere import AtmosphereClient, SchemaPublisher 20 + >>> 21 + >>> client = AtmosphereClient() 22 + >>> client.login("handle.bsky.social", "app-password") 23 + >>> 24 + >>> publisher = SchemaPublisher(client) 25 + >>> schema_uri = publisher.publish(MySampleType, version="1.0.0") 26 + 27 + Note: 28 + This module requires the ``atproto`` package to be installed:: 29 + 30 + pip install atproto 31 + """ 32 + 33 + from .client import AtmosphereClient 34 + from .schema import SchemaPublisher, SchemaLoader 35 + from .records import DatasetPublisher, DatasetLoader 36 + from .lens import LensPublisher, LensLoader 37 + from ._types import ( 38 + AtUri, 39 + SchemaRecord, 40 + DatasetRecord, 41 + LensRecord, 42 + ) 43 + 44 + __all__ = [ 45 + # Client 46 + "AtmosphereClient", 47 + # Schema operations 48 + "SchemaPublisher", 49 + "SchemaLoader", 50 + # Dataset operations 51 + "DatasetPublisher", 52 + "DatasetLoader", 53 + # Lens operations 54 + "LensPublisher", 55 + "LensLoader", 56 + # Types 57 + "AtUri", 58 + "SchemaRecord", 59 + "DatasetRecord", 60 + "LensRecord", 61 + ]
+329
src/atdata/atmosphere/_types.py
··· 1 + """Type definitions for ATProto record structures. 2 + 3 + This module defines the data structures used to represent ATProto records 4 + for schemas, datasets, and lenses. These types map to the Lexicon definitions 5 + in the ``ac.foundation.dataset.*`` namespace. 6 + """ 7 + 8 + from dataclasses import dataclass, field 9 + from datetime import datetime, timezone 10 + from typing import Optional, Literal, Any 11 + 12 + # Lexicon namespace for atdata records 13 + LEXICON_NAMESPACE = "ac.foundation.dataset" 14 + 15 + 16 + @dataclass 17 + class AtUri: 18 + """Parsed AT Protocol URI. 19 + 20 + AT URIs follow the format: at://<authority>/<collection>/<rkey> 21 + 22 + Example: 23 + >>> uri = AtUri.parse("at://did:plc:abc123/ac.foundation.dataset.sampleSchema/xyz") 24 + >>> uri.authority 25 + 'did:plc:abc123' 26 + >>> uri.collection 27 + 'ac.foundation.dataset.sampleSchema' 28 + >>> uri.rkey 29 + 'xyz' 30 + """ 31 + 32 + authority: str 33 + """The DID or handle of the repository owner.""" 34 + 35 + collection: str 36 + """The NSID of the record collection.""" 37 + 38 + rkey: str 39 + """The record key within the collection.""" 40 + 41 + @classmethod 42 + def parse(cls, uri: str) -> "AtUri": 43 + """Parse an AT URI string into components. 44 + 45 + Args: 46 + uri: AT URI string in format ``at://<authority>/<collection>/<rkey>`` 47 + 48 + Returns: 49 + Parsed AtUri instance. 50 + 51 + Raises: 52 + ValueError: If the URI format is invalid. 53 + """ 54 + if not uri.startswith("at://"): 55 + raise ValueError(f"Invalid AT URI: must start with 'at://': {uri}") 56 + 57 + parts = uri[5:].split("/") 58 + if len(parts) < 3: 59 + raise ValueError(f"Invalid AT URI: expected authority/collection/rkey: {uri}") 60 + 61 + return cls( 62 + authority=parts[0], 63 + collection=parts[1], 64 + rkey="/".join(parts[2:]), # rkey may contain slashes 65 + ) 66 + 67 + def __str__(self) -> str: 68 + """Format as AT URI string.""" 69 + return f"at://{self.authority}/{self.collection}/{self.rkey}" 70 + 71 + 72 + @dataclass 73 + class FieldType: 74 + """Schema field type definition. 75 + 76 + Represents a type in the schema type system, supporting primitives, 77 + ndarrays, and references to other schemas. 78 + """ 79 + 80 + kind: Literal["primitive", "ndarray", "ref", "array"] 81 + """The category of type.""" 82 + 83 + primitive: Optional[str] = None 84 + """For kind='primitive': one of 'str', 'int', 'float', 'bool', 'bytes'.""" 85 + 86 + dtype: Optional[str] = None 87 + """For kind='ndarray': numpy dtype string (e.g., 'float32').""" 88 + 89 + shape: Optional[list[int | None]] = None 90 + """For kind='ndarray': shape constraints (None for any dimension).""" 91 + 92 + ref: Optional[str] = None 93 + """For kind='ref': AT URI of referenced schema.""" 94 + 95 + items: Optional["FieldType"] = None 96 + """For kind='array': type of array elements.""" 97 + 98 + 99 + @dataclass 100 + class FieldDef: 101 + """Schema field definition.""" 102 + 103 + name: str 104 + """Field name.""" 105 + 106 + field_type: FieldType 107 + """Type of this field.""" 108 + 109 + optional: bool = False 110 + """Whether this field can be None.""" 111 + 112 + description: Optional[str] = None 113 + """Human-readable description.""" 114 + 115 + 116 + @dataclass 117 + class SchemaRecord: 118 + """ATProto record for a PackableSample schema. 119 + 120 + Maps to the ``ac.foundation.dataset.sampleSchema`` Lexicon. 121 + """ 122 + 123 + name: str 124 + """Human-readable schema name.""" 125 + 126 + version: str 127 + """Semantic version string (e.g., '1.0.0').""" 128 + 129 + fields: list[FieldDef] 130 + """List of field definitions.""" 131 + 132 + description: Optional[str] = None 133 + """Human-readable description.""" 134 + 135 + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) 136 + """When this record was created.""" 137 + 138 + metadata: Optional[dict] = None 139 + """Arbitrary metadata as msgpack-encoded bytes.""" 140 + 141 + def to_record(self) -> dict: 142 + """Convert to ATProto record dict for publishing.""" 143 + record = { 144 + "$type": f"{LEXICON_NAMESPACE}.sampleSchema", 145 + "name": self.name, 146 + "version": self.version, 147 + "fields": [self._field_to_dict(f) for f in self.fields], 148 + "createdAt": self.created_at.isoformat(), 149 + } 150 + if self.description: 151 + record["description"] = self.description 152 + if self.metadata: 153 + record["metadata"] = self.metadata 154 + return record 155 + 156 + def _field_to_dict(self, field_def: FieldDef) -> dict: 157 + """Convert a field definition to dict.""" 158 + result = { 159 + "name": field_def.name, 160 + "fieldType": self._type_to_dict(field_def.field_type), 161 + "optional": field_def.optional, 162 + } 163 + if field_def.description: 164 + result["description"] = field_def.description 165 + return result 166 + 167 + def _type_to_dict(self, field_type: FieldType) -> dict: 168 + """Convert a field type to dict.""" 169 + result: dict = {"$type": f"{LEXICON_NAMESPACE}.schemaType#{field_type.kind}"} 170 + 171 + if field_type.kind == "primitive": 172 + result["primitive"] = field_type.primitive 173 + elif field_type.kind == "ndarray": 174 + result["dtype"] = field_type.dtype 175 + if field_type.shape: 176 + result["shape"] = field_type.shape 177 + elif field_type.kind == "ref": 178 + result["ref"] = field_type.ref 179 + elif field_type.kind == "array": 180 + if field_type.items: 181 + result["items"] = self._type_to_dict(field_type.items) 182 + 183 + return result 184 + 185 + 186 + @dataclass 187 + class StorageLocation: 188 + """Dataset storage location specification.""" 189 + 190 + kind: Literal["external", "blobs"] 191 + """Storage type: external URLs or ATProto blobs.""" 192 + 193 + urls: Optional[list[str]] = None 194 + """For kind='external': WebDataset URLs with brace notation.""" 195 + 196 + blob_refs: Optional[list[dict]] = None 197 + """For kind='blobs': ATProto blob references.""" 198 + 199 + 200 + @dataclass 201 + class DatasetRecord: 202 + """ATProto record for a dataset index. 203 + 204 + Maps to the ``ac.foundation.dataset.record`` Lexicon. 205 + """ 206 + 207 + name: str 208 + """Human-readable dataset name.""" 209 + 210 + schema_ref: str 211 + """AT URI of the schema record.""" 212 + 213 + storage: StorageLocation 214 + """Where the dataset data is stored.""" 215 + 216 + description: Optional[str] = None 217 + """Human-readable description.""" 218 + 219 + tags: list[str] = field(default_factory=list) 220 + """Searchable tags.""" 221 + 222 + license: Optional[str] = None 223 + """SPDX license identifier.""" 224 + 225 + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) 226 + """When this record was created.""" 227 + 228 + metadata: Optional[bytes] = None 229 + """Arbitrary metadata as msgpack-encoded bytes.""" 230 + 231 + def to_record(self) -> dict: 232 + """Convert to ATProto record dict for publishing.""" 233 + record = { 234 + "$type": f"{LEXICON_NAMESPACE}.record", 235 + "name": self.name, 236 + "schemaRef": self.schema_ref, 237 + "storage": self._storage_to_dict(), 238 + "createdAt": self.created_at.isoformat(), 239 + } 240 + if self.description: 241 + record["description"] = self.description 242 + if self.tags: 243 + record["tags"] = self.tags 244 + if self.license: 245 + record["license"] = self.license 246 + if self.metadata: 247 + record["metadata"] = self.metadata 248 + return record 249 + 250 + def _storage_to_dict(self) -> dict: 251 + """Convert storage location to dict.""" 252 + if self.storage.kind == "external": 253 + return { 254 + "$type": f"{LEXICON_NAMESPACE}.storageExternal", 255 + "urls": self.storage.urls or [], 256 + } 257 + else: 258 + return { 259 + "$type": f"{LEXICON_NAMESPACE}.storageBlobs", 260 + "blobs": self.storage.blob_refs or [], 261 + } 262 + 263 + 264 + @dataclass 265 + class CodeReference: 266 + """Reference to lens code in a git repository.""" 267 + 268 + repository: str 269 + """Git repository URL.""" 270 + 271 + commit: str 272 + """Git commit hash.""" 273 + 274 + path: str 275 + """Path to the code file/function.""" 276 + 277 + 278 + @dataclass 279 + class LensRecord: 280 + """ATProto record for a lens transformation. 281 + 282 + Maps to the ``ac.foundation.dataset.lens`` Lexicon. 283 + """ 284 + 285 + name: str 286 + """Human-readable lens name.""" 287 + 288 + source_schema: str 289 + """AT URI of the source schema.""" 290 + 291 + target_schema: str 292 + """AT URI of the target schema.""" 293 + 294 + description: Optional[str] = None 295 + """What this transformation does.""" 296 + 297 + getter_code: Optional[CodeReference] = None 298 + """Reference to getter function code.""" 299 + 300 + putter_code: Optional[CodeReference] = None 301 + """Reference to putter function code.""" 302 + 303 + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) 304 + """When this record was created.""" 305 + 306 + def to_record(self) -> dict: 307 + """Convert to ATProto record dict for publishing.""" 308 + record: dict[str, Any] = { 309 + "$type": f"{LEXICON_NAMESPACE}.lens", 310 + "name": self.name, 311 + "sourceSchema": self.source_schema, 312 + "targetSchema": self.target_schema, 313 + "createdAt": self.created_at.isoformat(), 314 + } 315 + if self.description: 316 + record["description"] = self.description 317 + if self.getter_code: 318 + record["getterCode"] = { 319 + "repository": self.getter_code.repository, 320 + "commit": self.getter_code.commit, 321 + "path": self.getter_code.path, 322 + } 323 + if self.putter_code: 324 + record["putterCode"] = { 325 + "repository": self.putter_code.repository, 326 + "commit": self.putter_code.commit, 327 + "path": self.putter_code.path, 328 + } 329 + return record
+393
src/atdata/atmosphere/client.py
··· 1 + """ATProto client wrapper for atdata. 2 + 3 + This module provides the ``AtmosphereClient`` class which wraps the atproto SDK 4 + client with atdata-specific helpers for publishing and querying records. 5 + """ 6 + 7 + from typing import Optional, Any 8 + 9 + from ._types import AtUri, LEXICON_NAMESPACE 10 + 11 + # Lazy import to avoid requiring atproto if not using atmosphere features 12 + _atproto_client_class: Optional[type] = None 13 + 14 + 15 + def _get_atproto_client_class(): 16 + """Lazily import the atproto Client class.""" 17 + global _atproto_client_class 18 + if _atproto_client_class is None: 19 + try: 20 + from atproto import Client 21 + _atproto_client_class = Client 22 + except ImportError as e: 23 + raise ImportError( 24 + "The 'atproto' package is required for ATProto integration. " 25 + "Install it with: pip install atproto" 26 + ) from e 27 + return _atproto_client_class 28 + 29 + 30 + class AtmosphereClient: 31 + """ATProto client wrapper for atdata operations. 32 + 33 + This class wraps the atproto SDK client and provides higher-level methods 34 + for working with atdata records (schemas, datasets, lenses). 35 + 36 + Example: 37 + >>> client = AtmosphereClient() 38 + >>> client.login("alice.bsky.social", "app-password") 39 + >>> print(client.did) 40 + 'did:plc:...' 41 + 42 + Note: 43 + The password should be an app-specific password, not your main account 44 + password. Create app passwords in your Bluesky account settings. 45 + """ 46 + 47 + def __init__( 48 + self, 49 + base_url: Optional[str] = None, 50 + *, 51 + _client: Optional[Any] = None, 52 + ): 53 + """Initialize the ATProto client. 54 + 55 + Args: 56 + base_url: Optional PDS base URL. Defaults to bsky.social. 57 + _client: Optional pre-configured atproto Client for testing. 58 + """ 59 + if _client is not None: 60 + self._client = _client 61 + else: 62 + Client = _get_atproto_client_class() 63 + self._client = Client(base_url=base_url) if base_url else Client() 64 + 65 + self._session: Optional[dict] = None 66 + 67 + def login(self, handle: str, password: str) -> None: 68 + """Authenticate with the ATProto PDS. 69 + 70 + Args: 71 + handle: Your Bluesky handle (e.g., 'alice.bsky.social'). 72 + password: App-specific password (not your main password). 73 + 74 + Raises: 75 + atproto.exceptions.AtProtocolError: If authentication fails. 76 + """ 77 + profile = self._client.login(handle, password) 78 + self._session = { 79 + "did": profile.did, 80 + "handle": profile.handle, 81 + } 82 + 83 + def login_with_session(self, session_string: str) -> None: 84 + """Authenticate using an exported session string. 85 + 86 + This allows reusing a session without re-authenticating, which helps 87 + avoid rate limits on session creation. 88 + 89 + Args: 90 + session_string: Session string from ``export_session()``. 91 + """ 92 + self._client.login(session_string=session_string) 93 + self._session = { 94 + "did": self._client.me.did, 95 + "handle": self._client.me.handle, 96 + } 97 + 98 + def export_session(self) -> str: 99 + """Export the current session for later reuse. 100 + 101 + Returns: 102 + Session string that can be passed to ``login_with_session()``. 103 + 104 + Raises: 105 + ValueError: If not authenticated. 106 + """ 107 + if not self.is_authenticated: 108 + raise ValueError("Not authenticated") 109 + return self._client.export_session_string() 110 + 111 + @property 112 + def is_authenticated(self) -> bool: 113 + """Check if the client has a valid session.""" 114 + return self._session is not None 115 + 116 + @property 117 + def did(self) -> str: 118 + """Get the DID of the authenticated user. 119 + 120 + Returns: 121 + The DID string (e.g., 'did:plc:...'). 122 + 123 + Raises: 124 + ValueError: If not authenticated. 125 + """ 126 + if not self._session: 127 + raise ValueError("Not authenticated") 128 + return self._session["did"] 129 + 130 + @property 131 + def handle(self) -> str: 132 + """Get the handle of the authenticated user. 133 + 134 + Returns: 135 + The handle string (e.g., 'alice.bsky.social'). 136 + 137 + Raises: 138 + ValueError: If not authenticated. 139 + """ 140 + if not self._session: 141 + raise ValueError("Not authenticated") 142 + return self._session["handle"] 143 + 144 + def _ensure_authenticated(self) -> None: 145 + """Raise if not authenticated.""" 146 + if not self.is_authenticated: 147 + raise ValueError("Client must be authenticated to perform this operation") 148 + 149 + # Low-level record operations 150 + 151 + def create_record( 152 + self, 153 + collection: str, 154 + record: dict, 155 + *, 156 + rkey: Optional[str] = None, 157 + validate: bool = False, 158 + ) -> AtUri: 159 + """Create a record in the user's repository. 160 + 161 + Args: 162 + collection: The NSID of the record collection 163 + (e.g., 'ac.foundation.dataset.sampleSchema'). 164 + record: The record data. Must include a '$type' field. 165 + rkey: Optional explicit record key. If not provided, a TID is generated. 166 + validate: Whether to validate against the Lexicon schema. Set to False 167 + for custom lexicons that the PDS doesn't know about. 168 + 169 + Returns: 170 + The AT URI of the created record. 171 + 172 + Raises: 173 + ValueError: If not authenticated. 174 + atproto.exceptions.AtProtocolError: If record creation fails. 175 + """ 176 + self._ensure_authenticated() 177 + 178 + response = self._client.com.atproto.repo.create_record( 179 + data={ 180 + "repo": self.did, 181 + "collection": collection, 182 + "record": record, 183 + "rkey": rkey, 184 + "validate": validate, 185 + } 186 + ) 187 + 188 + return AtUri.parse(response.uri) 189 + 190 + def put_record( 191 + self, 192 + collection: str, 193 + rkey: str, 194 + record: dict, 195 + *, 196 + validate: bool = False, 197 + swap_commit: Optional[str] = None, 198 + ) -> AtUri: 199 + """Create or update a record at a specific key. 200 + 201 + Args: 202 + collection: The NSID of the record collection. 203 + rkey: The record key. 204 + record: The record data. Must include a '$type' field. 205 + validate: Whether to validate against the Lexicon schema. 206 + swap_commit: Optional CID for compare-and-swap update. 207 + 208 + Returns: 209 + The AT URI of the record. 210 + 211 + Raises: 212 + ValueError: If not authenticated. 213 + atproto.exceptions.AtProtocolError: If operation fails. 214 + """ 215 + self._ensure_authenticated() 216 + 217 + data: dict[str, Any] = { 218 + "repo": self.did, 219 + "collection": collection, 220 + "rkey": rkey, 221 + "record": record, 222 + "validate": validate, 223 + } 224 + if swap_commit: 225 + data["swapCommit"] = swap_commit 226 + 227 + response = self._client.com.atproto.repo.put_record(data=data) 228 + 229 + return AtUri.parse(response.uri) 230 + 231 + def get_record( 232 + self, 233 + uri: str | AtUri, 234 + ) -> dict: 235 + """Fetch a record by AT URI. 236 + 237 + Args: 238 + uri: The AT URI of the record. 239 + 240 + Returns: 241 + The record data as a dictionary. 242 + 243 + Raises: 244 + atproto.exceptions.AtProtocolError: If record not found. 245 + """ 246 + if isinstance(uri, str): 247 + uri = AtUri.parse(uri) 248 + 249 + response = self._client.com.atproto.repo.get_record( 250 + params={ 251 + "repo": uri.authority, 252 + "collection": uri.collection, 253 + "rkey": uri.rkey, 254 + } 255 + ) 256 + 257 + return response.value 258 + 259 + def delete_record( 260 + self, 261 + uri: str | AtUri, 262 + *, 263 + swap_commit: Optional[str] = None, 264 + ) -> None: 265 + """Delete a record. 266 + 267 + Args: 268 + uri: The AT URI of the record to delete. 269 + swap_commit: Optional CID for compare-and-swap delete. 270 + 271 + Raises: 272 + ValueError: If not authenticated. 273 + atproto.exceptions.AtProtocolError: If deletion fails. 274 + """ 275 + self._ensure_authenticated() 276 + 277 + if isinstance(uri, str): 278 + uri = AtUri.parse(uri) 279 + 280 + data: dict[str, Any] = { 281 + "repo": self.did, 282 + "collection": uri.collection, 283 + "rkey": uri.rkey, 284 + } 285 + if swap_commit: 286 + data["swapCommit"] = swap_commit 287 + 288 + self._client.com.atproto.repo.delete_record(data=data) 289 + 290 + def list_records( 291 + self, 292 + collection: str, 293 + *, 294 + repo: Optional[str] = None, 295 + limit: int = 100, 296 + cursor: Optional[str] = None, 297 + ) -> tuple[list[dict], Optional[str]]: 298 + """List records in a collection. 299 + 300 + Args: 301 + collection: The NSID of the record collection. 302 + repo: The DID of the repository to query. Defaults to the 303 + authenticated user's repository. 304 + limit: Maximum number of records to return (default 100). 305 + cursor: Pagination cursor from a previous call. 306 + 307 + Returns: 308 + A tuple of (records, next_cursor). The cursor is None if there 309 + are no more records. 310 + 311 + Raises: 312 + ValueError: If repo is None and not authenticated. 313 + """ 314 + if repo is None: 315 + self._ensure_authenticated() 316 + repo = self.did 317 + 318 + response = self._client.com.atproto.repo.list_records( 319 + params={ 320 + "repo": repo, 321 + "collection": collection, 322 + "limit": limit, 323 + "cursor": cursor, 324 + } 325 + ) 326 + 327 + records = [r.value for r in response.records] 328 + return records, response.cursor 329 + 330 + # Convenience methods for atdata collections 331 + 332 + def list_schemas( 333 + self, 334 + repo: Optional[str] = None, 335 + limit: int = 100, 336 + ) -> list[dict]: 337 + """List schema records. 338 + 339 + Args: 340 + repo: The DID to query. Defaults to authenticated user. 341 + limit: Maximum number to return. 342 + 343 + Returns: 344 + List of schema records. 345 + """ 346 + records, _ = self.list_records( 347 + f"{LEXICON_NAMESPACE}.sampleSchema", 348 + repo=repo, 349 + limit=limit, 350 + ) 351 + return records 352 + 353 + def list_datasets( 354 + self, 355 + repo: Optional[str] = None, 356 + limit: int = 100, 357 + ) -> list[dict]: 358 + """List dataset records. 359 + 360 + Args: 361 + repo: The DID to query. Defaults to authenticated user. 362 + limit: Maximum number to return. 363 + 364 + Returns: 365 + List of dataset records. 366 + """ 367 + records, _ = self.list_records( 368 + f"{LEXICON_NAMESPACE}.record", 369 + repo=repo, 370 + limit=limit, 371 + ) 372 + return records 373 + 374 + def list_lenses( 375 + self, 376 + repo: Optional[str] = None, 377 + limit: int = 100, 378 + ) -> list[dict]: 379 + """List lens records. 380 + 381 + Args: 382 + repo: The DID to query. Defaults to authenticated user. 383 + limit: Maximum number to return. 384 + 385 + Returns: 386 + List of lens records. 387 + """ 388 + records, _ = self.list_records( 389 + f"{LEXICON_NAMESPACE}.lens", 390 + repo=repo, 391 + limit=limit, 392 + ) 393 + return records
+280
src/atdata/atmosphere/lens.py
··· 1 + """Lens transformation publishing for ATProto. 2 + 3 + This module provides classes for publishing Lens transformation records to 4 + ATProto. Lenses are published as ``ac.foundation.dataset.lens`` records. 5 + 6 + Note: 7 + For security reasons, lens code is stored as references to git repositories 8 + rather than inline code. Users must manually install and import lens 9 + implementations. 10 + """ 11 + 12 + from typing import Optional, Callable 13 + 14 + from .client import AtmosphereClient 15 + from ._types import ( 16 + AtUri, 17 + LensRecord, 18 + CodeReference, 19 + LEXICON_NAMESPACE, 20 + ) 21 + 22 + # Import for type checking only 23 + from typing import TYPE_CHECKING 24 + if TYPE_CHECKING: 25 + from ..lens import Lens 26 + 27 + 28 + class LensPublisher: 29 + """Publishes Lens transformation records to ATProto. 30 + 31 + This class creates lens records that reference source and target schemas 32 + and point to the transformation code in a git repository. 33 + 34 + Example: 35 + >>> @atdata.lens 36 + ... def my_lens(source: SourceType) -> TargetType: 37 + ... return TargetType(field=source.other_field) 38 + >>> 39 + >>> client = AtmosphereClient() 40 + >>> client.login("handle", "password") 41 + >>> 42 + >>> publisher = LensPublisher(client) 43 + >>> uri = publisher.publish( 44 + ... name="my_lens", 45 + ... source_schema_uri="at://did:plc:abc/ac.foundation.dataset.sampleSchema/source", 46 + ... target_schema_uri="at://did:plc:abc/ac.foundation.dataset.sampleSchema/target", 47 + ... code_repository="https://github.com/user/repo", 48 + ... code_commit="abc123def456", 49 + ... getter_path="mymodule.lenses:my_lens", 50 + ... putter_path="mymodule.lenses:my_lens_putter", 51 + ... ) 52 + 53 + Security Note: 54 + Lens code is stored as references to git repositories rather than 55 + inline code. This prevents arbitrary code execution from ATProto 56 + records. Users must manually install and trust lens implementations. 57 + """ 58 + 59 + def __init__(self, client: AtmosphereClient): 60 + """Initialize the lens publisher. 61 + 62 + Args: 63 + client: Authenticated AtmosphereClient instance. 64 + """ 65 + self.client = client 66 + 67 + def publish( 68 + self, 69 + *, 70 + name: str, 71 + source_schema_uri: str, 72 + target_schema_uri: str, 73 + description: Optional[str] = None, 74 + code_repository: Optional[str] = None, 75 + code_commit: Optional[str] = None, 76 + getter_path: Optional[str] = None, 77 + putter_path: Optional[str] = None, 78 + rkey: Optional[str] = None, 79 + ) -> AtUri: 80 + """Publish a lens transformation record to ATProto. 81 + 82 + Args: 83 + name: Human-readable lens name. 84 + source_schema_uri: AT URI of the source schema. 85 + target_schema_uri: AT URI of the target schema. 86 + description: What this transformation does. 87 + code_repository: Git repository URL containing the lens code. 88 + code_commit: Git commit hash for reproducibility. 89 + getter_path: Module path to the getter function 90 + (e.g., 'mymodule.lenses:my_getter'). 91 + putter_path: Module path to the putter function 92 + (e.g., 'mymodule.lenses:my_putter'). 93 + rkey: Optional explicit record key. 94 + 95 + Returns: 96 + The AT URI of the created lens record. 97 + 98 + Raises: 99 + ValueError: If code references are incomplete. 100 + """ 101 + # Build code references if provided 102 + getter_code: Optional[CodeReference] = None 103 + putter_code: Optional[CodeReference] = None 104 + 105 + if code_repository and code_commit: 106 + if getter_path: 107 + getter_code = CodeReference( 108 + repository=code_repository, 109 + commit=code_commit, 110 + path=getter_path, 111 + ) 112 + if putter_path: 113 + putter_code = CodeReference( 114 + repository=code_repository, 115 + commit=code_commit, 116 + path=putter_path, 117 + ) 118 + 119 + lens_record = LensRecord( 120 + name=name, 121 + source_schema=source_schema_uri, 122 + target_schema=target_schema_uri, 123 + description=description, 124 + getter_code=getter_code, 125 + putter_code=putter_code, 126 + ) 127 + 128 + return self.client.create_record( 129 + collection=f"{LEXICON_NAMESPACE}.lens", 130 + record=lens_record.to_record(), 131 + rkey=rkey, 132 + validate=False, 133 + ) 134 + 135 + def publish_from_lens( 136 + self, 137 + lens_obj: "Lens", 138 + *, 139 + name: str, 140 + source_schema_uri: str, 141 + target_schema_uri: str, 142 + code_repository: str, 143 + code_commit: str, 144 + description: Optional[str] = None, 145 + rkey: Optional[str] = None, 146 + ) -> AtUri: 147 + """Publish a lens record from an existing Lens object. 148 + 149 + This method extracts the getter and putter function names from 150 + the Lens object and publishes a record referencing them. 151 + 152 + Args: 153 + lens_obj: The Lens object to publish. 154 + name: Human-readable lens name. 155 + source_schema_uri: AT URI of the source schema. 156 + target_schema_uri: AT URI of the target schema. 157 + code_repository: Git repository URL. 158 + code_commit: Git commit hash. 159 + description: What this transformation does. 160 + rkey: Optional explicit record key. 161 + 162 + Returns: 163 + The AT URI of the created lens record. 164 + """ 165 + # Extract function names from the lens 166 + getter_name = lens_obj._getter.__name__ 167 + putter_name = lens_obj._putter.__name__ 168 + 169 + # Get module info if available 170 + getter_module = getattr(lens_obj._getter, "__module__", "") 171 + putter_module = getattr(lens_obj._putter, "__module__", "") 172 + 173 + getter_path = f"{getter_module}:{getter_name}" if getter_module else getter_name 174 + putter_path = f"{putter_module}:{putter_name}" if putter_module else putter_name 175 + 176 + return self.publish( 177 + name=name, 178 + source_schema_uri=source_schema_uri, 179 + target_schema_uri=target_schema_uri, 180 + description=description, 181 + code_repository=code_repository, 182 + code_commit=code_commit, 183 + getter_path=getter_path, 184 + putter_path=putter_path, 185 + rkey=rkey, 186 + ) 187 + 188 + 189 + class LensLoader: 190 + """Loads lens records from ATProto. 191 + 192 + This class fetches lens transformation records. Note that actually 193 + using a lens requires installing the referenced code and importing 194 + it manually. 195 + 196 + Example: 197 + >>> client = AtmosphereClient() 198 + >>> loader = LensLoader(client) 199 + >>> 200 + >>> record = loader.get("at://did:plc:abc/ac.foundation.dataset.lens/xyz") 201 + >>> print(record["name"]) 202 + >>> print(record["sourceSchema"]) 203 + >>> print(record.get("getterCode", {}).get("repository")) 204 + """ 205 + 206 + def __init__(self, client: AtmosphereClient): 207 + """Initialize the lens loader. 208 + 209 + Args: 210 + client: AtmosphereClient instance. 211 + """ 212 + self.client = client 213 + 214 + def get(self, uri: str | AtUri) -> dict: 215 + """Fetch a lens record by AT URI. 216 + 217 + Args: 218 + uri: The AT URI of the lens record. 219 + 220 + Returns: 221 + The lens record as a dictionary. 222 + 223 + Raises: 224 + ValueError: If the record is not a lens record. 225 + """ 226 + record = self.client.get_record(uri) 227 + 228 + expected_type = f"{LEXICON_NAMESPACE}.lens" 229 + if record.get("$type") != expected_type: 230 + raise ValueError( 231 + f"Record at {uri} is not a lens record. " 232 + f"Expected $type='{expected_type}', got '{record.get('$type')}'" 233 + ) 234 + 235 + return record 236 + 237 + def list_all( 238 + self, 239 + repo: Optional[str] = None, 240 + limit: int = 100, 241 + ) -> list[dict]: 242 + """List lens records from a repository. 243 + 244 + Args: 245 + repo: The DID of the repository. Defaults to authenticated user. 246 + limit: Maximum number of records to return. 247 + 248 + Returns: 249 + List of lens records. 250 + """ 251 + return self.client.list_lenses(repo=repo, limit=limit) 252 + 253 + def find_by_schemas( 254 + self, 255 + source_schema_uri: str, 256 + target_schema_uri: Optional[str] = None, 257 + repo: Optional[str] = None, 258 + ) -> list[dict]: 259 + """Find lenses that transform between specific schemas. 260 + 261 + Args: 262 + source_schema_uri: AT URI of the source schema. 263 + target_schema_uri: Optional AT URI of the target schema. 264 + If not provided, returns all lenses from the source. 265 + repo: The DID of the repository to search. 266 + 267 + Returns: 268 + List of matching lens records. 269 + """ 270 + all_lenses = self.list_all(repo=repo, limit=1000) 271 + 272 + matches = [] 273 + for lens_record in all_lenses: 274 + if lens_record.get("sourceSchema") == source_schema_uri: 275 + if target_schema_uri is None: 276 + matches.append(lens_record) 277 + elif lens_record.get("targetSchema") == target_schema_uri: 278 + matches.append(lens_record) 279 + 280 + return matches
+342
src/atdata/atmosphere/records.py
··· 1 + """Dataset record publishing and loading for ATProto. 2 + 3 + This module provides classes for publishing dataset index records to ATProto 4 + and loading them back. Dataset records are published as 5 + ``ac.foundation.dataset.record`` records. 6 + """ 7 + 8 + from typing import Type, TypeVar, Optional 9 + import msgpack 10 + 11 + from .client import AtmosphereClient 12 + from .schema import SchemaPublisher 13 + from ._types import ( 14 + AtUri, 15 + DatasetRecord, 16 + StorageLocation, 17 + LEXICON_NAMESPACE, 18 + ) 19 + 20 + # Import for type checking only to avoid circular imports 21 + from typing import TYPE_CHECKING 22 + if TYPE_CHECKING: 23 + from ..dataset import PackableSample, Dataset 24 + 25 + ST = TypeVar("ST", bound="PackableSample") 26 + 27 + 28 + class DatasetPublisher: 29 + """Publishes dataset index records to ATProto. 30 + 31 + This class creates dataset records that reference a schema and point to 32 + external storage (WebDataset URLs) or ATProto blobs. 33 + 34 + Example: 35 + >>> dataset = atdata.Dataset[MySample]("s3://bucket/data-{000000..000009}.tar") 36 + >>> 37 + >>> client = AtmosphereClient() 38 + >>> client.login("handle", "password") 39 + >>> 40 + >>> publisher = DatasetPublisher(client) 41 + >>> uri = publisher.publish( 42 + ... dataset, 43 + ... name="My Training Data", 44 + ... description="Training data for my model", 45 + ... tags=["computer-vision", "training"], 46 + ... ) 47 + """ 48 + 49 + def __init__(self, client: AtmosphereClient): 50 + """Initialize the dataset publisher. 51 + 52 + Args: 53 + client: Authenticated AtmosphereClient instance. 54 + """ 55 + self.client = client 56 + self._schema_publisher = SchemaPublisher(client) 57 + 58 + def publish( 59 + self, 60 + dataset: "Dataset[ST]", 61 + *, 62 + name: str, 63 + schema_uri: Optional[str] = None, 64 + description: Optional[str] = None, 65 + tags: Optional[list[str]] = None, 66 + license: Optional[str] = None, 67 + auto_publish_schema: bool = True, 68 + schema_version: str = "1.0.0", 69 + rkey: Optional[str] = None, 70 + ) -> AtUri: 71 + """Publish a dataset index record to ATProto. 72 + 73 + Args: 74 + dataset: The Dataset to publish. 75 + name: Human-readable dataset name. 76 + schema_uri: AT URI of the schema record. If not provided and 77 + auto_publish_schema is True, the schema will be published. 78 + description: Human-readable description. 79 + tags: Searchable tags for discovery. 80 + license: SPDX license identifier (e.g., 'MIT', 'Apache-2.0'). 81 + auto_publish_schema: If True and schema_uri not provided, 82 + automatically publish the schema first. 83 + schema_version: Version for auto-published schema. 84 + rkey: Optional explicit record key. 85 + 86 + Returns: 87 + The AT URI of the created dataset record. 88 + 89 + Raises: 90 + ValueError: If schema_uri is not provided and auto_publish_schema is False. 91 + """ 92 + # Ensure we have a schema reference 93 + if schema_uri is None: 94 + if not auto_publish_schema: 95 + raise ValueError( 96 + "schema_uri is required when auto_publish_schema=False" 97 + ) 98 + # Auto-publish the schema 99 + schema_uri_obj = self._schema_publisher.publish( 100 + dataset.sample_type, 101 + version=schema_version, 102 + ) 103 + schema_uri = str(schema_uri_obj) 104 + 105 + # Build the storage location 106 + storage = StorageLocation( 107 + kind="external", 108 + urls=[dataset.url], 109 + ) 110 + 111 + # Build dataset record 112 + metadata_bytes: Optional[bytes] = None 113 + if dataset.metadata is not None: 114 + metadata_bytes = msgpack.packb(dataset.metadata) 115 + 116 + dataset_record = DatasetRecord( 117 + name=name, 118 + schema_ref=schema_uri, 119 + storage=storage, 120 + description=description, 121 + tags=tags or [], 122 + license=license, 123 + metadata=metadata_bytes, 124 + ) 125 + 126 + # Publish to ATProto 127 + return self.client.create_record( 128 + collection=f"{LEXICON_NAMESPACE}.record", 129 + record=dataset_record.to_record(), 130 + rkey=rkey, 131 + validate=False, 132 + ) 133 + 134 + def publish_with_urls( 135 + self, 136 + urls: list[str], 137 + schema_uri: str, 138 + *, 139 + name: str, 140 + description: Optional[str] = None, 141 + tags: Optional[list[str]] = None, 142 + license: Optional[str] = None, 143 + metadata: Optional[dict] = None, 144 + rkey: Optional[str] = None, 145 + ) -> AtUri: 146 + """Publish a dataset record with explicit URLs. 147 + 148 + This method allows publishing a dataset record without having a 149 + Dataset object, useful for registering existing WebDataset files. 150 + 151 + Args: 152 + urls: List of WebDataset URLs with brace notation. 153 + schema_uri: AT URI of the schema record. 154 + name: Human-readable dataset name. 155 + description: Human-readable description. 156 + tags: Searchable tags for discovery. 157 + license: SPDX license identifier. 158 + metadata: Arbitrary metadata dictionary. 159 + rkey: Optional explicit record key. 160 + 161 + Returns: 162 + The AT URI of the created dataset record. 163 + """ 164 + storage = StorageLocation( 165 + kind="external", 166 + urls=urls, 167 + ) 168 + 169 + metadata_bytes: Optional[bytes] = None 170 + if metadata is not None: 171 + metadata_bytes = msgpack.packb(metadata) 172 + 173 + dataset_record = DatasetRecord( 174 + name=name, 175 + schema_ref=schema_uri, 176 + storage=storage, 177 + description=description, 178 + tags=tags or [], 179 + license=license, 180 + metadata=metadata_bytes, 181 + ) 182 + 183 + return self.client.create_record( 184 + collection=f"{LEXICON_NAMESPACE}.record", 185 + record=dataset_record.to_record(), 186 + rkey=rkey, 187 + validate=False, 188 + ) 189 + 190 + 191 + class DatasetLoader: 192 + """Loads dataset records from ATProto. 193 + 194 + This class fetches dataset index records and can create Dataset objects 195 + from them. Note that loading a dataset requires having the corresponding 196 + Python class for the sample type. 197 + 198 + Example: 199 + >>> client = AtmosphereClient() 200 + >>> loader = DatasetLoader(client) 201 + >>> 202 + >>> # List available datasets 203 + >>> datasets = loader.list() 204 + >>> for ds in datasets: 205 + ... print(ds["name"], ds["schemaRef"]) 206 + >>> 207 + >>> # Get a specific dataset record 208 + >>> record = loader.get("at://did:plc:abc/ac.foundation.dataset.record/xyz") 209 + """ 210 + 211 + def __init__(self, client: AtmosphereClient): 212 + """Initialize the dataset loader. 213 + 214 + Args: 215 + client: AtmosphereClient instance. 216 + """ 217 + self.client = client 218 + 219 + def get(self, uri: str | AtUri) -> dict: 220 + """Fetch a dataset record by AT URI. 221 + 222 + Args: 223 + uri: The AT URI of the dataset record. 224 + 225 + Returns: 226 + The dataset record as a dictionary. 227 + 228 + Raises: 229 + ValueError: If the record is not a dataset record. 230 + """ 231 + record = self.client.get_record(uri) 232 + 233 + expected_type = f"{LEXICON_NAMESPACE}.record" 234 + if record.get("$type") != expected_type: 235 + raise ValueError( 236 + f"Record at {uri} is not a dataset record. " 237 + f"Expected $type='{expected_type}', got '{record.get('$type')}'" 238 + ) 239 + 240 + return record 241 + 242 + def list_all( 243 + self, 244 + repo: Optional[str] = None, 245 + limit: int = 100, 246 + ) -> list[dict]: 247 + """List dataset records from a repository. 248 + 249 + Args: 250 + repo: The DID of the repository. Defaults to authenticated user. 251 + limit: Maximum number of records to return. 252 + 253 + Returns: 254 + List of dataset records. 255 + """ 256 + return self.client.list_datasets(repo=repo, limit=limit) 257 + 258 + def get_urls(self, uri: str | AtUri) -> list[str]: 259 + """Get the WebDataset URLs from a dataset record. 260 + 261 + Args: 262 + uri: The AT URI of the dataset record. 263 + 264 + Returns: 265 + List of WebDataset URLs. 266 + 267 + Raises: 268 + ValueError: If the storage type is not external URLs. 269 + """ 270 + record = self.get(uri) 271 + storage = record.get("storage", {}) 272 + 273 + storage_type = storage.get("$type", "") 274 + if "storageExternal" in storage_type: 275 + return storage.get("urls", []) 276 + elif "storageBlobs" in storage_type: 277 + raise ValueError( 278 + "Dataset uses blob storage, not external URLs. " 279 + "Use get_blobs() instead." 280 + ) 281 + else: 282 + raise ValueError(f"Unknown storage type: {storage_type}") 283 + 284 + def get_metadata(self, uri: str | AtUri) -> Optional[dict]: 285 + """Get the metadata from a dataset record. 286 + 287 + Args: 288 + uri: The AT URI of the dataset record. 289 + 290 + Returns: 291 + The metadata dictionary, or None if no metadata. 292 + """ 293 + record = self.get(uri) 294 + metadata_bytes = record.get("metadata") 295 + 296 + if metadata_bytes is None: 297 + return None 298 + 299 + return msgpack.unpackb(metadata_bytes, raw=False) 300 + 301 + def to_dataset( 302 + self, 303 + uri: str | AtUri, 304 + sample_type: Type[ST], 305 + ) -> "Dataset[ST]": 306 + """Create a Dataset object from an ATProto record. 307 + 308 + This method creates a Dataset instance from a published record. 309 + You must provide the sample type class, which should match the 310 + schema referenced by the record. 311 + 312 + Args: 313 + uri: The AT URI of the dataset record. 314 + sample_type: The Python class for the sample type. 315 + 316 + Returns: 317 + A Dataset instance configured from the record. 318 + 319 + Raises: 320 + ValueError: If the storage type is not external URLs. 321 + 322 + Example: 323 + >>> loader = DatasetLoader(client) 324 + >>> dataset = loader.to_dataset(uri, MySampleType) 325 + >>> for batch in dataset.shuffled(batch_size=32): 326 + ... process(batch) 327 + """ 328 + # Import here to avoid circular import 329 + from ..dataset import Dataset 330 + 331 + urls = self.get_urls(uri) 332 + if not urls: 333 + raise ValueError("Dataset record has no URLs") 334 + 335 + # Use the first URL (multi-URL support could be added later) 336 + url = urls[0] 337 + 338 + # Get metadata URL if available 339 + record = self.get(uri) 340 + metadata_url = record.get("metadataUrl") 341 + 342 + return Dataset[sample_type](url, metadata_url=metadata_url)
+296
src/atdata/atmosphere/schema.py
··· 1 + """Schema publishing and loading for ATProto. 2 + 3 + This module provides classes for publishing PackableSample schemas to ATProto 4 + and loading them back. Schemas are published as ``ac.foundation.dataset.sampleSchema`` 5 + records. 6 + """ 7 + 8 + from dataclasses import fields, is_dataclass 9 + from typing import Type, TypeVar, Optional, Union, get_type_hints, get_origin, get_args 10 + import types 11 + 12 + from .client import AtmosphereClient 13 + from ._types import ( 14 + AtUri, 15 + SchemaRecord, 16 + FieldDef, 17 + FieldType, 18 + LEXICON_NAMESPACE, 19 + ) 20 + 21 + # Import for type checking only to avoid circular imports 22 + from typing import TYPE_CHECKING 23 + if TYPE_CHECKING: 24 + from ..dataset import PackableSample 25 + 26 + ST = TypeVar("ST", bound="PackableSample") 27 + 28 + 29 + class SchemaPublisher: 30 + """Publishes PackableSample schemas to ATProto. 31 + 32 + This class introspects a PackableSample class to extract its field 33 + definitions and publishes them as an ATProto schema record. 34 + 35 + Example: 36 + >>> @atdata.packable 37 + ... class MySample: 38 + ... image: NDArray 39 + ... label: str 40 + ... 41 + >>> client = AtmosphereClient() 42 + >>> client.login("handle", "password") 43 + >>> 44 + >>> publisher = SchemaPublisher(client) 45 + >>> uri = publisher.publish(MySample, version="1.0.0") 46 + >>> print(uri) 47 + at://did:plc:.../ac.foundation.dataset.sampleSchema/... 48 + """ 49 + 50 + def __init__(self, client: AtmosphereClient): 51 + """Initialize the schema publisher. 52 + 53 + Args: 54 + client: Authenticated AtmosphereClient instance. 55 + """ 56 + self.client = client 57 + 58 + def publish( 59 + self, 60 + sample_type: Type[ST], 61 + *, 62 + name: Optional[str] = None, 63 + version: str = "1.0.0", 64 + description: Optional[str] = None, 65 + metadata: Optional[dict] = None, 66 + rkey: Optional[str] = None, 67 + ) -> AtUri: 68 + """Publish a PackableSample schema to ATProto. 69 + 70 + Args: 71 + sample_type: The PackableSample class to publish. 72 + name: Human-readable name. Defaults to the class name. 73 + version: Semantic version string (e.g., '1.0.0'). 74 + description: Human-readable description. 75 + metadata: Arbitrary metadata dictionary. 76 + rkey: Optional explicit record key. If not provided, a TID is generated. 77 + 78 + Returns: 79 + The AT URI of the created schema record. 80 + 81 + Raises: 82 + ValueError: If sample_type is not a dataclass or client is not authenticated. 83 + TypeError: If a field type is not supported. 84 + """ 85 + if not is_dataclass(sample_type): 86 + raise ValueError(f"{sample_type.__name__} must be a dataclass (use @packable)") 87 + 88 + # Build the schema record 89 + schema_record = self._build_schema_record( 90 + sample_type, 91 + name=name, 92 + version=version, 93 + description=description, 94 + metadata=metadata, 95 + ) 96 + 97 + # Publish to ATProto 98 + return self.client.create_record( 99 + collection=f"{LEXICON_NAMESPACE}.sampleSchema", 100 + record=schema_record.to_record(), 101 + rkey=rkey, 102 + validate=False, # PDS doesn't know our lexicon 103 + ) 104 + 105 + def _build_schema_record( 106 + self, 107 + sample_type: Type[ST], 108 + *, 109 + name: Optional[str], 110 + version: str, 111 + description: Optional[str], 112 + metadata: Optional[dict], 113 + ) -> SchemaRecord: 114 + """Build a SchemaRecord from a PackableSample class.""" 115 + field_defs = [] 116 + type_hints = get_type_hints(sample_type) 117 + 118 + for f in fields(sample_type): 119 + field_type = type_hints.get(f.name, f.type) 120 + field_def = self._field_to_def(f.name, field_type) 121 + field_defs.append(field_def) 122 + 123 + return SchemaRecord( 124 + name=name or sample_type.__name__, 125 + version=version, 126 + description=description, 127 + fields=field_defs, 128 + metadata=metadata, 129 + ) 130 + 131 + def _field_to_def(self, name: str, python_type) -> FieldDef: 132 + """Convert a Python field to a FieldDef.""" 133 + # Check for Optional types (Union with None) 134 + is_optional = False 135 + origin = get_origin(python_type) 136 + 137 + # Handle Union types (including Optional which is Union[T, None]) 138 + if origin is Union or isinstance(python_type, types.UnionType): 139 + args = get_args(python_type) 140 + non_none_args = [a for a in args if a is not type(None)] 141 + if type(None) in args or len(non_none_args) < len(args): 142 + is_optional = True 143 + if len(non_none_args) == 1: 144 + python_type = non_none_args[0] 145 + elif len(non_none_args) > 1: 146 + # Complex union type - not fully supported yet 147 + raise TypeError(f"Complex union types not supported: {python_type}") 148 + 149 + field_type = self._python_type_to_field_type(python_type) 150 + 151 + return FieldDef( 152 + name=name, 153 + field_type=field_type, 154 + optional=is_optional, 155 + ) 156 + 157 + def _python_type_to_field_type(self, python_type) -> FieldType: 158 + """Map a Python type to a FieldType.""" 159 + # Handle primitives 160 + if python_type is str: 161 + return FieldType(kind="primitive", primitive="str") 162 + elif python_type is int: 163 + return FieldType(kind="primitive", primitive="int") 164 + elif python_type is float: 165 + return FieldType(kind="primitive", primitive="float") 166 + elif python_type is bool: 167 + return FieldType(kind="primitive", primitive="bool") 168 + elif python_type is bytes: 169 + return FieldType(kind="primitive", primitive="bytes") 170 + 171 + # Check for NDArray 172 + # NDArray from numpy.typing is a special generic alias 173 + type_str = str(python_type) 174 + if "NDArray" in type_str or "ndarray" in type_str.lower(): 175 + # Try to extract dtype info if available 176 + dtype = "float32" # Default 177 + args = get_args(python_type) 178 + if args: 179 + # NDArray[np.float64] or similar 180 + dtype_arg = args[-1] if args else None 181 + if dtype_arg is not None: 182 + dtype = self._numpy_dtype_to_string(dtype_arg) 183 + 184 + return FieldType(kind="ndarray", dtype=dtype, shape=None) 185 + 186 + # Check for list/array types 187 + origin = get_origin(python_type) 188 + if origin is list: 189 + args = get_args(python_type) 190 + if args: 191 + items = self._python_type_to_field_type(args[0]) 192 + return FieldType(kind="array", items=items) 193 + else: 194 + # Untyped list 195 + return FieldType(kind="array", items=FieldType(kind="primitive", primitive="str")) 196 + 197 + # Check for nested PackableSample (not yet supported) 198 + if is_dataclass(python_type): 199 + raise TypeError( 200 + f"Nested dataclass types not yet supported: {python_type.__name__}. " 201 + "Publish nested types separately and use references." 202 + ) 203 + 204 + raise TypeError(f"Unsupported type for schema field: {python_type}") 205 + 206 + def _numpy_dtype_to_string(self, dtype) -> str: 207 + """Convert a numpy dtype annotation to a string.""" 208 + dtype_str = str(dtype) 209 + # Handle common numpy dtypes 210 + dtype_map = { 211 + "float16": "float16", 212 + "float32": "float32", 213 + "float64": "float64", 214 + "int8": "int8", 215 + "int16": "int16", 216 + "int32": "int32", 217 + "int64": "int64", 218 + "uint8": "uint8", 219 + "uint16": "uint16", 220 + "uint32": "uint32", 221 + "uint64": "uint64", 222 + "bool": "bool", 223 + "complex64": "complex64", 224 + "complex128": "complex128", 225 + } 226 + 227 + for key, value in dtype_map.items(): 228 + if key in dtype_str: 229 + return value 230 + 231 + return "float32" # Default fallback 232 + 233 + 234 + class SchemaLoader: 235 + """Loads PackableSample schemas from ATProto. 236 + 237 + This class fetches schema records from ATProto and can list available 238 + schemas from a repository. 239 + 240 + Example: 241 + >>> client = AtmosphereClient() 242 + >>> client.login("handle", "password") 243 + >>> 244 + >>> loader = SchemaLoader(client) 245 + >>> schema = loader.get("at://did:plc:.../ac.foundation.dataset.sampleSchema/...") 246 + >>> print(schema["name"]) 247 + 'MySample' 248 + """ 249 + 250 + def __init__(self, client: AtmosphereClient): 251 + """Initialize the schema loader. 252 + 253 + Args: 254 + client: AtmosphereClient instance (authentication optional for reads). 255 + """ 256 + self.client = client 257 + 258 + def get(self, uri: str | AtUri) -> dict: 259 + """Fetch a schema record by AT URI. 260 + 261 + Args: 262 + uri: The AT URI of the schema record. 263 + 264 + Returns: 265 + The schema record as a dictionary. 266 + 267 + Raises: 268 + ValueError: If the record is not a schema record. 269 + atproto.exceptions.AtProtocolError: If record not found. 270 + """ 271 + record = self.client.get_record(uri) 272 + 273 + expected_type = f"{LEXICON_NAMESPACE}.sampleSchema" 274 + if record.get("$type") != expected_type: 275 + raise ValueError( 276 + f"Record at {uri} is not a schema record. " 277 + f"Expected $type='{expected_type}', got '{record.get('$type')}'" 278 + ) 279 + 280 + return record 281 + 282 + def list_all( 283 + self, 284 + repo: Optional[str] = None, 285 + limit: int = 100, 286 + ) -> list[dict]: 287 + """List schema records from a repository. 288 + 289 + Args: 290 + repo: The DID of the repository. Defaults to authenticated user. 291 + limit: Maximum number of records to return. 292 + 293 + Returns: 294 + List of schema records. 295 + """ 296 + return self.client.list_schemas(repo=repo, limit=limit)
+12 -169
src/atdata/dataset.py
··· 32 32 33 33 from pathlib import Path 34 34 import uuid 35 - import functools 36 35 37 36 import dataclasses 38 37 import types ··· 40 39 dataclass, 41 40 asdict, 42 41 ) 43 - from abc import ( 44 - ABC, 45 - abstractmethod, 46 - ) 42 + from abc import ABC 47 43 48 44 from tqdm import tqdm 49 45 import numpy as np ··· 66 62 TypeVar, 67 63 TypeAlias, 68 64 ) 69 - # from typing_inspect import get_bound, get_parameters 70 - from numpy.typing import ( 71 - NDArray, 72 - ArrayLike, 73 - ) 74 - 75 - # 76 - 77 - # import ekumen.atmosphere as eat 65 + from numpy.typing import NDArray 78 66 79 67 import msgpack 80 68 import ormsgpack ··· 97 85 ## 98 86 # Main base classes 99 87 100 - # TODO Check for best way to ensure this typevar is used as a dataclass type 101 - # DT = TypeVar( 'DT', bound = dataclass.__class__ ) 102 88 DT = TypeVar( 'DT' ) 103 89 104 90 MsgpackRawSample: TypeAlias = Dict[str, Any] 105 91 106 - # @dataclass 107 - # class ArrayBytes: 108 - # """Annotates bytes that should be interpreted as the raw contents of a 109 - # numpy NDArray""" 110 - 111 - # raw_bytes: bytes 112 - # """The raw bytes of the corresponding NDArray""" 113 - 114 - # def __init__( self, 115 - # array: Optional[ArrayLike] = None, 116 - # raw: Optional[bytes] = None, 117 - # ): 118 - # """TODO""" 119 - 120 - # if array is not None: 121 - # array = np.array( array ) 122 - # self.raw_bytes = eh.array_to_bytes( array ) 123 - 124 - # elif raw is not None: 125 - # self.raw_bytes = raw 126 - 127 - # else: 128 - # raise ValueError( 'Must provide either `array` or `raw` bytes' ) 129 - 130 - # @property 131 - # def to_numpy( self ) -> NDArray: 132 - # """Return the `raw_bytes` data as an NDArray""" 133 - # return eh.bytes_to_array( self.raw_bytes ) 134 92 135 93 def _make_packable( x ): 136 94 """Convert a value to a msgpack-compatible format. ··· 142 100 Returns: 143 101 The value in a format suitable for msgpack serialization. 144 102 """ 145 - # if isinstance( x, ArrayBytes ): 146 - # return x.raw_bytes 147 103 if isinstance( x, np.ndarray ): 148 104 return eh.array_to_bytes( x ) 149 105 return x ··· 227 183 # based on what is provided 228 184 229 185 if isinstance( var_cur_value, np.ndarray ): 230 - # we're good! 231 - pass 232 - 233 - # elif isinstance( var_cur_value, ArrayBytes ): 234 - # setattr( self, var_name, var_cur_value.to_numpy ) 186 + # Already the correct type, no conversion needed 187 + continue 235 188 236 189 elif isinstance( var_cur_value, bytes ): 237 190 # TODO This does create a constraint that serialized bytes ··· 412 365 raise AttributeError( f'No sample attribute named {name}' ) 413 366 414 367 415 - # class AnySample( BaseModel ): 416 - # """A sample that can hold anything""" 417 - # value: Any 418 - 419 - # class AnyBatch( BaseModel ): 420 - # """A batch of `AnySample`s""" 421 - # values: list[AnySample] 422 - 423 - 424 368 ST = TypeVar( 'ST', bound = PackableSample ) 425 - # BT = TypeVar( 'BT' ) 426 - 427 369 RT = TypeVar( 'RT', bound = PackableSample ) 428 370 429 - # TODO For python 3.13 430 - # BT = TypeVar( 'BT', default = None ) 431 - # IT = TypeVar( 'IT', default = Any ) 432 - 433 371 class Dataset( Generic[ST] ): 434 372 """A typed dataset built on WebDataset with lens transformations. 435 373 ··· 458 396 >>> # Transform to a different view 459 397 >>> ds_view = ds.as_type(MyDataView) 460 398 461 - TODO Expand this to show information on the `metadata_url` field 462 399 """ 463 400 464 - # sample_class: Type = get_parameters( ) 465 - # """The type of each returned sample from this `Dataset`'s iterator""" 466 - # batch_class: Type = get_bound( BT ) 467 - # """The type of a batch built from `sample_class`""" 468 - 469 401 @property 470 402 def sample_type( self ) -> Type: 471 403 """The type of each returned sample from this dataset's iterator. ··· 485 417 Returns: 486 418 ``SampleBatch[ST]`` where ``ST`` is this dataset's sample type. 487 419 """ 488 - # return self.__orig_class__.__args__[1] 489 420 return SampleBatch[self.sample_type] 490 421 491 - 492 - # _schema_registry_sample: dict[str, Type] 493 - # _schema_registry_batch: dict[str, Type | None] 494 - 495 - # 496 - 497 422 def __init__( self, url: str, 498 423 metadata_url: str | None = None, 499 424 ) -> None: ··· 513 438 514 439 self._metadata: dict[str, Any] | None = None 515 440 self.metadata_url: str | None = metadata_url 516 - """TODO""" 441 + """Optional URL to msgpack-encoded metadata for this dataset.""" 517 442 518 443 # Allow addition of automatic transformation of raw underlying data 519 444 self._output_lens: Lens | None = None ··· 540 465 ret._output_lens = lenses.transform( self.sample_type, ret.sample_type ) 541 466 return ret 542 467 543 - # @classmethod 544 - # def register( cls, uri: str, 545 - # sample_class: Type, 546 - # batch_class: Optional[Type] = None, 547 - # ): 548 - # """Register an `ekumen` schema to use a particular dataset sample class""" 549 - # cls._schema_registry_sample[uri] = sample_class 550 - # cls._schema_registry_batch[uri] = batch_class 551 - 552 - # @classmethod 553 - # def at( cls, uri: str ) -> 'Dataset': 554 - # """Create a Dataset for the `ekumen` index entry at `uri`""" 555 - # client = eat.Client() 556 - # return cls( ) 557 - 558 - # Common functionality 559 - 560 468 @property 561 469 def shard_list( self ) -> list[str]: 562 470 """List of individual dataset shards ··· 573 481 574 482 @property 575 483 def metadata( self ) -> dict[str, Any] | None: 576 - """TODO""" 484 + """Fetch and cache metadata from metadata_url. 485 + 486 + Returns: 487 + Deserialized metadata dictionary, or None if no metadata_url is set. 577 488 489 + Raises: 490 + requests.HTTPError: If metadata fetch fails. 491 + """ 578 492 if self.metadata_url is None: 579 493 return None 580 494 ··· 603 517 """ 604 518 605 519 if batch_size is None: 606 - # TODO Duplication here 607 520 return wds.pipeline.DataPipeline( 608 521 wds.shardlists.SimpleShardList( self.url ), 609 522 wds.shardlists.split_by_worker, 610 - # 611 523 wds.tariterators.tarfile_to_samples(), 612 - # wds.map( self.preprocess ), 613 524 wds.filters.map( self.wrap ), 614 525 ) 615 526 616 527 return wds.pipeline.DataPipeline( 617 528 wds.shardlists.SimpleShardList( self.url ), 618 529 wds.shardlists.split_by_worker, 619 - # 620 530 wds.tariterators.tarfile_to_samples(), 621 - # wds.map( self.preprocess ), 622 531 wds.filters.batched( batch_size ), 623 532 wds.filters.map( self.wrap_batch ), 624 533 ) ··· 646 555 ``SampleBatch[ST]`` instances; otherwise yields individual ``ST`` 647 556 samples. 648 557 """ 649 - 650 558 if batch_size is None: 651 - # TODO Duplication here 652 559 return wds.pipeline.DataPipeline( 653 560 wds.shardlists.SimpleShardList( self.url ), 654 561 wds.filters.shuffle( buffer_shards ), 655 562 wds.shardlists.split_by_worker, 656 - # 657 563 wds.tariterators.tarfile_to_samples(), 658 - # wds.shuffle( buffer_samples ), 659 - # wds.map( self.preprocess ), 660 564 wds.filters.shuffle( buffer_samples ), 661 565 wds.filters.map( self.wrap ), 662 566 ) ··· 665 569 wds.shardlists.SimpleShardList( self.url ), 666 570 wds.filters.shuffle( buffer_shards ), 667 571 wds.shardlists.split_by_worker, 668 - # 669 572 wds.tariterators.tarfile_to_samples(), 670 - # wds.shuffle( buffer_samples ), 671 - # wds.map( self.preprocess ), 672 573 wds.filters.shuffle( buffer_samples ), 673 574 wds.filters.batched( batch_size ), 674 575 wds.filters.map( self.wrap_batch ), ··· 731 632 df = pd.DataFrame( cur_buffer ) 732 633 df.to_parquet( cur_path, **kwargs ) 733 634 734 - 735 - # Implemented by specific subclasses 736 - 737 - # @property 738 - # @abstractmethod 739 - # def url( self ) -> str: 740 - # """str: Brace-notation URL of the underlying full WebDataset""" 741 - # pass 742 - 743 - # @classmethod 744 - # # TODO replace Any with IT 745 - # def preprocess( cls, sample: WDSRawSample ) -> Any: 746 - # """Pre-built preprocessor for a raw `sample` from the given dataset""" 747 - # return sample 748 - 749 - # @classmethod 750 - # TODO replace Any with IT 751 635 def wrap( self, sample: MsgpackRawSample ) -> ST: 752 636 """Wrap a raw msgpack sample into the appropriate dataset-specific type. 753 637 ··· 767 651 768 652 source_sample = self._output_lens.source_type.from_bytes( sample['msgpack'] ) 769 653 return self._output_lens( source_sample ) 770 - 771 - # try: 772 - # assert type( sample ) == dict 773 - # return cls.sample_class( **{ 774 - # k: v 775 - # for k, v in sample.items() if k != '__key__' 776 - # } ) 777 - 778 - # except Exception as e: 779 - # # Sample constructor failed -- revert to default 780 - # return AnySample( 781 - # value = sample, 782 - # ) 783 654 784 655 def wrap_batch( self, batch: WDSRawBatch ) -> SampleBatch[ST]: 785 656 """Wrap a batch of raw msgpack samples into a typed SampleBatch. ··· 810 681 for s in batch_source ] 811 682 return SampleBatch[self.sample_type]( batch_view ) 812 683 813 - # # @classmethod 814 - # def wrap_batch( self, batch: WDSRawBatch ) -> BT: 815 - # """Wrap a `batch` of samples into the appropriate dataset-specific type 816 - 817 - # This default implementation simply creates a list one sample at a time 818 - # """ 819 - # assert cls.batch_class is not None, 'No batch class specified' 820 - # return cls.batch_class( **batch ) 821 - 822 - 823 - ## 824 - # Shortcut decorators 825 - 826 - # def packable( cls ): 827 - # """TODO""" 828 - 829 - # def decorator( cls ): 830 - # # Create a new class dynamically 831 - # # The new class inherits from the new_parent_class first, then the original cls 832 - # new_bases = (PackableSample,) + cls.__bases__ 833 - # new_cls = type(cls.__name__, new_bases, dict(cls.__dict__)) 834 - 835 - # # Optionally, update __module__ and __qualname__ for better introspection 836 - # new_cls.__module__ = cls.__module__ 837 - # new_cls.__qualname__ = cls.__qualname__ 838 - 839 - # return new_cls 840 - # return decorator 841 684 842 685 def packable( cls ): 843 686 """Decorator to convert a regular class into a ``PackableSample``.
+2 -55
src/atdata/lens.py
··· 201 201 """ 202 202 return self._getter( s ) 203 203 204 - # TODO Figure out how to properly parameterize this 205 - # def _lens_factory[S, V]( register: bool = True ): 206 - # """Register the annotated function `f` as the getter of a sample lens""" 207 - 208 - # # The actual lens decorator taking a lens getter function to a lens object 209 - # def _decorator( f: LensGetter[S, V] ) -> Lens[S, V]: 210 - # ret = Lens[S, V]( f ) 211 - # if register: 212 - # _network.register( ret ) 213 - # return ret 214 - 215 - # # Return the lens decorator 216 - # return _decorator 217 - 218 - # # For convenience 219 - # lens = _lens_factory 220 204 221 205 def lens( f: LensGetter[S, V] ) -> Lens[S, V]: 222 206 """Decorator to create and register a lens transformation. ··· 245 229 _network.register( ret ) 246 230 return ret 247 231 248 - 249 - ## 250 - # Global registry of used lenses 251 - 252 - # _registered_lenses: Dict[LensSignature, Lens] = dict() 253 - # """TODO""" 254 232 255 233 class LensNetwork: 256 234 """Global registry for lens transformations between sample types. ··· 292 270 If a lens already exists for the same type pair, it will be 293 271 overwritten. 294 272 """ 295 - 296 - # sig = inspect.signature( _lens.get ) 297 - # input_types = list( sig.parameters.values() ) 298 - # assert len( input_types ) == 1, \ 299 - # 'Wrong number of input args for lens: should only have one' 300 - 301 - # input_type = input_types[0].annotation 302 - # print( input_type ) 303 - # output_type = sig.return_annotation 304 - 305 - # self._registry[input_type, output_type] = _lens 306 - # print( _lens.source_type ) 307 273 self._registry[_lens.source_type, _lens.view_type] = _lens 308 274 309 275 def transform( self, source: DatasetType, view: DatasetType ) -> Lens: ··· 323 289 Currently only supports direct transformations. Compositional 324 290 transformations (chaining multiple lenses) are not yet implemented. 325 291 """ 326 - 327 - # TODO Handle compositional closure 328 292 ret = self._registry.get( (source, view), None ) 329 293 if ret is None: 330 294 raise ValueError( f'No registered lens from source {source} to view {view}' ) ··· 332 296 return ret 333 297 334 298 335 - # Create global singleton registry instance 336 - _network = LensNetwork() 337 - 338 - # def lens( f: LensPutter ) -> Lens: 339 - # """Register the annotated function `f` as a sample lens""" 340 - # ## 341 - 342 - # sig = inspect.signature( f ) 343 - 344 - # input_types = list( sig.parameters.values() ) 345 - # output_type = sig.return_annotation 346 - 347 - # _registered_lenses[] 348 - 349 - # f.lens = Lens( 350 - 351 - # ) 352 - 353 - # return f 299 + # Global singleton registry instance 300 + _network = LensNetwork()
+5 -22
src/atdata/local.py
··· 25 25 from pathlib import Path 26 26 from uuid import uuid4 27 27 from tempfile import TemporaryDirectory 28 - import shutil 29 - import subprocess 30 28 from dotenv import dotenv_values 31 29 import msgpack 32 30 33 - # from redis_om import ( 34 - # EmbeddedJsonModel, 35 - # JsonModel, 36 - # Field, 37 - # Migrator, 38 - # get_redis_connection, 39 - # ) 40 - from redis import ( 41 - Redis, 42 - ) 31 + from redis import Redis 43 32 44 33 from s3fs import ( 45 34 S3FileSystem, ··· 340 329 local_cache_path = Path( temp_dir ) / p 341 330 342 331 # Copy to S3 using boto3 client (avoids s3fs async issues) 343 - print( 'Copying file to s3 ...', end = '' ) 344 - # Parse bucket and key from path (format: bucket/path/to/file.tar) 345 332 path_parts = Path( p ).parts 346 333 bucket = path_parts[0] 347 334 key = str( Path( *path_parts[1:] ) ) 348 335 349 336 with open( local_cache_path, 'rb' ) as f_in: 350 337 s3_client.put_object( Bucket=bucket, Key=key, Body=f_in.read() ) 351 - print( ' done.' ) 352 338 353 339 # Delete local cache file 354 - print( 'Deleting local cache file ...', end = '' ) 355 - os.remove( local_cache_path ) 356 - print( ' done.' ) 340 + local_cache_path.unlink() 357 341 358 342 written_shards.append( p ) 359 343 writer_post = _writer_post ··· 363 347 writer_post = lambda s: written_shards.append( s ) 364 348 365 349 written_shards = [] 366 - with wds.writer.ShardWriter( shard_pattern, 367 - # opener = lambda s: hive_fs.open( s, 'wb' ), 368 - # post = lambda s: written_shards.append( s ), 350 + with wds.writer.ShardWriter( 351 + shard_pattern, 369 352 opener = writer_opener, 370 353 post = writer_post, 371 - **kwargs 354 + **kwargs, 372 355 ) as sink: 373 356 for sample in ds.ordered( batch_size = None ): 374 357 sink.write( sample.as_wds )
-15
tests/conftest.py
··· 1 1 """Pytest configuration for atdata tests.""" 2 - 3 - import warnings 4 - import pytest 5 - 6 - 7 - @pytest.hookimpl(tryfirst=True) 8 - def pytest_configure(config): 9 - """Configure pytest to suppress known warnings from test infrastructure. 10 - 11 - Suppresses RuntimeWarnings from s3fs/moto async incompatibility that occur 12 - during test cleanup and coverage instrumentation. These are expected when 13 - mocking S3 operations and don't indicate real issues. 14 - """ 15 - warnings.simplefilter("ignore", RuntimeWarning) 16 - warnings.simplefilter("ignore", pytest.PytestUnraisableExceptionWarning)
+1224
tests/test_atmosphere.py
··· 1 + """Tests for the atdata.atmosphere module. 2 + 3 + This module contains comprehensive tests for ATProto integration including: 4 + - Type definitions (_types.py) 5 + - Client wrapper (client.py) 6 + - Schema publishing/loading (schema.py) 7 + - Dataset publishing/loading (records.py) 8 + - Lens publishing/loading (lens.py) 9 + """ 10 + 11 + from datetime import datetime, timezone 12 + from typing import Optional 13 + from unittest.mock import Mock, MagicMock, patch 14 + import pytest 15 + 16 + import numpy as np 17 + from numpy.typing import NDArray 18 + 19 + import atdata 20 + from atdata.atmosphere import ( 21 + AtmosphereClient, 22 + SchemaPublisher, 23 + SchemaLoader, 24 + DatasetPublisher, 25 + DatasetLoader, 26 + LensPublisher, 27 + LensLoader, 28 + AtUri, 29 + SchemaRecord, 30 + DatasetRecord, 31 + LensRecord, 32 + ) 33 + from atdata.atmosphere._types import ( 34 + FieldType, 35 + FieldDef, 36 + StorageLocation, 37 + CodeReference, 38 + LEXICON_NAMESPACE, 39 + ) 40 + 41 + 42 + # ============================================================================= 43 + # Test Fixtures 44 + # ============================================================================= 45 + 46 + @pytest.fixture 47 + def mock_atproto_client(): 48 + """Create a mock atproto SDK client.""" 49 + mock = Mock() 50 + mock.me = MagicMock() 51 + mock.me.did = "did:plc:test123456789" 52 + mock.me.handle = "test.bsky.social" 53 + 54 + # Mock login 55 + mock_profile = Mock() 56 + mock_profile.did = "did:plc:test123456789" 57 + mock_profile.handle = "test.bsky.social" 58 + mock.login.return_value = mock_profile 59 + 60 + # Mock export_session_string 61 + mock.export_session_string.return_value = "test-session-string" 62 + 63 + return mock 64 + 65 + 66 + @pytest.fixture 67 + def authenticated_client(mock_atproto_client): 68 + """Create an authenticated AtmosphereClient with mocked backend.""" 69 + client = AtmosphereClient(_client=mock_atproto_client) 70 + client.login("test.bsky.social", "test-password") 71 + return client 72 + 73 + 74 + @atdata.packable 75 + class BasicSample: 76 + """Simple sample type for testing.""" 77 + name: str 78 + value: int 79 + 80 + 81 + @atdata.packable 82 + class NumpySample: 83 + """Sample type with NDArray field.""" 84 + data: NDArray 85 + label: str 86 + 87 + 88 + @atdata.packable 89 + class OptionalSample: 90 + """Sample type with optional fields.""" 91 + required_field: str 92 + optional_field: Optional[int] 93 + optional_array: Optional[NDArray] 94 + 95 + 96 + @atdata.packable 97 + class AllTypesSample: 98 + """Sample type with all primitive types.""" 99 + str_field: str 100 + int_field: int 101 + float_field: float 102 + bool_field: bool 103 + bytes_field: bytes 104 + 105 + 106 + # ============================================================================= 107 + # Tests for _types.py - AtUri 108 + # ============================================================================= 109 + 110 + class TestAtUri: 111 + """Tests for AtUri parsing and formatting.""" 112 + 113 + def test_parse_valid_uri_with_did(self): 114 + """Parse a valid AT URI with a DID authority.""" 115 + uri = AtUri.parse("at://did:plc:abc123/com.example.record/key456") 116 + 117 + assert uri.authority == "did:plc:abc123" 118 + assert uri.collection == "com.example.record" 119 + assert uri.rkey == "key456" 120 + 121 + def test_parse_valid_uri_with_handle(self): 122 + """Parse a valid AT URI with a handle authority.""" 123 + uri = AtUri.parse("at://alice.bsky.social/app.bsky.feed.post/abc123") 124 + 125 + assert uri.authority == "alice.bsky.social" 126 + assert uri.collection == "app.bsky.feed.post" 127 + assert uri.rkey == "abc123" 128 + 129 + def test_parse_uri_with_slashes_in_rkey(self): 130 + """Parse a URI where rkey contains slashes.""" 131 + uri = AtUri.parse("at://did:plc:abc/collection/path/to/key") 132 + 133 + assert uri.authority == "did:plc:abc" 134 + assert uri.collection == "collection" 135 + assert uri.rkey == "path/to/key" 136 + 137 + def test_parse_invalid_uri_no_protocol(self): 138 + """Reject URIs without at:// protocol.""" 139 + with pytest.raises(ValueError, match="must start with 'at://'"): 140 + AtUri.parse("https://example.com/path") 141 + 142 + def test_parse_invalid_uri_missing_parts(self): 143 + """Reject URIs with missing components.""" 144 + with pytest.raises(ValueError, match="expected authority/collection/rkey"): 145 + AtUri.parse("at://did:plc:abc/collection") 146 + 147 + def test_str_roundtrip(self): 148 + """Verify __str__ produces valid URI that can be re-parsed.""" 149 + original = "at://did:plc:test123/ac.foundation.dataset.sampleSchema/xyz789" 150 + uri = AtUri.parse(original) 151 + assert str(uri) == original 152 + 153 + def test_parse_atdata_namespace(self): 154 + """Parse URIs in the atdata namespace.""" 155 + uri = AtUri.parse(f"at://did:plc:abc/{LEXICON_NAMESPACE}.sampleSchema/test") 156 + 157 + assert uri.collection == f"{LEXICON_NAMESPACE}.sampleSchema" 158 + 159 + 160 + # ============================================================================= 161 + # Tests for _types.py - FieldType 162 + # ============================================================================= 163 + 164 + class TestFieldType: 165 + """Tests for FieldType dataclass.""" 166 + 167 + def test_primitive_type(self): 168 + """Create a primitive field type.""" 169 + ft = FieldType(kind="primitive", primitive="str") 170 + 171 + assert ft.kind == "primitive" 172 + assert ft.primitive == "str" 173 + assert ft.dtype is None 174 + assert ft.shape is None 175 + 176 + def test_ndarray_type(self): 177 + """Create an ndarray field type.""" 178 + ft = FieldType(kind="ndarray", dtype="float32", shape=[224, 224, 3]) 179 + 180 + assert ft.kind == "ndarray" 181 + assert ft.dtype == "float32" 182 + assert ft.shape == [224, 224, 3] 183 + 184 + def test_ref_type(self): 185 + """Create a reference field type.""" 186 + ft = FieldType(kind="ref", ref="at://did:plc:abc/collection/key") 187 + 188 + assert ft.kind == "ref" 189 + assert ft.ref == "at://did:plc:abc/collection/key" 190 + 191 + def test_array_type(self): 192 + """Create an array field type with items.""" 193 + items = FieldType(kind="primitive", primitive="str") 194 + ft = FieldType(kind="array", items=items) 195 + 196 + assert ft.kind == "array" 197 + assert ft.items is not None 198 + assert ft.items.kind == "primitive" 199 + 200 + 201 + # ============================================================================= 202 + # Tests for _types.py - FieldDef 203 + # ============================================================================= 204 + 205 + class TestFieldDef: 206 + """Tests for FieldDef dataclass.""" 207 + 208 + def test_required_field(self): 209 + """Create a required field definition.""" 210 + fd = FieldDef( 211 + name="test_field", 212 + field_type=FieldType(kind="primitive", primitive="str"), 213 + optional=False, 214 + ) 215 + 216 + assert fd.name == "test_field" 217 + assert fd.optional is False 218 + 219 + def test_optional_field(self): 220 + """Create an optional field definition.""" 221 + fd = FieldDef( 222 + name="optional_field", 223 + field_type=FieldType(kind="primitive", primitive="int"), 224 + optional=True, 225 + ) 226 + 227 + assert fd.optional is True 228 + 229 + def test_field_with_description(self): 230 + """Create a field with description.""" 231 + fd = FieldDef( 232 + name="described_field", 233 + field_type=FieldType(kind="primitive", primitive="float"), 234 + optional=False, 235 + description="A field with a description", 236 + ) 237 + 238 + assert fd.description == "A field with a description" 239 + 240 + 241 + # ============================================================================= 242 + # Tests for _types.py - SchemaRecord 243 + # ============================================================================= 244 + 245 + class TestSchemaRecord: 246 + """Tests for SchemaRecord dataclass and to_record().""" 247 + 248 + def test_to_record_basic(self): 249 + """Convert a basic schema record to dict.""" 250 + schema = SchemaRecord( 251 + name="TestSchema", 252 + version="1.0.0", 253 + fields=[ 254 + FieldDef( 255 + name="field1", 256 + field_type=FieldType(kind="primitive", primitive="str"), 257 + optional=False, 258 + ), 259 + ], 260 + ) 261 + 262 + record = schema.to_record() 263 + 264 + assert record["$type"] == f"{LEXICON_NAMESPACE}.sampleSchema" 265 + assert record["name"] == "TestSchema" 266 + assert record["version"] == "1.0.0" 267 + assert len(record["fields"]) == 1 268 + assert "createdAt" in record 269 + 270 + def test_to_record_with_description(self): 271 + """Convert schema record with description.""" 272 + schema = SchemaRecord( 273 + name="DescribedSchema", 274 + version="2.0.0", 275 + description="A schema with description", 276 + fields=[], 277 + ) 278 + 279 + record = schema.to_record() 280 + 281 + assert record["description"] == "A schema with description" 282 + 283 + def test_to_record_with_metadata(self): 284 + """Convert schema record with metadata.""" 285 + schema = SchemaRecord( 286 + name="MetaSchema", 287 + version="1.0.0", 288 + fields=[], 289 + metadata={"author": "test", "tags": ["demo"]}, 290 + ) 291 + 292 + record = schema.to_record() 293 + 294 + assert record["metadata"] == {"author": "test", "tags": ["demo"]} 295 + 296 + def test_to_record_field_types(self): 297 + """Verify field type serialization in to_record().""" 298 + schema = SchemaRecord( 299 + name="TypesSchema", 300 + version="1.0.0", 301 + fields=[ 302 + FieldDef( 303 + name="primitive_field", 304 + field_type=FieldType(kind="primitive", primitive="int"), 305 + optional=False, 306 + ), 307 + FieldDef( 308 + name="array_field", 309 + field_type=FieldType(kind="ndarray", dtype="float32"), 310 + optional=True, 311 + ), 312 + ], 313 + ) 314 + 315 + record = schema.to_record() 316 + 317 + # Check primitive field 318 + prim_field = record["fields"][0] 319 + assert prim_field["name"] == "primitive_field" 320 + assert prim_field["fieldType"]["$type"] == f"{LEXICON_NAMESPACE}.schemaType#primitive" 321 + assert prim_field["fieldType"]["primitive"] == "int" 322 + assert prim_field["optional"] is False 323 + 324 + # Check ndarray field 325 + arr_field = record["fields"][1] 326 + assert arr_field["name"] == "array_field" 327 + assert arr_field["fieldType"]["$type"] == f"{LEXICON_NAMESPACE}.schemaType#ndarray" 328 + assert arr_field["fieldType"]["dtype"] == "float32" 329 + assert arr_field["optional"] is True 330 + 331 + 332 + # ============================================================================= 333 + # Tests for _types.py - StorageLocation 334 + # ============================================================================= 335 + 336 + class TestStorageLocation: 337 + """Tests for StorageLocation dataclass.""" 338 + 339 + def test_external_storage(self): 340 + """Create external URL storage location.""" 341 + storage = StorageLocation( 342 + kind="external", 343 + urls=["s3://bucket/data-{000000..000009}.tar"], 344 + ) 345 + 346 + assert storage.kind == "external" 347 + assert storage.urls == ["s3://bucket/data-{000000..000009}.tar"] 348 + assert storage.blob_refs is None 349 + 350 + def test_blob_storage(self): 351 + """Create ATProto blob storage location.""" 352 + storage = StorageLocation( 353 + kind="blobs", 354 + blob_refs=[{"cid": "bafyabc", "mimeType": "application/octet-stream"}], 355 + ) 356 + 357 + assert storage.kind == "blobs" 358 + assert storage.blob_refs is not None 359 + assert len(storage.blob_refs) == 1 360 + 361 + 362 + # ============================================================================= 363 + # Tests for _types.py - DatasetRecord 364 + # ============================================================================= 365 + 366 + class TestDatasetRecord: 367 + """Tests for DatasetRecord dataclass and to_record().""" 368 + 369 + def test_to_record_external_storage(self): 370 + """Convert dataset record with external storage.""" 371 + dataset = DatasetRecord( 372 + name="TestDataset", 373 + schema_ref="at://did:plc:abc/ac.foundation.dataset.sampleSchema/xyz", 374 + storage=StorageLocation( 375 + kind="external", 376 + urls=["s3://bucket/data.tar"], 377 + ), 378 + ) 379 + 380 + record = dataset.to_record() 381 + 382 + assert record["$type"] == f"{LEXICON_NAMESPACE}.record" 383 + assert record["name"] == "TestDataset" 384 + assert record["schemaRef"] == "at://did:plc:abc/ac.foundation.dataset.sampleSchema/xyz" 385 + assert record["storage"]["$type"] == f"{LEXICON_NAMESPACE}.storageExternal" 386 + assert record["storage"]["urls"] == ["s3://bucket/data.tar"] 387 + 388 + def test_to_record_blob_storage(self): 389 + """Convert dataset record with blob storage.""" 390 + dataset = DatasetRecord( 391 + name="BlobDataset", 392 + schema_ref="at://did:plc:abc/collection/key", 393 + storage=StorageLocation( 394 + kind="blobs", 395 + blob_refs=[{"cid": "bafytest"}], 396 + ), 397 + ) 398 + 399 + record = dataset.to_record() 400 + 401 + assert record["storage"]["$type"] == f"{LEXICON_NAMESPACE}.storageBlobs" 402 + assert record["storage"]["blobs"] == [{"cid": "bafytest"}] 403 + 404 + def test_to_record_with_tags_and_license(self): 405 + """Convert dataset record with tags and license.""" 406 + dataset = DatasetRecord( 407 + name="TaggedDataset", 408 + schema_ref="at://did:plc:abc/collection/key", 409 + storage=StorageLocation(kind="external", urls=[]), 410 + tags=["ml", "vision", "demo"], 411 + license="MIT", 412 + ) 413 + 414 + record = dataset.to_record() 415 + 416 + assert record["tags"] == ["ml", "vision", "demo"] 417 + assert record["license"] == "MIT" 418 + 419 + def test_to_record_with_metadata(self): 420 + """Convert dataset record with msgpack metadata.""" 421 + import msgpack 422 + 423 + metadata_bytes = msgpack.packb({"size": 1000, "split": "train"}) 424 + dataset = DatasetRecord( 425 + name="MetaDataset", 426 + schema_ref="at://did:plc:abc/collection/key", 427 + storage=StorageLocation(kind="external", urls=[]), 428 + metadata=metadata_bytes, 429 + ) 430 + 431 + record = dataset.to_record() 432 + 433 + assert record["metadata"] == metadata_bytes 434 + 435 + 436 + # ============================================================================= 437 + # Tests for _types.py - LensRecord 438 + # ============================================================================= 439 + 440 + class TestLensRecord: 441 + """Tests for LensRecord dataclass and to_record().""" 442 + 443 + def test_to_record_basic(self): 444 + """Convert basic lens record.""" 445 + lens = LensRecord( 446 + name="TestLens", 447 + source_schema="at://did:plc:abc/collection/source", 448 + target_schema="at://did:plc:abc/collection/target", 449 + ) 450 + 451 + record = lens.to_record() 452 + 453 + assert record["$type"] == f"{LEXICON_NAMESPACE}.lens" 454 + assert record["name"] == "TestLens" 455 + assert record["sourceSchema"] == "at://did:plc:abc/collection/source" 456 + assert record["targetSchema"] == "at://did:plc:abc/collection/target" 457 + assert "createdAt" in record 458 + 459 + def test_to_record_with_description(self): 460 + """Convert lens record with description.""" 461 + lens = LensRecord( 462 + name="DescribedLens", 463 + source_schema="at://a", 464 + target_schema="at://b", 465 + description="Transforms A to B", 466 + ) 467 + 468 + record = lens.to_record() 469 + 470 + assert record["description"] == "Transforms A to B" 471 + 472 + def test_to_record_with_code_references(self): 473 + """Convert lens record with code references.""" 474 + lens = LensRecord( 475 + name="CodeLens", 476 + source_schema="at://a", 477 + target_schema="at://b", 478 + getter_code=CodeReference( 479 + repository="https://github.com/user/repo", 480 + commit="abc123def456", 481 + path="module.lenses:getter_func", 482 + ), 483 + putter_code=CodeReference( 484 + repository="https://github.com/user/repo", 485 + commit="abc123def456", 486 + path="module.lenses:putter_func", 487 + ), 488 + ) 489 + 490 + record = lens.to_record() 491 + 492 + assert record["getterCode"]["repository"] == "https://github.com/user/repo" 493 + assert record["getterCode"]["commit"] == "abc123def456" 494 + assert record["getterCode"]["path"] == "module.lenses:getter_func" 495 + assert record["putterCode"]["path"] == "module.lenses:putter_func" 496 + 497 + 498 + # ============================================================================= 499 + # Tests for client.py - AtmosphereClient 500 + # ============================================================================= 501 + 502 + class TestAtmosphereClient: 503 + """Tests for AtmosphereClient.""" 504 + 505 + def test_init_default(self): 506 + """Initialize client with defaults.""" 507 + with patch("atdata.atmosphere.client._get_atproto_client_class") as mock_get: 508 + mock_class = Mock() 509 + mock_get.return_value = mock_class 510 + 511 + client = AtmosphereClient() 512 + 513 + mock_class.assert_called_once() 514 + assert not client.is_authenticated 515 + 516 + def test_init_with_base_url(self): 517 + """Initialize client with custom base URL.""" 518 + with patch("atdata.atmosphere.client._get_atproto_client_class") as mock_get: 519 + mock_class = Mock() 520 + mock_get.return_value = mock_class 521 + 522 + client = AtmosphereClient(base_url="https://custom.pds.example") 523 + 524 + mock_class.assert_called_once_with(base_url="https://custom.pds.example") 525 + 526 + def test_init_with_mock_client(self, mock_atproto_client): 527 + """Initialize with pre-configured mock client.""" 528 + client = AtmosphereClient(_client=mock_atproto_client) 529 + 530 + assert client._client is mock_atproto_client 531 + 532 + def test_login_success(self, mock_atproto_client): 533 + """Successful login sets session.""" 534 + client = AtmosphereClient(_client=mock_atproto_client) 535 + 536 + client.login("test.bsky.social", "password123") 537 + 538 + assert client.is_authenticated 539 + assert client.did == "did:plc:test123456789" 540 + assert client.handle == "test.bsky.social" 541 + mock_atproto_client.login.assert_called_once_with("test.bsky.social", "password123") 542 + 543 + def test_login_with_session(self, mock_atproto_client): 544 + """Login with exported session string.""" 545 + client = AtmosphereClient(_client=mock_atproto_client) 546 + 547 + client.login_with_session("test-session-string") 548 + 549 + assert client.is_authenticated 550 + mock_atproto_client.login.assert_called_once_with(session_string="test-session-string") 551 + 552 + def test_export_session(self, authenticated_client, mock_atproto_client): 553 + """Export session string.""" 554 + session = authenticated_client.export_session() 555 + 556 + assert session == "test-session-string" 557 + mock_atproto_client.export_session_string.assert_called_once() 558 + 559 + def test_export_session_not_authenticated(self, mock_atproto_client): 560 + """Export session raises when not authenticated.""" 561 + client = AtmosphereClient(_client=mock_atproto_client) 562 + 563 + with pytest.raises(ValueError, match="Not authenticated"): 564 + client.export_session() 565 + 566 + def test_did_not_authenticated(self, mock_atproto_client): 567 + """Accessing did raises when not authenticated.""" 568 + client = AtmosphereClient(_client=mock_atproto_client) 569 + 570 + with pytest.raises(ValueError, match="Not authenticated"): 571 + _ = client.did 572 + 573 + def test_handle_not_authenticated(self, mock_atproto_client): 574 + """Accessing handle raises when not authenticated.""" 575 + client = AtmosphereClient(_client=mock_atproto_client) 576 + 577 + with pytest.raises(ValueError, match="Not authenticated"): 578 + _ = client.handle 579 + 580 + def test_create_record(self, authenticated_client, mock_atproto_client): 581 + """Create a record via the client.""" 582 + mock_response = Mock() 583 + mock_response.uri = "at://did:plc:test123456789/collection/newkey" 584 + mock_atproto_client.com.atproto.repo.create_record.return_value = mock_response 585 + 586 + uri = authenticated_client.create_record( 587 + collection="collection", 588 + record={"$type": "collection", "data": "test"}, 589 + ) 590 + 591 + assert isinstance(uri, AtUri) 592 + assert uri.authority == "did:plc:test123456789" 593 + assert uri.collection == "collection" 594 + assert uri.rkey == "newkey" 595 + 596 + def test_create_record_not_authenticated(self, mock_atproto_client): 597 + """Create record raises when not authenticated.""" 598 + client = AtmosphereClient(_client=mock_atproto_client) 599 + 600 + with pytest.raises(ValueError, match="must be authenticated"): 601 + client.create_record(collection="test", record={}) 602 + 603 + def test_put_record(self, authenticated_client, mock_atproto_client): 604 + """Put (create or update) a record.""" 605 + mock_response = Mock() 606 + mock_response.uri = "at://did:plc:test123456789/collection/specific-key" 607 + mock_atproto_client.com.atproto.repo.put_record.return_value = mock_response 608 + 609 + uri = authenticated_client.put_record( 610 + collection="collection", 611 + rkey="specific-key", 612 + record={"$type": "collection", "data": "test"}, 613 + ) 614 + 615 + assert uri.rkey == "specific-key" 616 + 617 + def test_get_record(self, authenticated_client, mock_atproto_client): 618 + """Get a record by URI.""" 619 + mock_response = Mock() 620 + mock_response.value = {"$type": "test", "field": "value"} 621 + mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 622 + 623 + record = authenticated_client.get_record("at://did:plc:abc/collection/key") 624 + 625 + assert record["field"] == "value" 626 + 627 + def test_get_record_with_aturi_object(self, authenticated_client, mock_atproto_client): 628 + """Get a record using AtUri object.""" 629 + mock_response = Mock() 630 + mock_response.value = {"$type": "test", "data": 123} 631 + mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 632 + 633 + uri = AtUri(authority="did:plc:abc", collection="collection", rkey="key") 634 + record = authenticated_client.get_record(uri) 635 + 636 + assert record["data"] == 123 637 + 638 + def test_delete_record(self, authenticated_client, mock_atproto_client): 639 + """Delete a record.""" 640 + authenticated_client.delete_record("at://did:plc:test123456789/collection/key") 641 + 642 + mock_atproto_client.com.atproto.repo.delete_record.assert_called_once() 643 + 644 + def test_list_records(self, authenticated_client, mock_atproto_client): 645 + """List records in a collection.""" 646 + mock_record1 = Mock() 647 + mock_record1.value = {"name": "record1"} 648 + mock_record2 = Mock() 649 + mock_record2.value = {"name": "record2"} 650 + 651 + mock_response = Mock() 652 + mock_response.records = [mock_record1, mock_record2] 653 + mock_response.cursor = "next-page" 654 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_response 655 + 656 + records, cursor = authenticated_client.list_records("collection", limit=10) 657 + 658 + assert len(records) == 2 659 + assert records[0]["name"] == "record1" 660 + assert cursor == "next-page" 661 + 662 + def test_list_schemas_convenience(self, authenticated_client, mock_atproto_client): 663 + """Test list_schemas convenience method.""" 664 + mock_response = Mock() 665 + mock_response.records = [] 666 + mock_response.cursor = None 667 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_response 668 + 669 + schemas = authenticated_client.list_schemas() 670 + 671 + call_args = mock_atproto_client.com.atproto.repo.list_records.call_args 672 + assert f"{LEXICON_NAMESPACE}.sampleSchema" in str(call_args) 673 + 674 + 675 + # ============================================================================= 676 + # Tests for schema.py - SchemaPublisher 677 + # ============================================================================= 678 + 679 + class TestSchemaPublisher: 680 + """Tests for SchemaPublisher.""" 681 + 682 + def test_publish_basic_sample(self, authenticated_client, mock_atproto_client): 683 + """Publish a basic sample type schema.""" 684 + mock_response = Mock() 685 + mock_response.uri = f"at://did:plc:test123456789/{LEXICON_NAMESPACE}.sampleSchema/abc" 686 + mock_atproto_client.com.atproto.repo.create_record.return_value = mock_response 687 + 688 + publisher = SchemaPublisher(authenticated_client) 689 + uri = publisher.publish(BasicSample, version="1.0.0") 690 + 691 + assert isinstance(uri, AtUri) 692 + assert uri.collection == f"{LEXICON_NAMESPACE}.sampleSchema" 693 + 694 + # Verify the record structure 695 + call_args = mock_atproto_client.com.atproto.repo.create_record.call_args 696 + record = call_args.kwargs["data"]["record"] 697 + assert record["name"] == "BasicSample" 698 + assert record["version"] == "1.0.0" 699 + assert len(record["fields"]) == 2 700 + 701 + def test_publish_with_custom_name(self, authenticated_client, mock_atproto_client): 702 + """Publish with custom name override.""" 703 + mock_response = Mock() 704 + mock_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/abc" 705 + mock_atproto_client.com.atproto.repo.create_record.return_value = mock_response 706 + 707 + publisher = SchemaPublisher(authenticated_client) 708 + publisher.publish(BasicSample, name="CustomName", version="2.0.0") 709 + 710 + call_args = mock_atproto_client.com.atproto.repo.create_record.call_args 711 + record = call_args.kwargs["data"]["record"] 712 + assert record["name"] == "CustomName" 713 + 714 + def test_publish_numpy_sample(self, authenticated_client, mock_atproto_client): 715 + """Publish sample type with NDArray field.""" 716 + mock_response = Mock() 717 + mock_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/abc" 718 + mock_atproto_client.com.atproto.repo.create_record.return_value = mock_response 719 + 720 + publisher = SchemaPublisher(authenticated_client) 721 + publisher.publish(NumpySample, version="1.0.0") 722 + 723 + call_args = mock_atproto_client.com.atproto.repo.create_record.call_args 724 + record = call_args.kwargs["data"]["record"] 725 + 726 + # Find the data field 727 + data_field = next(f for f in record["fields"] if f["name"] == "data") 728 + assert "ndarray" in data_field["fieldType"]["$type"] 729 + 730 + def test_publish_optional_fields(self, authenticated_client, mock_atproto_client): 731 + """Publish sample type with optional fields.""" 732 + mock_response = Mock() 733 + mock_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/abc" 734 + mock_atproto_client.com.atproto.repo.create_record.return_value = mock_response 735 + 736 + publisher = SchemaPublisher(authenticated_client) 737 + publisher.publish(OptionalSample, version="1.0.0") 738 + 739 + call_args = mock_atproto_client.com.atproto.repo.create_record.call_args 740 + record = call_args.kwargs["data"]["record"] 741 + 742 + # Check optional field marking 743 + required = next(f for f in record["fields"] if f["name"] == "required_field") 744 + optional = next(f for f in record["fields"] if f["name"] == "optional_field") 745 + 746 + assert required["optional"] is False 747 + assert optional["optional"] is True 748 + 749 + def test_publish_all_primitive_types(self, authenticated_client, mock_atproto_client): 750 + """Publish sample with all primitive types.""" 751 + mock_response = Mock() 752 + mock_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/abc" 753 + mock_atproto_client.com.atproto.repo.create_record.return_value = mock_response 754 + 755 + publisher = SchemaPublisher(authenticated_client) 756 + publisher.publish(AllTypesSample, version="1.0.0") 757 + 758 + call_args = mock_atproto_client.com.atproto.repo.create_record.call_args 759 + record = call_args.kwargs["data"]["record"] 760 + 761 + # Verify each primitive type 762 + type_map = {f["name"]: f["fieldType"]["primitive"] for f in record["fields"]} 763 + assert type_map["str_field"] == "str" 764 + assert type_map["int_field"] == "int" 765 + assert type_map["float_field"] == "float" 766 + assert type_map["bool_field"] == "bool" 767 + assert type_map["bytes_field"] == "bytes" 768 + 769 + def test_publish_not_dataclass_error(self, authenticated_client): 770 + """Publishing non-dataclass raises error.""" 771 + publisher = SchemaPublisher(authenticated_client) 772 + 773 + class NotADataclass: 774 + pass 775 + 776 + with pytest.raises(ValueError, match="must be a dataclass"): 777 + publisher.publish(NotADataclass, version="1.0.0") 778 + 779 + 780 + class TestSchemaLoader: 781 + """Tests for SchemaLoader.""" 782 + 783 + def test_get_schema(self, authenticated_client, mock_atproto_client): 784 + """Get a schema by URI.""" 785 + mock_response = Mock() 786 + mock_response.value = { 787 + "$type": f"{LEXICON_NAMESPACE}.sampleSchema", 788 + "name": "TestSchema", 789 + "version": "1.0.0", 790 + "fields": [], 791 + } 792 + mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 793 + 794 + loader = SchemaLoader(authenticated_client) 795 + schema = loader.get(f"at://did:plc:abc/{LEXICON_NAMESPACE}.sampleSchema/xyz") 796 + 797 + assert schema["name"] == "TestSchema" 798 + 799 + def test_get_schema_wrong_type(self, authenticated_client, mock_atproto_client): 800 + """Get raises error for wrong record type.""" 801 + mock_response = Mock() 802 + mock_response.value = { 803 + "$type": "app.bsky.feed.post", 804 + "text": "Not a schema", 805 + } 806 + mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 807 + 808 + loader = SchemaLoader(authenticated_client) 809 + 810 + with pytest.raises(ValueError, match="not a schema record"): 811 + loader.get("at://did:plc:abc/app.bsky.feed.post/xyz") 812 + 813 + def test_list_all_schemas(self, authenticated_client, mock_atproto_client): 814 + """List all schemas.""" 815 + mock_record = Mock() 816 + mock_record.value = {"name": "Schema1"} 817 + 818 + mock_response = Mock() 819 + mock_response.records = [mock_record] 820 + mock_response.cursor = None 821 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_response 822 + 823 + loader = SchemaLoader(authenticated_client) 824 + schemas = loader.list_all() 825 + 826 + assert len(schemas) == 1 827 + assert schemas[0]["name"] == "Schema1" 828 + 829 + 830 + # ============================================================================= 831 + # Tests for records.py - DatasetPublisher 832 + # ============================================================================= 833 + 834 + class TestDatasetPublisher: 835 + """Tests for DatasetPublisher.""" 836 + 837 + def test_publish_with_urls(self, authenticated_client, mock_atproto_client): 838 + """Publish dataset with explicit URLs.""" 839 + mock_response = Mock() 840 + mock_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.record/abc" 841 + mock_atproto_client.com.atproto.repo.create_record.return_value = mock_response 842 + 843 + publisher = DatasetPublisher(authenticated_client) 844 + uri = publisher.publish_with_urls( 845 + urls=["s3://bucket/data-{000000..000009}.tar"], 846 + schema_uri="at://did:plc:abc/schema/xyz", 847 + name="TestDataset", 848 + description="A test dataset", 849 + tags=["test", "demo"], 850 + license="MIT", 851 + ) 852 + 853 + assert isinstance(uri, AtUri) 854 + 855 + call_args = mock_atproto_client.com.atproto.repo.create_record.call_args 856 + record = call_args.kwargs["data"]["record"] 857 + assert record["name"] == "TestDataset" 858 + assert record["schemaRef"] == "at://did:plc:abc/schema/xyz" 859 + assert record["tags"] == ["test", "demo"] 860 + assert record["license"] == "MIT" 861 + 862 + def test_publish_auto_schema(self, authenticated_client, mock_atproto_client): 863 + """Publish dataset with auto schema publishing.""" 864 + # Mock for schema creation 865 + schema_response = Mock() 866 + schema_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/schema123" 867 + 868 + # Mock for dataset creation 869 + dataset_response = Mock() 870 + dataset_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.record/dataset456" 871 + 872 + mock_atproto_client.com.atproto.repo.create_record.side_effect = [ 873 + schema_response, 874 + dataset_response, 875 + ] 876 + 877 + # Create a mock dataset 878 + mock_dataset = Mock() 879 + mock_dataset.url = "s3://bucket/data.tar" 880 + mock_dataset.sample_type = BasicSample 881 + mock_dataset.metadata = None 882 + 883 + publisher = DatasetPublisher(authenticated_client) 884 + uri = publisher.publish( 885 + mock_dataset, 886 + name="AutoSchemaDataset", 887 + auto_publish_schema=True, 888 + ) 889 + 890 + # Should have called create_record twice (schema + dataset) 891 + assert mock_atproto_client.com.atproto.repo.create_record.call_count == 2 892 + 893 + def test_publish_explicit_schema_uri(self, authenticated_client, mock_atproto_client): 894 + """Publish dataset with explicit schema URI (no auto publish).""" 895 + mock_response = Mock() 896 + mock_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.record/abc" 897 + mock_atproto_client.com.atproto.repo.create_record.return_value = mock_response 898 + 899 + mock_dataset = Mock() 900 + mock_dataset.url = "s3://bucket/data.tar" 901 + mock_dataset.metadata = None 902 + 903 + publisher = DatasetPublisher(authenticated_client) 904 + publisher.publish( 905 + mock_dataset, 906 + name="ExplicitSchemaDataset", 907 + schema_uri="at://did:plc:existing/schema/xyz", 908 + auto_publish_schema=False, 909 + ) 910 + 911 + # Should have called create_record only once (dataset only) 912 + assert mock_atproto_client.com.atproto.repo.create_record.call_count == 1 913 + 914 + def test_publish_no_schema_error(self, authenticated_client): 915 + """Publish without schema_uri and auto_publish_schema=False raises.""" 916 + mock_dataset = Mock() 917 + mock_dataset.url = "s3://bucket/data.tar" 918 + 919 + publisher = DatasetPublisher(authenticated_client) 920 + 921 + with pytest.raises(ValueError, match="schema_uri is required"): 922 + publisher.publish( 923 + mock_dataset, 924 + name="NoSchemaDataset", 925 + auto_publish_schema=False, 926 + ) 927 + 928 + 929 + class TestDatasetLoader: 930 + """Tests for DatasetLoader.""" 931 + 932 + def test_get_dataset(self, authenticated_client, mock_atproto_client): 933 + """Get a dataset record.""" 934 + mock_response = Mock() 935 + mock_response.value = { 936 + "$type": f"{LEXICON_NAMESPACE}.record", 937 + "name": "TestDataset", 938 + "schemaRef": "at://schema", 939 + "storage": { 940 + "$type": f"{LEXICON_NAMESPACE}.storageExternal", 941 + "urls": ["s3://bucket/data.tar"], 942 + }, 943 + } 944 + mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 945 + 946 + loader = DatasetLoader(authenticated_client) 947 + record = loader.get(f"at://did:plc:abc/{LEXICON_NAMESPACE}.record/xyz") 948 + 949 + assert record["name"] == "TestDataset" 950 + 951 + def test_get_dataset_wrong_type(self, authenticated_client, mock_atproto_client): 952 + """Get raises error for wrong record type.""" 953 + mock_response = Mock() 954 + mock_response.value = { 955 + "$type": f"{LEXICON_NAMESPACE}.sampleSchema", 956 + "name": "NotADataset", 957 + } 958 + mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 959 + 960 + loader = DatasetLoader(authenticated_client) 961 + 962 + with pytest.raises(ValueError, match="not a dataset record"): 963 + loader.get("at://did:plc:abc/collection/xyz") 964 + 965 + def test_get_urls(self, authenticated_client, mock_atproto_client): 966 + """Get WebDataset URLs from a dataset record.""" 967 + mock_response = Mock() 968 + mock_response.value = { 969 + "$type": f"{LEXICON_NAMESPACE}.record", 970 + "name": "TestDataset", 971 + "schemaRef": "at://schema", 972 + "storage": { 973 + "$type": f"{LEXICON_NAMESPACE}.storageExternal", 974 + "urls": ["s3://bucket/data-{000000..000009}.tar", "s3://bucket/extra.tar"], 975 + }, 976 + } 977 + mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 978 + 979 + loader = DatasetLoader(authenticated_client) 980 + urls = loader.get_urls(f"at://did:plc:abc/{LEXICON_NAMESPACE}.record/xyz") 981 + 982 + assert len(urls) == 2 983 + assert "data-{000000..000009}.tar" in urls[0] 984 + 985 + def test_get_urls_blob_storage_error(self, authenticated_client, mock_atproto_client): 986 + """Get URLs raises for blob storage datasets.""" 987 + mock_response = Mock() 988 + mock_response.value = { 989 + "$type": f"{LEXICON_NAMESPACE}.record", 990 + "name": "BlobDataset", 991 + "schemaRef": "at://schema", 992 + "storage": { 993 + "$type": f"{LEXICON_NAMESPACE}.storageBlobs", 994 + "blobs": [{"cid": "bafytest"}], 995 + }, 996 + } 997 + mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 998 + 999 + loader = DatasetLoader(authenticated_client) 1000 + 1001 + with pytest.raises(ValueError, match="blob storage"): 1002 + loader.get_urls(f"at://did:plc:abc/{LEXICON_NAMESPACE}.record/xyz") 1003 + 1004 + def test_get_metadata(self, authenticated_client, mock_atproto_client): 1005 + """Get metadata from dataset record.""" 1006 + import msgpack 1007 + 1008 + metadata_bytes = msgpack.packb({"split": "train", "samples": 10000}) 1009 + 1010 + mock_response = Mock() 1011 + mock_response.value = { 1012 + "$type": f"{LEXICON_NAMESPACE}.record", 1013 + "name": "MetaDataset", 1014 + "schemaRef": "at://schema", 1015 + "storage": {"$type": f"{LEXICON_NAMESPACE}.storageExternal", "urls": []}, 1016 + "metadata": metadata_bytes, 1017 + } 1018 + mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 1019 + 1020 + loader = DatasetLoader(authenticated_client) 1021 + metadata = loader.get_metadata(f"at://did:plc:abc/{LEXICON_NAMESPACE}.record/xyz") 1022 + 1023 + assert metadata["split"] == "train" 1024 + assert metadata["samples"] == 10000 1025 + 1026 + def test_get_metadata_none(self, authenticated_client, mock_atproto_client): 1027 + """Get metadata returns None when not present.""" 1028 + mock_response = Mock() 1029 + mock_response.value = { 1030 + "$type": f"{LEXICON_NAMESPACE}.record", 1031 + "name": "NoMetaDataset", 1032 + "schemaRef": "at://schema", 1033 + "storage": {"$type": f"{LEXICON_NAMESPACE}.storageExternal", "urls": []}, 1034 + } 1035 + mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 1036 + 1037 + loader = DatasetLoader(authenticated_client) 1038 + metadata = loader.get_metadata(f"at://did:plc:abc/{LEXICON_NAMESPACE}.record/xyz") 1039 + 1040 + assert metadata is None 1041 + 1042 + def test_list_all(self, authenticated_client, mock_atproto_client): 1043 + """List all datasets.""" 1044 + mock_record = Mock() 1045 + mock_record.value = {"name": "Dataset1"} 1046 + 1047 + mock_response = Mock() 1048 + mock_response.records = [mock_record] 1049 + mock_response.cursor = None 1050 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_response 1051 + 1052 + loader = DatasetLoader(authenticated_client) 1053 + datasets = loader.list_all() 1054 + 1055 + assert len(datasets) == 1 1056 + 1057 + 1058 + # ============================================================================= 1059 + # Tests for lens.py - LensPublisher 1060 + # ============================================================================= 1061 + 1062 + class TestLensPublisher: 1063 + """Tests for LensPublisher.""" 1064 + 1065 + def test_publish_with_code_refs(self, authenticated_client, mock_atproto_client): 1066 + """Publish lens with code references.""" 1067 + mock_response = Mock() 1068 + mock_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.lens/abc" 1069 + mock_atproto_client.com.atproto.repo.create_record.return_value = mock_response 1070 + 1071 + publisher = LensPublisher(authenticated_client) 1072 + uri = publisher.publish( 1073 + name="TestLens", 1074 + source_schema_uri="at://did:plc:abc/schema/source", 1075 + target_schema_uri="at://did:plc:abc/schema/target", 1076 + description="Transforms source to target", 1077 + code_repository="https://github.com/user/repo", 1078 + code_commit="abc123def456", 1079 + getter_path="module.lenses:my_getter", 1080 + putter_path="module.lenses:my_putter", 1081 + ) 1082 + 1083 + assert isinstance(uri, AtUri) 1084 + 1085 + call_args = mock_atproto_client.com.atproto.repo.create_record.call_args 1086 + record = call_args.kwargs["data"]["record"] 1087 + assert record["name"] == "TestLens" 1088 + assert record["sourceSchema"] == "at://did:plc:abc/schema/source" 1089 + assert record["targetSchema"] == "at://did:plc:abc/schema/target" 1090 + assert record["getterCode"]["repository"] == "https://github.com/user/repo" 1091 + 1092 + def test_publish_without_code_refs(self, authenticated_client, mock_atproto_client): 1093 + """Publish lens without code references.""" 1094 + mock_response = Mock() 1095 + mock_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.lens/abc" 1096 + mock_atproto_client.com.atproto.repo.create_record.return_value = mock_response 1097 + 1098 + publisher = LensPublisher(authenticated_client) 1099 + uri = publisher.publish( 1100 + name="MetadataOnlyLens", 1101 + source_schema_uri="at://source", 1102 + target_schema_uri="at://target", 1103 + ) 1104 + 1105 + call_args = mock_atproto_client.com.atproto.repo.create_record.call_args 1106 + record = call_args.kwargs["data"]["record"] 1107 + assert "getterCode" not in record 1108 + assert "putterCode" not in record 1109 + 1110 + def test_publish_from_lens_object(self, authenticated_client, mock_atproto_client): 1111 + """Publish lens from an atdata Lens object.""" 1112 + mock_response = Mock() 1113 + mock_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.lens/abc" 1114 + mock_atproto_client.com.atproto.repo.create_record.return_value = mock_response 1115 + 1116 + # Create a real lens 1117 + @atdata.lens 1118 + def test_lens(source: BasicSample) -> NumpySample: 1119 + return NumpySample( 1120 + data=np.array([source.value]), 1121 + label=source.name, 1122 + ) 1123 + 1124 + publisher = LensPublisher(authenticated_client) 1125 + uri = publisher.publish_from_lens( 1126 + test_lens, 1127 + name="FromObjectLens", 1128 + source_schema_uri="at://source", 1129 + target_schema_uri="at://target", 1130 + code_repository="https://github.com/user/repo", 1131 + code_commit="abc123", 1132 + ) 1133 + 1134 + call_args = mock_atproto_client.com.atproto.repo.create_record.call_args 1135 + record = call_args.kwargs["data"]["record"] 1136 + assert "test_lens" in record["getterCode"]["path"] 1137 + 1138 + 1139 + class TestLensLoader: 1140 + """Tests for LensLoader.""" 1141 + 1142 + def test_get_lens(self, authenticated_client, mock_atproto_client): 1143 + """Get a lens record.""" 1144 + mock_response = Mock() 1145 + mock_response.value = { 1146 + "$type": f"{LEXICON_NAMESPACE}.lens", 1147 + "name": "TestLens", 1148 + "sourceSchema": "at://source", 1149 + "targetSchema": "at://target", 1150 + } 1151 + mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 1152 + 1153 + loader = LensLoader(authenticated_client) 1154 + record = loader.get(f"at://did:plc:abc/{LEXICON_NAMESPACE}.lens/xyz") 1155 + 1156 + assert record["name"] == "TestLens" 1157 + 1158 + def test_get_lens_wrong_type(self, authenticated_client, mock_atproto_client): 1159 + """Get raises error for wrong record type.""" 1160 + mock_response = Mock() 1161 + mock_response.value = { 1162 + "$type": f"{LEXICON_NAMESPACE}.record", 1163 + "name": "NotALens", 1164 + } 1165 + mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 1166 + 1167 + loader = LensLoader(authenticated_client) 1168 + 1169 + with pytest.raises(ValueError, match="not a lens record"): 1170 + loader.get("at://did:plc:abc/collection/xyz") 1171 + 1172 + def test_list_all(self, authenticated_client, mock_atproto_client): 1173 + """List all lens records.""" 1174 + mock_record = Mock() 1175 + mock_record.value = {"name": "Lens1"} 1176 + 1177 + mock_response = Mock() 1178 + mock_response.records = [mock_record] 1179 + mock_response.cursor = None 1180 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_response 1181 + 1182 + loader = LensLoader(authenticated_client) 1183 + lenses = loader.list_all() 1184 + 1185 + assert len(lenses) == 1 1186 + 1187 + def test_find_by_schemas_source_only(self, authenticated_client, mock_atproto_client): 1188 + """Find lenses by source schema only.""" 1189 + mock_records = [ 1190 + Mock(value={"sourceSchema": "at://schema/a", "targetSchema": "at://schema/b"}), 1191 + Mock(value={"sourceSchema": "at://schema/a", "targetSchema": "at://schema/c"}), 1192 + Mock(value={"sourceSchema": "at://schema/x", "targetSchema": "at://schema/y"}), 1193 + ] 1194 + 1195 + mock_response = Mock() 1196 + mock_response.records = mock_records 1197 + mock_response.cursor = None 1198 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_response 1199 + 1200 + loader = LensLoader(authenticated_client) 1201 + matches = loader.find_by_schemas(source_schema_uri="at://schema/a") 1202 + 1203 + assert len(matches) == 2 1204 + 1205 + def test_find_by_schemas_both(self, authenticated_client, mock_atproto_client): 1206 + """Find lenses by both source and target schema.""" 1207 + mock_records = [ 1208 + Mock(value={"sourceSchema": "at://schema/a", "targetSchema": "at://schema/b"}), 1209 + Mock(value={"sourceSchema": "at://schema/a", "targetSchema": "at://schema/c"}), 1210 + ] 1211 + 1212 + mock_response = Mock() 1213 + mock_response.records = mock_records 1214 + mock_response.cursor = None 1215 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_response 1216 + 1217 + loader = LensLoader(authenticated_client) 1218 + matches = loader.find_by_schemas( 1219 + source_schema_uri="at://schema/a", 1220 + target_schema_uri="at://schema/b", 1221 + ) 1222 + 1223 + assert len(matches) == 1 1224 + assert matches[0]["targetSchema"] == "at://schema/b"
+2 -36
tests/test_dataset.py
··· 148 148 assert cur_assertion, \ 149 149 f'Did not properly incorporate property {k} of test type {SampleType}' 150 150 151 - # 152 - 153 - # def test_decorator_syntax(): 154 - # """Test use of decorator syntax for sample types""" 155 - 156 - # @atdata.packable 157 - # class BasicTestSampleDecorated: 158 - # name: str 159 - # position: int 160 - # value: float 161 - 162 - # @atdata.packable 163 - # class NumpyTestSampleDecorated: 164 - # label: int 165 - # image: NDArray 166 - 167 - # ## 168 - 169 - # test_create_sample( BasicTestSampleDecorated, { 170 - # 'name': 'Hello, world!', 171 - # 'position': 42, 172 - # 'value': 1024.768, 173 - # } ) 174 - 175 - # test_create_sample( NumpyTestSampleDecorated, { 176 - # 'label': 9_001, 177 - # 'image': np.random.randn( 1024, 1024 ), 178 - # } ) 179 - 180 - # 181 151 182 152 @pytest.mark.parametrize( 183 153 ('SampleType', 'sample_data', 'sample_wds_stem'), ··· 301 271 break 302 272 303 273 assert iterations_run == n_iterate, \ 304 - "Only found {iterations_run} samples, not {n_iterate}" 274 + f"Only found {iterations_run} samples, not {n_iterate}" 305 275 306 276 307 277 ## Shuffled ··· 353 323 break 354 324 355 325 assert iterations_run == n_iterate, \ 356 - "Only found {iterations_run} samples, not {n_iterate}" 326 + f"Only found {iterations_run} samples, not {n_iterate}" 357 327 358 328 # 359 329 ··· 401 371 402 372 parquet_filename = tmp_path / f'{sample_wds_stem}-segments.parquet' 403 373 dataset.to_parquet( parquet_filename, maxcount = n_per_file ) 404 - 405 - ## Double-check our `parquet` export 406 - 407 - # TODO 408 374 409 375 410 376 ##
+94
tests/test_helpers.py
··· 1 + """Tests for atdata._helpers module.""" 2 + 3 + import numpy as np 4 + import pytest 5 + 6 + from atdata._helpers import array_to_bytes, bytes_to_array 7 + 8 + 9 + class TestArraySerialization: 10 + """Test array_to_bytes and bytes_to_array round-trip serialization.""" 11 + 12 + @pytest.mark.parametrize("dtype", [ 13 + np.float32, 14 + np.float64, 15 + np.int32, 16 + np.int64, 17 + np.uint8, 18 + np.bool_, 19 + np.complex64, 20 + ]) 21 + def test_dtype_preservation(self, dtype): 22 + """Verify dtype is preserved through serialization.""" 23 + original = np.array([1, 2, 3], dtype=dtype) 24 + serialized = array_to_bytes(original) 25 + restored = bytes_to_array(serialized) 26 + 27 + assert restored.dtype == original.dtype 28 + np.testing.assert_array_equal(restored, original) 29 + 30 + @pytest.mark.parametrize("shape", [ 31 + (10,), 32 + (3, 4), 33 + (2, 3, 4), 34 + (1, 1, 1, 1), 35 + ]) 36 + def test_shape_preservation(self, shape): 37 + """Verify shape is preserved through serialization.""" 38 + original = np.random.rand(*shape).astype(np.float32) 39 + serialized = array_to_bytes(original) 40 + restored = bytes_to_array(serialized) 41 + 42 + assert restored.shape == original.shape 43 + np.testing.assert_array_almost_equal(restored, original) 44 + 45 + def test_empty_array(self): 46 + """Verify empty arrays serialize correctly.""" 47 + original = np.array([], dtype=np.float32) 48 + serialized = array_to_bytes(original) 49 + restored = bytes_to_array(serialized) 50 + 51 + assert restored.shape == (0,) 52 + assert restored.dtype == np.float32 53 + 54 + def test_scalar_array(self): 55 + """Verify 0-dimensional arrays serialize correctly.""" 56 + original = np.array(42.0) 57 + serialized = array_to_bytes(original) 58 + restored = bytes_to_array(serialized) 59 + 60 + assert restored.shape == () 61 + assert restored == 42.0 62 + 63 + def test_large_array(self): 64 + """Verify large arrays serialize correctly.""" 65 + original = np.random.rand(100, 100).astype(np.float32) 66 + serialized = array_to_bytes(original) 67 + restored = bytes_to_array(serialized) 68 + 69 + np.testing.assert_array_almost_equal(restored, original) 70 + 71 + def test_contiguous_and_noncontiguous(self): 72 + """Verify non-contiguous arrays serialize correctly.""" 73 + original = np.random.rand(10, 10).astype(np.float32) 74 + non_contiguous = original[::2, ::2] # Strided view 75 + 76 + assert not non_contiguous.flags['C_CONTIGUOUS'] 77 + 78 + serialized = array_to_bytes(non_contiguous) 79 + restored = bytes_to_array(serialized) 80 + 81 + np.testing.assert_array_almost_equal(restored, non_contiguous) 82 + 83 + def test_bytes_output_type(self): 84 + """Verify array_to_bytes returns bytes.""" 85 + arr = np.array([1, 2, 3]) 86 + result = array_to_bytes(arr) 87 + assert isinstance(result, bytes) 88 + 89 + def test_ndarray_output_type(self): 90 + """Verify bytes_to_array returns ndarray.""" 91 + arr = np.array([1, 2, 3]) 92 + serialized = array_to_bytes(arr) 93 + result = bytes_to_array(serialized) 94 + assert isinstance(result, np.ndarray)
+10 -7
tests/test_lens.py
··· 78 78 y = polite.put( polite( test_source ), test_source ) 79 79 assert y == test_source, \ 80 80 f'Violation of PutGet: {y} =/= {test_source}' 81 - 82 - # TODO Test PutPut 81 + 82 + # PutPut law: put(v2, put(v1, s)) = put(v2, s) 83 + another_view = View( 84 + name = 'Different Name', 85 + height = 165.0, 86 + ) 87 + z1 = polite.put( another_view, polite.put( update_view, test_source ) ) 88 + z2 = polite.put( another_view, test_source ) 89 + assert z1 == z2, \ 90 + f'Violation of PutPut: {z1} =/= {z2}' 83 91 84 92 def test_conversion( tmp_path ): 85 93 """Test automatic interconversion between sample types""" ··· 104 112 favorite_pizza = s.favorite_pizza, 105 113 favorite_image = s.favorite_image, 106 114 ) 107 - 108 - lens_network = atdata.LensNetwork() 109 - print( lens_network._registry ) 110 115 111 116 # Map a test sample through the view 112 117 test_source = Source( ··· 156 161 157 162 assert sample.name == test_view.name, \ 158 163 f'Divergence on auto-mapped dataset: `name` should be {test_view.name}, but is {sample.name}' 159 - # assert sample.height == test_view.height, \ 160 - # f'Divergence on auto-mapped dataset: `height` should be {test_view.height}, but is {sample.height}' 161 164 assert sample.favorite_pizza == test_view.favorite_pizza, \ 162 165 f'Divergence on auto-mapped dataset: `favorite_pizza` should be {test_view.favorite_pizza}, but is {sample.favorite_pizza}' 163 166 assert np.all( sample.favorite_image == test_view.favorite_image ), \
+2 -13
tests/test_local.py
··· 3 3 ## 4 4 # Imports 5 5 6 - # Tests 7 6 import pytest 8 - import warnings 9 - 10 - # Suppress s3fs/moto async incompatibility warnings early 11 - # These occur during test cleanup and coverage instrumentation 12 - # Use simplefilter to ensure it applies to all contexts including pytest internals 13 - warnings.simplefilter("ignore", category=RuntimeWarning) 14 - warnings.simplefilter("ignore", category=pytest.PytestUnraisableExceptionWarning) 15 7 16 8 # System 17 9 from dataclasses import dataclass ··· 68 60 """Provide a mock S3 environment using moto. 69 61 70 62 Note: Tests using this fixture may generate warnings due to s3fs/moto async 71 - incompatibility. These are expected and suppressed via warnings.filterwarnings. 63 + incompatibility. These are suppressed via @pytest.mark.filterwarnings on 64 + individual tests that use this fixture. 72 65 """ 73 - # Suppress s3fs/moto async incompatibility warnings 74 - warnings.filterwarnings("ignore", message="coroutine.*was never awaited") 75 - warnings.filterwarnings("ignore", category=pytest.PytestUnraisableExceptionWarning) 76 - 77 66 with mock_aws(): 78 67 # Create S3 credentials dict (no endpoint_url for moto) 79 68 creds = {