A loose federation of distributed, typed datasets
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 633 lines 21 kB view raw
1"""Integration tests for error handling and recovery. 2 3Tests error conditions and graceful failure including: 4- Missing schemas and data URLs 5- Malformed data (msgpack, tar) 6- Connection failures (Redis, S3, ATProto) 7- Authentication and rate limiting errors 8- Timeout scenarios 9- Partial failures in multi-shard datasets 10""" 11 12import pytest 13from unittest.mock import Mock, MagicMock, patch 14import tarfile 15import io 16 17 18import atdata 19import webdataset as wds 20from atdata.local import LocalIndex, LocalDatasetEntry 21from atdata.atmosphere import AtmosphereClient, AtUri 22 23 24## 25# Test sample types 26 27 28@atdata.packable 29class ErrorTestSample: 30 """Sample for error handling tests.""" 31 32 name: str 33 value: int 34 35 36## 37# Schema Error Tests 38 39 40class TestMissingSchema: 41 """Tests for missing schema errors.""" 42 43 def test_missing_schema_raises_keyerror(self, clean_redis): 44 """Accessing non-existent schema should raise KeyError.""" 45 index = LocalIndex(redis=clean_redis) 46 47 with pytest.raises(KeyError): 48 index.get_schema("local://schemas/NonExistent@1.0.0") 49 50 def test_dataset_with_invalid_schema_ref(self, clean_redis): 51 """Dataset entry with invalid schema ref should error on decode.""" 52 index = LocalIndex(redis=clean_redis) 53 54 entry = LocalDatasetEntry( 55 name="orphan-dataset", 56 schema_ref="local://schemas/DoesNotExist@1.0.0", 57 data_urls=["s3://bucket/data.tar"], 58 ) 59 entry.write_to(clean_redis) 60 61 # Entry exists but schema doesn't 62 retrieved = index.get_entry_by_name("orphan-dataset") 63 assert retrieved is not None 64 65 # Attempting to decode schema should fail 66 with pytest.raises(KeyError): 67 index.decode_schema(retrieved.schema_ref) 68 69 70## 71# Data URL Error Tests 72 73 74class TestMissingDataUrls: 75 """Tests for missing or inaccessible data URLs.""" 76 77 def test_empty_data_urls_raises(self, clean_redis): 78 """Dataset entry with empty URLs should be flagged.""" 79 index = LocalIndex(redis=clean_redis) 80 schema_ref = index.publish_schema(ErrorTestSample, version="1.0.0") 81 82 entry = LocalDatasetEntry( 83 name="empty-urls", 84 schema_ref=schema_ref, 85 data_urls=[], 86 ) 87 entry.write_to(clean_redis) 88 89 retrieved = index.get_entry_by_name("empty-urls") 90 assert retrieved.data_urls == [] 91 92 def test_nonexistent_tar_raises(self, tmp_path): 93 """Attempting to read non-existent tar should raise.""" 94 nonexistent_path = tmp_path / "does-not-exist.tar" 95 96 ds = atdata.Dataset[ErrorTestSample](str(nonexistent_path)) 97 98 # Iterating should fail 99 with pytest.raises(FileNotFoundError): 100 list(ds.ordered(batch_size=None)) 101 102 103## 104# Malformed Data Tests 105 106 107class TestMalformedMsgpack: 108 """Tests for corrupted msgpack data.""" 109 110 def test_invalid_msgpack_in_tar(self, tmp_path): 111 """Tar with invalid msgpack should raise on iteration.""" 112 tar_path = tmp_path / "corrupted-000000.tar" 113 114 # Create tar with invalid msgpack data 115 with tarfile.open(tar_path, "w") as tar: 116 # Add a valid key file 117 key_data = b"sample-0" 118 key_info = tarfile.TarInfo(name="sample-0.__key__") 119 key_info.size = len(key_data) 120 tar.addfile(key_info, fileobj=io.BytesIO(key_data)) 121 122 # Add invalid msgpack data 123 invalid_data = b"\xff\xff\xff\xff\xff" # Not valid msgpack 124 info = tarfile.TarInfo(name="sample-0.msgpack") 125 info.size = len(invalid_data) 126 tar.addfile(info, fileobj=io.BytesIO(invalid_data)) 127 128 ds = atdata.Dataset[ErrorTestSample](str(tar_path)) 129 130 # Should raise an error when trying to deserialize 131 with pytest.raises(Exception): # Could be msgpack error or ValueError 132 list(ds.ordered(batch_size=None)) 133 134 135class TestCorruptedTar: 136 """Tests for corrupted tar files.""" 137 138 def test_truncated_tar_raises(self, tmp_path): 139 """Truncated tar file should raise an error.""" 140 tar_path = tmp_path / "truncated-000000.tar" 141 142 # Create a valid tar then truncate it 143 with tarfile.open(tar_path, "w") as tar: 144 data = b"test data" 145 info = tarfile.TarInfo(name="test.txt") 146 info.size = len(data) 147 import io 148 149 tar.addfile(info, fileobj=io.BytesIO(data)) 150 151 # Truncate the file 152 with open(tar_path, "r+b") as f: 153 f.truncate(50) # Truncate to partial content 154 155 ds = atdata.Dataset[ErrorTestSample](str(tar_path)) 156 157 with pytest.raises(Exception): # tarfile.ReadError or similar 158 list(ds.ordered(batch_size=None)) 159 160 def test_not_a_tar_file_raises(self, tmp_path): 161 """Non-tar file should raise clear error.""" 162 fake_tar = tmp_path / "fake-000000.tar" 163 164 # Write random bytes 165 with open(fake_tar, "wb") as f: 166 f.write(b"This is not a tar file at all!") 167 168 ds = atdata.Dataset[ErrorTestSample](str(fake_tar)) 169 170 with pytest.raises(Exception): # tarfile.ReadError 171 list(ds.ordered(batch_size=None)) 172 173 174## 175# Redis Error Tests 176 177 178class TestRedisErrors: 179 """Tests for Redis connection errors.""" 180 181 def test_redis_connection_error(self): 182 """Operations with bad Redis connection should fail cleanly.""" 183 from redis import Redis, ConnectionError 184 185 # Create index with invalid Redis connection 186 bad_redis = Redis( 187 host="nonexistent.invalid.host", port=9999, socket_timeout=0.1 188 ) 189 190 index = LocalIndex(redis=bad_redis) 191 192 # Operations should raise connection errors 193 with pytest.raises((ConnectionError, Exception)): 194 index.publish_schema(ErrorTestSample, version="1.0.0") 195 196 def test_entry_lookup_with_bad_redis(self, clean_redis): 197 """Entry lookup should fail cleanly if Redis becomes unavailable.""" 198 index = LocalIndex(redis=clean_redis) 199 200 # First, add an entry 201 schema_ref = index.publish_schema(ErrorTestSample, version="1.0.0") 202 entry = LocalDatasetEntry( 203 name="test-entry", 204 schema_ref=schema_ref, 205 data_urls=["s3://bucket/data.tar"], 206 ) 207 entry.write_to(clean_redis) 208 209 # Entry should be retrievable 210 retrieved = index.get_entry_by_name("test-entry") 211 assert retrieved is not None 212 213 214## 215# ATProto Error Tests 216 217 218class TestAtProtoErrors: 219 """Tests for ATProto/Atmosphere errors.""" 220 221 def test_unauthenticated_publish_raises(self): 222 """Publishing without authentication should raise.""" 223 mock_client = Mock() 224 mock_client.me = None 225 226 client = AtmosphereClient(_client=mock_client) 227 228 # Not authenticated 229 assert not client.is_authenticated 230 231 from atdata.atmosphere import SchemaPublisher 232 233 publisher = SchemaPublisher(client) 234 235 with pytest.raises(ValueError, match="authenticated"): 236 publisher.publish(ErrorTestSample, version="1.0.0") 237 238 def test_invalid_at_uri_raises(self): 239 """Parsing invalid AT URI should raise ValueError.""" 240 invalid_uris = [ 241 "not-a-uri", 242 "https://example.com/path", 243 "at://", 244 "at://did:plc:abc", # Missing collection and rkey 245 "at://did:plc:abc/collection", # Missing rkey 246 ] 247 248 for uri in invalid_uris: 249 with pytest.raises(ValueError): 250 AtUri.parse(uri) 251 252 def test_api_error_response_handling(self): 253 """API errors should be propagated appropriately.""" 254 mock_client = Mock() 255 mock_client.me = MagicMock() 256 mock_client.me.did = "did:plc:test123" 257 258 # Simulate an API error 259 from atproto_client.exceptions import AtProtocolError 260 261 mock_client.com.atproto.repo.create_record.side_effect = AtProtocolError( 262 "API error occurred" 263 ) 264 265 # Create client and authenticate it 266 client = AtmosphereClient(_client=mock_client) 267 client._session = {"did": "did:plc:test123"} # Mark as authenticated 268 269 from atdata.atmosphere import SchemaPublisher 270 271 publisher = SchemaPublisher(client) 272 273 # Should propagate the API error 274 with pytest.raises(AtProtocolError): 275 publisher.publish(ErrorTestSample, version="1.0.0") 276 277 def test_expired_session_detection(self): 278 """Expired session should be detectable.""" 279 mock_client = Mock() 280 mock_client.me = None 281 mock_client.export_session_string.return_value = None 282 283 client = AtmosphereClient(_client=mock_client) 284 285 # Should not be authenticated 286 assert not client.is_authenticated 287 288 289## 290# Entry Not Found Tests 291 292 293class TestNotFoundErrors: 294 """Tests for not-found error handling.""" 295 296 def test_get_entry_by_name_not_found(self, clean_redis): 297 """Getting non-existent entry by name should raise KeyError.""" 298 index = LocalIndex(redis=clean_redis) 299 300 with pytest.raises(KeyError): 301 index.get_entry_by_name("nonexistent-dataset") 302 303 def test_get_entry_by_cid_not_found(self, clean_redis): 304 """Getting non-existent entry by CID should raise KeyError.""" 305 index = LocalIndex(redis=clean_redis) 306 307 with pytest.raises(KeyError): 308 index.get_entry("bafyreifake123456789") 309 310 311## 312# Error Message Quality Tests 313 314 315class TestErrorMessageQuality: 316 """Tests that error messages are helpful and don't leak sensitive info.""" 317 318 def test_missing_schema_error_includes_ref(self, clean_redis): 319 """Missing schema error should include the schema reference.""" 320 index = LocalIndex(redis=clean_redis) 321 322 try: 323 index.get_schema("local://schemas/MissingType@1.0.0") 324 assert False, "Should have raised KeyError" 325 except KeyError as e: 326 # Error should mention the schema reference 327 assert "MissingType" in str(e) or "local://" in str(e) 328 329 def test_invalid_uri_error_is_clear(self): 330 """Invalid AT URI error should explain the issue.""" 331 try: 332 AtUri.parse("not-valid") 333 assert False, "Should have raised ValueError" 334 except ValueError as e: 335 # Error should explain it's not a valid URI 336 assert "at://" in str(e).lower() or "uri" in str(e).lower() 337 338 def test_auth_error_no_credential_leak(self): 339 """Authentication errors should not leak credentials.""" 340 mock_client = Mock() 341 mock_client.me = None 342 343 client = AtmosphereClient(_client=mock_client) 344 345 from atdata.atmosphere import SchemaPublisher 346 347 publisher = SchemaPublisher(client) 348 349 try: 350 publisher.publish(ErrorTestSample, version="1.0.0") 351 except ValueError as e: 352 error_msg = str(e) 353 # Should not contain anything that looks like a password or token 354 assert "password" not in error_msg.lower() 355 assert "token" not in error_msg.lower() 356 assert "secret" not in error_msg.lower() 357 358 359## 360# Recovery Tests 361 362 363class TestRecovery: 364 """Tests for recovery from errors.""" 365 366 def test_can_continue_after_bad_sample(self, tmp_path, clean_redis): 367 """System should be usable after encountering bad data.""" 368 # First, try to read a bad file 369 bad_tar = tmp_path / "bad-000000.tar" 370 with open(bad_tar, "wb") as f: 371 f.write(b"not a tar file") 372 373 ds_bad = atdata.Dataset[ErrorTestSample](str(bad_tar)) 374 try: 375 list(ds_bad.ordered(batch_size=None)) 376 except Exception: 377 pass # Expected to fail 378 379 # Now use a good file - should still work 380 good_tar = tmp_path / "good-000000.tar" 381 import webdataset as wds 382 383 with wds.writer.TarWriter(str(good_tar)) as writer: 384 sample = ErrorTestSample(name="good", value=42) 385 writer.write(sample.as_wds) 386 387 ds_good = atdata.Dataset[ErrorTestSample](str(good_tar)) 388 samples = list(ds_good.ordered(batch_size=None)) 389 390 assert len(samples) == 1 391 assert samples[0].name == "good" 392 393 def test_index_usable_after_failed_publish(self, clean_redis): 394 """Index should remain usable after a failed operation.""" 395 index = LocalIndex(redis=clean_redis) 396 397 # Try to get a non-existent schema (fails as expected) 398 with pytest.raises(KeyError): 399 index.get_schema("local://schemas/NoSuch@1.0.0") 400 401 # Index should still work 402 schema_ref = index.publish_schema(ErrorTestSample, version="1.0.0") 403 assert schema_ref is not None 404 405 schema = index.get_schema(schema_ref) 406 assert schema["name"] == "ErrorTestSample" 407 408 409## 410# Validation Tests 411 412 413class TestInputValidation: 414 """Tests for input validation.""" 415 416 def test_empty_version_string(self, clean_redis): 417 """Empty version string should be handled.""" 418 index = LocalIndex(redis=clean_redis) 419 420 # Empty version - implementation may accept or reject 421 schema_ref = index.publish_schema(ErrorTestSample, version="") 422 # If it accepts, it should store and retrieve correctly 423 schema = index.get_schema(schema_ref) 424 assert schema is not None 425 426 def test_special_chars_in_version(self, clean_redis): 427 """Special characters in version should be handled.""" 428 index = LocalIndex(redis=clean_redis) 429 430 schema_ref = index.publish_schema( 431 ErrorTestSample, version="1.0.0-beta+build.123" 432 ) 433 schema = index.get_schema(schema_ref) 434 435 assert schema["version"] == "1.0.0-beta+build.123" 436 437 438## 439# Timeout Tests 440 441 442class TestTimeoutScenarios: 443 """Tests for timeout and slow connection scenarios.""" 444 445 def test_redis_socket_timeout(self): 446 """Redis operations should fail with socket timeout.""" 447 from redis import Redis 448 449 # Very short timeout to force failure 450 redis = Redis( 451 host="10.255.255.1", # Non-routable IP 452 port=6379, 453 socket_timeout=0.01, 454 socket_connect_timeout=0.01, 455 ) 456 457 index = LocalIndex(redis=redis) 458 459 # Should timeout quickly rather than hang 460 with pytest.raises(Exception): # TimeoutError or ConnectionError 461 index.publish_schema(ErrorTestSample, version="1.0.0") 462 463 def test_slow_iteration_continues(self, tmp_path): 464 """Dataset iteration should handle slow reads gracefully.""" 465 # Create a valid dataset 466 tar_path = tmp_path / "slow-000000.tar" 467 with wds.writer.TarWriter(str(tar_path)) as writer: 468 for i in range(5): 469 sample = ErrorTestSample(name=f"sample_{i}", value=i) 470 writer.write(sample.as_wds) 471 472 ds = atdata.Dataset[ErrorTestSample](str(tar_path)) 473 474 # Normal iteration should work 475 samples = list(ds.ordered(batch_size=None)) 476 assert len(samples) == 5 477 478 479## 480# Partial Failure Tests 481 482 483class TestPartialFailures: 484 """Tests for partial failures in multi-shard scenarios.""" 485 486 def test_multi_shard_with_missing_middle_shard(self, tmp_path): 487 """Multi-shard dataset with missing shard should fail cleanly.""" 488 # Create first and third shard, skip second 489 for i in [0, 2]: 490 tar_path = tmp_path / f"data-{i:06d}.tar" 491 with wds.writer.TarWriter(str(tar_path)) as writer: 492 sample = ErrorTestSample(name=f"shard_{i}", value=i) 493 writer.write(sample.as_wds) 494 495 # Use brace notation that expects all three shards 496 url = str(tmp_path / "data-{000000..000002}.tar") 497 ds = atdata.Dataset[ErrorTestSample](url) 498 499 # Should fail when hitting missing shard 500 with pytest.raises(FileNotFoundError): 501 list(ds.ordered(batch_size=None)) 502 503 def test_multi_shard_with_corrupted_shard(self, tmp_path): 504 """Multi-shard dataset with one corrupted shard should fail.""" 505 # Create two good shards 506 for i in range(2): 507 tar_path = tmp_path / f"data-{i:06d}.tar" 508 with wds.writer.TarWriter(str(tar_path)) as writer: 509 sample = ErrorTestSample(name=f"shard_{i}", value=i) 510 writer.write(sample.as_wds) 511 512 # Create a corrupted third shard 513 corrupted_path = tmp_path / "data-000002.tar" 514 with open(corrupted_path, "wb") as f: 515 f.write(b"this is not a valid tar file") 516 517 url = str(tmp_path / "data-{000000..000002}.tar") 518 ds = atdata.Dataset[ErrorTestSample](url) 519 520 # Should fail when hitting corrupted shard 521 with pytest.raises(Exception): # tarfile.ReadError or similar 522 list(ds.ordered(batch_size=None)) 523 524 def test_empty_shard_in_multi_shard(self, tmp_path): 525 """Empty shard in multi-shard dataset should be handled.""" 526 # Create one shard with data 527 tar_path = tmp_path / "data-000000.tar" 528 with wds.writer.TarWriter(str(tar_path)) as writer: 529 sample = ErrorTestSample(name="sample", value=42) 530 writer.write(sample.as_wds) 531 532 # Create an empty tar (valid but no samples) 533 empty_path = tmp_path / "data-000001.tar" 534 with tarfile.open(empty_path, "w"): 535 pass # Empty tar 536 537 url = str(tmp_path / "data-{000000..000001}.tar") 538 ds = atdata.Dataset[ErrorTestSample](url) 539 540 # Should handle empty shard gracefully 541 samples = list(ds.ordered(batch_size=None)) 542 # May get 1 sample (from first shard) or error depending on implementation 543 assert len(samples) >= 0 # At minimum, shouldn't crash 544 545 def test_good_shards_before_bad_are_processed(self, tmp_path): 546 """Samples from good shards before bad one should be accessible.""" 547 # Create first good shard with multiple samples 548 tar_path = tmp_path / "data-000000.tar" 549 with wds.writer.TarWriter(str(tar_path)) as writer: 550 for i in range(3): 551 sample = ErrorTestSample(name=f"good_{i}", value=i) 552 writer.write(sample.as_wds) 553 554 # Create second corrupted shard 555 corrupted_path = tmp_path / "data-000001.tar" 556 with open(corrupted_path, "wb") as f: 557 f.write(b"corrupted data") 558 559 url = str(tmp_path / "data-{000000..000001}.tar") 560 ds = atdata.Dataset[ErrorTestSample](url) 561 562 # Iterate and collect what we can 563 collected = [] 564 try: 565 for sample in ds.ordered(batch_size=None): 566 collected.append(sample) 567 except Exception: 568 pass # Expected to fail on second shard 569 570 # Should have gotten samples from first shard before failure 571 # Note: actual behavior depends on WebDataset's buffering 572 # This test documents the behavior rather than enforcing it 573 assert isinstance(collected, list) 574 575 576## 577# S3 Error Simulation Tests 578 579 580class TestS3ErrorSimulation: 581 """Tests for S3-related error scenarios using mocks.""" 582 583 def test_s3_access_denied_error(self): 584 """S3 access denied should raise clear error.""" 585 from atdata import S3Source 586 from botocore.exceptions import ClientError 587 588 # Create source with mock credentials 589 source = S3Source( 590 bucket="test-bucket", 591 keys=["data.tar"], 592 access_key="test", 593 secret_key="test", 594 ) 595 596 # Mock the client after source creation 597 with patch.object(source, "_get_client") as mock_get_client: 598 mock_client = Mock() 599 mock_client.get_object.side_effect = ClientError( 600 {"Error": {"Code": "AccessDenied", "Message": "Access Denied"}}, 601 "GetObject", 602 ) 603 mock_get_client.return_value = mock_client 604 605 # Opening shard should propagate the error 606 # Use full S3 URI as returned by shard_list 607 with pytest.raises(ClientError): 608 source.open_shard("s3://test-bucket/data.tar") 609 610 def test_s3_connection_timeout_simulation(self): 611 """S3 connection timeout should raise appropriate error.""" 612 from atdata import S3Source 613 from botocore.exceptions import ConnectTimeoutError 614 615 # Create source with mock credentials 616 source = S3Source( 617 bucket="test-bucket", 618 keys=["data.tar"], 619 access_key="test", 620 secret_key="test", 621 ) 622 623 # Mock the client after source creation 624 with patch.object(source, "_get_client") as mock_get_client: 625 mock_client = Mock() 626 mock_client.get_object.side_effect = ConnectTimeoutError( 627 endpoint_url="s3://test" 628 ) 629 mock_get_client.return_value = mock_client 630 631 # Use full S3 URI as returned by shard_list 632 with pytest.raises(ConnectTimeoutError): 633 source.open_shard("s3://test-bucket/data.tar")