A loose federation of distributed, typed datasets
1"""Integration tests for error handling and recovery.
2
3Tests error conditions and graceful failure including:
4- Missing schemas and data URLs
5- Malformed data (msgpack, tar)
6- Connection failures (Redis, S3, ATProto)
7- Authentication and rate limiting errors
8- Timeout scenarios
9- Partial failures in multi-shard datasets
10"""
11
12import pytest
13from unittest.mock import Mock, MagicMock, patch
14import tarfile
15import io
16
17
18import atdata
19import webdataset as wds
20from atdata.local import LocalIndex, LocalDatasetEntry
21from atdata.atmosphere import AtmosphereClient, AtUri
22
23
24##
25# Test sample types
26
27
28@atdata.packable
29class ErrorTestSample:
30 """Sample for error handling tests."""
31
32 name: str
33 value: int
34
35
36##
37# Schema Error Tests
38
39
40class TestMissingSchema:
41 """Tests for missing schema errors."""
42
43 def test_missing_schema_raises_keyerror(self, clean_redis):
44 """Accessing non-existent schema should raise KeyError."""
45 index = LocalIndex(redis=clean_redis)
46
47 with pytest.raises(KeyError):
48 index.get_schema("local://schemas/NonExistent@1.0.0")
49
50 def test_dataset_with_invalid_schema_ref(self, clean_redis):
51 """Dataset entry with invalid schema ref should error on decode."""
52 index = LocalIndex(redis=clean_redis)
53
54 entry = LocalDatasetEntry(
55 name="orphan-dataset",
56 schema_ref="local://schemas/DoesNotExist@1.0.0",
57 data_urls=["s3://bucket/data.tar"],
58 )
59 entry.write_to(clean_redis)
60
61 # Entry exists but schema doesn't
62 retrieved = index.get_entry_by_name("orphan-dataset")
63 assert retrieved is not None
64
65 # Attempting to decode schema should fail
66 with pytest.raises(KeyError):
67 index.decode_schema(retrieved.schema_ref)
68
69
70##
71# Data URL Error Tests
72
73
74class TestMissingDataUrls:
75 """Tests for missing or inaccessible data URLs."""
76
77 def test_empty_data_urls_raises(self, clean_redis):
78 """Dataset entry with empty URLs should be flagged."""
79 index = LocalIndex(redis=clean_redis)
80 schema_ref = index.publish_schema(ErrorTestSample, version="1.0.0")
81
82 entry = LocalDatasetEntry(
83 name="empty-urls",
84 schema_ref=schema_ref,
85 data_urls=[],
86 )
87 entry.write_to(clean_redis)
88
89 retrieved = index.get_entry_by_name("empty-urls")
90 assert retrieved.data_urls == []
91
92 def test_nonexistent_tar_raises(self, tmp_path):
93 """Attempting to read non-existent tar should raise."""
94 nonexistent_path = tmp_path / "does-not-exist.tar"
95
96 ds = atdata.Dataset[ErrorTestSample](str(nonexistent_path))
97
98 # Iterating should fail
99 with pytest.raises(FileNotFoundError):
100 list(ds.ordered(batch_size=None))
101
102
103##
104# Malformed Data Tests
105
106
107class TestMalformedMsgpack:
108 """Tests for corrupted msgpack data."""
109
110 def test_invalid_msgpack_in_tar(self, tmp_path):
111 """Tar with invalid msgpack should raise on iteration."""
112 tar_path = tmp_path / "corrupted-000000.tar"
113
114 # Create tar with invalid msgpack data
115 with tarfile.open(tar_path, "w") as tar:
116 # Add a valid key file
117 key_data = b"sample-0"
118 key_info = tarfile.TarInfo(name="sample-0.__key__")
119 key_info.size = len(key_data)
120 tar.addfile(key_info, fileobj=io.BytesIO(key_data))
121
122 # Add invalid msgpack data
123 invalid_data = b"\xff\xff\xff\xff\xff" # Not valid msgpack
124 info = tarfile.TarInfo(name="sample-0.msgpack")
125 info.size = len(invalid_data)
126 tar.addfile(info, fileobj=io.BytesIO(invalid_data))
127
128 ds = atdata.Dataset[ErrorTestSample](str(tar_path))
129
130 # Should raise an error when trying to deserialize
131 with pytest.raises(Exception): # Could be msgpack error or ValueError
132 list(ds.ordered(batch_size=None))
133
134
135class TestCorruptedTar:
136 """Tests for corrupted tar files."""
137
138 def test_truncated_tar_raises(self, tmp_path):
139 """Truncated tar file should raise an error."""
140 tar_path = tmp_path / "truncated-000000.tar"
141
142 # Create a valid tar then truncate it
143 with tarfile.open(tar_path, "w") as tar:
144 data = b"test data"
145 info = tarfile.TarInfo(name="test.txt")
146 info.size = len(data)
147 import io
148
149 tar.addfile(info, fileobj=io.BytesIO(data))
150
151 # Truncate the file
152 with open(tar_path, "r+b") as f:
153 f.truncate(50) # Truncate to partial content
154
155 ds = atdata.Dataset[ErrorTestSample](str(tar_path))
156
157 with pytest.raises(Exception): # tarfile.ReadError or similar
158 list(ds.ordered(batch_size=None))
159
160 def test_not_a_tar_file_raises(self, tmp_path):
161 """Non-tar file should raise clear error."""
162 fake_tar = tmp_path / "fake-000000.tar"
163
164 # Write random bytes
165 with open(fake_tar, "wb") as f:
166 f.write(b"This is not a tar file at all!")
167
168 ds = atdata.Dataset[ErrorTestSample](str(fake_tar))
169
170 with pytest.raises(Exception): # tarfile.ReadError
171 list(ds.ordered(batch_size=None))
172
173
174##
175# Redis Error Tests
176
177
178class TestRedisErrors:
179 """Tests for Redis connection errors."""
180
181 def test_redis_connection_error(self):
182 """Operations with bad Redis connection should fail cleanly."""
183 from redis import Redis, ConnectionError
184
185 # Create index with invalid Redis connection
186 bad_redis = Redis(
187 host="nonexistent.invalid.host", port=9999, socket_timeout=0.1
188 )
189
190 index = LocalIndex(redis=bad_redis)
191
192 # Operations should raise connection errors
193 with pytest.raises((ConnectionError, Exception)):
194 index.publish_schema(ErrorTestSample, version="1.0.0")
195
196 def test_entry_lookup_with_bad_redis(self, clean_redis):
197 """Entry lookup should fail cleanly if Redis becomes unavailable."""
198 index = LocalIndex(redis=clean_redis)
199
200 # First, add an entry
201 schema_ref = index.publish_schema(ErrorTestSample, version="1.0.0")
202 entry = LocalDatasetEntry(
203 name="test-entry",
204 schema_ref=schema_ref,
205 data_urls=["s3://bucket/data.tar"],
206 )
207 entry.write_to(clean_redis)
208
209 # Entry should be retrievable
210 retrieved = index.get_entry_by_name("test-entry")
211 assert retrieved is not None
212
213
214##
215# ATProto Error Tests
216
217
218class TestAtProtoErrors:
219 """Tests for ATProto/Atmosphere errors."""
220
221 def test_unauthenticated_publish_raises(self):
222 """Publishing without authentication should raise."""
223 mock_client = Mock()
224 mock_client.me = None
225
226 client = AtmosphereClient(_client=mock_client)
227
228 # Not authenticated
229 assert not client.is_authenticated
230
231 from atdata.atmosphere import SchemaPublisher
232
233 publisher = SchemaPublisher(client)
234
235 with pytest.raises(ValueError, match="authenticated"):
236 publisher.publish(ErrorTestSample, version="1.0.0")
237
238 def test_invalid_at_uri_raises(self):
239 """Parsing invalid AT URI should raise ValueError."""
240 invalid_uris = [
241 "not-a-uri",
242 "https://example.com/path",
243 "at://",
244 "at://did:plc:abc", # Missing collection and rkey
245 "at://did:plc:abc/collection", # Missing rkey
246 ]
247
248 for uri in invalid_uris:
249 with pytest.raises(ValueError):
250 AtUri.parse(uri)
251
252 def test_api_error_response_handling(self):
253 """API errors should be propagated appropriately."""
254 mock_client = Mock()
255 mock_client.me = MagicMock()
256 mock_client.me.did = "did:plc:test123"
257
258 # Simulate an API error
259 from atproto_client.exceptions import AtProtocolError
260
261 mock_client.com.atproto.repo.create_record.side_effect = AtProtocolError(
262 "API error occurred"
263 )
264
265 # Create client and authenticate it
266 client = AtmosphereClient(_client=mock_client)
267 client._session = {"did": "did:plc:test123"} # Mark as authenticated
268
269 from atdata.atmosphere import SchemaPublisher
270
271 publisher = SchemaPublisher(client)
272
273 # Should propagate the API error
274 with pytest.raises(AtProtocolError):
275 publisher.publish(ErrorTestSample, version="1.0.0")
276
277 def test_expired_session_detection(self):
278 """Expired session should be detectable."""
279 mock_client = Mock()
280 mock_client.me = None
281 mock_client.export_session_string.return_value = None
282
283 client = AtmosphereClient(_client=mock_client)
284
285 # Should not be authenticated
286 assert not client.is_authenticated
287
288
289##
290# Entry Not Found Tests
291
292
293class TestNotFoundErrors:
294 """Tests for not-found error handling."""
295
296 def test_get_entry_by_name_not_found(self, clean_redis):
297 """Getting non-existent entry by name should raise KeyError."""
298 index = LocalIndex(redis=clean_redis)
299
300 with pytest.raises(KeyError):
301 index.get_entry_by_name("nonexistent-dataset")
302
303 def test_get_entry_by_cid_not_found(self, clean_redis):
304 """Getting non-existent entry by CID should raise KeyError."""
305 index = LocalIndex(redis=clean_redis)
306
307 with pytest.raises(KeyError):
308 index.get_entry("bafyreifake123456789")
309
310
311##
312# Error Message Quality Tests
313
314
315class TestErrorMessageQuality:
316 """Tests that error messages are helpful and don't leak sensitive info."""
317
318 def test_missing_schema_error_includes_ref(self, clean_redis):
319 """Missing schema error should include the schema reference."""
320 index = LocalIndex(redis=clean_redis)
321
322 try:
323 index.get_schema("local://schemas/MissingType@1.0.0")
324 assert False, "Should have raised KeyError"
325 except KeyError as e:
326 # Error should mention the schema reference
327 assert "MissingType" in str(e) or "local://" in str(e)
328
329 def test_invalid_uri_error_is_clear(self):
330 """Invalid AT URI error should explain the issue."""
331 try:
332 AtUri.parse("not-valid")
333 assert False, "Should have raised ValueError"
334 except ValueError as e:
335 # Error should explain it's not a valid URI
336 assert "at://" in str(e).lower() or "uri" in str(e).lower()
337
338 def test_auth_error_no_credential_leak(self):
339 """Authentication errors should not leak credentials."""
340 mock_client = Mock()
341 mock_client.me = None
342
343 client = AtmosphereClient(_client=mock_client)
344
345 from atdata.atmosphere import SchemaPublisher
346
347 publisher = SchemaPublisher(client)
348
349 try:
350 publisher.publish(ErrorTestSample, version="1.0.0")
351 except ValueError as e:
352 error_msg = str(e)
353 # Should not contain anything that looks like a password or token
354 assert "password" not in error_msg.lower()
355 assert "token" not in error_msg.lower()
356 assert "secret" not in error_msg.lower()
357
358
359##
360# Recovery Tests
361
362
363class TestRecovery:
364 """Tests for recovery from errors."""
365
366 def test_can_continue_after_bad_sample(self, tmp_path, clean_redis):
367 """System should be usable after encountering bad data."""
368 # First, try to read a bad file
369 bad_tar = tmp_path / "bad-000000.tar"
370 with open(bad_tar, "wb") as f:
371 f.write(b"not a tar file")
372
373 ds_bad = atdata.Dataset[ErrorTestSample](str(bad_tar))
374 try:
375 list(ds_bad.ordered(batch_size=None))
376 except Exception:
377 pass # Expected to fail
378
379 # Now use a good file - should still work
380 good_tar = tmp_path / "good-000000.tar"
381 import webdataset as wds
382
383 with wds.writer.TarWriter(str(good_tar)) as writer:
384 sample = ErrorTestSample(name="good", value=42)
385 writer.write(sample.as_wds)
386
387 ds_good = atdata.Dataset[ErrorTestSample](str(good_tar))
388 samples = list(ds_good.ordered(batch_size=None))
389
390 assert len(samples) == 1
391 assert samples[0].name == "good"
392
393 def test_index_usable_after_failed_publish(self, clean_redis):
394 """Index should remain usable after a failed operation."""
395 index = LocalIndex(redis=clean_redis)
396
397 # Try to get a non-existent schema (fails as expected)
398 with pytest.raises(KeyError):
399 index.get_schema("local://schemas/NoSuch@1.0.0")
400
401 # Index should still work
402 schema_ref = index.publish_schema(ErrorTestSample, version="1.0.0")
403 assert schema_ref is not None
404
405 schema = index.get_schema(schema_ref)
406 assert schema["name"] == "ErrorTestSample"
407
408
409##
410# Validation Tests
411
412
413class TestInputValidation:
414 """Tests for input validation."""
415
416 def test_empty_version_string(self, clean_redis):
417 """Empty version string should be handled."""
418 index = LocalIndex(redis=clean_redis)
419
420 # Empty version - implementation may accept or reject
421 schema_ref = index.publish_schema(ErrorTestSample, version="")
422 # If it accepts, it should store and retrieve correctly
423 schema = index.get_schema(schema_ref)
424 assert schema is not None
425
426 def test_special_chars_in_version(self, clean_redis):
427 """Special characters in version should be handled."""
428 index = LocalIndex(redis=clean_redis)
429
430 schema_ref = index.publish_schema(
431 ErrorTestSample, version="1.0.0-beta+build.123"
432 )
433 schema = index.get_schema(schema_ref)
434
435 assert schema["version"] == "1.0.0-beta+build.123"
436
437
438##
439# Timeout Tests
440
441
442class TestTimeoutScenarios:
443 """Tests for timeout and slow connection scenarios."""
444
445 def test_redis_socket_timeout(self):
446 """Redis operations should fail with socket timeout."""
447 from redis import Redis
448
449 # Very short timeout to force failure
450 redis = Redis(
451 host="10.255.255.1", # Non-routable IP
452 port=6379,
453 socket_timeout=0.01,
454 socket_connect_timeout=0.01,
455 )
456
457 index = LocalIndex(redis=redis)
458
459 # Should timeout quickly rather than hang
460 with pytest.raises(Exception): # TimeoutError or ConnectionError
461 index.publish_schema(ErrorTestSample, version="1.0.0")
462
463 def test_slow_iteration_continues(self, tmp_path):
464 """Dataset iteration should handle slow reads gracefully."""
465 # Create a valid dataset
466 tar_path = tmp_path / "slow-000000.tar"
467 with wds.writer.TarWriter(str(tar_path)) as writer:
468 for i in range(5):
469 sample = ErrorTestSample(name=f"sample_{i}", value=i)
470 writer.write(sample.as_wds)
471
472 ds = atdata.Dataset[ErrorTestSample](str(tar_path))
473
474 # Normal iteration should work
475 samples = list(ds.ordered(batch_size=None))
476 assert len(samples) == 5
477
478
479##
480# Partial Failure Tests
481
482
483class TestPartialFailures:
484 """Tests for partial failures in multi-shard scenarios."""
485
486 def test_multi_shard_with_missing_middle_shard(self, tmp_path):
487 """Multi-shard dataset with missing shard should fail cleanly."""
488 # Create first and third shard, skip second
489 for i in [0, 2]:
490 tar_path = tmp_path / f"data-{i:06d}.tar"
491 with wds.writer.TarWriter(str(tar_path)) as writer:
492 sample = ErrorTestSample(name=f"shard_{i}", value=i)
493 writer.write(sample.as_wds)
494
495 # Use brace notation that expects all three shards
496 url = str(tmp_path / "data-{000000..000002}.tar")
497 ds = atdata.Dataset[ErrorTestSample](url)
498
499 # Should fail when hitting missing shard
500 with pytest.raises(FileNotFoundError):
501 list(ds.ordered(batch_size=None))
502
503 def test_multi_shard_with_corrupted_shard(self, tmp_path):
504 """Multi-shard dataset with one corrupted shard should fail."""
505 # Create two good shards
506 for i in range(2):
507 tar_path = tmp_path / f"data-{i:06d}.tar"
508 with wds.writer.TarWriter(str(tar_path)) as writer:
509 sample = ErrorTestSample(name=f"shard_{i}", value=i)
510 writer.write(sample.as_wds)
511
512 # Create a corrupted third shard
513 corrupted_path = tmp_path / "data-000002.tar"
514 with open(corrupted_path, "wb") as f:
515 f.write(b"this is not a valid tar file")
516
517 url = str(tmp_path / "data-{000000..000002}.tar")
518 ds = atdata.Dataset[ErrorTestSample](url)
519
520 # Should fail when hitting corrupted shard
521 with pytest.raises(Exception): # tarfile.ReadError or similar
522 list(ds.ordered(batch_size=None))
523
524 def test_empty_shard_in_multi_shard(self, tmp_path):
525 """Empty shard in multi-shard dataset should be handled."""
526 # Create one shard with data
527 tar_path = tmp_path / "data-000000.tar"
528 with wds.writer.TarWriter(str(tar_path)) as writer:
529 sample = ErrorTestSample(name="sample", value=42)
530 writer.write(sample.as_wds)
531
532 # Create an empty tar (valid but no samples)
533 empty_path = tmp_path / "data-000001.tar"
534 with tarfile.open(empty_path, "w"):
535 pass # Empty tar
536
537 url = str(tmp_path / "data-{000000..000001}.tar")
538 ds = atdata.Dataset[ErrorTestSample](url)
539
540 # Should handle empty shard gracefully
541 samples = list(ds.ordered(batch_size=None))
542 # May get 1 sample (from first shard) or error depending on implementation
543 assert len(samples) >= 0 # At minimum, shouldn't crash
544
545 def test_good_shards_before_bad_are_processed(self, tmp_path):
546 """Samples from good shards before bad one should be accessible."""
547 # Create first good shard with multiple samples
548 tar_path = tmp_path / "data-000000.tar"
549 with wds.writer.TarWriter(str(tar_path)) as writer:
550 for i in range(3):
551 sample = ErrorTestSample(name=f"good_{i}", value=i)
552 writer.write(sample.as_wds)
553
554 # Create second corrupted shard
555 corrupted_path = tmp_path / "data-000001.tar"
556 with open(corrupted_path, "wb") as f:
557 f.write(b"corrupted data")
558
559 url = str(tmp_path / "data-{000000..000001}.tar")
560 ds = atdata.Dataset[ErrorTestSample](url)
561
562 # Iterate and collect what we can
563 collected = []
564 try:
565 for sample in ds.ordered(batch_size=None):
566 collected.append(sample)
567 except Exception:
568 pass # Expected to fail on second shard
569
570 # Should have gotten samples from first shard before failure
571 # Note: actual behavior depends on WebDataset's buffering
572 # This test documents the behavior rather than enforcing it
573 assert isinstance(collected, list)
574
575
576##
577# S3 Error Simulation Tests
578
579
580class TestS3ErrorSimulation:
581 """Tests for S3-related error scenarios using mocks."""
582
583 def test_s3_access_denied_error(self):
584 """S3 access denied should raise clear error."""
585 from atdata import S3Source
586 from botocore.exceptions import ClientError
587
588 # Create source with mock credentials
589 source = S3Source(
590 bucket="test-bucket",
591 keys=["data.tar"],
592 access_key="test",
593 secret_key="test",
594 )
595
596 # Mock the client after source creation
597 with patch.object(source, "_get_client") as mock_get_client:
598 mock_client = Mock()
599 mock_client.get_object.side_effect = ClientError(
600 {"Error": {"Code": "AccessDenied", "Message": "Access Denied"}},
601 "GetObject",
602 )
603 mock_get_client.return_value = mock_client
604
605 # Opening shard should propagate the error
606 # Use full S3 URI as returned by shard_list
607 with pytest.raises(ClientError):
608 source.open_shard("s3://test-bucket/data.tar")
609
610 def test_s3_connection_timeout_simulation(self):
611 """S3 connection timeout should raise appropriate error."""
612 from atdata import S3Source
613 from botocore.exceptions import ConnectTimeoutError
614
615 # Create source with mock credentials
616 source = S3Source(
617 bucket="test-bucket",
618 keys=["data.tar"],
619 access_key="test",
620 secret_key="test",
621 )
622
623 # Mock the client after source creation
624 with patch.object(source, "_get_client") as mock_get_client:
625 mock_client = Mock()
626 mock_client.get_object.side_effect = ConnectTimeoutError(
627 endpoint_url="s3://test"
628 )
629 mock_get_client.return_value = mock_client
630
631 # Use full S3 URI as returned by shard_list
632 with pytest.raises(ConnectTimeoutError):
633 source.open_shard("s3://test-bucket/data.tar")