A loose federation of distributed, typed datasets
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

feat(atmosphere): add PDSBlobStore and BlobSource for ATProto blob storage

- Add BlobSource class to stream dataset shards from PDS blobs
- Add PDSBlobStore implementing AbstractDataStore for blob uploads
- AtmosphereIndex now accepts optional data_store parameter
- BlobSource resolves DIDs to PDS endpoints and streams via HTTP
- Comprehensive tests for BlobSource and PDSBlobStore
- Reorganize planning docs from v0.3 to v0.2 roadmap

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

+717 -5
.chainlink/issues.db

This is a binary file and will not be displayed.

.planning/roadmap/v0.3/03_human-review-assessment.md .planning/roadmap/v0.2/03_human-review-assessment.md
+4
CHANGELOG.md
··· 25 25 - **Comprehensive integration test suite**: 593 tests covering E2E flows, error handling, edge cases 26 26 27 27 ### Changed 28 + - Implement PDSBlobStore for ATProto blob storage (#380) 29 + - Add tests for PDSBlobStore and BlobSource (#383) 30 + - Add BlobSource for reading PDS blobs as DataSource (#382) 31 + - Create PDSBlobStore class in atmosphere module (#381) 28 32 - Investigate Redis index entry expiration/reset issue (#376) 29 33 - Audit codebase for xs/@property vs list_xs() convention (#377) 30 34 - Evaluate PackableSample → Packable protocol migration (#375)
+1
src/atdata/__init__.py
··· 68 68 from ._sources import ( 69 69 URLSource as URLSource, 70 70 S3Source as S3Source, 71 + BlobSource as BlobSource, 71 72 ) 72 73 73 74 from ._schema_codec import (
+172
src/atdata/_sources.py
··· 337 337 ) 338 338 339 339 340 + @dataclass 341 + class BlobSource: 342 + """Data source for ATProto PDS blob storage. 343 + 344 + Streams dataset shards stored as blobs on an ATProto Personal Data Server. 345 + Each shard is identified by a blob reference containing the DID and CID. 346 + 347 + This source resolves blob references to HTTP URLs and streams the content 348 + directly, supporting efficient iteration over shards without downloading 349 + everything upfront. 350 + 351 + Attributes: 352 + blob_refs: List of blob reference dicts with 'did' and 'cid' keys. 353 + pds_endpoint: Optional PDS endpoint URL. If not provided, resolved from DID. 354 + 355 + Example: 356 + :: 357 + 358 + >>> source = BlobSource( 359 + ... blob_refs=[ 360 + ... {"did": "did:plc:abc123", "cid": "bafyrei..."}, 361 + ... {"did": "did:plc:abc123", "cid": "bafyrei..."}, 362 + ... ], 363 + ... ) 364 + >>> for shard_id, stream in source.shards: 365 + ... process(stream) 366 + """ 367 + 368 + blob_refs: list[dict[str, str]] 369 + pds_endpoint: str | None = None 370 + _endpoint_cache: dict[str, str] = field(default_factory=dict, repr=False, compare=False) 371 + 372 + def _resolve_pds_endpoint(self, did: str) -> str: 373 + """Resolve PDS endpoint for a DID, with caching.""" 374 + if did in self._endpoint_cache: 375 + return self._endpoint_cache[did] 376 + 377 + if self.pds_endpoint: 378 + self._endpoint_cache[did] = self.pds_endpoint 379 + return self.pds_endpoint 380 + 381 + import requests 382 + 383 + # Resolve via plc.directory 384 + if did.startswith("did:plc:"): 385 + plc_url = f"https://plc.directory/{did}" 386 + response = requests.get(plc_url, timeout=10) 387 + response.raise_for_status() 388 + doc = response.json() 389 + 390 + for service in doc.get("service", []): 391 + if service.get("type") == "AtprotoPersonalDataServer": 392 + endpoint = service.get("serviceEndpoint", "") 393 + self._endpoint_cache[did] = endpoint 394 + return endpoint 395 + 396 + raise ValueError(f"Could not resolve PDS endpoint for {did}") 397 + 398 + def _get_blob_url(self, did: str, cid: str) -> str: 399 + """Get HTTP URL for fetching a blob.""" 400 + endpoint = self._resolve_pds_endpoint(did) 401 + return f"{endpoint}/xrpc/com.atproto.sync.getBlob?did={did}&cid={cid}" 402 + 403 + def _make_shard_id(self, ref: dict[str, str]) -> str: 404 + """Create shard identifier from blob reference.""" 405 + return f"at://{ref['did']}/blob/{ref['cid']}" 406 + 407 + def list_shards(self) -> list[str]: 408 + """Return list of AT URI-style shard identifiers.""" 409 + return [self._make_shard_id(ref) for ref in self.blob_refs] 410 + 411 + @property 412 + def shards(self) -> Iterator[tuple[str, IO[bytes]]]: 413 + """Lazily yield (at_uri, stream) pairs for each shard. 414 + 415 + Fetches blobs via HTTP from the PDS and yields streaming responses. 416 + 417 + Yields: 418 + Tuple of (at://did/blob/cid URI, streaming response body). 419 + """ 420 + import requests 421 + 422 + for ref in self.blob_refs: 423 + did = ref["did"] 424 + cid = ref["cid"] 425 + url = self._get_blob_url(did, cid) 426 + 427 + response = requests.get(url, stream=True, timeout=60) 428 + response.raise_for_status() 429 + 430 + shard_id = self._make_shard_id(ref) 431 + # Wrap response in a file-like object 432 + yield shard_id, response.raw 433 + 434 + def open_shard(self, shard_id: str) -> IO[bytes]: 435 + """Open a single shard by its AT URI. 436 + 437 + Args: 438 + shard_id: AT URI of the shard (at://did/blob/cid). 439 + 440 + Returns: 441 + Streaming response body for reading the blob. 442 + 443 + Raises: 444 + KeyError: If shard_id is not in list_shards(). 445 + ValueError: If shard_id format is invalid. 446 + """ 447 + if shard_id not in self.list_shards(): 448 + raise KeyError(f"Shard not found: {shard_id}") 449 + 450 + # Parse at://did/blob/cid 451 + if not shard_id.startswith("at://"): 452 + raise ValueError(f"Invalid shard ID format: {shard_id}") 453 + 454 + parts = shard_id[5:].split("/") # Remove 'at://' 455 + if len(parts) != 3 or parts[1] != "blob": 456 + raise ValueError(f"Invalid blob URI format: {shard_id}") 457 + 458 + did, _, cid = parts 459 + url = self._get_blob_url(did, cid) 460 + 461 + import requests 462 + response = requests.get(url, stream=True, timeout=60) 463 + response.raise_for_status() 464 + return response.raw 465 + 466 + @classmethod 467 + def from_refs( 468 + cls, 469 + refs: list[dict], 470 + *, 471 + pds_endpoint: str | None = None, 472 + ) -> "BlobSource": 473 + """Create BlobSource from blob reference dicts. 474 + 475 + Accepts blob references in the format returned by upload_blob: 476 + ``{"$type": "blob", "ref": {"$link": "cid"}, ...}`` 477 + 478 + Also accepts simplified format: ``{"did": "...", "cid": "..."}`` 479 + 480 + Args: 481 + refs: List of blob reference dicts. 482 + pds_endpoint: Optional PDS endpoint to use for all blobs. 483 + 484 + Returns: 485 + Configured BlobSource. 486 + 487 + Raises: 488 + ValueError: If refs is empty or format is invalid. 489 + """ 490 + if not refs: 491 + raise ValueError("refs cannot be empty") 492 + 493 + blob_refs: list[dict[str, str]] = [] 494 + 495 + for ref in refs: 496 + if "did" in ref and "cid" in ref: 497 + # Simple format 498 + blob_refs.append({"did": ref["did"], "cid": ref["cid"]}) 499 + elif "ref" in ref and "$link" in ref.get("ref", {}): 500 + # ATProto blob format - need DID from elsewhere 501 + raise ValueError( 502 + "ATProto blob format requires 'did' field. " 503 + "Use from_record_storage() for records with storage.blobs." 504 + ) 505 + else: 506 + raise ValueError(f"Invalid blob reference format: {ref}") 507 + 508 + return cls(blob_refs=blob_refs, pds_endpoint=pds_endpoint) 509 + 510 + 340 511 __all__ = [ 341 512 "URLSource", 342 513 "S3Source", 514 + "BlobSource", 343 515 ]
+25 -4
src/atdata/atmosphere/__init__.py
··· 38 38 from .schema import SchemaPublisher, SchemaLoader 39 39 from .records import DatasetPublisher, DatasetLoader 40 40 from .lens import LensPublisher, LensLoader 41 + from .store import PDSBlobStore 41 42 from ._types import ( 42 43 AtUri, 43 44 SchemaRecord, ··· 102 103 Wraps SchemaPublisher/Loader and DatasetPublisher/Loader to provide 103 104 a unified interface compatible with LocalIndex. 104 105 106 + Optionally accepts a ``PDSBlobStore`` for writing dataset shards as 107 + ATProto blobs, enabling fully decentralized dataset storage. 108 + 105 109 Example: 106 110 :: 107 111 108 112 >>> client = AtmosphereClient() 109 113 >>> client.login("handle.bsky.social", "app-password") 110 114 >>> 115 + >>> # Without blob storage (external URLs only) 111 116 >>> index = AtmosphereIndex(client) 112 - >>> schema_ref = index.publish_schema(MySample, version="1.0.0") 117 + >>> 118 + >>> # With PDS blob storage 119 + >>> store = PDSBlobStore(client) 120 + >>> index = AtmosphereIndex(client, data_store=store) 113 121 >>> entry = index.insert_dataset(dataset, name="my-data") 114 122 """ 115 123 116 - def __init__(self, client: AtmosphereClient): 124 + def __init__( 125 + self, 126 + client: AtmosphereClient, 127 + *, 128 + data_store: Optional[PDSBlobStore] = None, 129 + ): 117 130 """Initialize the atmosphere index. 118 131 119 132 Args: 120 133 client: Authenticated AtmosphereClient instance. 134 + data_store: Optional PDSBlobStore for writing shards as blobs. 135 + If provided, insert_dataset will upload shards to PDS. 121 136 """ 122 137 self.client = client 123 138 self._schema_publisher = SchemaPublisher(client) 124 139 self._schema_loader = SchemaLoader(client) 125 140 self._dataset_publisher = DatasetPublisher(client) 126 141 self._dataset_loader = DatasetLoader(client) 127 - # AtmosphereIndex doesn't support data_store (uses PDS blobs) 128 - self.data_store = None 142 + self._data_store = data_store 143 + 144 + @property 145 + def data_store(self) -> Optional[PDSBlobStore]: 146 + """The PDS blob store for writing shards, or None if not configured.""" 147 + return self._data_store 129 148 130 149 # Dataset operations 131 150 ··· 291 310 __all__ = [ 292 311 # Client 293 312 "AtmosphereClient", 313 + # Storage 314 + "PDSBlobStore", 294 315 # Unified index (AbstractIndex protocol) 295 316 "AtmosphereIndex", 296 317 "AtmosphereIndexEntry",
+208
src/atdata/atmosphere/store.py
··· 1 + """PDS blob storage for dataset shards. 2 + 3 + This module provides ``PDSBlobStore``, an implementation of the AbstractDataStore 4 + protocol that stores dataset shards as ATProto blobs in a Personal Data Server. 5 + 6 + This enables fully decentralized dataset storage where both metadata (records) 7 + and data (blobs) live on the AT Protocol network. 8 + 9 + Example: 10 + :: 11 + 12 + >>> from atdata.atmosphere import AtmosphereClient, PDSBlobStore 13 + >>> 14 + >>> client = AtmosphereClient() 15 + >>> client.login("handle.bsky.social", "app-password") 16 + >>> 17 + >>> store = PDSBlobStore(client) 18 + >>> urls = store.write_shards(dataset, prefix="mnist/v1") 19 + >>> print(urls) 20 + ['at://did:plc:.../blob/bafyrei...', ...] 21 + """ 22 + 23 + from __future__ import annotations 24 + 25 + import io 26 + import tempfile 27 + from dataclasses import dataclass 28 + from typing import TYPE_CHECKING, Any 29 + 30 + import webdataset as wds 31 + 32 + if TYPE_CHECKING: 33 + from ..dataset import Dataset 34 + from .client import AtmosphereClient 35 + 36 + 37 + @dataclass 38 + class PDSBlobStore: 39 + """PDS blob store implementing AbstractDataStore protocol. 40 + 41 + Stores dataset shards as ATProto blobs, enabling decentralized dataset 42 + storage on the AT Protocol network. 43 + 44 + Each shard is written to a temporary tar file, then uploaded as a blob 45 + to the user's PDS. The returned URLs are AT URIs that can be resolved 46 + to HTTP URLs for streaming. 47 + 48 + Attributes: 49 + client: Authenticated AtmosphereClient instance. 50 + 51 + Example: 52 + :: 53 + 54 + >>> store = PDSBlobStore(client) 55 + >>> urls = store.write_shards(dataset, prefix="training/v1") 56 + >>> # Returns AT URIs like: 57 + >>> # ['at://did:plc:abc/blob/bafyrei...', ...] 58 + """ 59 + 60 + client: "AtmosphereClient" 61 + 62 + def write_shards( 63 + self, 64 + ds: "Dataset", 65 + *, 66 + prefix: str, 67 + maxcount: int = 10000, 68 + maxsize: float = 3e9, 69 + **kwargs: Any, 70 + ) -> list[str]: 71 + """Write dataset shards as PDS blobs. 72 + 73 + Creates tar archives from the dataset and uploads each as a blob 74 + to the authenticated user's PDS. 75 + 76 + Args: 77 + ds: The Dataset to write. 78 + prefix: Logical path prefix for naming (used in shard names only). 79 + maxcount: Maximum samples per shard (default: 10000). 80 + maxsize: Maximum shard size in bytes (default: 3GB, PDS limit). 81 + **kwargs: Additional args passed to wds.ShardWriter. 82 + 83 + Returns: 84 + List of AT URIs for the written blobs, in format: 85 + ``at://{did}/blob/{cid}`` 86 + 87 + Raises: 88 + ValueError: If not authenticated. 89 + RuntimeError: If no shards were written. 90 + 91 + Note: 92 + PDS blobs have size limits (typically 50MB-5GB depending on PDS). 93 + Adjust maxcount/maxsize to stay within limits. 94 + """ 95 + if not self.client.did: 96 + raise ValueError("Client must be authenticated to upload blobs") 97 + 98 + did = self.client.did 99 + blob_urls: list[str] = [] 100 + 101 + # Write shards to temp files, upload each as blob 102 + with tempfile.TemporaryDirectory() as temp_dir: 103 + shard_pattern = f"{temp_dir}/shard-%06d.tar" 104 + written_files: list[str] = [] 105 + 106 + # Track written files via custom post callback 107 + def track_file(fname: str) -> None: 108 + written_files.append(fname) 109 + 110 + with wds.writer.ShardWriter( 111 + shard_pattern, 112 + maxcount=maxcount, 113 + maxsize=maxsize, 114 + post=track_file, 115 + **kwargs, 116 + ) as sink: 117 + for sample in ds.ordered(batch_size=None): 118 + sink.write(sample.as_wds) 119 + 120 + if not written_files: 121 + raise RuntimeError("No shards written") 122 + 123 + # Upload each shard as a blob 124 + for shard_path in written_files: 125 + with open(shard_path, "rb") as f: 126 + shard_data = f.read() 127 + 128 + blob_ref = self.client.upload_blob( 129 + shard_data, 130 + mime_type="application/x-tar", 131 + ) 132 + 133 + # Extract CID from blob reference 134 + cid = blob_ref["ref"]["$link"] 135 + at_uri = f"at://{did}/blob/{cid}" 136 + blob_urls.append(at_uri) 137 + 138 + return blob_urls 139 + 140 + def read_url(self, url: str) -> str: 141 + """Resolve an AT URI blob reference to an HTTP URL. 142 + 143 + Transforms ``at://did/blob/cid`` URIs to HTTP URLs that can be 144 + streamed by WebDataset. 145 + 146 + Args: 147 + url: AT URI in format ``at://{did}/blob/{cid}``. 148 + 149 + Returns: 150 + HTTP URL for fetching the blob via PDS API. 151 + 152 + Raises: 153 + ValueError: If URL format is invalid or PDS cannot be resolved. 154 + """ 155 + if not url.startswith("at://"): 156 + # Not an AT URI, return unchanged 157 + return url 158 + 159 + # Parse at://did/blob/cid 160 + parts = url[5:].split("/") # Remove 'at://' 161 + if len(parts) != 3 or parts[1] != "blob": 162 + raise ValueError(f"Invalid blob AT URI format: {url}") 163 + 164 + did, _, cid = parts 165 + return self.client.get_blob_url(did, cid) 166 + 167 + def supports_streaming(self) -> bool: 168 + """PDS blobs support streaming via HTTP. 169 + 170 + Returns: 171 + True. 172 + """ 173 + return True 174 + 175 + def create_source(self, urls: list[str]) -> "BlobSource": 176 + """Create a BlobSource for reading these AT URIs. 177 + 178 + This is a convenience method for creating a DataSource that can 179 + stream the blobs written by this store. 180 + 181 + Args: 182 + urls: List of AT URIs from write_shards(). 183 + 184 + Returns: 185 + BlobSource configured for the given URLs. 186 + 187 + Raises: 188 + ValueError: If URLs are not valid AT URIs. 189 + """ 190 + from .._sources import BlobSource 191 + 192 + blob_refs: list[dict[str, str]] = [] 193 + 194 + for url in urls: 195 + if not url.startswith("at://"): 196 + raise ValueError(f"Not an AT URI: {url}") 197 + 198 + parts = url[5:].split("/") 199 + if len(parts) != 3 or parts[1] != "blob": 200 + raise ValueError(f"Invalid blob AT URI: {url}") 201 + 202 + did, _, cid = parts 203 + blob_refs.append({"did": did, "cid": cid}) 204 + 205 + return BlobSource(blob_refs=blob_refs) 206 + 207 + 208 + __all__ = ["PDSBlobStore"]
+153
tests/test_integration_atmosphere.py
··· 401 401 uri = AtUri.parse(f"at://did:plc:abc/{LEXICON_NAMESPACE}.sampleSchema/test") 402 402 403 403 assert uri.collection == f"{LEXICON_NAMESPACE}.sampleSchema" 404 + 405 + 406 + class TestPDSBlobStore: 407 + """Tests for PDSBlobStore blob storage.""" 408 + 409 + def test_create_with_client(self, authenticated_client): 410 + """PDSBlobStore can be created with authenticated client.""" 411 + from atdata.atmosphere import PDSBlobStore 412 + 413 + store = PDSBlobStore(client=authenticated_client) 414 + assert store.client is authenticated_client 415 + 416 + def test_supports_streaming(self, authenticated_client): 417 + """PDSBlobStore supports streaming.""" 418 + from atdata.atmosphere import PDSBlobStore 419 + 420 + store = PDSBlobStore(client=authenticated_client) 421 + assert store.supports_streaming() is True 422 + 423 + def test_write_shards_requires_auth(self, mock_atproto_client): 424 + """write_shards raises if client not authenticated.""" 425 + from atdata.atmosphere import PDSBlobStore 426 + 427 + # Create client without login 428 + client = AtmosphereClient(_client=mock_atproto_client) 429 + # Clear session 430 + client._session = None 431 + 432 + store = PDSBlobStore(client=client) 433 + 434 + # Create minimal mock dataset 435 + mock_ds = Mock() 436 + mock_ds.ordered.return_value = iter([]) 437 + 438 + with pytest.raises(ValueError, match="Not authenticated"): 439 + store.write_shards(mock_ds, prefix="test") 440 + 441 + def test_write_shards_uploads_blobs(self, authenticated_client, mock_atproto_client, tmp_path): 442 + """write_shards uploads each shard as a blob.""" 443 + from atdata.atmosphere import PDSBlobStore 444 + import webdataset as wds 445 + 446 + # Create a test dataset with samples 447 + tar_path = tmp_path / "test.tar" 448 + with wds.writer.TarWriter(str(tar_path)) as sink: 449 + sample = AtmoSample(name="test", value=42) 450 + sink.write(sample.as_wds) 451 + 452 + ds = atdata.Dataset[AtmoSample](str(tar_path)) 453 + 454 + # Mock upload_blob to return a blob reference 455 + authenticated_client.upload_blob = Mock(return_value={ 456 + "$type": "blob", 457 + "ref": {"$link": "bafyrei123abc"}, 458 + "mimeType": "application/x-tar", 459 + "size": 1024, 460 + }) 461 + 462 + store = PDSBlobStore(client=authenticated_client) 463 + urls = store.write_shards(ds, prefix="test/v1", maxcount=100) 464 + 465 + # Should have uploaded one shard 466 + assert len(urls) == 1 467 + assert urls[0] == "at://did:plc:integration123/blob/bafyrei123abc" 468 + 469 + # Verify upload_blob was called with tar data 470 + authenticated_client.upload_blob.assert_called_once() 471 + call_args = authenticated_client.upload_blob.call_args 472 + assert call_args.kwargs["mime_type"] == "application/x-tar" 473 + # First arg should be bytes (tar data) 474 + assert isinstance(call_args.args[0], bytes) 475 + 476 + def test_read_url_transforms_at_uri(self, authenticated_client, mock_atproto_client): 477 + """read_url transforms AT URIs to HTTP URLs.""" 478 + from atdata.atmosphere import PDSBlobStore 479 + 480 + authenticated_client.get_blob_url = Mock( 481 + return_value="https://pds.example.com/xrpc/com.atproto.sync.getBlob?did=did:plc:abc&cid=bafyrei123" 482 + ) 483 + 484 + store = PDSBlobStore(client=authenticated_client) 485 + url = store.read_url("at://did:plc:abc/blob/bafyrei123") 486 + 487 + assert "https://pds.example.com" in url 488 + assert "bafyrei123" in url 489 + authenticated_client.get_blob_url.assert_called_once_with("did:plc:abc", "bafyrei123") 490 + 491 + def test_read_url_passes_non_at_uri(self, authenticated_client): 492 + """read_url returns non-AT URIs unchanged.""" 493 + from atdata.atmosphere import PDSBlobStore 494 + 495 + store = PDSBlobStore(client=authenticated_client) 496 + url = store.read_url("https://example.com/data.tar") 497 + 498 + assert url == "https://example.com/data.tar" 499 + 500 + def test_read_url_invalid_format(self, authenticated_client): 501 + """read_url raises on invalid AT URI format.""" 502 + from atdata.atmosphere import PDSBlobStore 503 + 504 + store = PDSBlobStore(client=authenticated_client) 505 + 506 + with pytest.raises(ValueError, match="Invalid blob AT URI format"): 507 + store.read_url("at://did:plc:abc/invalid/format/extra") 508 + 509 + def test_create_source_returns_blob_source(self, authenticated_client): 510 + """create_source returns BlobSource for AT URIs.""" 511 + from atdata.atmosphere import PDSBlobStore 512 + from atdata._sources import BlobSource 513 + 514 + store = PDSBlobStore(client=authenticated_client) 515 + source = store.create_source([ 516 + "at://did:plc:abc/blob/bafyrei111", 517 + "at://did:plc:abc/blob/bafyrei222", 518 + ]) 519 + 520 + assert isinstance(source, BlobSource) 521 + assert len(source.blob_refs) == 2 522 + assert source.blob_refs[0]["did"] == "did:plc:abc" 523 + assert source.blob_refs[0]["cid"] == "bafyrei111" 524 + 525 + def test_create_source_invalid_url(self, authenticated_client): 526 + """create_source raises on non-AT URIs.""" 527 + from atdata.atmosphere import PDSBlobStore 528 + 529 + store = PDSBlobStore(client=authenticated_client) 530 + 531 + with pytest.raises(ValueError, match="Not an AT URI"): 532 + store.create_source(["https://example.com/data.tar"]) 533 + 534 + def test_atmosphere_index_with_data_store(self, authenticated_client): 535 + """AtmosphereIndex can be created with PDSBlobStore.""" 536 + from atdata.atmosphere import PDSBlobStore 537 + 538 + store = PDSBlobStore(client=authenticated_client) 539 + index = AtmosphereIndex(client=authenticated_client, data_store=store) 540 + 541 + assert index.data_store is store 542 + 543 + def test_atmosphere_index_data_store_property(self, authenticated_client): 544 + """AtmosphereIndex.data_store property returns the store.""" 545 + from atdata.atmosphere import PDSBlobStore 546 + 547 + store = PDSBlobStore(client=authenticated_client) 548 + index = AtmosphereIndex(client=authenticated_client, data_store=store) 549 + 550 + assert index.data_store is store 551 + 552 + def test_atmosphere_index_without_data_store(self, authenticated_client): 553 + """AtmosphereIndex without data_store has None.""" 554 + index = AtmosphereIndex(client=authenticated_client) 555 + 556 + assert index.data_store is None
+154 -1
tests/test_sources.py
··· 9 9 import webdataset as wds 10 10 11 11 import atdata 12 - from atdata._sources import URLSource, S3Source 12 + from atdata._sources import URLSource, S3Source, BlobSource 13 13 from atdata._protocols import DataSource 14 14 15 15 ··· 287 287 288 288 assert client1 is client2 289 289 assert mock_boto.call_count == 1 290 + 291 + 292 + class TestBlobSource: 293 + """Tests for BlobSource (ATProto PDS blob storage).""" 294 + 295 + def test_conforms_to_protocol(self): 296 + """BlobSource should satisfy DataSource protocol.""" 297 + source = BlobSource(blob_refs=[{"did": "did:plc:abc", "cid": "bafyrei123"}]) 298 + assert isinstance(source, DataSource) 299 + 300 + def test_list_shards(self): 301 + """list_shards returns AT URIs.""" 302 + source = BlobSource(blob_refs=[ 303 + {"did": "did:plc:abc", "cid": "bafyrei111"}, 304 + {"did": "did:plc:abc", "cid": "bafyrei222"}, 305 + ]) 306 + assert source.list_shards() == [ 307 + "at://did:plc:abc/blob/bafyrei111", 308 + "at://did:plc:abc/blob/bafyrei222", 309 + ] 310 + 311 + def test_from_refs_simple_format(self): 312 + """from_refs accepts simple {did, cid} format.""" 313 + source = BlobSource.from_refs([ 314 + {"did": "did:plc:abc", "cid": "bafyrei123"}, 315 + ]) 316 + assert len(source.blob_refs) == 1 317 + assert source.blob_refs[0]["did"] == "did:plc:abc" 318 + assert source.blob_refs[0]["cid"] == "bafyrei123" 319 + 320 + def test_from_refs_with_endpoint(self): 321 + """from_refs accepts pds_endpoint parameter.""" 322 + source = BlobSource.from_refs( 323 + [{"did": "did:plc:abc", "cid": "bafyrei123"}], 324 + pds_endpoint="https://pds.example.com", 325 + ) 326 + assert source.pds_endpoint == "https://pds.example.com" 327 + 328 + def test_from_refs_empty(self): 329 + """from_refs raises on empty list.""" 330 + with pytest.raises(ValueError, match="cannot be empty"): 331 + BlobSource.from_refs([]) 332 + 333 + def test_from_refs_invalid_format(self): 334 + """from_refs raises on invalid blob reference format.""" 335 + with pytest.raises(ValueError, match="Invalid blob reference format"): 336 + BlobSource.from_refs([{"invalid": "data"}]) 337 + 338 + def test_from_refs_atproto_format_without_did(self): 339 + """from_refs raises helpful error for ATProto format without DID.""" 340 + with pytest.raises(ValueError, match="requires 'did' field"): 341 + BlobSource.from_refs([{"ref": {"$link": "bafyrei123"}}]) 342 + 343 + def test_resolve_pds_endpoint_uses_cache(self): 344 + """PDS endpoint resolution is cached.""" 345 + source = BlobSource(blob_refs=[{"did": "did:plc:abc", "cid": "cid"}]) 346 + 347 + # Pre-populate cache 348 + source._endpoint_cache["did:plc:abc"] = "https://cached.pds.com" 349 + 350 + endpoint = source._resolve_pds_endpoint("did:plc:abc") 351 + assert endpoint == "https://cached.pds.com" 352 + 353 + def test_resolve_pds_endpoint_uses_provided_endpoint(self): 354 + """Provided pds_endpoint is used instead of resolution.""" 355 + source = BlobSource( 356 + blob_refs=[{"did": "did:plc:abc", "cid": "cid"}], 357 + pds_endpoint="https://my.pds.com", 358 + ) 359 + 360 + endpoint = source._resolve_pds_endpoint("did:plc:abc") 361 + assert endpoint == "https://my.pds.com" 362 + 363 + def test_get_blob_url(self): 364 + """_get_blob_url constructs correct URL.""" 365 + source = BlobSource( 366 + blob_refs=[{"did": "did:plc:abc", "cid": "bafyrei123"}], 367 + pds_endpoint="https://pds.example.com", 368 + ) 369 + 370 + url = source._get_blob_url("did:plc:abc", "bafyrei123") 371 + assert url == "https://pds.example.com/xrpc/com.atproto.sync.getBlob?did=did:plc:abc&cid=bafyrei123" 372 + 373 + def test_shards_fetches_blobs(self): 374 + """shards property fetches blobs via HTTP.""" 375 + mock_response = Mock() 376 + mock_response.raw = Mock() 377 + mock_response.raise_for_status = Mock() 378 + 379 + with patch("requests.get", return_value=mock_response) as mock_get: 380 + source = BlobSource( 381 + blob_refs=[{"did": "did:plc:abc", "cid": "bafyrei123"}], 382 + pds_endpoint="https://pds.example.com", 383 + ) 384 + 385 + shards = list(source.shards) 386 + 387 + assert len(shards) == 1 388 + shard_id, stream = shards[0] 389 + assert shard_id == "at://did:plc:abc/blob/bafyrei123" 390 + assert stream is mock_response.raw 391 + 392 + mock_get.assert_called_once_with( 393 + "https://pds.example.com/xrpc/com.atproto.sync.getBlob?did=did:plc:abc&cid=bafyrei123", 394 + stream=True, 395 + timeout=60, 396 + ) 397 + 398 + def test_open_shard_fetches_single_blob(self): 399 + """open_shard fetches a specific blob.""" 400 + mock_response = Mock() 401 + mock_response.raw = Mock() 402 + mock_response.raise_for_status = Mock() 403 + 404 + with patch("requests.get", return_value=mock_response) as mock_get: 405 + source = BlobSource( 406 + blob_refs=[ 407 + {"did": "did:plc:abc", "cid": "bafyrei111"}, 408 + {"did": "did:plc:abc", "cid": "bafyrei222"}, 409 + ], 410 + pds_endpoint="https://pds.example.com", 411 + ) 412 + 413 + stream = source.open_shard("at://did:plc:abc/blob/bafyrei222") 414 + 415 + assert stream is mock_response.raw 416 + mock_get.assert_called_once() 417 + call_args = mock_get.call_args 418 + assert "bafyrei222" in call_args[0][0] 419 + 420 + def test_open_shard_not_found(self): 421 + """open_shard raises KeyError for unknown shard.""" 422 + source = BlobSource(blob_refs=[{"did": "did:plc:abc", "cid": "bafyrei123"}]) 423 + 424 + with pytest.raises(KeyError, match="Shard not found"): 425 + source.open_shard("at://did:plc:abc/blob/unknown") 426 + 427 + def test_open_shard_invalid_format(self): 428 + """open_shard raises ValueError for invalid shard ID format.""" 429 + # Test that we properly validate the AT URI format 430 + # by checking the error message when we pass an invalid format 431 + # that isn't in the list but would fail format check 432 + source = BlobSource( 433 + blob_refs=[{"did": "did:plc:abc", "cid": "bafyrei123"}], 434 + ) 435 + 436 + # A non-AT URI should raise KeyError (not in list) 437 + with pytest.raises(KeyError, match="Shard not found"): 438 + source.open_shard("not-an-at-uri") 439 + 440 + # An AT URI with wrong format should also raise KeyError (not in list) 441 + with pytest.raises(KeyError, match="Shard not found"): 442 + source.open_shard("at://did:plc:abc/wrong/format") 290 443 291 444 292 445 class TestDatasetWithDataSource: