A loose federation of distributed, typed datasets
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

test: add blob storage tests and shared sample fixtures

Add comprehensive unit tests for blob operations:
- AtmosphereClient: upload_blob, get_blob, get_blob_url, _resolve_pds_endpoint
- DatasetPublisher: publish_with_blobs with metadata
- DatasetLoader: get_storage_type, get_blobs, get_blob_urls

Add shared sample type definitions in conftest.py:
- SharedBasicSample, SharedNumpySample, SharedOptionalSample, SharedAllTypesSample

Also trim verbose docstrings on internal helper functions in dataset.py.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

+378 -44
+4
CHANGELOG.md
··· 11 11 ### Fixed 12 12 13 13 ### Changed 14 + - Add shared sample type definitions to conftest.py (#219) 15 + - Add blob operation tests for DatasetLoader and DatasetPublisher (#220) 16 + - Trim verbose docstrings on internal helper functions (#222) 17 + - Remove commented debug code from dataset.py (#221) 14 18 - Review and fix tar writing in examples to use as_wds pattern (#217) 15 19 - Add blob storage demo to atmosphere_demo.py example (#216) 16 20 - Implement full blob storage support for atmosphere datasets (#211)
+5 -44
src/atdata/dataset.py
··· 89 89 90 90 91 91 def _make_packable( x ): 92 - """Convert a value to a msgpack-compatible format. 93 - 94 - Args: 95 - x: A value to convert. If it's a numpy array, converts to bytes. 96 - Otherwise returns the value unchanged. 97 - 98 - Returns: 99 - The value in a format suitable for msgpack serialization. 100 - """ 92 + """Convert numpy arrays to bytes; pass through other values unchanged.""" 101 93 if isinstance( x, np.ndarray ): 102 94 return eh.array_to_bytes( x ) 103 95 return x 104 96 105 - def _is_possibly_ndarray_type( t ): 106 - """Check if a type annotation is or contains NDArray. 107 97 108 - Args: 109 - t: A type annotation to check. 110 - 111 - Returns: 112 - ``True`` if the type is ``NDArray`` or a union containing ``NDArray`` 113 - (e.g., ``NDArray | None``), ``False`` otherwise. 114 - """ 115 - 116 - # Directly an NDArray 98 + def _is_possibly_ndarray_type( t ): 99 + """Return True if type annotation is NDArray or Optional[NDArray].""" 117 100 if t == NDArray: 118 - # print( 'is an NDArray' ) 119 101 return True 120 - 121 - # Check for Optionals (i.e., NDArray | None) 122 102 if isinstance( t, types.UnionType ): 123 - t_parts = t.__args__ 124 - if any( x == NDArray 125 - for x in t_parts ): 126 - return True 127 - 128 - # Not an NDArray 103 + return any( x == NDArray for x in t.__args__ ) 129 104 return False 130 105 131 106 @dataclass ··· 266 241 } 267 242 268 243 def _batch_aggregate( xs: Sequence ): 269 - """Aggregate a sequence of values into a batch-appropriate format. 270 - 271 - Args: 272 - xs: A sequence of values to aggregate. If the first element is a numpy 273 - array, all elements are stacked into a single array. Otherwise, 274 - returns a list. 275 - 276 - Returns: 277 - A numpy array (if elements are arrays) or a list (otherwise). 278 - """ 279 - 244 + """Stack arrays into numpy array with batch dim; otherwise return list.""" 280 245 if not xs: 281 - # Empty sequence 282 246 return [] 283 - 284 - # Aggregate 285 247 if isinstance( xs[0], np.ndarray ): 286 248 return np.array( list( xs ) ) 287 - 288 249 return list( xs ) 289 250 290 251 class SampleBatch( Generic[DT] ):
+46
tests/conftest.py
··· 2 2 3 3 import pytest 4 4 from redis import Redis 5 + from typing import Optional 5 6 7 + import numpy as np 8 + from numpy.typing import NDArray 9 + 10 + import atdata 11 + 12 + 13 + # ============================================================================= 14 + # Shared sample types for testing 15 + # ============================================================================= 16 + 17 + @atdata.packable 18 + class SharedBasicSample: 19 + """Basic sample with primitive fields for general testing.""" 20 + name: str 21 + value: int 22 + 23 + 24 + @atdata.packable 25 + class SharedNumpySample: 26 + """Sample with NDArray field for array serialization testing.""" 27 + data: NDArray 28 + label: str 29 + 30 + 31 + @atdata.packable 32 + class SharedOptionalSample: 33 + """Sample with optional fields for null handling testing.""" 34 + required: str 35 + optional_int: Optional[int] = None 36 + optional_array: Optional[NDArray] = None 37 + 38 + 39 + @atdata.packable 40 + class SharedAllTypesSample: 41 + """Sample with all supported primitive types.""" 42 + str_field: str 43 + int_field: int 44 + float_field: float 45 + bool_field: bool 46 + bytes_field: bytes 47 + 48 + 49 + # ============================================================================= 50 + # Fixtures 51 + # ============================================================================= 6 52 7 53 @pytest.fixture 8 54 def redis_connection():
+323
tests/test_atmosphere.py
··· 642 642 643 643 mock_atproto_client.com.atproto.repo.delete_record.assert_called_once() 644 644 645 + def test_upload_blob(self, authenticated_client, mock_atproto_client): 646 + """Upload blob returns proper blob reference dict.""" 647 + mock_blob_ref = Mock() 648 + mock_blob_ref.ref = Mock(link="bafkreitest123") 649 + mock_blob_ref.mime_type = "application/x-tar" 650 + mock_blob_ref.size = 1024 651 + 652 + mock_response = Mock() 653 + mock_response.blob = mock_blob_ref 654 + mock_atproto_client.upload_blob.return_value = mock_response 655 + 656 + result = authenticated_client.upload_blob(b"test data", mime_type="application/x-tar") 657 + 658 + assert result["$type"] == "blob" 659 + assert result["ref"]["$link"] == "bafkreitest123" 660 + assert result["mimeType"] == "application/x-tar" 661 + assert result["size"] == 1024 662 + 663 + def test_upload_blob_not_authenticated(self, mock_atproto_client): 664 + """Upload blob raises when not authenticated.""" 665 + client = AtmosphereClient(_client=mock_atproto_client) 666 + 667 + with pytest.raises(ValueError, match="must be authenticated"): 668 + client.upload_blob(b"data") 669 + 670 + def test_get_blob(self, authenticated_client): 671 + """Get blob fetches from resolved PDS endpoint.""" 672 + with patch("requests.get") as mock_get: 673 + mock_did_response = Mock() 674 + mock_did_response.json.return_value = { 675 + "service": [ 676 + {"type": "AtprotoPersonalDataServer", "serviceEndpoint": "https://pds.example.com"} 677 + ] 678 + } 679 + mock_did_response.raise_for_status = Mock() 680 + 681 + mock_blob_response = Mock() 682 + mock_blob_response.content = b"blob data here" 683 + mock_blob_response.raise_for_status = Mock() 684 + 685 + mock_get.side_effect = [mock_did_response, mock_blob_response] 686 + 687 + result = authenticated_client.get_blob("did:plc:abc123", "bafkreitest") 688 + 689 + assert result == b"blob data here" 690 + assert mock_get.call_count == 2 691 + 692 + def test_get_blob_pds_not_found(self, authenticated_client): 693 + """Get blob raises when PDS cannot be resolved.""" 694 + import requests as req_module 695 + with patch("requests.get") as mock_get: 696 + mock_get.side_effect = req_module.RequestException("Network error") 697 + 698 + with pytest.raises(ValueError, match="Could not resolve PDS"): 699 + authenticated_client.get_blob("did:plc:unknown", "cid123") 700 + 701 + def test_get_blob_url(self, authenticated_client): 702 + """Get blob URL constructs proper URL.""" 703 + with patch("requests.get") as mock_get: 704 + mock_response = Mock() 705 + mock_response.json.return_value = { 706 + "service": [ 707 + {"type": "AtprotoPersonalDataServer", "serviceEndpoint": "https://pds.example.com"} 708 + ] 709 + } 710 + mock_response.raise_for_status = Mock() 711 + mock_get.return_value = mock_response 712 + 713 + url = authenticated_client.get_blob_url("did:plc:abc", "bafkreitest") 714 + 715 + assert url == "https://pds.example.com/xrpc/com.atproto.sync.getBlob?did=did:plc:abc&cid=bafkreitest" 716 + 717 + def test_get_blob_url_pds_not_found(self, authenticated_client): 718 + """Get blob URL raises when PDS cannot be resolved.""" 719 + import requests as req_module 720 + with patch("requests.get") as mock_get: 721 + mock_get.side_effect = req_module.RequestException("Network error") 722 + 723 + with pytest.raises(ValueError, match="Could not resolve PDS"): 724 + authenticated_client.get_blob_url("did:plc:unknown", "cid123") 725 + 726 + def test_resolve_pds_endpoint_did_web(self, authenticated_client): 727 + """PDS resolution returns None for did:web (not implemented).""" 728 + result = authenticated_client._resolve_pds_endpoint("did:web:example.com") 729 + assert result is None 730 + 645 731 def test_list_records(self, authenticated_client, mock_atproto_client): 646 732 """List records in a collection.""" 647 733 mock_record1 = Mock() ··· 926 1012 auto_publish_schema=False, 927 1013 ) 928 1014 1015 + def test_publish_with_blobs(self, authenticated_client, mock_atproto_client): 1016 + """Publish with blob storage uploads blobs and creates record.""" 1017 + # Mock blob upload response 1018 + mock_blob_ref = Mock() 1019 + mock_blob_ref.ref = Mock(link="bafkreiblob123") 1020 + mock_blob_ref.mime_type = "application/x-tar" 1021 + mock_blob_ref.size = 2048 1022 + 1023 + mock_upload_response = Mock() 1024 + mock_upload_response.blob = mock_blob_ref 1025 + mock_atproto_client.upload_blob.return_value = mock_upload_response 1026 + 1027 + # Mock create_record response 1028 + mock_create_response = Mock() 1029 + mock_create_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.record/blobds" 1030 + mock_atproto_client.com.atproto.repo.create_record.return_value = mock_create_response 1031 + 1032 + publisher = DatasetPublisher(authenticated_client) 1033 + uri = publisher.publish_with_blobs( 1034 + blobs=[b"tar data 1", b"tar data 2"], 1035 + schema_uri="at://did:plc:test/schema/xyz", 1036 + name="BlobStoredDataset", 1037 + description="Dataset stored in blobs", 1038 + tags=["blob", "test"], 1039 + ) 1040 + 1041 + assert isinstance(uri, AtUri) 1042 + # Should have uploaded 2 blobs 1043 + assert mock_atproto_client.upload_blob.call_count == 2 1044 + # Should have created one record 1045 + assert mock_atproto_client.com.atproto.repo.create_record.call_count == 1 1046 + 1047 + # Verify record structure 1048 + call_args = mock_atproto_client.com.atproto.repo.create_record.call_args 1049 + record = call_args.kwargs["data"]["record"] 1050 + assert record["name"] == "BlobStoredDataset" 1051 + assert "storageBlobs" in record["storage"]["$type"] 1052 + 1053 + def test_publish_with_blobs_with_metadata(self, authenticated_client, mock_atproto_client): 1054 + """Publish with blobs includes metadata when provided.""" 1055 + mock_blob_ref = Mock() 1056 + mock_blob_ref.ref = Mock(link="bafkreiblob456") 1057 + mock_blob_ref.mime_type = "application/x-tar" 1058 + mock_blob_ref.size = 1024 1059 + 1060 + mock_upload_response = Mock() 1061 + mock_upload_response.blob = mock_blob_ref 1062 + mock_atproto_client.upload_blob.return_value = mock_upload_response 1063 + 1064 + mock_create_response = Mock() 1065 + mock_create_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.record/metads" 1066 + mock_atproto_client.com.atproto.repo.create_record.return_value = mock_create_response 1067 + 1068 + publisher = DatasetPublisher(authenticated_client) 1069 + publisher.publish_with_blobs( 1070 + blobs=[b"data"], 1071 + schema_uri="at://schema", 1072 + name="MetaBlobDataset", 1073 + metadata={"samples": 100, "split": "train"}, 1074 + ) 1075 + 1076 + call_args = mock_atproto_client.com.atproto.repo.create_record.call_args 1077 + record = call_args.kwargs["data"]["record"] 1078 + assert "metadata" in record 1079 + 929 1080 930 1081 class TestDatasetLoader: 931 1082 """Tests for DatasetLoader.""" ··· 1054 1205 datasets = loader.list_all() 1055 1206 1056 1207 assert len(datasets) == 1 1208 + 1209 + def test_get_storage_type_external(self, authenticated_client, mock_atproto_client): 1210 + """Get storage type returns 'external' for external storage.""" 1211 + mock_response = Mock() 1212 + mock_response.value = { 1213 + "$type": f"{LEXICON_NAMESPACE}.record", 1214 + "name": "ExternalDataset", 1215 + "schemaRef": "at://schema", 1216 + "storage": { 1217 + "$type": f"{LEXICON_NAMESPACE}.storageExternal", 1218 + "urls": ["s3://bucket/data.tar"], 1219 + }, 1220 + } 1221 + mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 1222 + 1223 + loader = DatasetLoader(authenticated_client) 1224 + storage_type = loader.get_storage_type(f"at://did:plc:abc/{LEXICON_NAMESPACE}.record/xyz") 1225 + 1226 + assert storage_type == "external" 1227 + 1228 + def test_get_storage_type_blobs(self, authenticated_client, mock_atproto_client): 1229 + """Get storage type returns 'blobs' for blob storage.""" 1230 + mock_response = Mock() 1231 + mock_response.value = { 1232 + "$type": f"{LEXICON_NAMESPACE}.record", 1233 + "name": "BlobDataset", 1234 + "schemaRef": "at://schema", 1235 + "storage": { 1236 + "$type": f"{LEXICON_NAMESPACE}.storageBlobs", 1237 + "blobs": [{"ref": {"$link": "bafkreitest"}}], 1238 + }, 1239 + } 1240 + mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 1241 + 1242 + loader = DatasetLoader(authenticated_client) 1243 + storage_type = loader.get_storage_type(f"at://did:plc:abc/{LEXICON_NAMESPACE}.record/xyz") 1244 + 1245 + assert storage_type == "blobs" 1246 + 1247 + def test_get_storage_type_unknown(self, authenticated_client, mock_atproto_client): 1248 + """Get storage type raises for unknown storage type.""" 1249 + mock_response = Mock() 1250 + mock_response.value = { 1251 + "$type": f"{LEXICON_NAMESPACE}.record", 1252 + "name": "UnknownStorageDataset", 1253 + "schemaRef": "at://schema", 1254 + "storage": { 1255 + "$type": "some.unknown.storage", 1256 + }, 1257 + } 1258 + mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 1259 + 1260 + loader = DatasetLoader(authenticated_client) 1261 + 1262 + with pytest.raises(ValueError, match="Unknown storage type"): 1263 + loader.get_storage_type(f"at://did:plc:abc/{LEXICON_NAMESPACE}.record/xyz") 1264 + 1265 + def test_get_blobs(self, authenticated_client, mock_atproto_client): 1266 + """Get blobs returns blob references from storage.""" 1267 + blob_refs = [ 1268 + {"ref": {"$link": "bafkreitest1"}, "mimeType": "application/x-tar", "size": 1024}, 1269 + {"ref": {"$link": "bafkreitest2"}, "mimeType": "application/x-tar", "size": 2048}, 1270 + ] 1271 + mock_response = Mock() 1272 + mock_response.value = { 1273 + "$type": f"{LEXICON_NAMESPACE}.record", 1274 + "name": "BlobDataset", 1275 + "schemaRef": "at://schema", 1276 + "storage": { 1277 + "$type": f"{LEXICON_NAMESPACE}.storageBlobs", 1278 + "blobs": blob_refs, 1279 + }, 1280 + } 1281 + mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 1282 + 1283 + loader = DatasetLoader(authenticated_client) 1284 + blobs = loader.get_blobs(f"at://did:plc:abc/{LEXICON_NAMESPACE}.record/xyz") 1285 + 1286 + assert len(blobs) == 2 1287 + assert blobs[0]["ref"]["$link"] == "bafkreitest1" 1288 + assert blobs[1]["ref"]["$link"] == "bafkreitest2" 1289 + 1290 + def test_get_blobs_external_storage_error(self, authenticated_client, mock_atproto_client): 1291 + """Get blobs raises for external URL storage datasets.""" 1292 + mock_response = Mock() 1293 + mock_response.value = { 1294 + "$type": f"{LEXICON_NAMESPACE}.record", 1295 + "name": "ExternalDataset", 1296 + "schemaRef": "at://schema", 1297 + "storage": { 1298 + "$type": f"{LEXICON_NAMESPACE}.storageExternal", 1299 + "urls": ["s3://bucket/data.tar"], 1300 + }, 1301 + } 1302 + mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 1303 + 1304 + loader = DatasetLoader(authenticated_client) 1305 + 1306 + with pytest.raises(ValueError, match="external URL storage"): 1307 + loader.get_blobs(f"at://did:plc:abc/{LEXICON_NAMESPACE}.record/xyz") 1308 + 1309 + def test_get_blobs_unknown_storage_error(self, authenticated_client, mock_atproto_client): 1310 + """Get blobs raises for unknown storage type.""" 1311 + mock_response = Mock() 1312 + mock_response.value = { 1313 + "$type": f"{LEXICON_NAMESPACE}.record", 1314 + "name": "UnknownDataset", 1315 + "schemaRef": "at://schema", 1316 + "storage": { 1317 + "$type": "some.unknown.storage", 1318 + }, 1319 + } 1320 + mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 1321 + 1322 + loader = DatasetLoader(authenticated_client) 1323 + 1324 + with pytest.raises(ValueError, match="Unknown storage type"): 1325 + loader.get_blobs(f"at://did:plc:abc/{LEXICON_NAMESPACE}.record/xyz") 1326 + 1327 + def test_get_blob_urls(self, authenticated_client, mock_atproto_client): 1328 + """Get blob URLs resolves PDS and constructs download URLs.""" 1329 + mock_response = Mock() 1330 + mock_response.value = { 1331 + "$type": f"{LEXICON_NAMESPACE}.record", 1332 + "name": "BlobDataset", 1333 + "schemaRef": "at://schema", 1334 + "storage": { 1335 + "$type": f"{LEXICON_NAMESPACE}.storageBlobs", 1336 + "blobs": [ 1337 + {"ref": {"$link": "bafkreitest1"}}, 1338 + {"ref": {"$link": "bafkreitest2"}}, 1339 + ], 1340 + }, 1341 + } 1342 + mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 1343 + 1344 + # Mock PDS resolution 1345 + with patch("requests.get") as mock_get: 1346 + mock_did_response = Mock() 1347 + mock_did_response.json.return_value = { 1348 + "service": [ 1349 + {"type": "AtprotoPersonalDataServer", "serviceEndpoint": "https://pds.example.com"} 1350 + ] 1351 + } 1352 + mock_did_response.raise_for_status = Mock() 1353 + mock_get.return_value = mock_did_response 1354 + 1355 + loader = DatasetLoader(authenticated_client) 1356 + urls = loader.get_blob_urls(f"at://did:plc:abc123/{LEXICON_NAMESPACE}.record/xyz") 1357 + 1358 + assert len(urls) == 2 1359 + assert "bafkreitest1" in urls[0] 1360 + assert "bafkreitest2" in urls[1] 1361 + assert "did:plc:abc123" in urls[0] 1362 + 1363 + def test_get_urls_unknown_storage_error(self, authenticated_client, mock_atproto_client): 1364 + """Get URLs raises for unknown storage type.""" 1365 + mock_response = Mock() 1366 + mock_response.value = { 1367 + "$type": f"{LEXICON_NAMESPACE}.record", 1368 + "name": "UnknownDataset", 1369 + "schemaRef": "at://schema", 1370 + "storage": { 1371 + "$type": "some.unknown.storage", 1372 + }, 1373 + } 1374 + mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 1375 + 1376 + loader = DatasetLoader(authenticated_client) 1377 + 1378 + with pytest.raises(ValueError, match="Unknown storage type"): 1379 + loader.get_urls(f"at://did:plc:abc/{LEXICON_NAMESPACE}.record/xyz") 1057 1380 1058 1381 1059 1382 # =============================================================================