A loose federation of distributed, typed datasets
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

test: add comprehensive integration test suite

Add 9 integration test modules covering:
- End-to-end data flow pipeline
- Lens transformation chains
- Local storage workflows
- Atmosphere ATProto workflows
- Cross-backend interoperability
- Promotion pipeline (local to atmosphere)
- Dynamic type loading from schemas
- Error handling and recovery
- Edge cases and data type coverage

Also document chainlink issue tracking usage in CLAUDE.md.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

+5264
+10
CHANGELOG.md
··· 11 11 ### Fixed 12 12 13 13 ### Changed 14 + - Comprehensive integration test suite for atdata (#190) 15 + - Integration test: Error handling and recovery (#198) 16 + - Integration test: Edge cases and data type coverage (#199) 17 + - Integration test: Promotion pipeline (local to atmosphere) (#196) 18 + - Integration test: Cross-backend interoperability (#195) 19 + - Integration test: Atmosphere ATProto workflow (#194) 20 + - Integration test: Dynamic type loading from schemas (#197) 21 + - Integration test: Local storage complete workflow (#192) 22 + - Integration test: Lens transformation chains (#193) 23 + - Integration test: End-to-end data flow pipeline (#191) 14 24 - Adversarial review: Test suite and codebase comprehensive assessment (#181) 15 25 - Consolidate test sample type definitions into conftest.py (#184) 16 26 - Trim verbose docstrings that restate function signatures (#189)
+13
CLAUDE.md
··· 168 168 ... 169 169 ``` 170 170 171 + ## Issue Tracking 172 + 173 + This project uses **chainlink** for issue tracking. Chainlink commands do NOT need to be prefixed with `uv run`: 174 + ```bash 175 + # Correct - run chainlink directly 176 + chainlink list 177 + chainlink close 123 178 + chainlink show 123 179 + 180 + # Incorrect - don't use uv run 181 + uv run chainlink list # Not needed 182 + ``` 183 + 171 184 ## Git Workflow 172 185 173 186 ### Committing Changes
+406
tests/test_integration_atmosphere.py
··· 1 + """Integration tests for Atmosphere ATProto workflows. 2 + 3 + Tests end-to-end Atmosphere operations including: 4 + - Full publish workflow (login → publish schema → publish dataset → query) 5 + - Session persistence and restoration 6 + - Record discovery and querying 7 + - AtmosphereIndex compliance with AbstractIndex 8 + """ 9 + 10 + import pytest 11 + from dataclasses import dataclass 12 + from unittest.mock import Mock, MagicMock, patch 13 + 14 + import numpy as np 15 + from numpy.typing import NDArray 16 + import msgpack 17 + 18 + import atdata 19 + from atdata.atmosphere import ( 20 + AtmosphereClient, 21 + AtmosphereIndex, 22 + AtmosphereIndexEntry, 23 + SchemaPublisher, 24 + SchemaLoader, 25 + DatasetPublisher, 26 + DatasetLoader, 27 + AtUri, 28 + ) 29 + from atdata.atmosphere._types import LEXICON_NAMESPACE 30 + 31 + 32 + ## 33 + # Test sample types 34 + 35 + 36 + @atdata.packable 37 + class AtmoSample: 38 + """Sample for atmosphere tests.""" 39 + name: str 40 + value: int 41 + 42 + 43 + @atdata.packable 44 + class AtmoNDArraySample: 45 + """Sample with NDArray for atmosphere tests.""" 46 + label: str 47 + data: NDArray 48 + 49 + 50 + ## 51 + # Fixtures 52 + 53 + 54 + @pytest.fixture 55 + def mock_atproto_client(): 56 + """Create a mock atproto SDK client.""" 57 + mock = Mock() 58 + mock.me = MagicMock() 59 + mock.me.did = "did:plc:integration123" 60 + mock.me.handle = "integration.test.social" 61 + 62 + mock_profile = Mock() 63 + mock_profile.did = "did:plc:integration123" 64 + mock_profile.handle = "integration.test.social" 65 + mock.login.return_value = mock_profile 66 + mock.export_session_string.return_value = "test-session-export" 67 + 68 + return mock 69 + 70 + 71 + @pytest.fixture 72 + def authenticated_client(mock_atproto_client): 73 + """Create an authenticated AtmosphereClient.""" 74 + client = AtmosphereClient(_client=mock_atproto_client) 75 + client.login("integration.test.social", "test-password") 76 + return client 77 + 78 + 79 + ## 80 + # Full Workflow Tests 81 + 82 + 83 + class TestFullPublishWorkflow: 84 + """End-to-end tests for publish workflow.""" 85 + 86 + def test_login_publish_schema_publish_dataset(self, mock_atproto_client): 87 + """Full workflow: login → publish schema → publish dataset.""" 88 + # Setup mock responses 89 + schema_response = Mock() 90 + schema_response.uri = f"at://did:plc:integration123/{LEXICON_NAMESPACE}.sampleSchema/schema123" 91 + 92 + dataset_response = Mock() 93 + dataset_response.uri = f"at://did:plc:integration123/{LEXICON_NAMESPACE}.dataset/dataset456" 94 + 95 + mock_atproto_client.com.atproto.repo.create_record.side_effect = [ 96 + schema_response, 97 + dataset_response, 98 + ] 99 + 100 + # Execute workflow 101 + client = AtmosphereClient(_client=mock_atproto_client) 102 + client.login("test.social", "password") 103 + 104 + # Publish schema 105 + schema_pub = SchemaPublisher(client) 106 + schema_uri = schema_pub.publish(AtmoSample, version="1.0.0") 107 + 108 + assert isinstance(schema_uri, AtUri) 109 + assert schema_uri.collection == f"{LEXICON_NAMESPACE}.sampleSchema" 110 + 111 + # Publish dataset using correct API 112 + dataset_pub = DatasetPublisher(client) 113 + dataset_uri = dataset_pub.publish_with_urls( 114 + urls=["s3://bucket/data.tar"], 115 + schema_uri=str(schema_uri), 116 + name="test-dataset", 117 + ) 118 + 119 + assert isinstance(dataset_uri, AtUri) 120 + assert dataset_uri.collection == f"{LEXICON_NAMESPACE}.dataset" 121 + 122 + 123 + class TestSessionPersistence: 124 + """Tests for session export and restoration.""" 125 + 126 + def test_export_session_returns_string(self, authenticated_client): 127 + """Authenticated client should export session.""" 128 + session = authenticated_client.export_session() 129 + assert isinstance(session, str) 130 + assert len(session) > 0 131 + 132 + def test_login_with_session_restores_auth(self, mock_atproto_client): 133 + """Login with session string should restore authentication.""" 134 + client = AtmosphereClient(_client=mock_atproto_client) 135 + 136 + assert not client.is_authenticated 137 + 138 + client.login_with_session("saved-session-string") 139 + 140 + assert client.is_authenticated 141 + mock_atproto_client.login.assert_called_with(session_string="saved-session-string") 142 + 143 + def test_session_round_trip(self, mock_atproto_client): 144 + """Export then import session should maintain auth.""" 145 + # First client - login and export 146 + client1 = AtmosphereClient(_client=mock_atproto_client) 147 + client1.login("user@test.social", "password") 148 + session = client1.export_session() 149 + 150 + # Second client - restore from session 151 + mock_atproto_client2 = Mock() 152 + mock_atproto_client2.me = mock_atproto_client.me 153 + mock_atproto_client2.login.return_value = mock_atproto_client.login.return_value 154 + mock_atproto_client2.export_session_string.return_value = session 155 + 156 + client2 = AtmosphereClient(_client=mock_atproto_client2) 157 + client2.login_with_session(session) 158 + 159 + assert client2.is_authenticated 160 + assert client2.did == client1.did 161 + 162 + 163 + class TestRecordDiscovery: 164 + """Tests for finding and querying records.""" 165 + 166 + def test_list_schemas_returns_all(self, authenticated_client, mock_atproto_client): 167 + """list_schemas should return all schema records.""" 168 + mock_record1 = Mock() 169 + mock_record1.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/s1" 170 + mock_record1.value = {"name": "Schema1", "version": "1.0.0", "fields": []} 171 + 172 + mock_record2 = Mock() 173 + mock_record2.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/s2" 174 + mock_record2.value = {"name": "Schema2", "version": "1.0.0", "fields": []} 175 + 176 + mock_response = Mock() 177 + mock_response.records = [mock_record1, mock_record2] 178 + mock_response.cursor = None 179 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_response 180 + 181 + loader = SchemaLoader(authenticated_client) 182 + schemas = loader.list_all() 183 + 184 + assert len(schemas) == 2 185 + 186 + def test_get_schema_by_uri(self, authenticated_client, mock_atproto_client): 187 + """get should retrieve schema by URI.""" 188 + mock_response = Mock() 189 + mock_response.value = { 190 + "$type": f"{LEXICON_NAMESPACE}.sampleSchema", 191 + "name": "FoundSchema", 192 + "version": "2.0.0", 193 + "fields": [ 194 + {"name": "field1", "fieldType": {"$type": f"{LEXICON_NAMESPACE}.schemaType#primitive", "primitive": "str"}, "optional": False} 195 + ] 196 + } 197 + mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 198 + 199 + loader = SchemaLoader(authenticated_client) 200 + schema = loader.get("at://did:plc:test/schema/key") 201 + 202 + assert schema["name"] == "FoundSchema" 203 + assert schema["version"] == "2.0.0" 204 + 205 + 206 + class TestAtmosphereIndex: 207 + """Tests for AtmosphereIndex AbstractIndex compliance.""" 208 + 209 + def test_index_list_datasets_yields_entries(self, authenticated_client, mock_atproto_client): 210 + """list_datasets should yield AtmosphereIndexEntry objects.""" 211 + mock_record = Mock() 212 + mock_record.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.dataset/d1" 213 + mock_record.value = { 214 + "name": "listed-dataset", 215 + "schemaRef": "at://schema", 216 + "storage": {"$type": f"{LEXICON_NAMESPACE}.storageExternal", "urls": ["s3://data.tar"]}, 217 + } 218 + 219 + mock_response = Mock() 220 + mock_response.records = [mock_record] 221 + mock_response.cursor = None 222 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_response 223 + 224 + index = AtmosphereIndex(authenticated_client) 225 + entries = list(index.list_datasets()) 226 + 227 + assert len(entries) == 1 228 + assert isinstance(entries[0], AtmosphereIndexEntry) 229 + 230 + def test_entry_from_record_has_properties(self): 231 + """AtmosphereIndexEntry should expose IndexEntry properties.""" 232 + record = { 233 + "name": "test-dataset", 234 + "schemaRef": "at://did:plc:schema/schema/key", 235 + "storage": {"$type": f"{LEXICON_NAMESPACE}.storageExternal", "urls": ["s3://data.tar"]}, 236 + } 237 + 238 + entry = AtmosphereIndexEntry("at://test/dataset/key", record) 239 + 240 + assert entry.name == "test-dataset" 241 + assert entry.schema_ref == "at://did:plc:schema/schema/key" 242 + assert entry.data_urls == ["s3://data.tar"] 243 + assert entry.uri == "at://test/dataset/key" 244 + 245 + def test_entry_metadata_unpacking(self): 246 + """Entry should unpack msgpack metadata.""" 247 + original_meta = {"version": "1.0", "count": 100} 248 + packed_meta = msgpack.packb(original_meta) 249 + 250 + record = { 251 + "name": "meta-dataset", 252 + "schemaRef": "at://schema", 253 + "storage": {"$type": f"{LEXICON_NAMESPACE}.storageExternal", "urls": ["s3://data.tar"]}, 254 + "metadata": packed_meta, 255 + } 256 + 257 + entry = AtmosphereIndexEntry("at://test/dataset/key", record) 258 + 259 + assert entry.metadata == original_meta 260 + assert entry.metadata["version"] == "1.0" 261 + 262 + def test_entry_no_metadata_returns_none(self): 263 + """Entry without metadata should return None.""" 264 + record = { 265 + "name": "no-meta", 266 + "schemaRef": "at://schema", 267 + "storage": {"$type": f"{LEXICON_NAMESPACE}.storageExternal", "urls": ["s3://data.tar"]}, 268 + } 269 + 270 + entry = AtmosphereIndexEntry("at://test/dataset/key", record) 271 + 272 + assert entry.metadata is None 273 + 274 + 275 + class TestExternalStorageUrls: 276 + """Tests for datasets with external storage URLs.""" 277 + 278 + def test_publish_with_urls(self, authenticated_client, mock_atproto_client): 279 + """Publish dataset with external URLs.""" 280 + mock_response = Mock() 281 + mock_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.dataset/urls" 282 + mock_atproto_client.com.atproto.repo.create_record.return_value = mock_response 283 + 284 + publisher = DatasetPublisher(authenticated_client) 285 + uri = publisher.publish_with_urls( 286 + urls=["s3://bucket/shard-000.tar", "s3://bucket/shard-001.tar"], 287 + schema_uri="at://did:plc:schema/schema/key", 288 + name="multi-url-dataset", 289 + ) 290 + 291 + assert uri is not None 292 + 293 + call_args = mock_atproto_client.com.atproto.repo.create_record.call_args 294 + record = call_args.kwargs["data"]["record"] 295 + 296 + assert record["storage"]["$type"] == f"{LEXICON_NAMESPACE}.storageExternal" 297 + assert len(record["storage"]["urls"]) == 2 298 + 299 + def test_entry_extracts_external_urls(self): 300 + """Entry should extract URLs from external storage.""" 301 + record = { 302 + "name": "external-test", 303 + "schemaRef": "at://schema", 304 + "storage": { 305 + "$type": f"{LEXICON_NAMESPACE}.storageExternal", 306 + "urls": ["https://cdn.example.com/data-000.tar", "https://cdn.example.com/data-001.tar"], 307 + }, 308 + } 309 + 310 + entry = AtmosphereIndexEntry("at://test/dataset/key", record) 311 + 312 + assert entry.data_urls == [ 313 + "https://cdn.example.com/data-000.tar", 314 + "https://cdn.example.com/data-001.tar", 315 + ] 316 + 317 + 318 + class TestSchemaPublishing: 319 + """Tests for schema record publishing.""" 320 + 321 + def test_publish_basic_schema(self, authenticated_client, mock_atproto_client): 322 + """Schema should publish with correct structure.""" 323 + mock_response = Mock() 324 + mock_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/basic" 325 + mock_atproto_client.com.atproto.repo.create_record.return_value = mock_response 326 + 327 + publisher = SchemaPublisher(authenticated_client) 328 + uri = publisher.publish(AtmoSample, version="1.0.0") 329 + 330 + assert isinstance(uri, AtUri) 331 + 332 + call_args = mock_atproto_client.com.atproto.repo.create_record.call_args 333 + record = call_args.kwargs["data"]["record"] 334 + 335 + assert record["name"] == "AtmoSample" 336 + assert record["version"] == "1.0.0" 337 + 338 + field_names = {f["name"] for f in record["fields"]} 339 + assert "name" in field_names 340 + assert "value" in field_names 341 + 342 + def test_publish_ndarray_schema(self, authenticated_client, mock_atproto_client): 343 + """Schema with NDArray field should publish correctly.""" 344 + mock_response = Mock() 345 + mock_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/ndarray" 346 + mock_atproto_client.com.atproto.repo.create_record.return_value = mock_response 347 + 348 + publisher = SchemaPublisher(authenticated_client) 349 + uri = publisher.publish(AtmoNDArraySample, version="1.0.0") 350 + 351 + call_args = mock_atproto_client.com.atproto.repo.create_record.call_args 352 + record = call_args.kwargs["data"]["record"] 353 + 354 + # Find the data field 355 + data_field = next(f for f in record["fields"] if f["name"] == "data") 356 + assert "ndarray" in data_field["fieldType"]["$type"] 357 + 358 + 359 + class TestErrorHandling: 360 + """Tests for error handling in atmosphere operations.""" 361 + 362 + def test_not_authenticated_raises_on_publish(self, mock_atproto_client): 363 + """Publishing without authentication should raise.""" 364 + client = AtmosphereClient(_client=mock_atproto_client) 365 + 366 + publisher = SchemaPublisher(client) 367 + 368 + with pytest.raises(ValueError, match="authenticated"): 369 + publisher.publish(AtmoSample, version="1.0.0") 370 + 371 + def test_invalid_uri_raises(self): 372 + """Invalid AT URI should raise ValueError.""" 373 + with pytest.raises(ValueError): 374 + AtUri.parse("not-a-valid-uri") 375 + 376 + with pytest.raises(ValueError): 377 + AtUri.parse("https://example.com/path") 378 + 379 + def test_uri_missing_parts_raises(self): 380 + """AT URI with missing parts should raise.""" 381 + with pytest.raises(ValueError, match="expected authority/collection/rkey"): 382 + AtUri.parse("at://did:plc:abc/collection") 383 + 384 + 385 + class TestAtUriParsing: 386 + """Tests for AT URI parsing and formatting.""" 387 + 388 + def test_parse_valid_uri(self): 389 + """Parse a valid AT URI.""" 390 + uri = AtUri.parse("at://did:plc:abc123/com.example.record/key456") 391 + 392 + assert uri.authority == "did:plc:abc123" 393 + assert uri.collection == "com.example.record" 394 + assert uri.rkey == "key456" 395 + 396 + def test_uri_str_roundtrip(self): 397 + """String conversion should roundtrip.""" 398 + original = "at://did:plc:test123/ac.foundation.dataset.sampleSchema/xyz789" 399 + uri = AtUri.parse(original) 400 + assert str(uri) == original 401 + 402 + def test_parse_atdata_namespace(self): 403 + """Parse URIs in the atdata namespace.""" 404 + uri = AtUri.parse(f"at://did:plc:abc/{LEXICON_NAMESPACE}.sampleSchema/test") 405 + 406 + assert uri.collection == f"{LEXICON_NAMESPACE}.sampleSchema"
+557
tests/test_integration_cross_backend.py
··· 1 + """Integration tests for cross-backend interoperability. 2 + 3 + Tests that abstract protocols work consistently across: 4 + - LocalIndex and AtmosphereIndex (AbstractIndex protocol) 5 + - LocalDatasetEntry and AtmosphereIndexEntry (IndexEntry protocol) 6 + - S3DataStore (AbstractDataStore protocol) 7 + """ 8 + 9 + import pytest 10 + from dataclasses import dataclass 11 + from typing import Type 12 + from unittest.mock import Mock, MagicMock 13 + 14 + import numpy as np 15 + from numpy.typing import NDArray 16 + 17 + import atdata 18 + from atdata.local import LocalIndex, LocalDatasetEntry 19 + from atdata._protocols import IndexEntry, AbstractIndex 20 + from atdata.atmosphere import ( 21 + AtmosphereClient, 22 + AtmosphereIndex, 23 + AtmosphereIndexEntry, 24 + ) 25 + from atdata.atmosphere._types import LEXICON_NAMESPACE 26 + 27 + 28 + ## 29 + # Test sample types 30 + 31 + 32 + @atdata.packable 33 + class CrossBackendSample: 34 + """Sample for cross-backend tests.""" 35 + name: str 36 + value: int 37 + 38 + 39 + @atdata.packable 40 + class CrossBackendArraySample: 41 + """Sample with NDArray for cross-backend tests.""" 42 + label: str 43 + data: NDArray 44 + 45 + 46 + ## 47 + # Fixtures 48 + 49 + 50 + @pytest.fixture 51 + def mock_atproto_client(): 52 + """Create a mock atproto SDK client.""" 53 + mock = Mock() 54 + mock.me = MagicMock() 55 + mock.me.did = "did:plc:crossbackend123" 56 + mock.me.handle = "crossbackend.test.social" 57 + 58 + mock_profile = Mock() 59 + mock_profile.did = "did:plc:crossbackend123" 60 + mock_profile.handle = "crossbackend.test.social" 61 + mock.login.return_value = mock_profile 62 + mock.export_session_string.return_value = "test-session-export" 63 + 64 + return mock 65 + 66 + 67 + @pytest.fixture 68 + def authenticated_atmosphere_client(mock_atproto_client): 69 + """Create an authenticated AtmosphereClient.""" 70 + client = AtmosphereClient(_client=mock_atproto_client) 71 + client.login("crossbackend.test.social", "test-password") 72 + return client 73 + 74 + 75 + @pytest.fixture 76 + def local_index(clean_redis): 77 + """Create a LocalIndex backed by Redis.""" 78 + return LocalIndex(redis=clean_redis) 79 + 80 + 81 + @pytest.fixture 82 + def atmosphere_index(authenticated_atmosphere_client): 83 + """Create an AtmosphereIndex.""" 84 + return AtmosphereIndex(authenticated_atmosphere_client) 85 + 86 + 87 + ## 88 + # IndexEntry Protocol Tests 89 + 90 + 91 + class TestIndexEntryProtocol: 92 + """Tests that LocalDatasetEntry and AtmosphereIndexEntry are interchangeable.""" 93 + 94 + def test_local_entry_satisfies_protocol(self): 95 + """LocalDatasetEntry should satisfy IndexEntry protocol.""" 96 + entry = LocalDatasetEntry( 97 + _name="test-dataset", 98 + _schema_ref="local://schemas/TestSample@1.0.0", 99 + _data_urls=["s3://bucket/test.tar"], 100 + ) 101 + 102 + assert isinstance(entry, IndexEntry) 103 + assert entry.name == "test-dataset" 104 + assert entry.schema_ref == "local://schemas/TestSample@1.0.0" 105 + assert entry.data_urls == ["s3://bucket/test.tar"] 106 + assert entry.metadata is None 107 + 108 + def test_atmosphere_entry_satisfies_protocol(self): 109 + """AtmosphereIndexEntry should satisfy IndexEntry protocol.""" 110 + record = { 111 + "name": "atmo-dataset", 112 + "schemaRef": "at://did:plc:test/ac.foundation.dataset.sampleSchema/abc", 113 + "storage": { 114 + "$type": f"{LEXICON_NAMESPACE}.storageExternal", 115 + "urls": ["s3://bucket/atmo.tar"], 116 + }, 117 + } 118 + entry = AtmosphereIndexEntry("at://did:plc:test/dataset/xyz", record) 119 + 120 + assert isinstance(entry, IndexEntry) 121 + assert entry.name == "atmo-dataset" 122 + assert entry.schema_ref == "at://did:plc:test/ac.foundation.dataset.sampleSchema/abc" 123 + assert entry.data_urls == ["s3://bucket/atmo.tar"] 124 + assert entry.metadata is None 125 + 126 + def test_entries_work_with_common_function(self): 127 + """Both entry types should work with functions accepting IndexEntry.""" 128 + def process_entry(entry: IndexEntry) -> dict: 129 + return { 130 + "name": entry.name, 131 + "schema": entry.schema_ref, 132 + "url_count": len(entry.data_urls), 133 + } 134 + 135 + local_entry = LocalDatasetEntry( 136 + _name="local-ds", 137 + _schema_ref="local://schemas/Test@1.0.0", 138 + _data_urls=["s3://bucket/local.tar"], 139 + ) 140 + 141 + atmo_record = { 142 + "name": "atmo-ds", 143 + "schemaRef": "at://did:plc:test/schema/abc", 144 + "storage": { 145 + "$type": f"{LEXICON_NAMESPACE}.storageExternal", 146 + "urls": ["s3://bucket/atmo.tar", "s3://bucket/atmo2.tar"], 147 + }, 148 + } 149 + atmo_entry = AtmosphereIndexEntry("at://test", atmo_record) 150 + 151 + local_result = process_entry(local_entry) 152 + atmo_result = process_entry(atmo_entry) 153 + 154 + assert local_result["name"] == "local-ds" 155 + assert local_result["url_count"] == 1 156 + 157 + assert atmo_result["name"] == "atmo-ds" 158 + assert atmo_result["url_count"] == 2 159 + 160 + def test_entries_with_metadata(self): 161 + """Both entry types should handle metadata consistently.""" 162 + import msgpack 163 + 164 + # Local entry with metadata 165 + local_entry = LocalDatasetEntry( 166 + _name="local-meta", 167 + _schema_ref="local://schemas/Test@1.0.0", 168 + _data_urls=["s3://bucket/local.tar"], 169 + _metadata={"version": "1.0", "samples": 100}, 170 + ) 171 + 172 + # Atmosphere entry with metadata 173 + atmo_record = { 174 + "name": "atmo-meta", 175 + "schemaRef": "at://did:plc:test/schema/abc", 176 + "storage": { 177 + "$type": f"{LEXICON_NAMESPACE}.storageExternal", 178 + "urls": ["s3://bucket/atmo.tar"], 179 + }, 180 + "metadata": msgpack.packb({"version": "1.0", "samples": 200}), 181 + } 182 + atmo_entry = AtmosphereIndexEntry("at://test", atmo_record) 183 + 184 + assert local_entry.metadata["version"] == "1.0" 185 + assert local_entry.metadata["samples"] == 100 186 + 187 + assert atmo_entry.metadata["version"] == "1.0" 188 + assert atmo_entry.metadata["samples"] == 200 189 + 190 + def test_entries_with_multiple_urls(self): 191 + """Both entry types should handle multiple data URLs.""" 192 + urls = [ 193 + "s3://bucket/shard-000000.tar", 194 + "s3://bucket/shard-000001.tar", 195 + "s3://bucket/shard-000002.tar", 196 + ] 197 + 198 + local_entry = LocalDatasetEntry( 199 + _name="multi-shard-local", 200 + _schema_ref="local://schemas/Test@1.0.0", 201 + _data_urls=urls, 202 + ) 203 + 204 + atmo_record = { 205 + "name": "multi-shard-atmo", 206 + "schemaRef": "at://did:plc:test/schema/abc", 207 + "storage": { 208 + "$type": f"{LEXICON_NAMESPACE}.storageExternal", 209 + "urls": urls, 210 + }, 211 + } 212 + atmo_entry = AtmosphereIndexEntry("at://test", atmo_record) 213 + 214 + assert local_entry.data_urls == urls 215 + assert atmo_entry.data_urls == urls 216 + 217 + 218 + ## 219 + # AbstractIndex Protocol Tests 220 + 221 + 222 + class TestAbstractIndexProtocol: 223 + """Tests that LocalIndex and AtmosphereIndex share common behavior.""" 224 + 225 + def test_local_index_list_datasets_yields_entries(self, local_index, clean_redis): 226 + """LocalIndex.list_datasets should yield IndexEntry objects.""" 227 + # Insert an entry directly via Redis for testing 228 + entry = LocalDatasetEntry( 229 + _name="test-list", 230 + _schema_ref="local://schemas/Test@1.0.0", 231 + _data_urls=["s3://bucket/test.tar"], 232 + ) 233 + entry.write_to(clean_redis) 234 + 235 + entries = list(local_index.list_datasets()) 236 + 237 + assert len(entries) >= 1 238 + found = [e for e in entries if e.name == "test-list"] 239 + assert len(found) == 1 240 + assert isinstance(found[0], IndexEntry) 241 + 242 + def test_atmosphere_index_list_datasets_yields_entries( 243 + self, atmosphere_index, mock_atproto_client 244 + ): 245 + """AtmosphereIndex.list_datasets should yield IndexEntry objects.""" 246 + mock_record = Mock() 247 + mock_record.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.dataset/d1" 248 + mock_record.value = { 249 + "name": "atmo-list", 250 + "schemaRef": "at://schema", 251 + "storage": { 252 + "$type": f"{LEXICON_NAMESPACE}.storageExternal", 253 + "urls": ["s3://data.tar"], 254 + }, 255 + } 256 + 257 + mock_response = Mock() 258 + mock_response.records = [mock_record] 259 + mock_response.cursor = None 260 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_response 261 + 262 + entries = list(atmosphere_index.list_datasets()) 263 + 264 + assert len(entries) == 1 265 + assert isinstance(entries[0], IndexEntry) 266 + assert entries[0].name == "atmo-list" 267 + 268 + def test_local_index_publish_schema(self, local_index): 269 + """LocalIndex.publish_schema should return schema reference.""" 270 + schema_ref = local_index.publish_schema(CrossBackendSample, version="1.0.0") 271 + 272 + assert schema_ref is not None 273 + assert "CrossBackendSample" in schema_ref 274 + assert "1.0.0" in schema_ref 275 + 276 + def test_atmosphere_index_publish_schema( 277 + self, atmosphere_index, mock_atproto_client 278 + ): 279 + """AtmosphereIndex.publish_schema should return AT URI.""" 280 + mock_response = Mock() 281 + mock_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/abc" 282 + mock_atproto_client.com.atproto.repo.create_record.return_value = mock_response 283 + 284 + schema_ref = atmosphere_index.publish_schema( 285 + CrossBackendSample, version="1.0.0" 286 + ) 287 + 288 + assert schema_ref is not None 289 + assert "at://" in str(schema_ref) 290 + 291 + def test_local_index_get_schema(self, local_index): 292 + """LocalIndex should retrieve published schemas.""" 293 + schema_ref = local_index.publish_schema(CrossBackendSample, version="2.0.0") 294 + schema = local_index.get_schema(schema_ref) 295 + 296 + assert schema["name"] == "CrossBackendSample" 297 + assert schema["version"] == "2.0.0" 298 + assert len(schema["fields"]) == 2 299 + 300 + def test_atmosphere_index_get_schema( 301 + self, atmosphere_index, mock_atproto_client 302 + ): 303 + """AtmosphereIndex should retrieve schemas.""" 304 + mock_response = Mock() 305 + mock_response.value = { 306 + "$type": f"{LEXICON_NAMESPACE}.sampleSchema", 307 + "name": "RetrievedSchema", 308 + "version": "1.0.0", 309 + "fields": [{"name": "field1", "fieldType": {"$type": f"{LEXICON_NAMESPACE}.schemaType#primitive", "primitive": "str"}, "optional": False}], 310 + } 311 + mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 312 + 313 + schema = atmosphere_index.get_schema("at://did:plc:test/schema/key") 314 + 315 + assert schema["name"] == "RetrievedSchema" 316 + assert schema["version"] == "1.0.0" 317 + 318 + 319 + class TestSchemaPortability: 320 + """Tests that schemas can be used across backends.""" 321 + 322 + def test_schema_field_structure_matches(self, local_index): 323 + """Schema structure should be consistent regardless of backend.""" 324 + schema_ref = local_index.publish_schema(CrossBackendSample, version="1.0.0") 325 + schema = local_index.get_schema(schema_ref) 326 + 327 + # Verify schema has expected structure 328 + assert "name" in schema 329 + assert "version" in schema 330 + assert "fields" in schema 331 + 332 + field_names = {f["name"] for f in schema["fields"]} 333 + assert "name" in field_names 334 + assert "value" in field_names 335 + 336 + def test_ndarray_schema_field_structure(self, local_index): 337 + """NDArray fields should be represented consistently.""" 338 + schema_ref = local_index.publish_schema(CrossBackendArraySample, version="1.0.0") 339 + schema = local_index.get_schema(schema_ref) 340 + 341 + field_names = {f["name"] for f in schema["fields"]} 342 + assert "label" in field_names 343 + assert "data" in field_names 344 + 345 + # Find the data field and verify it's marked as ndarray type 346 + data_field = next(f for f in schema["fields"] if f["name"] == "data") 347 + field_type = data_field["fieldType"] 348 + # Field type should indicate it's an ndarray 349 + assert "ndarray" in field_type.get("$type", "").lower() or \ 350 + field_type.get("primitive") == "ndarray" 351 + 352 + 353 + class TestCrossBackendSchemaResolution: 354 + """Tests for schema resolution across different backends.""" 355 + 356 + def test_local_schema_ref_format(self, local_index): 357 + """Local schema refs should use local:// URI scheme.""" 358 + schema_ref = local_index.publish_schema(CrossBackendSample, version="1.0.0") 359 + 360 + assert schema_ref.startswith("local://") 361 + assert "schemas" in schema_ref 362 + 363 + def test_atmosphere_schema_ref_format( 364 + self, atmosphere_index, mock_atproto_client 365 + ): 366 + """Atmosphere schema refs should use at:// URI scheme.""" 367 + mock_response = Mock() 368 + mock_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/abc" 369 + mock_atproto_client.com.atproto.repo.create_record.return_value = mock_response 370 + 371 + schema_ref = atmosphere_index.publish_schema( 372 + CrossBackendSample, version="1.0.0" 373 + ) 374 + 375 + assert "at://" in str(schema_ref) 376 + 377 + 378 + class TestIndexEntryCreation: 379 + """Tests for creating index entries via different backends.""" 380 + 381 + def test_local_entry_has_cid(self): 382 + """LocalDatasetEntry should generate a CID.""" 383 + entry = LocalDatasetEntry( 384 + _name="cid-test", 385 + _schema_ref="local://schemas/Test@1.0.0", 386 + _data_urls=["s3://bucket/test.tar"], 387 + ) 388 + 389 + assert entry.cid is not None 390 + assert len(entry.cid) > 0 391 + 392 + def test_atmosphere_entry_has_uri(self): 393 + """AtmosphereIndexEntry should have a URI.""" 394 + record = { 395 + "name": "uri-test", 396 + "schemaRef": "at://schema", 397 + "storage": { 398 + "$type": f"{LEXICON_NAMESPACE}.storageExternal", 399 + "urls": ["s3://test.tar"], 400 + }, 401 + } 402 + entry = AtmosphereIndexEntry("at://did:plc:test/dataset/xyz", record) 403 + 404 + assert entry.uri == "at://did:plc:test/dataset/xyz" 405 + 406 + def test_same_content_same_local_cid(self): 407 + """Same content should produce same CID in local entries.""" 408 + entry1 = LocalDatasetEntry( 409 + _name="cid-test-1", 410 + _schema_ref="local://schemas/Test@1.0.0", 411 + _data_urls=["s3://bucket/same.tar"], 412 + ) 413 + entry2 = LocalDatasetEntry( 414 + _name="cid-test-2", 415 + _schema_ref="local://schemas/Test@1.0.0", 416 + _data_urls=["s3://bucket/same.tar"], 417 + ) 418 + 419 + # Different names but same content should produce same CID 420 + assert entry1.cid == entry2.cid 421 + 422 + def test_different_content_different_local_cid(self): 423 + """Different content should produce different CID in local entries.""" 424 + entry1 = LocalDatasetEntry( 425 + _name="cid-diff-1", 426 + _schema_ref="local://schemas/Test@1.0.0", 427 + _data_urls=["s3://bucket/file1.tar"], 428 + ) 429 + entry2 = LocalDatasetEntry( 430 + _name="cid-diff-2", 431 + _schema_ref="local://schemas/Test@1.0.0", 432 + _data_urls=["s3://bucket/file2.tar"], 433 + ) 434 + 435 + assert entry1.cid != entry2.cid 436 + 437 + 438 + class TestListingConsistency: 439 + """Tests that listing operations behave consistently.""" 440 + 441 + def test_empty_local_index_lists_no_datasets(self, clean_redis): 442 + """Empty LocalIndex should list no datasets.""" 443 + index = LocalIndex(redis=clean_redis) 444 + entries = list(index.list_datasets()) 445 + 446 + # Should be empty or contain only pre-existing entries 447 + # (clean_redis fixture should clear it) 448 + assert len(entries) == 0 449 + 450 + def test_empty_atmosphere_index_lists_no_datasets( 451 + self, atmosphere_index, mock_atproto_client 452 + ): 453 + """Empty AtmosphereIndex should list no datasets.""" 454 + mock_response = Mock() 455 + mock_response.records = [] 456 + mock_response.cursor = None 457 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_response 458 + 459 + entries = list(atmosphere_index.list_datasets()) 460 + 461 + assert len(entries) == 0 462 + 463 + 464 + class TestGenericIndexFunction: 465 + """Tests for functions that work with any AbstractIndex implementation.""" 466 + 467 + def count_datasets(self, index) -> int: 468 + """Count datasets in an index (works with any AbstractIndex).""" 469 + return sum(1 for _ in index.list_datasets()) 470 + 471 + def get_all_names(self, index) -> list[str]: 472 + """Get all dataset names from an index.""" 473 + return [entry.name for entry in index.list_datasets()] 474 + 475 + def test_count_works_with_local(self, local_index, clean_redis): 476 + """Dataset count function should work with LocalIndex.""" 477 + # Insert some entries 478 + for i in range(3): 479 + entry = LocalDatasetEntry( 480 + _name=f"count-test-{i}", 481 + _schema_ref="local://schemas/Test@1.0.0", 482 + _data_urls=[f"s3://bucket/test-{i}.tar"], 483 + ) 484 + entry.write_to(clean_redis) 485 + 486 + count = self.count_datasets(local_index) 487 + assert count >= 3 488 + 489 + def test_count_works_with_atmosphere( 490 + self, atmosphere_index, mock_atproto_client 491 + ): 492 + """Dataset count function should work with AtmosphereIndex.""" 493 + mock_records = [] 494 + for i in range(5): 495 + record = Mock() 496 + record.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.dataset/d{i}" 497 + record.value = { 498 + "name": f"count-atmo-{i}", 499 + "schemaRef": "at://schema", 500 + "storage": { 501 + "$type": f"{LEXICON_NAMESPACE}.storageExternal", 502 + "urls": [f"s3://data-{i}.tar"], 503 + }, 504 + } 505 + mock_records.append(record) 506 + 507 + mock_response = Mock() 508 + mock_response.records = mock_records 509 + mock_response.cursor = None 510 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_response 511 + 512 + count = self.count_datasets(atmosphere_index) 513 + assert count == 5 514 + 515 + def test_get_names_works_with_local(self, local_index, clean_redis): 516 + """Name retrieval function should work with LocalIndex.""" 517 + names = ["alpha", "beta", "gamma"] 518 + for name in names: 519 + entry = LocalDatasetEntry( 520 + _name=name, 521 + _schema_ref="local://schemas/Test@1.0.0", 522 + _data_urls=[f"s3://bucket/{name}.tar"], 523 + ) 524 + entry.write_to(clean_redis) 525 + 526 + retrieved_names = self.get_all_names(local_index) 527 + 528 + for name in names: 529 + assert name in retrieved_names 530 + 531 + def test_get_names_works_with_atmosphere( 532 + self, atmosphere_index, mock_atproto_client 533 + ): 534 + """Name retrieval function should work with AtmosphereIndex.""" 535 + names = ["delta", "epsilon", "zeta"] 536 + mock_records = [] 537 + for name in names: 538 + record = Mock() 539 + record.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.dataset/{name}" 540 + record.value = { 541 + "name": name, 542 + "schemaRef": "at://schema", 543 + "storage": { 544 + "$type": f"{LEXICON_NAMESPACE}.storageExternal", 545 + "urls": [f"s3://{name}.tar"], 546 + }, 547 + } 548 + mock_records.append(record) 549 + 550 + mock_response = Mock() 551 + mock_response.records = mock_records 552 + mock_response.cursor = None 553 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_response 554 + 555 + retrieved_names = self.get_all_names(atmosphere_index) 556 + 557 + assert retrieved_names == names
+582
tests/test_integration_dynamic_types.py
··· 1 + """Integration tests for dynamic type loading from schemas. 2 + 3 + Tests the schema_to_type() functionality for: 4 + - Schema → Type reconstruction 5 + - Reconstructed types working with Dataset 6 + - Complex field types (NDArray, optional, lists) 7 + - Type caching behavior 8 + - Schema from different sources (local, atmosphere) 9 + """ 10 + 11 + import pytest 12 + from dataclasses import dataclass 13 + 14 + import numpy as np 15 + from numpy.typing import NDArray 16 + import webdataset as wds 17 + 18 + import atdata 19 + from atdata._schema_codec import schema_to_type, clear_type_cache, get_cached_types 20 + import atdata.local as atlocal 21 + 22 + 23 + ## 24 + # Test sample types for comparison 25 + 26 + 27 + @dataclass 28 + class SimpleSample(atdata.PackableSample): 29 + """Simple sample for testing.""" 30 + name: str 31 + value: int 32 + score: float 33 + 34 + 35 + @dataclass 36 + class ArraySample(atdata.PackableSample): 37 + """Sample with NDArray field.""" 38 + label: str 39 + image: NDArray 40 + 41 + 42 + @dataclass 43 + class OptionalSample(atdata.PackableSample): 44 + """Sample with optional fields.""" 45 + name: str 46 + value: int 47 + extra: str | None = None 48 + embedding: NDArray | None = None 49 + 50 + 51 + @dataclass 52 + class ListSample(atdata.PackableSample): 53 + """Sample with list fields.""" 54 + tags: list[str] 55 + scores: list[float] 56 + 57 + 58 + ## 59 + # Fixtures 60 + 61 + 62 + @pytest.fixture(autouse=True) 63 + def clear_cache(): 64 + """Clear type cache before each test.""" 65 + clear_type_cache() 66 + yield 67 + clear_type_cache() 68 + 69 + 70 + ## 71 + # Basic Schema to Type Tests 72 + 73 + 74 + class TestSchemaToType: 75 + """Tests for basic schema_to_type functionality.""" 76 + 77 + def test_simple_primitive_schema(self): 78 + """Schema with primitive fields should produce usable type.""" 79 + schema = { 80 + "name": "SimpleSample", 81 + "version": "1.0.0", 82 + "fields": [ 83 + {"name": "name", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 84 + {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 85 + {"name": "score", "fieldType": {"$type": "local#primitive", "primitive": "float"}, "optional": False}, 86 + ] 87 + } 88 + 89 + SampleType = schema_to_type(schema) 90 + 91 + # Should be able to create instances 92 + instance = SampleType(name="test", value=42, score=0.5) 93 + assert instance.name == "test" 94 + assert instance.value == 42 95 + assert instance.score == 0.5 96 + 97 + def test_ndarray_field_schema(self): 98 + """Schema with NDArray field should produce working type.""" 99 + schema = { 100 + "name": "ArraySample", 101 + "version": "1.0.0", 102 + "fields": [ 103 + {"name": "label", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 104 + {"name": "image", "fieldType": {"$type": "local#ndarray", "dtype": "float32"}, "optional": False}, 105 + ] 106 + } 107 + 108 + SampleType = schema_to_type(schema) 109 + 110 + # Should work with numpy arrays 111 + arr = np.random.randn(32, 32).astype(np.float32) 112 + instance = SampleType(label="test", image=arr) 113 + assert instance.label == "test" 114 + np.testing.assert_array_equal(instance.image, arr) 115 + 116 + def test_optional_field_schema(self): 117 + """Schema with optional fields should use None as default.""" 118 + schema = { 119 + "name": "OptionalSample", 120 + "version": "1.0.0", 121 + "fields": [ 122 + {"name": "name", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 123 + {"name": "extra", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": True}, 124 + ] 125 + } 126 + 127 + SampleType = schema_to_type(schema) 128 + 129 + # Optional field should default to None 130 + instance = SampleType(name="test") 131 + assert instance.name == "test" 132 + assert instance.extra is None 133 + 134 + # Can also provide value 135 + instance2 = SampleType(name="test", extra="optional") 136 + assert instance2.extra == "optional" 137 + 138 + def test_list_field_schema(self): 139 + """Schema with list fields should produce working type.""" 140 + schema = { 141 + "name": "ListSample", 142 + "version": "1.0.0", 143 + "fields": [ 144 + {"name": "tags", "fieldType": {"$type": "local#array", "items": {"$type": "local#primitive", "primitive": "str"}}, "optional": False}, 145 + {"name": "scores", "fieldType": {"$type": "local#array", "items": {"$type": "local#primitive", "primitive": "float"}}, "optional": False}, 146 + ] 147 + } 148 + 149 + SampleType = schema_to_type(schema) 150 + 151 + instance = SampleType(tags=["a", "b", "c"], scores=[1.0, 2.0, 3.0]) 152 + assert instance.tags == ["a", "b", "c"] 153 + assert instance.scores == [1.0, 2.0, 3.0] 154 + 155 + def test_all_primitive_types(self): 156 + """All primitive types should be supported.""" 157 + schema = { 158 + "name": "AllPrimitives", 159 + "version": "1.0.0", 160 + "fields": [ 161 + {"name": "s", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 162 + {"name": "i", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 163 + {"name": "f", "fieldType": {"$type": "local#primitive", "primitive": "float"}, "optional": False}, 164 + {"name": "b", "fieldType": {"$type": "local#primitive", "primitive": "bool"}, "optional": False}, 165 + {"name": "raw", "fieldType": {"$type": "local#primitive", "primitive": "bytes"}, "optional": False}, 166 + ] 167 + } 168 + 169 + SampleType = schema_to_type(schema) 170 + 171 + instance = SampleType(s="hello", i=42, f=3.14, b=True, raw=b"bytes") 172 + assert instance.s == "hello" 173 + assert instance.i == 42 174 + assert instance.f == 3.14 175 + assert instance.b is True 176 + assert instance.raw == b"bytes" 177 + 178 + 179 + class TestDynamicTypeWithDataset: 180 + """Tests for using dynamically generated types with Dataset.""" 181 + 182 + def test_load_dataset_with_dynamic_type(self, tmp_path): 183 + """Dynamic type should work with Dataset loading.""" 184 + # First, create a dataset with a known type 185 + tar_path = tmp_path / "data.tar" 186 + original_samples = [ 187 + SimpleSample(name=f"item_{i}", value=i * 10, score=float(i) * 0.5) 188 + for i in range(10) 189 + ] 190 + 191 + with wds.writer.TarWriter(str(tar_path)) as sink: 192 + for sample in original_samples: 193 + sink.write(sample.as_wds) 194 + 195 + # Now create the type dynamically 196 + schema = { 197 + "name": "SimpleSample", 198 + "version": "1.0.0", 199 + "fields": [ 200 + {"name": "name", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 201 + {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 202 + {"name": "score", "fieldType": {"$type": "local#primitive", "primitive": "float"}, "optional": False}, 203 + ] 204 + } 205 + 206 + DynamicType = schema_to_type(schema) 207 + 208 + # Load with dynamic type 209 + dataset = atdata.Dataset[DynamicType](str(tar_path)) 210 + loaded = list(dataset.ordered(batch_size=None)) 211 + 212 + assert len(loaded) == 10 213 + for i, sample in enumerate(loaded): 214 + assert sample.name == f"item_{i}" 215 + assert sample.value == i * 10 216 + assert sample.score == float(i) * 0.5 217 + 218 + def test_load_dataset_with_ndarray_dynamic_type(self, tmp_path): 219 + """Dynamic type with NDArray should deserialize correctly.""" 220 + # Create dataset 221 + tar_path = tmp_path / "array_data.tar" 222 + original_arrays = [np.random.randn(16, 16).astype(np.float32) for _ in range(5)] 223 + 224 + with wds.writer.TarWriter(str(tar_path)) as sink: 225 + for i, arr in enumerate(original_arrays): 226 + sample = ArraySample(label=f"arr_{i}", image=arr) 227 + sink.write(sample.as_wds) 228 + 229 + # Dynamic type 230 + schema = { 231 + "name": "ArraySample", 232 + "version": "1.0.0", 233 + "fields": [ 234 + {"name": "label", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 235 + {"name": "image", "fieldType": {"$type": "local#ndarray", "dtype": "float32"}, "optional": False}, 236 + ] 237 + } 238 + 239 + DynamicType = schema_to_type(schema) 240 + dataset = atdata.Dataset[DynamicType](str(tar_path)) 241 + loaded = list(dataset.ordered(batch_size=None)) 242 + 243 + assert len(loaded) == 5 244 + for i, sample in enumerate(loaded): 245 + assert sample.label == f"arr_{i}" 246 + np.testing.assert_array_almost_equal(sample.image, original_arrays[i]) 247 + 248 + def test_batch_iteration_with_dynamic_type(self, tmp_path): 249 + """Batching should work with dynamic types.""" 250 + tar_path = tmp_path / "batch_data.tar" 251 + 252 + with wds.writer.TarWriter(str(tar_path)) as sink: 253 + for i in range(20): 254 + sample = SimpleSample(name=f"item_{i}", value=i, score=float(i)) 255 + sink.write(sample.as_wds) 256 + 257 + schema = { 258 + "name": "SimpleSample", 259 + "version": "1.0.0", 260 + "fields": [ 261 + {"name": "name", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 262 + {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 263 + {"name": "score", "fieldType": {"$type": "local#primitive", "primitive": "float"}, "optional": False}, 264 + ] 265 + } 266 + 267 + DynamicType = schema_to_type(schema) 268 + dataset = atdata.Dataset[DynamicType](str(tar_path)) 269 + 270 + batches = list(dataset.ordered(batch_size=5)) 271 + assert len(batches) == 4 272 + 273 + for batch in batches: 274 + assert isinstance(batch, atdata.SampleBatch) 275 + assert len(batch.samples) == 5 276 + 277 + 278 + class TestTypeCaching: 279 + """Tests for type caching behavior.""" 280 + 281 + def test_same_schema_returns_cached_type(self): 282 + """Identical schemas should return same cached type.""" 283 + schema = { 284 + "name": "CachedSample", 285 + "version": "1.0.0", 286 + "fields": [ 287 + {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 288 + ] 289 + } 290 + 291 + Type1 = schema_to_type(schema) 292 + Type2 = schema_to_type(schema) 293 + 294 + assert Type1 is Type2 295 + 296 + def test_different_version_different_type(self): 297 + """Different version should produce different type.""" 298 + schema1 = { 299 + "name": "VersionedSample", 300 + "version": "1.0.0", 301 + "fields": [ 302 + {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 303 + ] 304 + } 305 + schema2 = { 306 + "name": "VersionedSample", 307 + "version": "2.0.0", 308 + "fields": [ 309 + {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 310 + ] 311 + } 312 + 313 + Type1 = schema_to_type(schema1) 314 + Type2 = schema_to_type(schema2) 315 + 316 + # Different versions = different types 317 + assert Type1 is not Type2 318 + 319 + def test_different_fields_different_type(self): 320 + """Different fields should produce different type.""" 321 + schema1 = { 322 + "name": "FieldSample", 323 + "version": "1.0.0", 324 + "fields": [ 325 + {"name": "a", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 326 + ] 327 + } 328 + schema2 = { 329 + "name": "FieldSample", 330 + "version": "1.0.0", 331 + "fields": [ 332 + {"name": "b", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 333 + ] 334 + } 335 + 336 + Type1 = schema_to_type(schema1) 337 + Type2 = schema_to_type(schema2) 338 + 339 + assert Type1 is not Type2 340 + 341 + def test_use_cache_false_bypasses_cache(self): 342 + """use_cache=False should always create new type.""" 343 + schema = { 344 + "name": "NoCacheSample", 345 + "version": "1.0.0", 346 + "fields": [ 347 + {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 348 + ] 349 + } 350 + 351 + Type1 = schema_to_type(schema, use_cache=False) 352 + Type2 = schema_to_type(schema, use_cache=False) 353 + 354 + # Without cache, each call creates new type 355 + assert Type1 is not Type2 356 + 357 + def test_clear_cache_removes_types(self): 358 + """clear_type_cache should remove all cached types.""" 359 + schema = { 360 + "name": "ClearableSample", 361 + "version": "1.0.0", 362 + "fields": [ 363 + {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 364 + ] 365 + } 366 + 367 + Type1 = schema_to_type(schema) 368 + clear_type_cache() 369 + Type2 = schema_to_type(schema) 370 + 371 + # After clear, should get new type 372 + assert Type1 is not Type2 373 + 374 + def test_get_cached_types_returns_cache_copy(self): 375 + """get_cached_types should return cache contents.""" 376 + schema = { 377 + "name": "TrackedSample", 378 + "version": "1.0.0", 379 + "fields": [ 380 + {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 381 + ] 382 + } 383 + 384 + schema_to_type(schema) 385 + cache = get_cached_types() 386 + 387 + assert len(cache) == 1 388 + assert "TrackedSample" in list(cache.keys())[0] 389 + 390 + 391 + class TestSchemaFromLocalIndex: 392 + """Tests for loading types from LocalIndex schemas.""" 393 + 394 + def test_publish_then_decode_schema(self, clean_redis): 395 + """Published schema should be decodable to usable type.""" 396 + index = atlocal.Index(redis=clean_redis) 397 + 398 + # Publish a schema 399 + schema_ref = index.publish_schema(SimpleSample) 400 + 401 + # Decode it back 402 + ReconstructedType = index.decode_schema(schema_ref) 403 + 404 + # Should be usable 405 + instance = ReconstructedType(name="test", value=42, score=0.5) 406 + assert instance.name == "test" 407 + assert instance.value == 42 408 + 409 + def test_publish_ndarray_then_decode(self, clean_redis): 410 + """NDArray schema should decode correctly.""" 411 + index = atlocal.Index(redis=clean_redis) 412 + 413 + schema_ref = index.publish_schema(ArraySample) 414 + ReconstructedType = index.decode_schema(schema_ref) 415 + 416 + arr = np.random.randn(8, 8).astype(np.float32) 417 + instance = ReconstructedType(label="test", image=arr) 418 + np.testing.assert_array_equal(instance.image, arr) 419 + 420 + def test_decoded_type_works_with_dataset(self, clean_redis, tmp_path): 421 + """Decoded type should work with Dataset iteration.""" 422 + index = atlocal.Index(redis=clean_redis) 423 + 424 + # Create dataset with original type 425 + tar_path = tmp_path / "original.tar" 426 + with wds.writer.TarWriter(str(tar_path)) as sink: 427 + for i in range(5): 428 + sample = SimpleSample(name=f"s_{i}", value=i, score=float(i)) 429 + sink.write(sample.as_wds) 430 + 431 + # Publish and decode schema 432 + schema_ref = index.publish_schema(SimpleSample) 433 + DecodedType = index.decode_schema(schema_ref) 434 + 435 + # Load with decoded type 436 + dataset = atdata.Dataset[DecodedType](str(tar_path)) 437 + loaded = list(dataset.ordered(batch_size=None)) 438 + 439 + assert len(loaded) == 5 440 + for i, sample in enumerate(loaded): 441 + assert sample.name == f"s_{i}" 442 + assert sample.value == i 443 + 444 + 445 + class TestSchemaValidation: 446 + """Tests for schema validation and error handling.""" 447 + 448 + def test_schema_without_name_raises(self): 449 + """Schema without name should raise ValueError.""" 450 + schema = { 451 + "version": "1.0.0", 452 + "fields": [ 453 + {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 454 + ] 455 + } 456 + 457 + with pytest.raises(ValueError, match="must have a 'name'"): 458 + schema_to_type(schema) 459 + 460 + def test_schema_without_fields_raises(self): 461 + """Schema without fields should raise ValueError.""" 462 + schema = { 463 + "name": "EmptySample", 464 + "version": "1.0.0", 465 + "fields": [] 466 + } 467 + 468 + with pytest.raises(ValueError, match="must have at least one field"): 469 + schema_to_type(schema) 470 + 471 + def test_field_without_name_raises(self): 472 + """Field without name should raise an error.""" 473 + schema = { 474 + "name": "BadFieldSample", 475 + "version": "1.0.0", 476 + "fields": [ 477 + {"fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 478 + ] 479 + } 480 + 481 + # Raises KeyError during cache key generation or ValueError during field processing 482 + with pytest.raises((KeyError, ValueError)): 483 + schema_to_type(schema) 484 + 485 + def test_unknown_primitive_raises(self): 486 + """Unknown primitive type should raise ValueError.""" 487 + schema = { 488 + "name": "UnknownPrimitive", 489 + "version": "1.0.0", 490 + "fields": [ 491 + {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "complex128"}, "optional": False}, 492 + ] 493 + } 494 + 495 + with pytest.raises(ValueError, match="Unknown primitive type"): 496 + schema_to_type(schema) 497 + 498 + def test_unknown_field_type_raises(self): 499 + """Unknown field type kind should raise ValueError.""" 500 + schema = { 501 + "name": "UnknownType", 502 + "version": "1.0.0", 503 + "fields": [ 504 + {"name": "value", "fieldType": {"$type": "local#custom"}, "optional": False}, 505 + ] 506 + } 507 + 508 + with pytest.raises(ValueError, match="Unknown field type kind"): 509 + schema_to_type(schema) 510 + 511 + 512 + class TestComplexSchemaScenarios: 513 + """Complex integration scenarios with dynamic types.""" 514 + 515 + def test_optional_ndarray_schema(self, tmp_path): 516 + """Optional NDArray field should handle None correctly.""" 517 + schema = { 518 + "name": "OptionalArraySample", 519 + "version": "1.0.0", 520 + "fields": [ 521 + {"name": "name", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 522 + {"name": "embedding", "fieldType": {"$type": "local#ndarray", "dtype": "float32"}, "optional": True}, 523 + ] 524 + } 525 + 526 + DynamicType = schema_to_type(schema) 527 + 528 + # Create dataset with some None values 529 + tar_path = tmp_path / "optional_array.tar" 530 + with wds.writer.TarWriter(str(tar_path)) as sink: 531 + for i in range(6): 532 + if i % 2 == 0: 533 + sample = OptionalSample(name=f"s_{i}", value=i, embedding=np.zeros(4, dtype=np.float32)) 534 + else: 535 + sample = OptionalSample(name=f"s_{i}", value=i, embedding=None) 536 + sink.write(sample.as_wds) 537 + 538 + # Note: The OptionalSample has different fields than DynamicType 539 + # This test verifies the dynamic type can be created, not cross-compatibility 540 + 541 + instance_with = DynamicType(name="test", embedding=np.zeros(4)) 542 + instance_without = DynamicType(name="test") 543 + 544 + assert instance_with.embedding is not None 545 + assert instance_without.embedding is None 546 + 547 + def test_nested_list_schema(self): 548 + """Nested list types should work.""" 549 + schema = { 550 + "name": "NestedListSample", 551 + "version": "1.0.0", 552 + "fields": [ 553 + {"name": "matrix", "fieldType": { 554 + "$type": "local#array", 555 + "items": {"$type": "local#array", "items": {"$type": "local#primitive", "primitive": "int"}} 556 + }, "optional": False}, 557 + ] 558 + } 559 + 560 + DynamicType = schema_to_type(schema) 561 + 562 + instance = DynamicType(matrix=[[1, 2], [3, 4], [5, 6]]) 563 + assert instance.matrix == [[1, 2], [3, 4], [5, 6]] 564 + 565 + def test_multiple_schemas_same_session(self, clean_redis): 566 + """Multiple different schemas should coexist.""" 567 + index = atlocal.Index(redis=clean_redis) 568 + 569 + # Publish multiple schemas 570 + ref1 = index.publish_schema(SimpleSample, version="1.0.0") 571 + ref2 = index.publish_schema(ArraySample, version="1.0.0") 572 + ref3 = index.publish_schema(ListSample, version="1.0.0") 573 + 574 + # Decode all 575 + Type1 = index.decode_schema(ref1) 576 + Type2 = index.decode_schema(ref2) 577 + Type3 = index.decode_schema(ref3) 578 + 579 + # All should be usable 580 + assert Type1(name="a", value=1, score=0.5).name == "a" 581 + assert Type2(label="b", image=np.zeros(4)).label == "b" 582 + assert Type3(tags=["x"], scores=[1.0]).tags == ["x"]
+676
tests/test_integration_e2e.py
··· 1 + """End-to-end integration tests for atdata data flow pipeline. 2 + 3 + Tests the complete workflow: Create → Store → Load → Iterate → Verify. 4 + 5 + These tests verify: 6 + - Full pipeline with various sample types 7 + - Multi-shard datasets with brace notation 8 + - Large batch handling and memory efficiency 9 + - Metadata round-trip preservation 10 + - Parquet export with transformations 11 + """ 12 + 13 + import pytest 14 + from dataclasses import dataclass 15 + from pathlib import Path 16 + 17 + import numpy as np 18 + from numpy.typing import NDArray 19 + import webdataset as wds 20 + 21 + import atdata 22 + 23 + 24 + ## 25 + # Test sample types 26 + 27 + 28 + @atdata.packable 29 + class SimpleSample: 30 + """Basic sample with primitive types only.""" 31 + name: str 32 + value: int 33 + score: float 34 + active: bool 35 + 36 + 37 + @atdata.packable 38 + class NDArraySample: 39 + """Sample with multiple NDArray fields of different shapes.""" 40 + label: int 41 + image: NDArray 42 + features: NDArray 43 + 44 + 45 + @atdata.packable 46 + class OptionalNDArraySample: 47 + """Sample with optional NDArray fields.""" 48 + label: int 49 + image: NDArray 50 + embeddings: NDArray | None = None 51 + 52 + 53 + @atdata.packable 54 + class BytesSample: 55 + """Sample with bytes field.""" 56 + name: str 57 + raw_data: bytes 58 + 59 + 60 + @atdata.packable 61 + class ListSample: 62 + """Sample with list fields.""" 63 + tags: list[str] 64 + scores: list[float] 65 + ids: list[int] 66 + 67 + 68 + @dataclass 69 + class InheritanceSample(atdata.PackableSample): 70 + """Sample using inheritance syntax instead of decorator.""" 71 + title: str 72 + count: int 73 + measurements: NDArray 74 + 75 + 76 + ## 77 + # Helper functions 78 + 79 + 80 + def create_simple_samples(n: int) -> list[SimpleSample]: 81 + """Create n simple samples with distinct values.""" 82 + return [ 83 + SimpleSample( 84 + name=f"sample_{i}", 85 + value=i * 10, 86 + score=float(i) * 0.5, 87 + active=(i % 2 == 0), 88 + ) 89 + for i in range(n) 90 + ] 91 + 92 + 93 + def create_ndarray_samples(n: int, img_shape: tuple = (64, 64)) -> list[NDArraySample]: 94 + """Create n NDArray samples with distinct values.""" 95 + return [ 96 + NDArraySample( 97 + label=i, 98 + image=np.random.randn(*img_shape).astype(np.float32), 99 + features=np.random.randn(128).astype(np.float32), 100 + ) 101 + for i in range(n) 102 + ] 103 + 104 + 105 + def create_optional_samples(n: int, include_optional: bool) -> list[OptionalNDArraySample]: 106 + """Create samples with or without optional embeddings.""" 107 + return [ 108 + OptionalNDArraySample( 109 + label=i, 110 + image=np.random.randn(32, 32).astype(np.float32), 111 + embeddings=np.random.randn(64).astype(np.float32) if include_optional else None, 112 + ) 113 + for i in range(n) 114 + ] 115 + 116 + 117 + def write_single_shard(path: Path, samples: list) -> str: 118 + """Write samples to a single tar file, return path.""" 119 + tar_path = path.as_posix() 120 + with wds.writer.TarWriter(tar_path) as sink: 121 + for sample in samples: 122 + sink.write(sample.as_wds) 123 + return tar_path 124 + 125 + 126 + def write_multi_shard( 127 + base_path: Path, 128 + samples: list, 129 + samples_per_shard: int = 10, 130 + ) -> tuple[str, int]: 131 + """Write samples to multiple shards, return brace pattern and shard count.""" 132 + pattern = (base_path / "shard-%06d.tar").as_posix() 133 + with wds.writer.ShardWriter(pattern=pattern, maxcount=samples_per_shard) as sink: 134 + for sample in samples: 135 + sink.write(sample.as_wds) 136 + 137 + n_shards = (len(samples) + samples_per_shard - 1) // samples_per_shard 138 + brace_pattern = ( 139 + base_path / f"shard-{{000000..{n_shards - 1:06d}}}.tar" 140 + ).as_posix() 141 + return brace_pattern, n_shards 142 + 143 + 144 + ## 145 + # Full Pipeline Tests 146 + 147 + 148 + class TestFullPipelineSimple: 149 + """End-to-end tests with simple primitive-only samples.""" 150 + 151 + def test_create_store_load_iterate_single_shard(self, tmp_path): 152 + """Full pipeline: create → store → load → iterate (single shard).""" 153 + n_samples = 50 154 + samples = create_simple_samples(n_samples) 155 + 156 + # Store 157 + tar_path = write_single_shard(tmp_path / "simple.tar", samples) 158 + 159 + # Load 160 + dataset = atdata.Dataset[SimpleSample](tar_path) 161 + 162 + # Iterate without batching 163 + loaded = list(dataset.ordered(batch_size=None)) 164 + 165 + # Verify 166 + assert len(loaded) == n_samples 167 + for i, sample in enumerate(loaded): 168 + assert isinstance(sample, SimpleSample) 169 + assert sample.name == f"sample_{i}" 170 + assert sample.value == i * 10 171 + assert sample.score == float(i) * 0.5 172 + assert sample.active == (i % 2 == 0) 173 + 174 + def test_create_store_load_iterate_batched(self, tmp_path): 175 + """Full pipeline with batching.""" 176 + n_samples = 100 177 + batch_size = 16 178 + samples = create_simple_samples(n_samples) 179 + 180 + tar_path = write_single_shard(tmp_path / "batched.tar", samples) 181 + dataset = atdata.Dataset[SimpleSample](tar_path) 182 + 183 + # Iterate with batching 184 + batches = list(dataset.ordered(batch_size=batch_size)) 185 + 186 + # Verify batch structure (WebDataset drops incomplete final batch) 187 + total_samples = sum(len(b.samples) for b in batches) 188 + assert total_samples >= (n_samples // batch_size) * batch_size 189 + 190 + for batch in batches: 191 + assert isinstance(batch, atdata.SampleBatch) 192 + assert batch.sample_type == SimpleSample 193 + assert len(batch.samples) <= batch_size 194 + 195 + # Verify aggregated attributes 196 + names = batch.name 197 + values = batch.value 198 + assert isinstance(names, list) 199 + assert isinstance(values, list) 200 + assert len(names) == len(batch.samples) 201 + assert len(values) == len(batch.samples) 202 + 203 + def test_inheritance_syntax_pipeline(self, tmp_path): 204 + """Full pipeline using inheritance-style sample definition.""" 205 + n_samples = 25 206 + samples = [ 207 + InheritanceSample( 208 + title=f"doc_{i}", 209 + count=i * 5, 210 + measurements=np.random.randn(10).astype(np.float32), 211 + ) 212 + for i in range(n_samples) 213 + ] 214 + 215 + tar_path = write_single_shard(tmp_path / "inheritance.tar", samples) 216 + dataset = atdata.Dataset[InheritanceSample](tar_path) 217 + 218 + loaded = list(dataset.ordered(batch_size=None)) 219 + 220 + assert len(loaded) == n_samples 221 + for i, sample in enumerate(loaded): 222 + assert isinstance(sample, InheritanceSample) 223 + assert sample.title == f"doc_{i}" 224 + assert sample.count == i * 5 225 + assert isinstance(sample.measurements, np.ndarray) 226 + 227 + 228 + class TestFullPipelineNDArray: 229 + """End-to-end tests with NDArray samples.""" 230 + 231 + def test_ndarray_serialization_roundtrip(self, tmp_path): 232 + """NDArray fields should serialize and deserialize exactly.""" 233 + n_samples = 20 234 + samples = create_ndarray_samples(n_samples, img_shape=(32, 32)) 235 + 236 + tar_path = write_single_shard(tmp_path / "ndarray.tar", samples) 237 + dataset = atdata.Dataset[NDArraySample](tar_path) 238 + 239 + loaded = list(dataset.ordered(batch_size=None)) 240 + 241 + assert len(loaded) == n_samples 242 + for original, loaded_sample in zip(samples, loaded): 243 + assert loaded_sample.label == original.label 244 + np.testing.assert_array_almost_equal(loaded_sample.image, original.image) 245 + np.testing.assert_array_almost_equal(loaded_sample.features, original.features) 246 + 247 + def test_ndarray_batch_stacking(self, tmp_path): 248 + """NDArray fields should stack into batch dimension.""" 249 + n_samples = 32 250 + batch_size = 8 251 + img_shape = (16, 16) 252 + feature_dim = 64 253 + 254 + samples = [ 255 + NDArraySample( 256 + label=i, 257 + image=np.full(img_shape, i, dtype=np.float32), 258 + features=np.full(feature_dim, i * 0.1, dtype=np.float32), 259 + ) 260 + for i in range(n_samples) 261 + ] 262 + 263 + tar_path = write_single_shard(tmp_path / "stacking.tar", samples) 264 + dataset = atdata.Dataset[NDArraySample](tar_path) 265 + 266 + batches = list(dataset.ordered(batch_size=batch_size)) 267 + 268 + for batch_idx, batch in enumerate(batches): 269 + # Check stacked shapes 270 + assert batch.image.shape == (batch_size, *img_shape) 271 + assert batch.features.shape == (batch_size, feature_dim) 272 + 273 + # Check values 274 + for i in range(batch_size): 275 + sample_idx = batch_idx * batch_size + i 276 + np.testing.assert_array_equal( 277 + batch.image[i], 278 + np.full(img_shape, sample_idx, dtype=np.float32), 279 + ) 280 + 281 + def test_optional_ndarray_with_values(self, tmp_path): 282 + """Optional NDArray with actual values should roundtrip.""" 283 + n_samples = 15 284 + samples = create_optional_samples(n_samples, include_optional=True) 285 + 286 + tar_path = write_single_shard(tmp_path / "optional_filled.tar", samples) 287 + dataset = atdata.Dataset[OptionalNDArraySample](tar_path) 288 + 289 + loaded = list(dataset.ordered(batch_size=None)) 290 + 291 + for original, loaded_sample in zip(samples, loaded): 292 + assert loaded_sample.embeddings is not None 293 + np.testing.assert_array_almost_equal( 294 + loaded_sample.embeddings, 295 + original.embeddings, 296 + ) 297 + 298 + def test_optional_ndarray_with_none(self, tmp_path): 299 + """Optional NDArray with None should roundtrip.""" 300 + n_samples = 15 301 + samples = create_optional_samples(n_samples, include_optional=False) 302 + 303 + tar_path = write_single_shard(tmp_path / "optional_none.tar", samples) 304 + dataset = atdata.Dataset[OptionalNDArraySample](tar_path) 305 + 306 + loaded = list(dataset.ordered(batch_size=None)) 307 + 308 + for loaded_sample in loaded: 309 + assert loaded_sample.embeddings is None 310 + 311 + def test_mixed_dtypes(self, tmp_path): 312 + """Various numpy dtypes should serialize correctly.""" 313 + @atdata.packable 314 + class MultiDtypeSample: 315 + f32: NDArray 316 + f64: NDArray 317 + i32: NDArray 318 + i64: NDArray 319 + u8: NDArray 320 + 321 + samples = [ 322 + MultiDtypeSample( 323 + f32=np.array([1.0, 2.0, 3.0], dtype=np.float32), 324 + f64=np.array([1.0, 2.0, 3.0], dtype=np.float64), 325 + i32=np.array([1, 2, 3], dtype=np.int32), 326 + i64=np.array([1, 2, 3], dtype=np.int64), 327 + u8=np.array([255, 128, 0], dtype=np.uint8), 328 + ) 329 + for _ in range(10) 330 + ] 331 + 332 + tar_path = write_single_shard(tmp_path / "multidtype.tar", samples) 333 + dataset = atdata.Dataset[MultiDtypeSample](tar_path) 334 + 335 + loaded = list(dataset.ordered(batch_size=None)) 336 + 337 + for original, loaded_sample in zip(samples, loaded): 338 + assert loaded_sample.f32.dtype == np.float32 339 + assert loaded_sample.f64.dtype == np.float64 340 + assert loaded_sample.i32.dtype == np.int32 341 + assert loaded_sample.i64.dtype == np.int64 342 + assert loaded_sample.u8.dtype == np.uint8 343 + np.testing.assert_array_equal(loaded_sample.f32, original.f32) 344 + 345 + 346 + class TestMultiShardPipeline: 347 + """End-to-end tests with multi-shard datasets using brace notation.""" 348 + 349 + def test_multi_shard_ordered_iteration(self, tmp_path): 350 + """Multi-shard dataset should iterate all samples in order.""" 351 + n_samples = 100 352 + samples_per_shard = 10 353 + samples = create_simple_samples(n_samples) 354 + 355 + brace_pattern, n_shards = write_multi_shard( 356 + tmp_path, 357 + samples, 358 + samples_per_shard=samples_per_shard, 359 + ) 360 + 361 + assert n_shards == 10 362 + 363 + dataset = atdata.Dataset[SimpleSample](brace_pattern) 364 + loaded = list(dataset.ordered(batch_size=None)) 365 + 366 + assert len(loaded) == n_samples 367 + 368 + # Verify ordering within each shard 369 + for i, sample in enumerate(loaded): 370 + assert sample.name == f"sample_{i}" 371 + 372 + def test_multi_shard_batched(self, tmp_path): 373 + """Multi-shard dataset with batching should work correctly.""" 374 + n_samples = 120 375 + samples_per_shard = 15 376 + batch_size = 8 377 + samples = create_simple_samples(n_samples) 378 + 379 + brace_pattern, n_shards = write_multi_shard( 380 + tmp_path, 381 + samples, 382 + samples_per_shard=samples_per_shard, 383 + ) 384 + 385 + dataset = atdata.Dataset[SimpleSample](brace_pattern) 386 + batches = list(dataset.ordered(batch_size=batch_size)) 387 + 388 + # Total samples should match 389 + total_samples = sum(len(b.samples) for b in batches) 390 + assert total_samples == (n_samples // batch_size) * batch_size 391 + 392 + def test_multi_shard_shuffled(self, tmp_path): 393 + """Multi-shard shuffled iteration should work.""" 394 + n_samples = 50 395 + samples_per_shard = 10 396 + samples = create_simple_samples(n_samples) 397 + 398 + brace_pattern, _ = write_multi_shard( 399 + tmp_path, 400 + samples, 401 + samples_per_shard=samples_per_shard, 402 + ) 403 + 404 + dataset = atdata.Dataset[SimpleSample](brace_pattern) 405 + 406 + # Collect some samples from shuffled iteration 407 + shuffled_samples = [] 408 + for sample in dataset.shuffled(batch_size=None): 409 + shuffled_samples.append(sample) 410 + if len(shuffled_samples) >= 30: 411 + break 412 + 413 + assert len(shuffled_samples) == 30 414 + 415 + # All samples should be valid SimpleSample instances 416 + for sample in shuffled_samples: 417 + assert isinstance(sample, SimpleSample) 418 + assert sample.name.startswith("sample_") 419 + 420 + def test_single_shard_via_brace_pattern(self, tmp_path): 421 + """Single shard via brace pattern should work.""" 422 + n_samples = 25 423 + samples = create_simple_samples(n_samples) 424 + 425 + # Create exactly one shard 426 + brace_pattern, n_shards = write_multi_shard( 427 + tmp_path, 428 + samples, 429 + samples_per_shard=100, # More than samples, so single shard 430 + ) 431 + 432 + assert n_shards == 1 433 + 434 + dataset = atdata.Dataset[SimpleSample](brace_pattern) 435 + loaded = list(dataset.ordered(batch_size=None)) 436 + 437 + assert len(loaded) == n_samples 438 + 439 + 440 + class TestLargeBatchHandling: 441 + """Tests for handling large batches and many samples.""" 442 + 443 + def test_large_batch_size(self, tmp_path): 444 + """Large batch sizes should work correctly.""" 445 + n_samples = 200 446 + batch_size = 64 447 + samples = create_simple_samples(n_samples) 448 + 449 + tar_path = write_single_shard(tmp_path / "large_batch.tar", samples) 450 + dataset = atdata.Dataset[SimpleSample](tar_path) 451 + 452 + batches = list(dataset.ordered(batch_size=batch_size)) 453 + 454 + # Verify we got the expected number of complete batches 455 + total_samples = sum(len(b.samples) for b in batches) 456 + assert total_samples >= (n_samples // batch_size) * batch_size 457 + for batch in batches: 458 + assert len(batch.samples) <= batch_size 459 + 460 + def test_many_samples_single_shard(self, tmp_path): 461 + """Many samples in single shard should work.""" 462 + n_samples = 500 463 + samples = create_simple_samples(n_samples) 464 + 465 + tar_path = write_single_shard(tmp_path / "many.tar", samples) 466 + dataset = atdata.Dataset[SimpleSample](tar_path) 467 + 468 + loaded = list(dataset.ordered(batch_size=None)) 469 + assert len(loaded) == n_samples 470 + 471 + def test_large_ndarray_samples(self, tmp_path): 472 + """Large NDArray fields should serialize correctly.""" 473 + n_samples = 10 474 + large_shape = (256, 256) # Larger images 475 + 476 + samples = create_ndarray_samples(n_samples, img_shape=large_shape) 477 + 478 + tar_path = write_single_shard(tmp_path / "large_ndarray.tar", samples) 479 + dataset = atdata.Dataset[NDArraySample](tar_path) 480 + 481 + loaded = list(dataset.ordered(batch_size=None)) 482 + 483 + for original, loaded_sample in zip(samples, loaded): 484 + assert loaded_sample.image.shape == large_shape 485 + np.testing.assert_array_almost_equal( 486 + loaded_sample.image, 487 + original.image, 488 + ) 489 + 490 + 491 + class TestBytesAndListSamples: 492 + """Tests for bytes and list field types.""" 493 + 494 + def test_bytes_field_roundtrip(self, tmp_path): 495 + """Bytes fields should roundtrip correctly.""" 496 + samples = [ 497 + BytesSample( 498 + name=f"item_{i}", 499 + raw_data=f"binary_data_{i}".encode("utf-8"), 500 + ) 501 + for i in range(20) 502 + ] 503 + 504 + tar_path = write_single_shard(tmp_path / "bytes.tar", samples) 505 + dataset = atdata.Dataset[BytesSample](tar_path) 506 + 507 + loaded = list(dataset.ordered(batch_size=None)) 508 + 509 + for original, loaded_sample in zip(samples, loaded): 510 + assert loaded_sample.name == original.name 511 + assert loaded_sample.raw_data == original.raw_data 512 + 513 + def test_list_fields_roundtrip(self, tmp_path): 514 + """List fields should roundtrip correctly.""" 515 + samples = [ 516 + ListSample( 517 + tags=[f"tag_{j}" for j in range(3)], 518 + scores=[float(j) * 0.1 for j in range(5)], 519 + ids=[i * 10 + j for j in range(4)], 520 + ) 521 + for i in range(15) 522 + ] 523 + 524 + tar_path = write_single_shard(tmp_path / "lists.tar", samples) 525 + dataset = atdata.Dataset[ListSample](tar_path) 526 + 527 + loaded = list(dataset.ordered(batch_size=None)) 528 + 529 + for original, loaded_sample in zip(samples, loaded): 530 + assert loaded_sample.tags == original.tags 531 + assert loaded_sample.scores == original.scores 532 + assert loaded_sample.ids == original.ids 533 + 534 + 535 + class TestMetadataRoundTrip: 536 + """Tests for metadata preservation through the pipeline.""" 537 + 538 + def test_dataset_with_metadata_url(self, tmp_path): 539 + """Dataset with metadata_url should fetch and cache metadata.""" 540 + from unittest.mock import Mock, patch, MagicMock 541 + import msgpack 542 + 543 + samples = create_simple_samples(10) 544 + tar_path = write_single_shard(tmp_path / "meta.tar", samples) 545 + 546 + test_metadata = { 547 + "version": "1.0.0", 548 + "created_by": "test", 549 + "sample_count": 10, 550 + "nested": {"key": "value"}, 551 + } 552 + 553 + # Create a proper mock that supports context manager protocol 554 + mock_response = MagicMock() 555 + mock_response.content = msgpack.packb(test_metadata) 556 + mock_response.raise_for_status = Mock() 557 + mock_response.__enter__ = Mock(return_value=mock_response) 558 + mock_response.__exit__ = Mock(return_value=False) 559 + 560 + with patch("atdata.dataset.requests.get", return_value=mock_response): 561 + dataset = atdata.Dataset[SimpleSample]( 562 + tar_path, 563 + metadata_url="http://example.com/meta.msgpack", 564 + ) 565 + 566 + # Fetch metadata 567 + metadata = dataset.metadata 568 + 569 + assert metadata == test_metadata 570 + assert metadata["version"] == "1.0.0" 571 + assert metadata["nested"]["key"] == "value" 572 + 573 + # Second access should use cache 574 + metadata2 = dataset.metadata 575 + assert metadata2 == test_metadata 576 + 577 + 578 + class TestParquetExport: 579 + """Tests for Parquet export functionality.""" 580 + 581 + def test_simple_parquet_export(self, tmp_path): 582 + """Simple samples should export to Parquet correctly.""" 583 + import pandas as pd 584 + 585 + n_samples = 50 586 + samples = create_simple_samples(n_samples) 587 + 588 + tar_path = write_single_shard(tmp_path / "for_parquet.tar", samples) 589 + dataset = atdata.Dataset[SimpleSample](tar_path) 590 + 591 + parquet_path = tmp_path / "output.parquet" 592 + dataset.to_parquet(parquet_path) 593 + 594 + # Verify Parquet file 595 + df = pd.read_parquet(parquet_path) 596 + assert len(df) == n_samples 597 + assert list(df.columns) == ["name", "value", "score", "active"] 598 + assert df["name"].iloc[0] == "sample_0" 599 + assert df["value"].iloc[0] == 0 600 + 601 + def test_parquet_export_with_maxcount(self, tmp_path): 602 + """Parquet export with maxcount should create segments.""" 603 + import pandas as pd 604 + 605 + n_samples = 45 606 + maxcount = 10 607 + samples = create_simple_samples(n_samples) 608 + 609 + tar_path = write_single_shard(tmp_path / "segmented.tar", samples) 610 + dataset = atdata.Dataset[SimpleSample](tar_path) 611 + 612 + parquet_path = tmp_path / "segments.parquet" 613 + dataset.to_parquet(parquet_path, maxcount=maxcount) 614 + 615 + # Should create 5 segment files (45 samples / 10 per file) 616 + segment_files = list(tmp_path.glob("segments-*.parquet")) 617 + assert len(segment_files) == 5 618 + 619 + # Total rows should match 620 + total_rows = sum(len(pd.read_parquet(f)) for f in segment_files) 621 + assert total_rows == n_samples 622 + 623 + 624 + class TestIterationModes: 625 + """Tests for different iteration modes.""" 626 + 627 + def test_ordered_is_deterministic(self, tmp_path): 628 + """Ordered iteration should be deterministic across multiple passes.""" 629 + n_samples = 30 630 + samples = create_simple_samples(n_samples) 631 + 632 + tar_path = write_single_shard(tmp_path / "ordered.tar", samples) 633 + dataset = atdata.Dataset[SimpleSample](tar_path) 634 + 635 + # Two passes should yield identical results 636 + pass1 = [s.name for s in dataset.ordered(batch_size=None)] 637 + pass2 = [s.name for s in dataset.ordered(batch_size=None)] 638 + 639 + assert pass1 == pass2 640 + 641 + def test_shuffled_changes_order(self, tmp_path): 642 + """Shuffled iteration should change order (with high probability).""" 643 + n_samples = 100 644 + samples = create_simple_samples(n_samples) 645 + 646 + tar_path = write_single_shard(tmp_path / "shuffle_test.tar", samples) 647 + dataset = atdata.Dataset[SimpleSample](tar_path) 648 + 649 + # Collect samples from multiple shuffled passes 650 + passes = [] 651 + for _ in range(3): 652 + names = [] 653 + for sample in dataset.shuffled(batch_size=None): 654 + names.append(sample.name) 655 + if len(names) >= n_samples: 656 + break 657 + passes.append(names) 658 + 659 + # At least two passes should differ (very high probability with 100 samples) 660 + # Note: This could theoretically fail, but probability is astronomically low 661 + assert passes[0] != passes[1] or passes[1] != passes[2] or passes[0] != passes[2] 662 + 663 + def test_batch_size_one(self, tmp_path): 664 + """batch_size=1 should return single-element batches.""" 665 + n_samples = 10 666 + samples = create_simple_samples(n_samples) 667 + 668 + tar_path = write_single_shard(tmp_path / "batch1.tar", samples) 669 + dataset = atdata.Dataset[SimpleSample](tar_path) 670 + 671 + batches = list(dataset.ordered(batch_size=1)) 672 + 673 + assert len(batches) == n_samples 674 + for batch in batches: 675 + assert isinstance(batch, atdata.SampleBatch) 676 + assert len(batch.samples) == 1
+650
tests/test_integration_edge_cases.py
··· 1 + """Integration tests for edge cases and data type coverage. 2 + 3 + Tests boundary conditions and unusual data patterns including: 4 + - Empty and single-sample datasets 5 + - Special numpy dtypes (float16, complex128) 6 + - Unicode and special characters 7 + - Very long strings and large arrays 8 + - Nested list types 9 + - All primitive type variations 10 + """ 11 + 12 + import pytest 13 + from pathlib import Path 14 + from dataclasses import dataclass 15 + from typing import Optional 16 + 17 + import numpy as np 18 + from numpy.typing import NDArray 19 + import webdataset as wds 20 + 21 + import atdata 22 + from atdata.local import LocalIndex, LocalDatasetEntry 23 + 24 + 25 + ## 26 + # Edge Case Sample Types 27 + 28 + 29 + @atdata.packable 30 + class EmptyCompatSample: 31 + """Sample type for empty dataset tests.""" 32 + id: int 33 + 34 + 35 + @atdata.packable 36 + class AllPrimitivesSample: 37 + """Sample with all primitive types.""" 38 + str_field: str 39 + int_field: int 40 + float_field: float 41 + bool_field: bool 42 + bytes_field: bytes 43 + 44 + 45 + @atdata.packable 46 + class OptionalFieldsSample: 47 + """Sample with optional fields.""" 48 + required_str: str 49 + optional_str: str | None 50 + optional_int: int | None 51 + optional_float: float | None 52 + optional_array: NDArray | None 53 + 54 + 55 + @atdata.packable 56 + class ListFieldsSample: 57 + """Sample with list fields.""" 58 + str_list: list[str] 59 + int_list: list[int] 60 + float_list: list[float] 61 + bool_list: list[bool] 62 + 63 + 64 + @atdata.packable 65 + class UnicodeSample: 66 + """Sample with unicode content.""" 67 + text: str 68 + label: str 69 + 70 + 71 + @atdata.packable 72 + class NDArraySample: 73 + """Sample with NDArray field.""" 74 + label: str 75 + data: NDArray 76 + 77 + 78 + ## 79 + # Helper Functions 80 + 81 + 82 + def create_tar_with_samples(tar_path: Path, samples: list) -> None: 83 + """Create a tar file with the given samples.""" 84 + tar_path.parent.mkdir(parents=True, exist_ok=True) 85 + with wds.writer.TarWriter(str(tar_path)) as writer: 86 + for sample in samples: 87 + writer.write(sample.as_wds) 88 + 89 + 90 + ## 91 + # Empty and Single Sample Tests 92 + 93 + 94 + class TestEmptyAndMinimalDatasets: 95 + """Tests for boundary dataset sizes.""" 96 + 97 + def test_single_sample_dataset(self, tmp_path): 98 + """Dataset with exactly one sample should work correctly.""" 99 + tar_path = tmp_path / "single-000000.tar" 100 + sample = EmptyCompatSample(id=42) 101 + create_tar_with_samples(tar_path, [sample]) 102 + 103 + ds = atdata.Dataset[EmptyCompatSample](str(tar_path)) 104 + samples = list(ds.ordered(batch_size=None)) 105 + 106 + assert len(samples) == 1 107 + assert samples[0].id == 42 108 + 109 + def test_single_sample_batch(self, tmp_path): 110 + """Batching single sample should produce batch of size 1.""" 111 + tar_path = tmp_path / "single-batch-000000.tar" 112 + sample = EmptyCompatSample(id=99) 113 + create_tar_with_samples(tar_path, [sample]) 114 + 115 + ds = atdata.Dataset[EmptyCompatSample](str(tar_path)) 116 + batches = list(ds.ordered(batch_size=10)) 117 + 118 + assert len(batches) >= 1 119 + assert len(batches[0].samples) == 1 120 + 121 + 122 + ## 123 + # Primitive Type Coverage Tests 124 + 125 + 126 + class TestPrimitiveTypes: 127 + """Tests for all primitive types.""" 128 + 129 + def test_all_primitives_roundtrip(self, tmp_path): 130 + """All primitive types should serialize and deserialize correctly.""" 131 + tar_path = tmp_path / "primitives-000000.tar" 132 + 133 + original = AllPrimitivesSample( 134 + str_field="hello world", 135 + int_field=42, 136 + float_field=3.14159, 137 + bool_field=True, 138 + bytes_field=b"\x00\x01\x02\xff", 139 + ) 140 + create_tar_with_samples(tar_path, [original]) 141 + 142 + ds = atdata.Dataset[AllPrimitivesSample](str(tar_path)) 143 + loaded = list(ds.ordered(batch_size=None))[0] 144 + 145 + assert loaded.str_field == "hello world" 146 + assert loaded.int_field == 42 147 + assert abs(loaded.float_field - 3.14159) < 1e-5 148 + assert loaded.bool_field is True 149 + assert loaded.bytes_field == b"\x00\x01\x02\xff" 150 + 151 + def test_extreme_int_values(self, tmp_path): 152 + """Very large and small integers should be preserved.""" 153 + tar_path = tmp_path / "extreme-int-000000.tar" 154 + 155 + @atdata.packable 156 + class ExtremeSample: 157 + value: int 158 + 159 + samples = [ 160 + ExtremeSample(value=0), 161 + ExtremeSample(value=-1), 162 + ExtremeSample(value=2**62), # Large positive 163 + ExtremeSample(value=-(2**62)), # Large negative 164 + ] 165 + create_tar_with_samples(tar_path, samples) 166 + 167 + ds = atdata.Dataset[ExtremeSample](str(tar_path)) 168 + loaded = list(ds.ordered(batch_size=None)) 169 + 170 + assert loaded[0].value == 0 171 + assert loaded[1].value == -1 172 + assert loaded[2].value == 2**62 173 + assert loaded[3].value == -(2**62) 174 + 175 + def test_special_float_values(self, tmp_path): 176 + """Special float values (inf, -inf, very small) should be handled.""" 177 + tar_path = tmp_path / "special-float-000000.tar" 178 + 179 + @atdata.packable 180 + class FloatSample: 181 + value: float 182 + 183 + samples = [ 184 + FloatSample(value=0.0), 185 + FloatSample(value=-0.0), 186 + FloatSample(value=1e-300), # Very small 187 + FloatSample(value=1e300), # Very large 188 + FloatSample(value=float("inf")), 189 + FloatSample(value=float("-inf")), 190 + ] 191 + create_tar_with_samples(tar_path, samples) 192 + 193 + ds = atdata.Dataset[FloatSample](str(tar_path)) 194 + loaded = list(ds.ordered(batch_size=None)) 195 + 196 + assert loaded[0].value == 0.0 197 + assert loaded[2].value == 1e-300 198 + assert loaded[3].value == 1e300 199 + assert loaded[4].value == float("inf") 200 + assert loaded[5].value == float("-inf") 201 + 202 + 203 + ## 204 + # Optional Field Tests 205 + 206 + 207 + class TestOptionalFields: 208 + """Tests for optional (nullable) fields.""" 209 + 210 + def test_optional_fields_with_values(self, tmp_path): 211 + """Optional fields with values should roundtrip correctly.""" 212 + tar_path = tmp_path / "optional-present-000000.tar" 213 + 214 + sample = OptionalFieldsSample( 215 + required_str="required", 216 + optional_str="optional", 217 + optional_int=42, 218 + optional_float=3.14, 219 + optional_array=np.array([1, 2, 3]), 220 + ) 221 + create_tar_with_samples(tar_path, [sample]) 222 + 223 + ds = atdata.Dataset[OptionalFieldsSample](str(tar_path)) 224 + loaded = list(ds.ordered(batch_size=None))[0] 225 + 226 + assert loaded.required_str == "required" 227 + assert loaded.optional_str == "optional" 228 + assert loaded.optional_int == 42 229 + assert loaded.optional_float == 3.14 230 + assert np.array_equal(loaded.optional_array, np.array([1, 2, 3])) 231 + 232 + def test_optional_fields_with_none(self, tmp_path): 233 + """Optional fields with None should roundtrip correctly.""" 234 + tar_path = tmp_path / "optional-none-000000.tar" 235 + 236 + sample = OptionalFieldsSample( 237 + required_str="required", 238 + optional_str=None, 239 + optional_int=None, 240 + optional_float=None, 241 + optional_array=None, 242 + ) 243 + create_tar_with_samples(tar_path, [sample]) 244 + 245 + ds = atdata.Dataset[OptionalFieldsSample](str(tar_path)) 246 + loaded = list(ds.ordered(batch_size=None))[0] 247 + 248 + assert loaded.required_str == "required" 249 + assert loaded.optional_str is None 250 + assert loaded.optional_int is None 251 + assert loaded.optional_float is None 252 + assert loaded.optional_array is None 253 + 254 + 255 + ## 256 + # List Field Tests 257 + 258 + 259 + class TestListFields: 260 + """Tests for list type fields.""" 261 + 262 + def test_list_fields_roundtrip(self, tmp_path): 263 + """List fields should serialize and deserialize correctly.""" 264 + tar_path = tmp_path / "lists-000000.tar" 265 + 266 + sample = ListFieldsSample( 267 + str_list=["a", "b", "c"], 268 + int_list=[1, 2, 3, 4, 5], 269 + float_list=[1.1, 2.2, 3.3], 270 + bool_list=[True, False, True], 271 + ) 272 + create_tar_with_samples(tar_path, [sample]) 273 + 274 + ds = atdata.Dataset[ListFieldsSample](str(tar_path)) 275 + loaded = list(ds.ordered(batch_size=None))[0] 276 + 277 + assert loaded.str_list == ["a", "b", "c"] 278 + assert loaded.int_list == [1, 2, 3, 4, 5] 279 + assert loaded.float_list == [1.1, 2.2, 3.3] 280 + assert loaded.bool_list == [True, False, True] 281 + 282 + def test_empty_lists(self, tmp_path): 283 + """Empty lists should be handled correctly.""" 284 + tar_path = tmp_path / "empty-lists-000000.tar" 285 + 286 + sample = ListFieldsSample( 287 + str_list=[], 288 + int_list=[], 289 + float_list=[], 290 + bool_list=[], 291 + ) 292 + create_tar_with_samples(tar_path, [sample]) 293 + 294 + ds = atdata.Dataset[ListFieldsSample](str(tar_path)) 295 + loaded = list(ds.ordered(batch_size=None))[0] 296 + 297 + assert loaded.str_list == [] 298 + assert loaded.int_list == [] 299 + assert loaded.float_list == [] 300 + assert loaded.bool_list == [] 301 + 302 + def test_large_lists(self, tmp_path): 303 + """Large lists should be handled correctly.""" 304 + tar_path = tmp_path / "large-lists-000000.tar" 305 + 306 + sample = ListFieldsSample( 307 + str_list=[f"item-{i}" for i in range(1000)], 308 + int_list=list(range(1000)), 309 + float_list=[float(i) for i in range(1000)], 310 + bool_list=[i % 2 == 0 for i in range(1000)], 311 + ) 312 + create_tar_with_samples(tar_path, [sample]) 313 + 314 + ds = atdata.Dataset[ListFieldsSample](str(tar_path)) 315 + loaded = list(ds.ordered(batch_size=None))[0] 316 + 317 + assert len(loaded.str_list) == 1000 318 + assert len(loaded.int_list) == 1000 319 + assert loaded.str_list[500] == "item-500" 320 + assert loaded.int_list[999] == 999 321 + 322 + 323 + ## 324 + # Unicode and Special Character Tests 325 + 326 + 327 + class TestUnicodeAndSpecialChars: 328 + """Tests for unicode and special characters.""" 329 + 330 + def test_unicode_strings(self, tmp_path): 331 + """Unicode strings should roundtrip correctly.""" 332 + tar_path = tmp_path / "unicode-000000.tar" 333 + 334 + samples = [ 335 + UnicodeSample(text="Hello World", label="ascii"), 336 + UnicodeSample(text="Bonjour le monde", label="accents"), 337 + UnicodeSample(text="Hallo Welt", label="german"), 338 + UnicodeSample(text="Witaj Swiecie", label="polish"), 339 + ] 340 + create_tar_with_samples(tar_path, samples) 341 + 342 + ds = atdata.Dataset[UnicodeSample](str(tar_path)) 343 + loaded = list(ds.ordered(batch_size=None)) 344 + 345 + assert loaded[0].text == "Hello World" 346 + assert loaded[1].text == "Bonjour le monde" 347 + 348 + def test_emoji(self, tmp_path): 349 + """Emoji should roundtrip correctly.""" 350 + tar_path = tmp_path / "emoji-000000.tar" 351 + 352 + sample = UnicodeSample( 353 + text="Hello World! Have a great day!", 354 + label="with-emoji" 355 + ) 356 + create_tar_with_samples(tar_path, [sample]) 357 + 358 + ds = atdata.Dataset[UnicodeSample](str(tar_path)) 359 + loaded = list(ds.ordered(batch_size=None))[0] 360 + 361 + assert "Hello" in loaded.text 362 + assert "great day" in loaded.text 363 + 364 + def test_cjk_characters(self, tmp_path): 365 + """CJK characters should roundtrip correctly.""" 366 + tar_path = tmp_path / "cjk-000000.tar" 367 + 368 + samples = [ 369 + UnicodeSample(text="Nihongo", label="japanese"), 370 + UnicodeSample(text="Zhongwen", label="chinese"), 371 + UnicodeSample(text="Hangugeo", label="korean"), 372 + ] 373 + create_tar_with_samples(tar_path, samples) 374 + 375 + ds = atdata.Dataset[UnicodeSample](str(tar_path)) 376 + loaded = list(ds.ordered(batch_size=None)) 377 + 378 + assert len(loaded) == 3 379 + 380 + def test_special_chars_in_string_fields(self, tmp_path): 381 + """Special characters (newlines, tabs, quotes) should roundtrip.""" 382 + tar_path = tmp_path / "special-chars-000000.tar" 383 + 384 + sample = UnicodeSample( 385 + text='Line1\nLine2\tTabbed\r\nWindows\0Null"Quotes"', 386 + label="special", 387 + ) 388 + create_tar_with_samples(tar_path, [sample]) 389 + 390 + ds = atdata.Dataset[UnicodeSample](str(tar_path)) 391 + loaded = list(ds.ordered(batch_size=None))[0] 392 + 393 + assert "Line1\nLine2" in loaded.text 394 + assert "\t" in loaded.text 395 + 396 + 397 + ## 398 + # NDArray Type Tests 399 + 400 + 401 + class TestNDArrayTypes: 402 + """Tests for various NDArray dtypes and shapes.""" 403 + 404 + def test_common_dtypes(self, tmp_path): 405 + """Common numpy dtypes should work correctly.""" 406 + dtypes = [np.float32, np.float64, np.int32, np.int64, np.uint8] 407 + 408 + for dtype in dtypes: 409 + tar_path = tmp_path / f"dtype-{dtype.__name__}-000000.tar" 410 + sample = NDArraySample( 411 + label=f"dtype-{dtype.__name__}", 412 + data=np.array([1, 2, 3, 4, 5], dtype=dtype), 413 + ) 414 + create_tar_with_samples(tar_path, [sample]) 415 + 416 + ds = atdata.Dataset[NDArraySample](str(tar_path)) 417 + loaded = list(ds.ordered(batch_size=None))[0] 418 + 419 + assert loaded.data.dtype == dtype 420 + assert np.array_equal(loaded.data, np.array([1, 2, 3, 4, 5], dtype=dtype)) 421 + 422 + def test_float16_dtype(self, tmp_path): 423 + """float16 (half precision) should work correctly.""" 424 + tar_path = tmp_path / "float16-000000.tar" 425 + sample = NDArraySample( 426 + label="float16", 427 + data=np.array([1.0, 2.0, 3.0], dtype=np.float16), 428 + ) 429 + create_tar_with_samples(tar_path, [sample]) 430 + 431 + ds = atdata.Dataset[NDArraySample](str(tar_path)) 432 + loaded = list(ds.ordered(batch_size=None))[0] 433 + 434 + assert loaded.data.dtype == np.float16 435 + 436 + def test_complex_dtype(self, tmp_path): 437 + """Complex dtypes should work correctly.""" 438 + tar_path = tmp_path / "complex-000000.tar" 439 + sample = NDArraySample( 440 + label="complex128", 441 + data=np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=np.complex128), 442 + ) 443 + create_tar_with_samples(tar_path, [sample]) 444 + 445 + ds = atdata.Dataset[NDArraySample](str(tar_path)) 446 + loaded = list(ds.ordered(batch_size=None))[0] 447 + 448 + assert loaded.data.dtype == np.complex128 449 + assert loaded.data[0] == 1 + 2j 450 + 451 + def test_multidimensional_arrays(self, tmp_path): 452 + """Multidimensional arrays should preserve shape.""" 453 + tar_path = tmp_path / "multidim-000000.tar" 454 + 455 + shapes = [(3, 4), (2, 3, 4), (2, 2, 2, 2)] 456 + 457 + for i, shape in enumerate(shapes): 458 + tar_path_i = tmp_path / f"multidim-{i}-000000.tar" 459 + sample = NDArraySample( 460 + label=f"shape-{shape}", 461 + data=np.ones(shape, dtype=np.float32), 462 + ) 463 + create_tar_with_samples(tar_path_i, [sample]) 464 + 465 + ds = atdata.Dataset[NDArraySample](str(tar_path_i)) 466 + loaded = list(ds.ordered(batch_size=None))[0] 467 + 468 + assert loaded.data.shape == shape 469 + 470 + def test_large_array(self, tmp_path): 471 + """Moderately large arrays should work correctly.""" 472 + tar_path = tmp_path / "large-array-000000.tar" 473 + 474 + # 1000x1000 float32 = 4MB 475 + large_array = np.random.randn(1000, 1000).astype(np.float32) 476 + sample = NDArraySample(label="large", data=large_array) 477 + create_tar_with_samples(tar_path, [sample]) 478 + 479 + ds = atdata.Dataset[NDArraySample](str(tar_path)) 480 + loaded = list(ds.ordered(batch_size=None))[0] 481 + 482 + assert loaded.data.shape == (1000, 1000) 483 + assert np.allclose(loaded.data, large_array) 484 + 485 + 486 + ## 487 + # String Edge Cases 488 + 489 + 490 + class TestStringEdgeCases: 491 + """Tests for string field edge cases.""" 492 + 493 + def test_empty_string(self, tmp_path): 494 + """Empty strings should be preserved.""" 495 + tar_path = tmp_path / "empty-string-000000.tar" 496 + sample = UnicodeSample(text="", label="empty") 497 + create_tar_with_samples(tar_path, [sample]) 498 + 499 + ds = atdata.Dataset[UnicodeSample](str(tar_path)) 500 + loaded = list(ds.ordered(batch_size=None))[0] 501 + 502 + assert loaded.text == "" 503 + assert loaded.label == "empty" 504 + 505 + def test_long_string(self, tmp_path): 506 + """Long strings should be handled correctly.""" 507 + tar_path = tmp_path / "long-string-000000.tar" 508 + 509 + # 100KB string 510 + long_text = "x" * (100 * 1024) 511 + sample = UnicodeSample(text=long_text, label="long") 512 + create_tar_with_samples(tar_path, [sample]) 513 + 514 + ds = atdata.Dataset[UnicodeSample](str(tar_path)) 515 + loaded = list(ds.ordered(batch_size=None))[0] 516 + 517 + assert len(loaded.text) == 100 * 1024 518 + assert loaded.text == long_text 519 + 520 + def test_binary_bytes_field(self, tmp_path): 521 + """Binary bytes with all possible byte values.""" 522 + tar_path = tmp_path / "binary-bytes-000000.tar" 523 + 524 + # All possible byte values 525 + all_bytes = bytes(range(256)) 526 + sample = AllPrimitivesSample( 527 + str_field="test", 528 + int_field=0, 529 + float_field=0.0, 530 + bool_field=False, 531 + bytes_field=all_bytes, 532 + ) 533 + create_tar_with_samples(tar_path, [sample]) 534 + 535 + ds = atdata.Dataset[AllPrimitivesSample](str(tar_path)) 536 + loaded = list(ds.ordered(batch_size=None))[0] 537 + 538 + assert loaded.bytes_field == all_bytes 539 + assert len(loaded.bytes_field) == 256 540 + 541 + 542 + ## 543 + # Schema and Index Edge Cases 544 + 545 + 546 + class TestSchemaEdgeCases: 547 + """Tests for schema edge cases.""" 548 + 549 + def test_schema_with_many_fields(self, clean_redis): 550 + """Schema with many fields should work correctly.""" 551 + 552 + @atdata.packable 553 + class ManyFieldsSample: 554 + f1: str 555 + f2: str 556 + f3: str 557 + f4: str 558 + f5: str 559 + f6: int 560 + f7: int 561 + f8: int 562 + f9: float 563 + f10: float 564 + 565 + index = LocalIndex(redis=clean_redis) 566 + schema_ref = index.publish_schema(ManyFieldsSample, version="1.0.0") 567 + schema = index.get_schema(schema_ref) 568 + 569 + assert len(schema["fields"]) == 10 570 + 571 + def test_dataset_name_with_special_chars(self, clean_redis): 572 + """Dataset names with special characters should work.""" 573 + index = LocalIndex(redis=clean_redis) 574 + schema_ref = index.publish_schema(EmptyCompatSample, version="1.0.0") 575 + 576 + # Various special names 577 + names = [ 578 + "dataset-with-dashes", 579 + "dataset_with_underscores", 580 + "dataset.with.dots", 581 + "UPPERCASE-name", 582 + ] 583 + 584 + for name in names: 585 + entry = LocalDatasetEntry( 586 + _name=name, 587 + _schema_ref=schema_ref, 588 + _data_urls=["s3://bucket/data.tar"], 589 + ) 590 + entry.write_to(clean_redis) 591 + 592 + retrieved = index.get_entry_by_name(name) 593 + assert retrieved.name == name 594 + 595 + 596 + ## 597 + # Batch Processing Edge Cases 598 + 599 + 600 + class TestBatchEdgeCases: 601 + """Tests for batch processing edge cases.""" 602 + 603 + def test_batch_size_larger_than_dataset(self, tmp_path): 604 + """Batch size larger than dataset size should work.""" 605 + tar_path = tmp_path / "small-dataset-000000.tar" 606 + samples = [EmptyCompatSample(id=i) for i in range(3)] 607 + create_tar_with_samples(tar_path, samples) 608 + 609 + ds = atdata.Dataset[EmptyCompatSample](str(tar_path)) 610 + batches = list(ds.ordered(batch_size=100)) 611 + 612 + # Should get at least one batch 613 + assert len(batches) >= 1 614 + # Total samples should be 3 615 + total_samples = sum(len(batch.samples) for batch in batches) 616 + assert total_samples == 3 617 + 618 + def test_batch_size_one(self, tmp_path): 619 + """Batch size of 1 should produce individual samples in batches.""" 620 + tar_path = tmp_path / "batch-one-000000.tar" 621 + samples = [EmptyCompatSample(id=i) for i in range(5)] 622 + create_tar_with_samples(tar_path, samples) 623 + 624 + ds = atdata.Dataset[EmptyCompatSample](str(tar_path)) 625 + batches = list(ds.ordered(batch_size=1)) 626 + 627 + assert len(batches) == 5 628 + for batch in batches: 629 + assert len(batch.samples) == 1 630 + 631 + def test_batch_aggregation_with_arrays(self, tmp_path): 632 + """Batch aggregation should stack NDArrays correctly.""" 633 + tar_path = tmp_path / "batch-arrays-000000.tar" 634 + 635 + samples = [ 636 + NDArraySample(label=f"s{i}", data=np.array([i, i + 1, i + 2])) 637 + for i in range(4) 638 + ] 639 + create_tar_with_samples(tar_path, samples) 640 + 641 + ds = atdata.Dataset[NDArraySample](str(tar_path)) 642 + batches = list(ds.ordered(batch_size=4)) 643 + 644 + batch = batches[0] 645 + # data attribute should be stacked 646 + stacked_data = batch.data 647 + 648 + assert stacked_data.shape == (4, 3) 649 + assert np.array_equal(stacked_data[0], np.array([0, 1, 2])) 650 + assert np.array_equal(stacked_data[3], np.array([3, 4, 5]))
+428
tests/test_integration_error_handling.py
··· 1 + """Integration tests for error handling and recovery. 2 + 3 + Tests error conditions and graceful failure including: 4 + - Missing schemas and data URLs 5 + - Malformed data (msgpack, tar) 6 + - Connection failures (Redis, S3, ATProto) 7 + - Authentication and rate limiting errors 8 + """ 9 + 10 + import pytest 11 + from pathlib import Path 12 + from unittest.mock import Mock, MagicMock, patch 13 + import tarfile 14 + import tempfile 15 + 16 + import numpy as np 17 + from numpy.typing import NDArray 18 + 19 + import atdata 20 + from atdata.local import LocalIndex, LocalDatasetEntry 21 + from atdata.atmosphere import AtmosphereClient, AtUri 22 + from atdata.atmosphere._types import LEXICON_NAMESPACE 23 + 24 + 25 + ## 26 + # Test sample types 27 + 28 + 29 + @atdata.packable 30 + class ErrorTestSample: 31 + """Sample for error handling tests.""" 32 + name: str 33 + value: int 34 + 35 + 36 + ## 37 + # Schema Error Tests 38 + 39 + 40 + class TestMissingSchema: 41 + """Tests for missing schema errors.""" 42 + 43 + def test_missing_schema_raises_keyerror(self, clean_redis): 44 + """Accessing non-existent schema should raise KeyError.""" 45 + index = LocalIndex(redis=clean_redis) 46 + 47 + with pytest.raises(KeyError): 48 + index.get_schema("local://schemas/NonExistent@1.0.0") 49 + 50 + def test_dataset_with_invalid_schema_ref(self, clean_redis): 51 + """Dataset entry with invalid schema ref should error on decode.""" 52 + index = LocalIndex(redis=clean_redis) 53 + 54 + entry = LocalDatasetEntry( 55 + _name="orphan-dataset", 56 + _schema_ref="local://schemas/DoesNotExist@1.0.0", 57 + _data_urls=["s3://bucket/data.tar"], 58 + ) 59 + entry.write_to(clean_redis) 60 + 61 + # Entry exists but schema doesn't 62 + retrieved = index.get_entry_by_name("orphan-dataset") 63 + assert retrieved is not None 64 + 65 + # Attempting to decode schema should fail 66 + with pytest.raises(KeyError): 67 + index.decode_schema(retrieved.schema_ref) 68 + 69 + 70 + ## 71 + # Data URL Error Tests 72 + 73 + 74 + class TestMissingDataUrls: 75 + """Tests for missing or inaccessible data URLs.""" 76 + 77 + def test_empty_data_urls_raises(self, clean_redis): 78 + """Dataset entry with empty URLs should be flagged.""" 79 + index = LocalIndex(redis=clean_redis) 80 + schema_ref = index.publish_schema(ErrorTestSample, version="1.0.0") 81 + 82 + entry = LocalDatasetEntry( 83 + _name="empty-urls", 84 + _schema_ref=schema_ref, 85 + _data_urls=[], 86 + ) 87 + entry.write_to(clean_redis) 88 + 89 + retrieved = index.get_entry_by_name("empty-urls") 90 + assert retrieved.data_urls == [] 91 + 92 + def test_nonexistent_tar_raises(self, tmp_path): 93 + """Attempting to read non-existent tar should raise.""" 94 + nonexistent_path = tmp_path / "does-not-exist.tar" 95 + 96 + ds = atdata.Dataset[ErrorTestSample](str(nonexistent_path)) 97 + 98 + # Iterating should fail 99 + with pytest.raises(FileNotFoundError): 100 + list(ds.ordered(batch_size=None)) 101 + 102 + 103 + ## 104 + # Malformed Data Tests 105 + 106 + 107 + class TestMalformedMsgpack: 108 + """Tests for corrupted msgpack data.""" 109 + 110 + def test_invalid_msgpack_in_tar(self, tmp_path): 111 + """Tar with invalid msgpack should raise on iteration.""" 112 + tar_path = tmp_path / "corrupted-000000.tar" 113 + 114 + import io 115 + 116 + # Create tar with invalid msgpack data 117 + with tarfile.open(tar_path, "w") as tar: 118 + # Add a valid key file 119 + key_data = b"sample-0" 120 + key_info = tarfile.TarInfo(name="sample-0.__key__") 121 + key_info.size = len(key_data) 122 + tar.addfile(key_info, fileobj=io.BytesIO(key_data)) 123 + 124 + # Add invalid msgpack data 125 + invalid_data = b"\xff\xff\xff\xff\xff" # Not valid msgpack 126 + info = tarfile.TarInfo(name="sample-0.msgpack") 127 + info.size = len(invalid_data) 128 + tar.addfile(info, fileobj=io.BytesIO(invalid_data)) 129 + 130 + ds = atdata.Dataset[ErrorTestSample](str(tar_path)) 131 + 132 + # Should raise an error when trying to deserialize 133 + with pytest.raises(Exception): # Could be msgpack error or ValueError 134 + list(ds.ordered(batch_size=None)) 135 + 136 + 137 + class TestCorruptedTar: 138 + """Tests for corrupted tar files.""" 139 + 140 + def test_truncated_tar_raises(self, tmp_path): 141 + """Truncated tar file should raise an error.""" 142 + tar_path = tmp_path / "truncated-000000.tar" 143 + 144 + # Create a valid tar then truncate it 145 + with tarfile.open(tar_path, "w") as tar: 146 + data = b"test data" 147 + info = tarfile.TarInfo(name="test.txt") 148 + info.size = len(data) 149 + import io 150 + tar.addfile(info, fileobj=io.BytesIO(data)) 151 + 152 + # Truncate the file 153 + with open(tar_path, "r+b") as f: 154 + f.truncate(50) # Truncate to partial content 155 + 156 + ds = atdata.Dataset[ErrorTestSample](str(tar_path)) 157 + 158 + with pytest.raises(Exception): # tarfile.ReadError or similar 159 + list(ds.ordered(batch_size=None)) 160 + 161 + def test_not_a_tar_file_raises(self, tmp_path): 162 + """Non-tar file should raise clear error.""" 163 + fake_tar = tmp_path / "fake-000000.tar" 164 + 165 + # Write random bytes 166 + with open(fake_tar, "wb") as f: 167 + f.write(b"This is not a tar file at all!") 168 + 169 + ds = atdata.Dataset[ErrorTestSample](str(fake_tar)) 170 + 171 + with pytest.raises(Exception): # tarfile.ReadError 172 + list(ds.ordered(batch_size=None)) 173 + 174 + 175 + ## 176 + # Redis Error Tests 177 + 178 + 179 + class TestRedisErrors: 180 + """Tests for Redis connection errors.""" 181 + 182 + def test_redis_connection_error(self): 183 + """Operations with bad Redis connection should fail cleanly.""" 184 + from redis import Redis, ConnectionError 185 + 186 + # Create index with invalid Redis connection 187 + bad_redis = Redis(host="nonexistent.invalid.host", port=9999, socket_timeout=0.1) 188 + 189 + index = LocalIndex(redis=bad_redis) 190 + 191 + # Operations should raise connection errors 192 + with pytest.raises((ConnectionError, Exception)): 193 + index.publish_schema(ErrorTestSample, version="1.0.0") 194 + 195 + def test_entry_lookup_with_bad_redis(self, clean_redis): 196 + """Entry lookup should fail cleanly if Redis becomes unavailable.""" 197 + index = LocalIndex(redis=clean_redis) 198 + 199 + # First, add an entry 200 + schema_ref = index.publish_schema(ErrorTestSample, version="1.0.0") 201 + entry = LocalDatasetEntry( 202 + _name="test-entry", 203 + _schema_ref=schema_ref, 204 + _data_urls=["s3://bucket/data.tar"], 205 + ) 206 + entry.write_to(clean_redis) 207 + 208 + # Entry should be retrievable 209 + retrieved = index.get_entry_by_name("test-entry") 210 + assert retrieved is not None 211 + 212 + 213 + ## 214 + # ATProto Error Tests 215 + 216 + 217 + class TestAtProtoErrors: 218 + """Tests for ATProto/Atmosphere errors.""" 219 + 220 + def test_unauthenticated_publish_raises(self): 221 + """Publishing without authentication should raise.""" 222 + mock_client = Mock() 223 + mock_client.me = None 224 + 225 + client = AtmosphereClient(_client=mock_client) 226 + 227 + # Not authenticated 228 + assert not client.is_authenticated 229 + 230 + from atdata.atmosphere import SchemaPublisher 231 + publisher = SchemaPublisher(client) 232 + 233 + with pytest.raises(ValueError, match="authenticated"): 234 + publisher.publish(ErrorTestSample, version="1.0.0") 235 + 236 + def test_invalid_at_uri_raises(self): 237 + """Parsing invalid AT URI should raise ValueError.""" 238 + invalid_uris = [ 239 + "not-a-uri", 240 + "https://example.com/path", 241 + "at://", 242 + "at://did:plc:abc", # Missing collection and rkey 243 + "at://did:plc:abc/collection", # Missing rkey 244 + ] 245 + 246 + for uri in invalid_uris: 247 + with pytest.raises(ValueError): 248 + AtUri.parse(uri) 249 + 250 + def test_api_error_response_handling(self): 251 + """API errors should be propagated appropriately.""" 252 + mock_client = Mock() 253 + mock_client.me = MagicMock() 254 + mock_client.me.did = "did:plc:test123" 255 + 256 + # Simulate an API error 257 + from atproto_client.exceptions import AtProtocolError 258 + 259 + mock_client.com.atproto.repo.create_record.side_effect = AtProtocolError( 260 + "API error occurred" 261 + ) 262 + 263 + # Create client and authenticate it 264 + client = AtmosphereClient(_client=mock_client) 265 + client._session = {"did": "did:plc:test123"} # Mark as authenticated 266 + 267 + from atdata.atmosphere import SchemaPublisher 268 + publisher = SchemaPublisher(client) 269 + 270 + # Should propagate the API error 271 + with pytest.raises(AtProtocolError): 272 + publisher.publish(ErrorTestSample, version="1.0.0") 273 + 274 + def test_expired_session_detection(self): 275 + """Expired session should be detectable.""" 276 + mock_client = Mock() 277 + mock_client.me = None 278 + mock_client.export_session_string.return_value = None 279 + 280 + client = AtmosphereClient(_client=mock_client) 281 + 282 + # Should not be authenticated 283 + assert not client.is_authenticated 284 + 285 + 286 + ## 287 + # Entry Not Found Tests 288 + 289 + 290 + class TestNotFoundErrors: 291 + """Tests for not-found error handling.""" 292 + 293 + def test_get_entry_by_name_not_found(self, clean_redis): 294 + """Getting non-existent entry by name should raise KeyError.""" 295 + index = LocalIndex(redis=clean_redis) 296 + 297 + with pytest.raises(KeyError): 298 + index.get_entry_by_name("nonexistent-dataset") 299 + 300 + def test_get_entry_by_cid_not_found(self, clean_redis): 301 + """Getting non-existent entry by CID should raise KeyError.""" 302 + index = LocalIndex(redis=clean_redis) 303 + 304 + with pytest.raises(KeyError): 305 + index.get_entry("bafyreifake123456789") 306 + 307 + 308 + ## 309 + # Error Message Quality Tests 310 + 311 + 312 + class TestErrorMessageQuality: 313 + """Tests that error messages are helpful and don't leak sensitive info.""" 314 + 315 + def test_missing_schema_error_includes_ref(self, clean_redis): 316 + """Missing schema error should include the schema reference.""" 317 + index = LocalIndex(redis=clean_redis) 318 + 319 + try: 320 + index.get_schema("local://schemas/MissingType@1.0.0") 321 + assert False, "Should have raised KeyError" 322 + except KeyError as e: 323 + # Error should mention the schema reference 324 + assert "MissingType" in str(e) or "local://" in str(e) 325 + 326 + def test_invalid_uri_error_is_clear(self): 327 + """Invalid AT URI error should explain the issue.""" 328 + try: 329 + AtUri.parse("not-valid") 330 + assert False, "Should have raised ValueError" 331 + except ValueError as e: 332 + # Error should explain it's not a valid URI 333 + assert "at://" in str(e).lower() or "uri" in str(e).lower() 334 + 335 + def test_auth_error_no_credential_leak(self): 336 + """Authentication errors should not leak credentials.""" 337 + mock_client = Mock() 338 + mock_client.me = None 339 + 340 + client = AtmosphereClient(_client=mock_client) 341 + 342 + from atdata.atmosphere import SchemaPublisher 343 + publisher = SchemaPublisher(client) 344 + 345 + try: 346 + publisher.publish(ErrorTestSample, version="1.0.0") 347 + except ValueError as e: 348 + error_msg = str(e) 349 + # Should not contain anything that looks like a password or token 350 + assert "password" not in error_msg.lower() 351 + assert "token" not in error_msg.lower() 352 + assert "secret" not in error_msg.lower() 353 + 354 + 355 + ## 356 + # Recovery Tests 357 + 358 + 359 + class TestRecovery: 360 + """Tests for recovery from errors.""" 361 + 362 + def test_can_continue_after_bad_sample(self, tmp_path, clean_redis): 363 + """System should be usable after encountering bad data.""" 364 + # First, try to read a bad file 365 + bad_tar = tmp_path / "bad-000000.tar" 366 + with open(bad_tar, "wb") as f: 367 + f.write(b"not a tar file") 368 + 369 + ds_bad = atdata.Dataset[ErrorTestSample](str(bad_tar)) 370 + try: 371 + list(ds_bad.ordered(batch_size=None)) 372 + except Exception: 373 + pass # Expected to fail 374 + 375 + # Now use a good file - should still work 376 + good_tar = tmp_path / "good-000000.tar" 377 + import webdataset as wds 378 + with wds.writer.TarWriter(str(good_tar)) as writer: 379 + sample = ErrorTestSample(name="good", value=42) 380 + writer.write(sample.as_wds) 381 + 382 + ds_good = atdata.Dataset[ErrorTestSample](str(good_tar)) 383 + samples = list(ds_good.ordered(batch_size=None)) 384 + 385 + assert len(samples) == 1 386 + assert samples[0].name == "good" 387 + 388 + def test_index_usable_after_failed_publish(self, clean_redis): 389 + """Index should remain usable after a failed operation.""" 390 + index = LocalIndex(redis=clean_redis) 391 + 392 + # Try to get a non-existent schema (fails as expected) 393 + with pytest.raises(KeyError): 394 + index.get_schema("local://schemas/NoSuch@1.0.0") 395 + 396 + # Index should still work 397 + schema_ref = index.publish_schema(ErrorTestSample, version="1.0.0") 398 + assert schema_ref is not None 399 + 400 + schema = index.get_schema(schema_ref) 401 + assert schema["name"] == "ErrorTestSample" 402 + 403 + 404 + ## 405 + # Validation Tests 406 + 407 + 408 + class TestInputValidation: 409 + """Tests for input validation.""" 410 + 411 + def test_empty_version_string(self, clean_redis): 412 + """Empty version string should be handled.""" 413 + index = LocalIndex(redis=clean_redis) 414 + 415 + # Empty version - implementation may accept or reject 416 + schema_ref = index.publish_schema(ErrorTestSample, version="") 417 + # If it accepts, it should store and retrieve correctly 418 + schema = index.get_schema(schema_ref) 419 + assert schema is not None 420 + 421 + def test_special_chars_in_version(self, clean_redis): 422 + """Special characters in version should be handled.""" 423 + index = LocalIndex(redis=clean_redis) 424 + 425 + schema_ref = index.publish_schema(ErrorTestSample, version="1.0.0-beta+build.123") 426 + schema = index.get_schema(schema_ref) 427 + 428 + assert schema["version"] == "1.0.0-beta+build.123"
+633
tests/test_integration_lens.py
··· 1 + """Integration tests for lens transformation chains. 2 + 3 + Tests complex lens workflows including: 4 + - Chained transformations (A → B → C) 5 + - Batch lens application 6 + - Optional field handling 7 + - Bidirectional round-trips (lens laws) 8 + - LensNetwork discovery 9 + - NDArray transformations 10 + """ 11 + 12 + import pytest 13 + from dataclasses import dataclass 14 + 15 + import numpy as np 16 + from numpy.typing import NDArray 17 + import webdataset as wds 18 + 19 + import atdata 20 + from atdata.lens import LensNetwork 21 + 22 + 23 + ## 24 + # Test sample types for lens chains 25 + 26 + 27 + @atdata.packable 28 + class FullRecord: 29 + """Complete record with all fields.""" 30 + id: int 31 + name: str 32 + email: str 33 + age: int 34 + score: float 35 + embedding: NDArray 36 + 37 + 38 + @atdata.packable 39 + class ProfileView: 40 + """View with profile information only.""" 41 + name: str 42 + email: str 43 + age: int 44 + 45 + 46 + @atdata.packable 47 + class NameView: 48 + """Minimal view with just name.""" 49 + name: str 50 + 51 + 52 + @atdata.packable 53 + class ScoredRecord: 54 + """Record with score and embedding.""" 55 + id: int 56 + score: float 57 + embedding: NDArray 58 + 59 + 60 + @atdata.packable 61 + class OptionalFieldSample: 62 + """Sample with optional fields.""" 63 + name: str 64 + value: int 65 + extra: str | None = None 66 + embedding: NDArray | None = None 67 + 68 + 69 + @atdata.packable 70 + class OptionalView: 71 + """View of optional sample.""" 72 + name: str 73 + extra: str | None = None 74 + 75 + 76 + ## 77 + # Lens definitions 78 + 79 + 80 + @atdata.lens 81 + def full_to_profile(full: FullRecord) -> ProfileView: 82 + """Extract profile from full record.""" 83 + return ProfileView( 84 + name=full.name, 85 + email=full.email, 86 + age=full.age, 87 + ) 88 + 89 + 90 + @full_to_profile.putter 91 + def full_to_profile_put(view: ProfileView, source: FullRecord) -> FullRecord: 92 + """Update full record from profile view.""" 93 + return FullRecord( 94 + id=source.id, 95 + name=view.name, 96 + email=view.email, 97 + age=view.age, 98 + score=source.score, 99 + embedding=source.embedding, 100 + ) 101 + 102 + 103 + @atdata.lens 104 + def profile_to_name(profile: ProfileView) -> NameView: 105 + """Extract just name from profile.""" 106 + return NameView(name=profile.name) 107 + 108 + 109 + @profile_to_name.putter 110 + def profile_to_name_put(view: NameView, source: ProfileView) -> ProfileView: 111 + """Update profile from name view.""" 112 + return ProfileView( 113 + name=view.name, 114 + email=source.email, 115 + age=source.age, 116 + ) 117 + 118 + 119 + @atdata.lens 120 + def full_to_scored(full: FullRecord) -> ScoredRecord: 121 + """Extract scoring data from full record.""" 122 + return ScoredRecord( 123 + id=full.id, 124 + score=full.score, 125 + embedding=full.embedding, 126 + ) 127 + 128 + 129 + @full_to_scored.putter 130 + def full_to_scored_put(view: ScoredRecord, source: FullRecord) -> FullRecord: 131 + """Update full record from scored view.""" 132 + return FullRecord( 133 + id=view.id, 134 + name=source.name, 135 + email=source.email, 136 + age=source.age, 137 + score=view.score, 138 + embedding=view.embedding, 139 + ) 140 + 141 + 142 + @atdata.lens 143 + def optional_to_view(opt: OptionalFieldSample) -> OptionalView: 144 + """Extract optional fields view.""" 145 + return OptionalView( 146 + name=opt.name, 147 + extra=opt.extra, 148 + ) 149 + 150 + 151 + @optional_to_view.putter 152 + def optional_to_view_put(view: OptionalView, source: OptionalFieldSample) -> OptionalFieldSample: 153 + """Update optional sample from view.""" 154 + return OptionalFieldSample( 155 + name=view.name, 156 + value=source.value, 157 + extra=view.extra, 158 + embedding=source.embedding, 159 + ) 160 + 161 + 162 + ## 163 + # Helper functions 164 + 165 + 166 + def create_full_records(n: int) -> list[FullRecord]: 167 + """Create n full records with distinct values.""" 168 + return [ 169 + FullRecord( 170 + id=i, 171 + name=f"user_{i}", 172 + email=f"user_{i}@example.com", 173 + age=20 + (i % 50), 174 + score=float(i) * 0.1, 175 + embedding=np.random.randn(64).astype(np.float32), 176 + ) 177 + for i in range(n) 178 + ] 179 + 180 + 181 + def write_dataset(path, samples) -> str: 182 + """Write samples to tar file, return path.""" 183 + tar_path = path.as_posix() 184 + with wds.writer.TarWriter(tar_path) as sink: 185 + for sample in samples: 186 + sink.write(sample.as_wds) 187 + return tar_path 188 + 189 + 190 + ## 191 + # Chained Transformation Tests 192 + 193 + 194 + class TestChainedTransformations: 195 + """Tests for chaining multiple lens transformations.""" 196 + 197 + def test_manual_chain_two_lenses(self, tmp_path): 198 + """Manually chain two lenses: Full → Profile → Name.""" 199 + n_samples = 20 200 + records = create_full_records(n_samples) 201 + 202 + tar_path = write_dataset(tmp_path / "full.tar", records) 203 + 204 + # Load as ProfileView first 205 + profile_ds = atdata.Dataset[FullRecord](tar_path).as_type(ProfileView) 206 + 207 + # Iterate and transform to NameView 208 + for i, profile in enumerate(profile_ds.ordered(batch_size=None)): 209 + assert isinstance(profile, ProfileView) 210 + 211 + # Apply second transformation manually 212 + name_view = profile_to_name(profile) 213 + assert isinstance(name_view, NameView) 214 + assert name_view.name == f"user_{i}" 215 + 216 + if i >= 10: 217 + break 218 + 219 + def test_chain_round_trip(self): 220 + """Chain of transformations should support round-trip.""" 221 + original = FullRecord( 222 + id=1, 223 + name="Alice", 224 + email="alice@test.com", 225 + age=30, 226 + score=0.95, 227 + embedding=np.array([1.0, 2.0, 3.0], dtype=np.float32), 228 + ) 229 + 230 + # Forward chain: Full → Profile → Name 231 + profile = full_to_profile(original) 232 + name = profile_to_name(profile) 233 + 234 + assert name.name == "Alice" 235 + 236 + # Reverse chain with updates 237 + new_name = NameView(name="Alice Updated") 238 + updated_profile = profile_to_name.put(new_name, profile) 239 + updated_full = full_to_profile.put(updated_profile, original) 240 + 241 + assert updated_full.name == "Alice Updated" 242 + assert updated_full.id == 1 # Preserved 243 + assert updated_full.score == 0.95 # Preserved 244 + 245 + def test_parallel_views_from_same_source(self, tmp_path): 246 + """Same source can have multiple views through different lenses.""" 247 + n_samples = 15 248 + records = create_full_records(n_samples) 249 + 250 + tar_path = write_dataset(tmp_path / "multi.tar", records) 251 + 252 + # Create two different views of the same data 253 + profile_ds = atdata.Dataset[FullRecord](tar_path).as_type(ProfileView) 254 + scored_ds = atdata.Dataset[FullRecord](tar_path).as_type(ScoredRecord) 255 + 256 + profiles = list(profile_ds.ordered(batch_size=None)) 257 + scored = list(scored_ds.ordered(batch_size=None)) 258 + 259 + assert len(profiles) == n_samples 260 + assert len(scored) == n_samples 261 + 262 + for i in range(n_samples): 263 + # Both views from same source 264 + assert profiles[i].name == f"user_{i}" 265 + assert scored[i].id == i 266 + assert scored[i].score == float(i) * 0.1 267 + 268 + 269 + class TestBatchLensApplication: 270 + """Tests for applying lenses to batched data.""" 271 + 272 + def test_lens_with_batched_iteration(self, tmp_path): 273 + """Lens should apply correctly to batched dataset iteration.""" 274 + n_samples = 32 275 + batch_size = 8 276 + records = create_full_records(n_samples) 277 + 278 + tar_path = write_dataset(tmp_path / "batch.tar", records) 279 + dataset = atdata.Dataset[FullRecord](tar_path).as_type(ProfileView) 280 + 281 + batches = list(dataset.ordered(batch_size=batch_size)) 282 + 283 + for batch in batches: 284 + assert isinstance(batch, atdata.SampleBatch) 285 + assert batch.sample_type == ProfileView 286 + 287 + # Verify samples are ProfileView instances 288 + for sample in batch.samples: 289 + assert isinstance(sample, ProfileView) 290 + assert hasattr(sample, "name") 291 + assert hasattr(sample, "email") 292 + assert hasattr(sample, "age") 293 + # Should not have FullRecord fields 294 + assert not hasattr(sample, "id") 295 + assert not hasattr(sample, "score") 296 + 297 + def test_batch_aggregation_after_lens(self, tmp_path): 298 + """Batch aggregation should work on lens-transformed samples.""" 299 + n_samples = 24 300 + batch_size = 6 301 + records = create_full_records(n_samples) 302 + 303 + tar_path = write_dataset(tmp_path / "agg.tar", records) 304 + dataset = atdata.Dataset[FullRecord](tar_path).as_type(ProfileView) 305 + 306 + batch_idx = 0 307 + for batch in dataset.ordered(batch_size=batch_size): 308 + # Access aggregated attributes 309 + names = batch.name 310 + emails = batch.email 311 + ages = batch.age 312 + 313 + assert isinstance(names, list) 314 + assert isinstance(emails, list) 315 + assert isinstance(ages, list) 316 + assert len(names) == len(batch.samples) 317 + 318 + # Verify values 319 + for i, name in enumerate(names): 320 + expected_idx = batch_idx * batch_size + i 321 + assert name == f"user_{expected_idx}" 322 + 323 + batch_idx += 1 324 + 325 + def test_ndarray_lens_with_batching(self, tmp_path): 326 + """Lens transforming NDArray fields should work with batching.""" 327 + n_samples = 20 328 + batch_size = 5 329 + records = create_full_records(n_samples) 330 + 331 + tar_path = write_dataset(tmp_path / "ndarray.tar", records) 332 + dataset = atdata.Dataset[FullRecord](tar_path).as_type(ScoredRecord) 333 + 334 + for batch in dataset.ordered(batch_size=batch_size): 335 + # NDArray should be stacked 336 + embeddings = batch.embedding 337 + assert isinstance(embeddings, np.ndarray) 338 + assert embeddings.shape == (batch_size, 64) 339 + 340 + 341 + class TestLensLaws: 342 + """Tests for lens law compliance (well-behavedness).""" 343 + 344 + def test_getput_law(self): 345 + """GetPut law: put(get(s), s) == s.""" 346 + source = FullRecord( 347 + id=42, 348 + name="Test", 349 + email="test@example.com", 350 + age=25, 351 + score=0.5, 352 + embedding=np.array([1.0, 2.0], dtype=np.float32), 353 + ) 354 + 355 + view = full_to_profile(source) 356 + result = full_to_profile.put(view, source) 357 + 358 + assert result.id == source.id 359 + assert result.name == source.name 360 + assert result.email == source.email 361 + assert result.age == source.age 362 + assert result.score == source.score 363 + np.testing.assert_array_equal(result.embedding, source.embedding) 364 + 365 + def test_putget_law(self): 366 + """PutGet law: get(put(v, s)) == v.""" 367 + source = FullRecord( 368 + id=42, 369 + name="Original", 370 + email="original@example.com", 371 + age=25, 372 + score=0.5, 373 + embedding=np.array([1.0, 2.0], dtype=np.float32), 374 + ) 375 + 376 + new_view = ProfileView( 377 + name="Updated", 378 + email="updated@example.com", 379 + age=30, 380 + ) 381 + 382 + updated = full_to_profile.put(new_view, source) 383 + retrieved = full_to_profile(updated) 384 + 385 + assert retrieved.name == new_view.name 386 + assert retrieved.email == new_view.email 387 + assert retrieved.age == new_view.age 388 + 389 + def test_putput_law(self): 390 + """PutPut law: put(v2, put(v1, s)) == put(v2, s).""" 391 + source = FullRecord( 392 + id=42, 393 + name="Original", 394 + email="original@example.com", 395 + age=25, 396 + score=0.5, 397 + embedding=np.array([1.0, 2.0], dtype=np.float32), 398 + ) 399 + 400 + view1 = ProfileView(name="First", email="first@example.com", age=26) 401 + view2 = ProfileView(name="Second", email="second@example.com", age=27) 402 + 403 + # Two ways to get final result 404 + result1 = full_to_profile.put(view2, full_to_profile.put(view1, source)) 405 + result2 = full_to_profile.put(view2, source) 406 + 407 + assert result1.name == result2.name 408 + assert result1.email == result2.email 409 + assert result1.age == result2.age 410 + 411 + 412 + class TestOptionalFieldLenses: 413 + """Tests for lenses handling optional fields.""" 414 + 415 + def test_optional_field_with_value(self, tmp_path): 416 + """Lens should handle optional fields that have values.""" 417 + samples = [ 418 + OptionalFieldSample( 419 + name=f"item_{i}", 420 + value=i * 10, 421 + extra=f"extra_{i}", 422 + embedding=np.random.randn(32).astype(np.float32), 423 + ) 424 + for i in range(10) 425 + ] 426 + 427 + tar_path = write_dataset(tmp_path / "opt_filled.tar", samples) 428 + dataset = atdata.Dataset[OptionalFieldSample](tar_path).as_type(OptionalView) 429 + 430 + for i, view in enumerate(dataset.ordered(batch_size=None)): 431 + assert isinstance(view, OptionalView) 432 + assert view.name == f"item_{i}" 433 + assert view.extra == f"extra_{i}" 434 + 435 + def test_optional_field_with_none(self, tmp_path): 436 + """Lens should handle optional fields that are None.""" 437 + samples = [ 438 + OptionalFieldSample( 439 + name=f"item_{i}", 440 + value=i * 10, 441 + extra=None, 442 + embedding=None, 443 + ) 444 + for i in range(10) 445 + ] 446 + 447 + tar_path = write_dataset(tmp_path / "opt_none.tar", samples) 448 + dataset = atdata.Dataset[OptionalFieldSample](tar_path).as_type(OptionalView) 449 + 450 + for view in dataset.ordered(batch_size=None): 451 + assert isinstance(view, OptionalView) 452 + assert view.extra is None 453 + 454 + def test_optional_field_lens_roundtrip(self): 455 + """Lens with optional fields should support round-trip.""" 456 + source = OptionalFieldSample( 457 + name="test", 458 + value=42, 459 + extra="optional", 460 + embedding=np.array([1.0, 2.0], dtype=np.float32), 461 + ) 462 + 463 + view = optional_to_view(source) 464 + assert view.extra == "optional" 465 + 466 + # Update with None 467 + new_view = OptionalView(name="updated", extra=None) 468 + updated = optional_to_view.put(new_view, source) 469 + 470 + assert updated.name == "updated" 471 + assert updated.extra is None 472 + assert updated.value == 42 # Preserved 473 + 474 + 475 + class TestLensNetworkDiscovery: 476 + """Tests for LensNetwork registry and discovery.""" 477 + 478 + def test_registered_lens_discoverable(self): 479 + """Registered lenses should be discoverable via LensNetwork.""" 480 + network = LensNetwork() 481 + 482 + # The lenses defined above should be registered 483 + lens = network.transform(FullRecord, ProfileView) 484 + assert lens is not None 485 + assert lens.source_type == FullRecord 486 + assert lens.view_type == ProfileView 487 + 488 + def test_unregistered_lens_raises(self): 489 + """Querying unregistered lens should raise ValueError.""" 490 + @atdata.packable 491 + class UnknownSource: 492 + x: int 493 + 494 + @atdata.packable 495 + class UnknownView: 496 + y: int 497 + 498 + network = LensNetwork() 499 + 500 + with pytest.raises(ValueError, match="No registered lens"): 501 + network.transform(UnknownSource, UnknownView) 502 + 503 + def test_multiple_lenses_registered(self): 504 + """Multiple lenses can be registered and retrieved independently.""" 505 + network = LensNetwork() 506 + 507 + # All our test lenses should be registered 508 + lens1 = network.transform(FullRecord, ProfileView) 509 + lens2 = network.transform(ProfileView, NameView) 510 + lens3 = network.transform(FullRecord, ScoredRecord) 511 + 512 + assert lens1 is not lens2 513 + assert lens2 is not lens3 514 + assert lens1.view_type == ProfileView 515 + assert lens2.view_type == NameView 516 + assert lens3.view_type == ScoredRecord 517 + 518 + 519 + class TestNDArrayTransformations: 520 + """Tests for lenses that transform NDArray fields.""" 521 + 522 + def test_ndarray_field_preserved(self, tmp_path): 523 + """NDArray fields should be correctly preserved through lens.""" 524 + records = create_full_records(10) 525 + 526 + tar_path = write_dataset(tmp_path / "ndarray.tar", records) 527 + dataset = atdata.Dataset[FullRecord](tar_path).as_type(ScoredRecord) 528 + 529 + for i, scored in enumerate(dataset.ordered(batch_size=None)): 530 + assert isinstance(scored.embedding, np.ndarray) 531 + assert scored.embedding.shape == (64,) 532 + assert scored.embedding.dtype == np.float32 533 + np.testing.assert_array_almost_equal( 534 + scored.embedding, 535 + records[i].embedding, 536 + ) 537 + 538 + def test_ndarray_transformation_lens(self): 539 + """Lens that transforms NDArray values.""" 540 + @atdata.packable 541 + class RawData: 542 + values: NDArray 543 + 544 + @atdata.packable 545 + class NormalizedData: 546 + normalized: NDArray 547 + 548 + @atdata.lens 549 + def normalize(raw: RawData) -> NormalizedData: 550 + arr = raw.values 551 + normalized = (arr - arr.mean()) / (arr.std() + 1e-8) 552 + return NormalizedData(normalized=normalized) 553 + 554 + source = RawData(values=np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)) 555 + view = normalize(source) 556 + 557 + assert isinstance(view.normalized, np.ndarray) 558 + # Normalized should have mean ~0 and std ~1 559 + assert abs(view.normalized.mean()) < 0.01 560 + assert abs(view.normalized.std() - 1.0) < 0.01 561 + 562 + 563 + class TestComplexLensScenarios: 564 + """Complex integration scenarios combining multiple features.""" 565 + 566 + def test_dataset_lens_chain_with_batching(self, tmp_path): 567 + """Full pipeline: Dataset → Lens → Batch iteration.""" 568 + n_samples = 50 569 + batch_size = 10 570 + records = create_full_records(n_samples) 571 + 572 + tar_path = write_dataset(tmp_path / "complex.tar", records) 573 + 574 + # Create lens-transformed dataset 575 + dataset = atdata.Dataset[FullRecord](tar_path).as_type(ProfileView) 576 + 577 + total_samples = 0 578 + for batch in dataset.ordered(batch_size=batch_size): 579 + assert isinstance(batch, atdata.SampleBatch) 580 + assert batch.sample_type == ProfileView 581 + 582 + # Apply second lens to each sample 583 + for sample in batch.samples: 584 + name_view = profile_to_name(sample) 585 + assert isinstance(name_view, NameView) 586 + assert name_view.name.startswith("user_") 587 + 588 + total_samples += len(batch.samples) 589 + 590 + assert total_samples == n_samples 591 + 592 + def test_shuffled_iteration_with_lens(self, tmp_path): 593 + """Lens should work with shuffled iteration.""" 594 + n_samples = 30 595 + records = create_full_records(n_samples) 596 + 597 + tar_path = write_dataset(tmp_path / "shuffle.tar", records) 598 + dataset = atdata.Dataset[FullRecord](tar_path).as_type(ProfileView) 599 + 600 + seen_names = set() 601 + for profile in dataset.shuffled(batch_size=None): 602 + assert isinstance(profile, ProfileView) 603 + seen_names.add(profile.name) 604 + if len(seen_names) >= 20: 605 + break 606 + 607 + # Should have seen multiple distinct names 608 + assert len(seen_names) >= 10 609 + 610 + def test_lens_preserves_all_fields(self, tmp_path): 611 + """Lens transformation should preserve all view fields exactly.""" 612 + records = [ 613 + FullRecord( 614 + id=i, 615 + name=f"name_{i}", 616 + email=f"email_{i}@test.com", 617 + age=20 + i, 618 + score=float(i) * 0.5, 619 + embedding=np.full(64, float(i), dtype=np.float32), 620 + ) 621 + for i in range(10) 622 + ] 623 + 624 + tar_path = write_dataset(tmp_path / "preserve.tar", records) 625 + dataset = atdata.Dataset[FullRecord](tar_path).as_type(ScoredRecord) 626 + 627 + for i, scored in enumerate(dataset.ordered(batch_size=None)): 628 + assert scored.id == i 629 + assert scored.score == float(i) * 0.5 630 + np.testing.assert_array_equal( 631 + scored.embedding, 632 + np.full(64, float(i), dtype=np.float32), 633 + )
+663
tests/test_integration_local.py
··· 1 + """Integration tests for local storage complete workflow. 2 + 3 + Tests end-to-end local storage workflows including: 4 + - Full Repo workflow: Init → publish_schema → insert → query → load 5 + - Schema versioning and CID consistency 6 + - Dataset discovery and querying 7 + - Metadata persistence through full cycle 8 + - cache_local mode comparison 9 + """ 10 + 11 + import pytest 12 + from dataclasses import dataclass 13 + from pathlib import Path 14 + 15 + import numpy as np 16 + from numpy.typing import NDArray 17 + from moto import mock_aws 18 + import webdataset as wds 19 + 20 + import atdata 21 + import atdata.local as atlocal 22 + 23 + 24 + ## 25 + # Test sample types 26 + 27 + 28 + @dataclass 29 + class WorkflowSample(atdata.PackableSample): 30 + """Sample for workflow tests.""" 31 + name: str 32 + value: int 33 + score: float 34 + 35 + 36 + @dataclass 37 + class ArrayWorkflowSample(atdata.PackableSample): 38 + """Sample with array for workflow tests.""" 39 + label: str 40 + data: NDArray 41 + 42 + 43 + @dataclass 44 + class MetadataSample(atdata.PackableSample): 45 + """Sample for metadata workflow tests.""" 46 + id: int 47 + content: str 48 + 49 + 50 + ## 51 + # Fixtures 52 + 53 + 54 + @pytest.fixture 55 + def mock_s3(): 56 + """Provide mock S3 environment using moto.""" 57 + with mock_aws(): 58 + import boto3 59 + creds = { 60 + 'AWS_ACCESS_KEY_ID': 'testing', 61 + 'AWS_SECRET_ACCESS_KEY': 'testing' 62 + } 63 + s3_client = boto3.client( 64 + 's3', 65 + aws_access_key_id=creds['AWS_ACCESS_KEY_ID'], 66 + aws_secret_access_key=creds['AWS_SECRET_ACCESS_KEY'], 67 + region_name='us-east-1' 68 + ) 69 + bucket_name = 'integration-test-bucket' 70 + s3_client.create_bucket(Bucket=bucket_name) 71 + yield { 72 + 'credentials': creds, 73 + 'bucket': bucket_name, 74 + 'hive_path': f'{bucket_name}/datasets', 75 + 's3_client': s3_client 76 + } 77 + 78 + 79 + def create_workflow_dataset(tmp_path: Path, n_samples: int = 10) -> atdata.Dataset: 80 + """Create a WorkflowSample dataset.""" 81 + tmp_path.mkdir(parents=True, exist_ok=True) 82 + tar_path = tmp_path / "workflow-000000.tar" 83 + with wds.writer.TarWriter(str(tar_path)) as sink: 84 + for i in range(n_samples): 85 + sample = WorkflowSample( 86 + name=f"item_{i}", 87 + value=i * 100, 88 + score=float(i) * 0.5, 89 + ) 90 + sink.write(sample.as_wds) 91 + return atdata.Dataset[WorkflowSample](url=str(tar_path)) 92 + 93 + 94 + def create_array_dataset(tmp_path: Path, n_samples: int = 5) -> atdata.Dataset: 95 + """Create an ArrayWorkflowSample dataset.""" 96 + tmp_path.mkdir(parents=True, exist_ok=True) 97 + tar_path = tmp_path / "array-000000.tar" 98 + with wds.writer.TarWriter(str(tar_path)) as sink: 99 + for i in range(n_samples): 100 + sample = ArrayWorkflowSample( 101 + label=f"array_{i}", 102 + data=np.random.randn(32, 32).astype(np.float32), 103 + ) 104 + sink.write(sample.as_wds) 105 + return atdata.Dataset[ArrayWorkflowSample](url=str(tar_path)) 106 + 107 + 108 + ## 109 + # Full Workflow Tests 110 + 111 + 112 + class TestFullRepoWorkflow: 113 + """End-to-end tests for complete Repo workflow.""" 114 + 115 + @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") 116 + @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") 117 + def test_init_publish_schema_insert_query(self, mock_s3, clean_redis, tmp_path): 118 + """Full workflow: init repo → publish schema → insert → query entry.""" 119 + # Initialize repo 120 + repo = atlocal.Repo( 121 + s3_credentials=mock_s3['credentials'], 122 + hive_path=mock_s3['hive_path'], 123 + redis=clean_redis 124 + ) 125 + 126 + # Publish schema first 127 + schema_ref = repo.index.publish_schema(WorkflowSample) 128 + assert schema_ref is not None 129 + assert "WorkflowSample" in schema_ref 130 + 131 + # Create and insert dataset 132 + ds = create_workflow_dataset(tmp_path, n_samples=15) 133 + entry, new_ds = repo.insert(ds, name="workflow-test", maxcount=100) 134 + 135 + # Query back 136 + assert entry.cid is not None 137 + assert entry.name == "workflow-test" 138 + assert len(entry.data_urls) > 0 139 + 140 + # Verify in index 141 + all_entries = repo.index.all_entries 142 + assert len(all_entries) == 1 143 + assert all_entries[0].cid == entry.cid 144 + 145 + @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") 146 + @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") 147 + def test_multiple_datasets_same_schema(self, mock_s3, clean_redis, tmp_path): 148 + """Insert multiple datasets with same schema type.""" 149 + repo = atlocal.Repo( 150 + s3_credentials=mock_s3['credentials'], 151 + hive_path=mock_s3['hive_path'], 152 + redis=clean_redis 153 + ) 154 + 155 + # Create multiple datasets 156 + ds1 = create_workflow_dataset(tmp_path / "ds1", n_samples=10) 157 + ds2 = create_workflow_dataset(tmp_path / "ds2", n_samples=20) 158 + ds3 = create_workflow_dataset(tmp_path / "ds3", n_samples=5) 159 + 160 + entry1, _ = repo.insert(ds1, name="dataset-1", maxcount=100) 161 + entry2, _ = repo.insert(ds2, name="dataset-2", maxcount=100) 162 + entry3, _ = repo.insert(ds3, name="dataset-3", maxcount=100) 163 + 164 + # All should have same schema_ref pattern 165 + assert "WorkflowSample" in entry1.schema_ref 166 + assert "WorkflowSample" in entry2.schema_ref 167 + assert "WorkflowSample" in entry3.schema_ref 168 + 169 + # But different CIDs (different URLs) 170 + assert entry1.cid != entry2.cid 171 + assert entry2.cid != entry3.cid 172 + 173 + # All should be in index 174 + all_entries = repo.index.all_entries 175 + assert len(all_entries) == 3 176 + 177 + @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") 178 + @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") 179 + def test_different_schema_types(self, mock_s3, clean_redis, tmp_path): 180 + """Insert datasets with different schema types.""" 181 + repo = atlocal.Repo( 182 + s3_credentials=mock_s3['credentials'], 183 + hive_path=mock_s3['hive_path'], 184 + redis=clean_redis 185 + ) 186 + 187 + # Different sample types 188 + ds1 = create_workflow_dataset(tmp_path / "simple", n_samples=5) 189 + ds2 = create_array_dataset(tmp_path / "array", n_samples=3) 190 + 191 + entry1, _ = repo.insert(ds1, name="simple-ds", maxcount=100) 192 + entry2, _ = repo.insert(ds2, name="array-ds", maxcount=100) 193 + 194 + # Different schema refs 195 + assert "WorkflowSample" in entry1.schema_ref 196 + assert "ArrayWorkflowSample" in entry2.schema_ref 197 + 198 + assert len(repo.index.all_entries) == 2 199 + 200 + 201 + class TestSchemaManagement: 202 + """Tests for schema publishing and retrieval.""" 203 + 204 + def test_publish_schema_creates_record(self, clean_redis): 205 + """Publishing schema should create a retrievable record.""" 206 + index = atlocal.Index(redis=clean_redis) 207 + 208 + schema_ref = index.publish_schema(WorkflowSample) 209 + assert schema_ref is not None 210 + 211 + # Should be able to get schema back 212 + schema = index.get_schema(schema_ref) 213 + assert schema is not None 214 + # Schema name may or may not include module prefix 215 + assert "WorkflowSample" in schema["name"] 216 + 217 + # Should have correct fields 218 + field_names = {f["name"] for f in schema["fields"]} 219 + assert "name" in field_names 220 + assert "value" in field_names 221 + assert "score" in field_names 222 + 223 + def test_publish_schema_with_version(self, clean_redis): 224 + """Publishing schema with version should include version.""" 225 + index = atlocal.Index(redis=clean_redis) 226 + 227 + schema_ref = index.publish_schema(WorkflowSample, version="2.0.0") 228 + assert "2.0.0" in schema_ref 229 + 230 + schema = index.get_schema(schema_ref) 231 + assert schema["version"] == "2.0.0" 232 + 233 + def test_publish_schema_with_ndarray(self, clean_redis): 234 + """Schema with NDArray field should publish correctly.""" 235 + index = atlocal.Index(redis=clean_redis) 236 + 237 + schema_ref = index.publish_schema(ArrayWorkflowSample) 238 + schema = index.get_schema(schema_ref) 239 + 240 + # Find the data field 241 + data_field = next(f for f in schema["fields"] if f["name"] == "data") 242 + assert data_field["fieldType"]["$type"] == "local#ndarray" 243 + 244 + def test_list_schemas(self, clean_redis): 245 + """Should list all published schemas.""" 246 + index = atlocal.Index(redis=clean_redis) 247 + 248 + # Publish multiple schemas 249 + index.publish_schema(WorkflowSample, version="1.0.0") 250 + index.publish_schema(ArrayWorkflowSample, version="1.0.0") 251 + 252 + schemas = list(index.list_schemas()) 253 + assert len(schemas) >= 2 254 + 255 + def test_decode_schema_creates_type(self, clean_redis): 256 + """decode_schema should reconstruct a usable type.""" 257 + index = atlocal.Index(redis=clean_redis) 258 + 259 + schema_ref = index.publish_schema(WorkflowSample) 260 + reconstructed = index.decode_schema(schema_ref) 261 + 262 + assert reconstructed is not None 263 + # Should be able to create instances 264 + instance = reconstructed(name="test", value=42, score=0.5) 265 + assert instance.name == "test" 266 + assert instance.value == 42 267 + 268 + 269 + class TestCIDDeterminism: 270 + """Tests for CID generation consistency.""" 271 + 272 + def test_same_content_same_cid(self): 273 + """Identical content should produce identical CIDs.""" 274 + entry1 = atlocal.LocalDatasetEntry( 275 + _name="test", 276 + _schema_ref="local://schemas/Test@1.0.0", 277 + _data_urls=["s3://bucket/data.tar"], 278 + _metadata={"key": "value"}, 279 + ) 280 + entry2 = atlocal.LocalDatasetEntry( 281 + _name="test", 282 + _schema_ref="local://schemas/Test@1.0.0", 283 + _data_urls=["s3://bucket/data.tar"], 284 + _metadata={"key": "value"}, 285 + ) 286 + 287 + assert entry1.cid == entry2.cid 288 + 289 + def test_different_urls_different_cid(self): 290 + """Different data URLs should produce different CIDs.""" 291 + entry1 = atlocal.LocalDatasetEntry( 292 + _name="test", 293 + _schema_ref="local://schemas/Test@1.0.0", 294 + _data_urls=["s3://bucket/data-v1.tar"], 295 + ) 296 + entry2 = atlocal.LocalDatasetEntry( 297 + _name="test", 298 + _schema_ref="local://schemas/Test@1.0.0", 299 + _data_urls=["s3://bucket/data-v2.tar"], 300 + ) 301 + 302 + assert entry1.cid != entry2.cid 303 + 304 + def test_different_schema_different_cid(self): 305 + """Different schema refs should produce different CIDs.""" 306 + entry1 = atlocal.LocalDatasetEntry( 307 + _name="test", 308 + _schema_ref="local://schemas/TypeA@1.0.0", 309 + _data_urls=["s3://bucket/data.tar"], 310 + ) 311 + entry2 = atlocal.LocalDatasetEntry( 312 + _name="test", 313 + _schema_ref="local://schemas/TypeB@1.0.0", 314 + _data_urls=["s3://bucket/data.tar"], 315 + ) 316 + 317 + assert entry1.cid != entry2.cid 318 + 319 + def test_name_does_not_affect_cid(self): 320 + """Dataset name should not affect CID (only content matters).""" 321 + entry1 = atlocal.LocalDatasetEntry( 322 + _name="name-one", 323 + _schema_ref="local://schemas/Test@1.0.0", 324 + _data_urls=["s3://bucket/data.tar"], 325 + ) 326 + entry2 = atlocal.LocalDatasetEntry( 327 + _name="name-two", 328 + _schema_ref="local://schemas/Test@1.0.0", 329 + _data_urls=["s3://bucket/data.tar"], 330 + ) 331 + 332 + # CID based on schema_ref and data_urls, not name 333 + assert entry1.cid == entry2.cid 334 + 335 + def test_cid_format_is_valid(self): 336 + """CIDs should have valid ATProto-compatible format.""" 337 + entry = atlocal.LocalDatasetEntry( 338 + _name="test", 339 + _schema_ref="local://schemas/Test@1.0.0", 340 + _data_urls=["s3://bucket/data.tar"], 341 + ) 342 + 343 + # CIDv1 with dag-cbor starts with 'bafy' 344 + assert entry.cid.startswith("bafy") 345 + # Should be base32 encoded (alphanumeric lowercase) 346 + assert entry.cid.isalnum() 347 + assert entry.cid.islower() 348 + 349 + 350 + class TestDatasetDiscovery: 351 + """Tests for querying and discovering datasets.""" 352 + 353 + def test_get_entry_by_name(self, clean_redis): 354 + """Should retrieve entry by name.""" 355 + index = atlocal.Index(redis=clean_redis) 356 + 357 + # Add entries 358 + entry1 = atlocal.LocalDatasetEntry( 359 + _name="findme", 360 + _schema_ref="local://schemas/Test@1.0.0", 361 + _data_urls=["s3://bucket/findme.tar"], 362 + ) 363 + entry1.write_to(clean_redis) 364 + 365 + # Query by name 366 + found = index.get_entry_by_name("findme") 367 + assert found is not None 368 + assert found.name == "findme" 369 + 370 + def test_get_entry_by_cid(self, clean_redis): 371 + """Should retrieve entry by CID.""" 372 + index = atlocal.Index(redis=clean_redis) 373 + 374 + entry = atlocal.LocalDatasetEntry( 375 + _name="bycid", 376 + _schema_ref="local://schemas/Test@1.0.0", 377 + _data_urls=["s3://bucket/bycid.tar"], 378 + ) 379 + entry.write_to(clean_redis) 380 + 381 + # Query by CID 382 + found = index.get_entry(cid=entry.cid) 383 + assert found is not None 384 + assert found.cid == entry.cid 385 + assert found.name == "bycid" 386 + 387 + def test_list_all_datasets(self, clean_redis): 388 + """Should list all datasets in index.""" 389 + index = atlocal.Index(redis=clean_redis) 390 + 391 + # Add multiple entries 392 + for i in range(5): 393 + entry = atlocal.LocalDatasetEntry( 394 + _name=f"dataset-{i}", 395 + _schema_ref="local://schemas/Test@1.0.0", 396 + _data_urls=[f"s3://bucket/dataset-{i}.tar"], 397 + ) 398 + entry.write_to(clean_redis) 399 + 400 + # List all 401 + all_entries = list(index.entries) 402 + assert len(all_entries) == 5 403 + 404 + names = {e.name for e in all_entries} 405 + for i in range(5): 406 + assert f"dataset-{i}" in names 407 + 408 + def test_entries_generator_is_lazy(self, clean_redis): 409 + """entries property should be a generator, not load all at once.""" 410 + index = atlocal.Index(redis=clean_redis) 411 + 412 + # Add entries 413 + for i in range(10): 414 + entry = atlocal.LocalDatasetEntry( 415 + _name=f"lazy-{i}", 416 + _schema_ref="local://schemas/Test@1.0.0", 417 + _data_urls=[f"s3://bucket/lazy-{i}.tar"], 418 + ) 419 + entry.write_to(clean_redis) 420 + 421 + # Should be a generator 422 + entries = index.entries 423 + import types 424 + assert isinstance(entries, types.GeneratorType) 425 + 426 + # Can iterate partially 427 + first_three = [] 428 + for i, entry in enumerate(entries): 429 + first_three.append(entry) 430 + if i >= 2: 431 + break 432 + assert len(first_three) == 3 433 + 434 + 435 + class TestMetadataPersistence: 436 + """Tests for metadata preservation through storage cycle.""" 437 + 438 + @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") 439 + @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") 440 + def test_metadata_preserved_through_insert(self, mock_s3, clean_redis, tmp_path): 441 + """Metadata should be preserved when inserting dataset.""" 442 + repo = atlocal.Repo( 443 + s3_credentials=mock_s3['credentials'], 444 + hive_path=mock_s3['hive_path'], 445 + redis=clean_redis 446 + ) 447 + 448 + ds = create_workflow_dataset(tmp_path, n_samples=5) 449 + ds._metadata = { 450 + "version": "1.0.0", 451 + "author": "test", 452 + "created": "2024-01-01", 453 + "nested": {"key": "value", "count": 42}, 454 + } 455 + 456 + entry, new_ds = repo.insert(ds, name="with-metadata", maxcount=100) 457 + 458 + # Metadata should be in entry 459 + assert entry.metadata is not None 460 + assert entry.metadata["version"] == "1.0.0" 461 + assert entry.metadata["author"] == "test" 462 + assert entry.metadata["nested"]["key"] == "value" 463 + assert entry.metadata["nested"]["count"] == 42 464 + 465 + def test_metadata_round_trip_redis(self, clean_redis): 466 + """Metadata should round-trip through Redis correctly.""" 467 + original = atlocal.LocalDatasetEntry( 468 + _name="meta-test", 469 + _schema_ref="local://schemas/Test@1.0.0", 470 + _data_urls=["s3://bucket/data.tar"], 471 + _metadata={ 472 + "string": "hello", 473 + "number": 123, 474 + "float": 3.14, 475 + "bool": True, 476 + "list": [1, 2, 3], 477 + "nested": {"a": 1, "b": 2}, 478 + }, 479 + ) 480 + 481 + original.write_to(clean_redis) 482 + loaded = atlocal.LocalDatasetEntry.from_redis(clean_redis, original.cid) 483 + 484 + assert loaded.metadata == original.metadata 485 + assert loaded.metadata["string"] == "hello" 486 + assert loaded.metadata["number"] == 123 487 + assert loaded.metadata["list"] == [1, 2, 3] 488 + assert loaded.metadata["nested"]["a"] == 1 489 + 490 + def test_none_metadata_handled(self, clean_redis): 491 + """None metadata should be handled gracefully.""" 492 + entry = atlocal.LocalDatasetEntry( 493 + _name="no-meta", 494 + _schema_ref="local://schemas/Test@1.0.0", 495 + _data_urls=["s3://bucket/data.tar"], 496 + _metadata=None, 497 + ) 498 + 499 + entry.write_to(clean_redis) 500 + loaded = atlocal.LocalDatasetEntry.from_redis(clean_redis, entry.cid) 501 + 502 + # Should be None or empty, not error 503 + assert loaded.metadata is None or loaded.metadata == {} 504 + 505 + 506 + class TestCacheLocalModes: 507 + """Tests comparing cache_local=True vs cache_local=False modes.""" 508 + 509 + @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") 510 + @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") 511 + def test_cache_local_true_produces_valid_entry(self, mock_s3, clean_redis, tmp_path): 512 + """cache_local=True should produce valid index entry.""" 513 + repo = atlocal.Repo( 514 + s3_credentials=mock_s3['credentials'], 515 + hive_path=mock_s3['hive_path'], 516 + redis=clean_redis 517 + ) 518 + 519 + ds = create_workflow_dataset(tmp_path, n_samples=10) 520 + entry, new_ds = repo.insert(ds, name="cached", cache_local=True, maxcount=100) 521 + 522 + assert entry.cid is not None 523 + assert len(entry.data_urls) > 0 524 + assert ".tar" in entry.data_urls[0] 525 + 526 + @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") 527 + @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") 528 + def test_cache_local_false_produces_valid_entry(self, mock_s3, clean_redis, tmp_path): 529 + """cache_local=False should produce valid index entry.""" 530 + repo = atlocal.Repo( 531 + s3_credentials=mock_s3['credentials'], 532 + hive_path=mock_s3['hive_path'], 533 + redis=clean_redis 534 + ) 535 + 536 + ds = create_workflow_dataset(tmp_path, n_samples=10) 537 + entry, new_ds = repo.insert(ds, name="direct", cache_local=False, maxcount=100) 538 + 539 + assert entry.cid is not None 540 + assert len(entry.data_urls) > 0 541 + assert ".tar" in entry.data_urls[0] 542 + 543 + @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") 544 + @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") 545 + def test_both_modes_produce_same_structure(self, mock_s3, clean_redis, tmp_path): 546 + """Both cache modes should produce entries with same structure.""" 547 + repo = atlocal.Repo( 548 + s3_credentials=mock_s3['credentials'], 549 + hive_path=mock_s3['hive_path'], 550 + redis=clean_redis 551 + ) 552 + 553 + ds1 = create_workflow_dataset(tmp_path / "cached", n_samples=10) 554 + ds2 = create_workflow_dataset(tmp_path / "direct", n_samples=10) 555 + 556 + entry1, _ = repo.insert(ds1, name="cached-mode", cache_local=True, maxcount=100) 557 + entry2, _ = repo.insert(ds2, name="direct-mode", cache_local=False, maxcount=100) 558 + 559 + # Both should have valid structure 560 + assert entry1.schema_ref == entry2.schema_ref # Same type 561 + assert len(entry1.data_urls) == len(entry2.data_urls) # Same shard count 562 + 563 + 564 + class TestIndexEntryProtocol: 565 + """Tests for IndexEntry protocol compliance.""" 566 + 567 + def test_local_entry_implements_protocol(self): 568 + """LocalDatasetEntry should implement IndexEntry protocol.""" 569 + from atdata._protocols import IndexEntry 570 + 571 + entry = atlocal.LocalDatasetEntry( 572 + _name="protocol-test", 573 + _schema_ref="local://schemas/Test@1.0.0", 574 + _data_urls=["s3://bucket/data.tar"], 575 + ) 576 + 577 + assert isinstance(entry, IndexEntry) 578 + 579 + def test_entry_has_required_properties(self): 580 + """Entry should have all required IndexEntry properties.""" 581 + entry = atlocal.LocalDatasetEntry( 582 + _name="props-test", 583 + _schema_ref="local://schemas/Test@1.0.0", 584 + _data_urls=["s3://bucket/data.tar"], 585 + _metadata={"key": "value"}, 586 + ) 587 + 588 + # Required properties 589 + assert hasattr(entry, 'name') 590 + assert hasattr(entry, 'schema_ref') 591 + assert hasattr(entry, 'data_urls') 592 + assert hasattr(entry, 'metadata') 593 + assert hasattr(entry, 'cid') 594 + 595 + # Values accessible 596 + assert entry.name == "props-test" 597 + assert entry.schema_ref == "local://schemas/Test@1.0.0" 598 + assert entry.data_urls == ["s3://bucket/data.tar"] 599 + assert entry.metadata == {"key": "value"} 600 + 601 + def test_legacy_properties_work(self): 602 + """Legacy properties should still work for backwards compatibility.""" 603 + entry = atlocal.LocalDatasetEntry( 604 + _name="legacy-test", 605 + _schema_ref="local://schemas/Test@1.0.0", 606 + _data_urls=["s3://bucket/legacy.tar"], 607 + ) 608 + 609 + # Legacy aliases 610 + assert entry.wds_url == "s3://bucket/legacy.tar" 611 + assert entry.sample_kind == "local://schemas/Test@1.0.0" 612 + 613 + 614 + class TestMultiShardStorage: 615 + """Tests for multi-shard dataset storage.""" 616 + 617 + @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") 618 + @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") 619 + def test_large_dataset_creates_multiple_shards(self, mock_s3, clean_redis, tmp_path): 620 + """Large dataset should create multiple shard files.""" 621 + repo = atlocal.Repo( 622 + s3_credentials=mock_s3['credentials'], 623 + hive_path=mock_s3['hive_path'], 624 + redis=clean_redis 625 + ) 626 + 627 + # Create dataset with many samples 628 + tar_path = tmp_path / "large-000000.tar" 629 + with wds.writer.TarWriter(str(tar_path)) as sink: 630 + for i in range(100): 631 + sample = WorkflowSample( 632 + name=f"item_{i}", 633 + value=i, 634 + score=float(i), 635 + ) 636 + sink.write(sample.as_wds) 637 + 638 + ds = atdata.Dataset[WorkflowSample](url=str(tar_path)) 639 + 640 + # Insert with small maxcount to force sharding 641 + entry, new_ds = repo.insert(ds, name="sharded", maxcount=10) 642 + 643 + # Should have multiple shards (URL with brace notation) 644 + assert "{" in new_ds.url and "}" in new_ds.url 645 + 646 + @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") 647 + @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") 648 + def test_single_shard_no_brace_notation(self, mock_s3, clean_redis, tmp_path): 649 + """Small dataset should result in single shard without brace notation.""" 650 + repo = atlocal.Repo( 651 + s3_credentials=mock_s3['credentials'], 652 + hive_path=mock_s3['hive_path'], 653 + redis=clean_redis 654 + ) 655 + 656 + ds = create_workflow_dataset(tmp_path, n_samples=5) 657 + 658 + # Large maxcount ensures single shard 659 + entry, new_ds = repo.insert(ds, name="single", maxcount=1000) 660 + 661 + # Should be single file, no brace notation 662 + assert "{" not in new_ds.url 663 + assert ".tar" in new_ds.url
+646
tests/test_integration_promotion.py
··· 1 + """Integration tests for the promotion pipeline (local → atmosphere). 2 + 3 + Tests end-to-end promotion workflows including: 4 + - Full promotion with local index and mocked atmosphere 5 + - Schema deduplication across multiple promotions 6 + - Metadata preservation during promotion 7 + - Multi-dataset promotion with shared schemas 8 + - Large dataset handling with many shards 9 + """ 10 + 11 + import pytest 12 + from pathlib import Path 13 + from unittest.mock import Mock, MagicMock, patch 14 + from dataclasses import dataclass 15 + 16 + import numpy as np 17 + from numpy.typing import NDArray 18 + import webdataset as wds 19 + 20 + import atdata 21 + from atdata.local import LocalIndex, LocalDatasetEntry 22 + from atdata.promote import promote_to_atmosphere, _find_existing_schema 23 + from atdata.atmosphere import AtmosphereClient 24 + from atdata.atmosphere._types import LEXICON_NAMESPACE 25 + 26 + 27 + ## 28 + # Test sample types 29 + 30 + 31 + @atdata.packable 32 + class PromotionSample: 33 + """Sample for promotion tests.""" 34 + name: str 35 + value: int 36 + 37 + 38 + @atdata.packable 39 + class PromotionArraySample: 40 + """Sample with NDArray for promotion tests.""" 41 + label: str 42 + features: NDArray 43 + 44 + 45 + ## 46 + # Fixtures 47 + 48 + 49 + @pytest.fixture 50 + def mock_atproto_client(): 51 + """Create a mock atproto SDK client.""" 52 + mock = Mock() 53 + mock.me = MagicMock() 54 + mock.me.did = "did:plc:promotion123" 55 + mock.me.handle = "promotion.test.social" 56 + 57 + mock_profile = Mock() 58 + mock_profile.did = "did:plc:promotion123" 59 + mock_profile.handle = "promotion.test.social" 60 + mock.login.return_value = mock_profile 61 + mock.export_session_string.return_value = "test-session-export" 62 + 63 + return mock 64 + 65 + 66 + @pytest.fixture 67 + def authenticated_client(mock_atproto_client): 68 + """Create an authenticated AtmosphereClient.""" 69 + client = AtmosphereClient(_client=mock_atproto_client) 70 + client.login("promotion.test.social", "test-password") 71 + return client 72 + 73 + 74 + @pytest.fixture 75 + def local_index_with_data(clean_redis, tmp_path): 76 + """Create a LocalIndex with a sample dataset.""" 77 + index = LocalIndex(redis=clean_redis) 78 + 79 + # Publish schema 80 + schema_ref = index.publish_schema(PromotionSample, version="1.0.0") 81 + 82 + # Create a tar file with samples 83 + tar_path = tmp_path / "promotion-test-000000.tar" 84 + with wds.writer.TarWriter(str(tar_path)) as writer: 85 + for i in range(5): 86 + sample = PromotionSample(name=f"sample-{i}", value=i * 10) 87 + writer.write(sample.as_wds) 88 + 89 + # Create entry 90 + entry = LocalDatasetEntry( 91 + _name="promotion-test-dataset", 92 + _schema_ref=schema_ref, 93 + _data_urls=[str(tar_path)], 94 + _metadata={"version": "1.0", "sample_count": 5}, 95 + ) 96 + entry.write_to(clean_redis) 97 + 98 + return index, entry 99 + 100 + 101 + ## 102 + # Full Promotion Workflow Tests 103 + 104 + 105 + class TestFullPromotionWorkflow: 106 + """End-to-end tests for promotion workflow.""" 107 + 108 + def test_promote_local_to_atmosphere( 109 + self, local_index_with_data, authenticated_client, mock_atproto_client 110 + ): 111 + """Full workflow: LocalIndex dataset → promote → AtmosphereIndex.""" 112 + local_index, local_entry = local_index_with_data 113 + 114 + # Setup mock responses for atmosphere operations 115 + schema_response = Mock() 116 + schema_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/promoted-schema" 117 + 118 + dataset_response = Mock() 119 + dataset_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.dataset/promoted-dataset" 120 + 121 + mock_atproto_client.com.atproto.repo.create_record.side_effect = [ 122 + schema_response, 123 + dataset_response, 124 + ] 125 + 126 + # Mock list_records to return empty (no existing schema) 127 + mock_list_response = Mock() 128 + mock_list_response.records = [] 129 + mock_list_response.cursor = None 130 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_list_response 131 + 132 + # Promote 133 + result_uri = promote_to_atmosphere( 134 + local_entry, 135 + local_index, 136 + authenticated_client, 137 + ) 138 + 139 + assert result_uri is not None 140 + assert "at://" in result_uri 141 + 142 + def test_promoted_dataset_preserves_name( 143 + self, local_index_with_data, authenticated_client, mock_atproto_client 144 + ): 145 + """Promoted dataset should preserve the original name.""" 146 + local_index, local_entry = local_index_with_data 147 + 148 + # Setup mocks 149 + mock_list_response = Mock() 150 + mock_list_response.records = [] 151 + mock_list_response.cursor = None 152 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_list_response 153 + 154 + schema_response = Mock() 155 + schema_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/s1" 156 + 157 + dataset_response = Mock() 158 + dataset_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.dataset/d1" 159 + 160 + mock_atproto_client.com.atproto.repo.create_record.side_effect = [ 161 + schema_response, 162 + dataset_response, 163 + ] 164 + 165 + promote_to_atmosphere(local_entry, local_index, authenticated_client) 166 + 167 + # Check that dataset was published with correct name 168 + calls = mock_atproto_client.com.atproto.repo.create_record.call_args_list 169 + dataset_call = calls[-1] # Last call is for dataset 170 + record = dataset_call.kwargs["data"]["record"] 171 + assert record["name"] == "promotion-test-dataset" 172 + 173 + def test_promoted_dataset_preserves_data_urls( 174 + self, local_index_with_data, authenticated_client, mock_atproto_client 175 + ): 176 + """Promoted dataset should use the original data URLs.""" 177 + local_index, local_entry = local_index_with_data 178 + 179 + # Setup mocks 180 + mock_list_response = Mock() 181 + mock_list_response.records = [] 182 + mock_list_response.cursor = None 183 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_list_response 184 + 185 + schema_response = Mock() 186 + schema_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/s1" 187 + 188 + dataset_response = Mock() 189 + dataset_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.dataset/d1" 190 + 191 + mock_atproto_client.com.atproto.repo.create_record.side_effect = [ 192 + schema_response, 193 + dataset_response, 194 + ] 195 + 196 + promote_to_atmosphere(local_entry, local_index, authenticated_client) 197 + 198 + # Check that dataset was published with original URLs 199 + calls = mock_atproto_client.com.atproto.repo.create_record.call_args_list 200 + dataset_call = calls[-1] 201 + record = dataset_call.kwargs["data"]["record"] 202 + assert "storage" in record 203 + assert local_entry.data_urls[0] in str(record["storage"]["urls"]) 204 + 205 + 206 + ## 207 + # Schema Deduplication Tests 208 + 209 + 210 + class TestSchemaDeduplication: 211 + """Tests for schema deduplication during promotion.""" 212 + 213 + def test_reuses_existing_schema( 214 + self, local_index_with_data, authenticated_client, mock_atproto_client 215 + ): 216 + """Promotion should reuse existing schema instead of creating duplicate.""" 217 + local_index, local_entry = local_index_with_data 218 + 219 + # Patch _find_existing_schema to return an existing schema URI 220 + with patch("atdata.promote._find_existing_schema") as mock_find: 221 + mock_find.return_value = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/existing" 222 + 223 + # Only dataset should be created (schema exists) 224 + dataset_response = Mock() 225 + dataset_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.dataset/d1" 226 + mock_atproto_client.com.atproto.repo.create_record.return_value = dataset_response 227 + 228 + promote_to_atmosphere(local_entry, local_index, authenticated_client) 229 + 230 + # Should only have 1 create_record call (for dataset, not schema) 231 + assert mock_atproto_client.com.atproto.repo.create_record.call_count == 1 232 + 233 + # Verify it was the dataset call 234 + call_kwargs = mock_atproto_client.com.atproto.repo.create_record.call_args.kwargs 235 + assert "dataset" in call_kwargs["data"]["collection"] 236 + 237 + def test_creates_schema_when_not_found( 238 + self, local_index_with_data, authenticated_client, mock_atproto_client 239 + ): 240 + """Promotion should create new schema when none exists.""" 241 + local_index, local_entry = local_index_with_data 242 + 243 + # Mock empty list (no existing schemas) 244 + mock_list_response = Mock() 245 + mock_list_response.records = [] 246 + mock_list_response.cursor = None 247 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_list_response 248 + 249 + # Both schema and dataset should be created 250 + schema_response = Mock() 251 + schema_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/new" 252 + 253 + dataset_response = Mock() 254 + dataset_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.dataset/d1" 255 + 256 + mock_atproto_client.com.atproto.repo.create_record.side_effect = [ 257 + schema_response, 258 + dataset_response, 259 + ] 260 + 261 + promote_to_atmosphere(local_entry, local_index, authenticated_client) 262 + 263 + # Should have 2 create_record calls (schema + dataset) 264 + assert mock_atproto_client.com.atproto.repo.create_record.call_count == 2 265 + 266 + def test_version_mismatch_creates_new_schema( 267 + self, local_index_with_data, authenticated_client, mock_atproto_client 268 + ): 269 + """Different version should create new schema even if name matches.""" 270 + local_index, local_entry = local_index_with_data 271 + 272 + # Mock existing schema with different version 273 + existing_schema = Mock() 274 + existing_schema.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/v1" 275 + existing_schema.value = { 276 + "name": "test_integration_promotion.PromotionSample", 277 + "version": "2.0.0", # Different version! 278 + } 279 + 280 + mock_list_response = Mock() 281 + mock_list_response.records = [existing_schema] 282 + mock_list_response.cursor = None 283 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_list_response 284 + 285 + # Both should be created (version mismatch) 286 + schema_response = Mock() 287 + schema_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/v1new" 288 + 289 + dataset_response = Mock() 290 + dataset_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.dataset/d1" 291 + 292 + mock_atproto_client.com.atproto.repo.create_record.side_effect = [ 293 + schema_response, 294 + dataset_response, 295 + ] 296 + 297 + promote_to_atmosphere(local_entry, local_index, authenticated_client) 298 + 299 + # Should have 2 create_record calls (new schema + dataset) 300 + assert mock_atproto_client.com.atproto.repo.create_record.call_count == 2 301 + 302 + 303 + ## 304 + # Metadata Preservation Tests 305 + 306 + 307 + class TestMetadataPreservation: 308 + """Tests for metadata preservation during promotion.""" 309 + 310 + def test_metadata_included_in_promoted_dataset( 311 + self, local_index_with_data, authenticated_client, mock_atproto_client 312 + ): 313 + """Metadata from local entry should be included in promoted dataset.""" 314 + local_index, local_entry = local_index_with_data 315 + 316 + # Setup mocks 317 + mock_list_response = Mock() 318 + mock_list_response.records = [] 319 + mock_list_response.cursor = None 320 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_list_response 321 + 322 + schema_response = Mock() 323 + schema_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/s1" 324 + 325 + dataset_response = Mock() 326 + dataset_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.dataset/d1" 327 + 328 + mock_atproto_client.com.atproto.repo.create_record.side_effect = [ 329 + schema_response, 330 + dataset_response, 331 + ] 332 + 333 + promote_to_atmosphere(local_entry, local_index, authenticated_client) 334 + 335 + # Check metadata was passed 336 + calls = mock_atproto_client.com.atproto.repo.create_record.call_args_list 337 + dataset_call = calls[-1] 338 + record = dataset_call.kwargs["data"]["record"] 339 + 340 + # The metadata should be in the record (may be msgpack encoded) 341 + assert "metadata" in record 342 + 343 + def test_none_metadata_handled(self, clean_redis, authenticated_client, mock_atproto_client): 344 + """Entry without metadata should promote successfully.""" 345 + index = LocalIndex(redis=clean_redis) 346 + schema_ref = index.publish_schema(PromotionSample, version="1.0.0") 347 + 348 + entry = LocalDatasetEntry( 349 + _name="no-metadata-dataset", 350 + _schema_ref=schema_ref, 351 + _data_urls=["s3://bucket/data.tar"], 352 + # No _metadata specified 353 + ) 354 + entry.write_to(clean_redis) 355 + 356 + # Setup mocks 357 + mock_list_response = Mock() 358 + mock_list_response.records = [] 359 + mock_list_response.cursor = None 360 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_list_response 361 + 362 + schema_response = Mock() 363 + schema_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/s1" 364 + 365 + dataset_response = Mock() 366 + dataset_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.dataset/d1" 367 + 368 + mock_atproto_client.com.atproto.repo.create_record.side_effect = [ 369 + schema_response, 370 + dataset_response, 371 + ] 372 + 373 + # Should not raise 374 + result = promote_to_atmosphere(entry, index, authenticated_client) 375 + assert result is not None 376 + 377 + 378 + ## 379 + # Multi-Dataset Promotion Tests 380 + 381 + 382 + class TestMultiDatasetPromotion: 383 + """Tests for promoting multiple datasets.""" 384 + 385 + def test_multiple_datasets_share_schema( 386 + self, clean_redis, authenticated_client, mock_atproto_client 387 + ): 388 + """Multiple datasets using same schema should reuse the schema.""" 389 + index = LocalIndex(redis=clean_redis) 390 + schema_ref = index.publish_schema(PromotionSample, version="1.0.0") 391 + 392 + # Create multiple entries with same schema 393 + entries = [] 394 + for i in range(3): 395 + entry = LocalDatasetEntry( 396 + _name=f"dataset-{i}", 397 + _schema_ref=schema_ref, 398 + _data_urls=[f"s3://bucket/data-{i}.tar"], 399 + ) 400 + entry.write_to(clean_redis) 401 + entries.append(entry) 402 + 403 + # Track whether schema has been "published" to atmosphere 404 + schema_published = {"value": False} 405 + schema_uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/shared" 406 + 407 + def mock_find_existing(client, name, version): 408 + # Return schema URI after first promotion 409 + if schema_published["value"]: 410 + return schema_uri 411 + return None 412 + 413 + # Setup create_record responses 414 + schema_response = Mock() 415 + schema_response.uri = schema_uri 416 + 417 + dataset_responses = [ 418 + Mock(uri=f"at://did:plc:test/{LEXICON_NAMESPACE}.dataset/d{i}") 419 + for i in range(3) 420 + ] 421 + 422 + mock_atproto_client.com.atproto.repo.create_record.side_effect = [ 423 + schema_response, # First promotion creates schema 424 + dataset_responses[0], # First dataset 425 + dataset_responses[1], # Second dataset 426 + dataset_responses[2], # Third dataset 427 + ] 428 + 429 + with patch("atdata.promote._find_existing_schema", side_effect=mock_find_existing): 430 + # Promote all three 431 + for i, entry in enumerate(entries): 432 + promote_to_atmosphere(entry, index, authenticated_client) 433 + # After first promotion, schema exists 434 + if i == 0: 435 + schema_published["value"] = True 436 + 437 + # Should have 4 create_record calls: 1 schema + 3 datasets 438 + assert mock_atproto_client.com.atproto.repo.create_record.call_count == 4 439 + 440 + 441 + ## 442 + # Large Dataset Tests 443 + 444 + 445 + class TestLargeDatasetPromotion: 446 + """Tests for promoting datasets with many shards.""" 447 + 448 + def test_many_shards_promoted( 449 + self, clean_redis, authenticated_client, mock_atproto_client 450 + ): 451 + """Dataset with many shards should have all URLs promoted.""" 452 + index = LocalIndex(redis=clean_redis) 453 + schema_ref = index.publish_schema(PromotionSample, version="1.0.0") 454 + 455 + # Create entry with many shards 456 + shard_urls = [f"s3://bucket/shard-{i:06d}.tar" for i in range(100)] 457 + entry = LocalDatasetEntry( 458 + _name="large-dataset", 459 + _schema_ref=schema_ref, 460 + _data_urls=shard_urls, 461 + ) 462 + entry.write_to(clean_redis) 463 + 464 + # Setup mocks 465 + mock_list_response = Mock() 466 + mock_list_response.records = [] 467 + mock_list_response.cursor = None 468 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_list_response 469 + 470 + schema_response = Mock() 471 + schema_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/s1" 472 + 473 + dataset_response = Mock() 474 + dataset_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.dataset/large" 475 + 476 + mock_atproto_client.com.atproto.repo.create_record.side_effect = [ 477 + schema_response, 478 + dataset_response, 479 + ] 480 + 481 + promote_to_atmosphere(entry, index, authenticated_client) 482 + 483 + # Verify all 100 URLs were included 484 + calls = mock_atproto_client.com.atproto.repo.create_record.call_args_list 485 + dataset_call = calls[-1] 486 + record = dataset_call.kwargs["data"]["record"] 487 + storage_urls = record["storage"]["urls"] 488 + 489 + assert len(storage_urls) == 100 490 + assert storage_urls[0] == "s3://bucket/shard-000000.tar" 491 + assert storage_urls[99] == "s3://bucket/shard-000099.tar" 492 + 493 + 494 + ## 495 + # Error Handling Tests 496 + 497 + 498 + class TestPromotionErrors: 499 + """Tests for error handling during promotion.""" 500 + 501 + def test_empty_data_urls_raises(self, clean_redis, authenticated_client): 502 + """Promotion of entry with no data URLs should raise.""" 503 + index = LocalIndex(redis=clean_redis) 504 + schema_ref = index.publish_schema(PromotionSample, version="1.0.0") 505 + 506 + entry = LocalDatasetEntry( 507 + _name="empty-dataset", 508 + _schema_ref=schema_ref, 509 + _data_urls=[], 510 + ) 511 + 512 + with pytest.raises(ValueError, match="has no data URLs"): 513 + promote_to_atmosphere(entry, index, authenticated_client) 514 + 515 + def test_missing_schema_raises( 516 + self, clean_redis, authenticated_client, mock_atproto_client 517 + ): 518 + """Promotion with missing local schema should raise.""" 519 + index = LocalIndex(redis=clean_redis) 520 + 521 + # Entry references a schema that doesn't exist 522 + entry = LocalDatasetEntry( 523 + _name="orphan-dataset", 524 + _schema_ref="local://schemas/NonExistent@1.0.0", 525 + _data_urls=["s3://bucket/data.tar"], 526 + ) 527 + 528 + with pytest.raises(KeyError): 529 + promote_to_atmosphere(entry, index, authenticated_client) 530 + 531 + 532 + ## 533 + # Custom Options Tests 534 + 535 + 536 + class TestPromotionOptions: 537 + """Tests for promotion with custom options.""" 538 + 539 + def test_custom_name_override( 540 + self, local_index_with_data, authenticated_client, mock_atproto_client 541 + ): 542 + """Custom name should override local entry name.""" 543 + local_index, local_entry = local_index_with_data 544 + 545 + # Setup mocks 546 + mock_list_response = Mock() 547 + mock_list_response.records = [] 548 + mock_list_response.cursor = None 549 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_list_response 550 + 551 + schema_response = Mock() 552 + schema_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/s1" 553 + 554 + dataset_response = Mock() 555 + dataset_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.dataset/d1" 556 + 557 + mock_atproto_client.com.atproto.repo.create_record.side_effect = [ 558 + schema_response, 559 + dataset_response, 560 + ] 561 + 562 + promote_to_atmosphere( 563 + local_entry, 564 + local_index, 565 + authenticated_client, 566 + name="custom-promoted-name", 567 + ) 568 + 569 + calls = mock_atproto_client.com.atproto.repo.create_record.call_args_list 570 + dataset_call = calls[-1] 571 + record = dataset_call.kwargs["data"]["record"] 572 + assert record["name"] == "custom-promoted-name" 573 + 574 + def test_tags_and_license( 575 + self, local_index_with_data, authenticated_client, mock_atproto_client 576 + ): 577 + """Tags and license should be passed to promoted dataset.""" 578 + local_index, local_entry = local_index_with_data 579 + 580 + # Setup mocks 581 + mock_list_response = Mock() 582 + mock_list_response.records = [] 583 + mock_list_response.cursor = None 584 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_list_response 585 + 586 + schema_response = Mock() 587 + schema_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/s1" 588 + 589 + dataset_response = Mock() 590 + dataset_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.dataset/d1" 591 + 592 + mock_atproto_client.com.atproto.repo.create_record.side_effect = [ 593 + schema_response, 594 + dataset_response, 595 + ] 596 + 597 + promote_to_atmosphere( 598 + local_entry, 599 + local_index, 600 + authenticated_client, 601 + tags=["ml", "training", "images"], 602 + license="Apache-2.0", 603 + ) 604 + 605 + calls = mock_atproto_client.com.atproto.repo.create_record.call_args_list 606 + dataset_call = calls[-1] 607 + record = dataset_call.kwargs["data"]["record"] 608 + 609 + assert record.get("tags") == ["ml", "training", "images"] 610 + assert record.get("license") == "Apache-2.0" 611 + 612 + def test_description_passed( 613 + self, local_index_with_data, authenticated_client, mock_atproto_client 614 + ): 615 + """Description should be passed to promoted dataset.""" 616 + local_index, local_entry = local_index_with_data 617 + 618 + # Setup mocks 619 + mock_list_response = Mock() 620 + mock_list_response.records = [] 621 + mock_list_response.cursor = None 622 + mock_atproto_client.com.atproto.repo.list_records.return_value = mock_list_response 623 + 624 + schema_response = Mock() 625 + schema_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/s1" 626 + 627 + dataset_response = Mock() 628 + dataset_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.dataset/d1" 629 + 630 + mock_atproto_client.com.atproto.repo.create_record.side_effect = [ 631 + schema_response, 632 + dataset_response, 633 + ] 634 + 635 + promote_to_atmosphere( 636 + local_entry, 637 + local_index, 638 + authenticated_client, 639 + description="A promoted dataset for testing purposes.", 640 + ) 641 + 642 + calls = mock_atproto_client.com.atproto.repo.create_record.call_args_list 643 + dataset_call = calls[-1] 644 + record = dataset_call.kwargs["data"]["record"] 645 + 646 + assert record.get("description") == "A promoted dataset for testing purposes."