A loose federation of distributed, typed datasets
1"""Tests for data source implementations."""
2
3from pathlib import Path
4from unittest.mock import Mock, patch, MagicMock
5
6import pytest
7import webdataset as wds
8
9import atdata
10from atdata._sources import URLSource, S3Source, BlobSource
11from atdata._protocols import DataSource
12
13
14# Test sample type
15@atdata.packable
16class SourceTestSample:
17 """Simple sample for testing data sources."""
18
19 name: str
20 value: int
21
22
23def create_test_tar(path: Path, samples: list[dict]) -> None:
24 """Create a test tar file with msgpack samples."""
25 with wds.writer.TarWriter(str(path)) as sink:
26 for i, data in enumerate(samples):
27 sample = SourceTestSample(**data)
28 sink.write(sample.as_wds)
29
30
31class TestURLSource:
32 """Tests for URLSource."""
33
34 def test_conforms_to_protocol(self):
35 """URLSource should satisfy DataSource protocol."""
36 source = URLSource("http://example.com/data.tar")
37 assert isinstance(source, DataSource)
38
39 def test_shard_list_single_url(self):
40 """shard_list returns single URL unchanged."""
41 source = URLSource("http://example.com/data.tar")
42 assert source.shard_list == ["http://example.com/data.tar"]
43
44 def test_shard_list_brace_expansion(self):
45 """shard_list expands brace patterns."""
46 source = URLSource("data-{000..002}.tar")
47 assert source.shard_list == [
48 "data-000.tar",
49 "data-001.tar",
50 "data-002.tar",
51 ]
52
53 def test_shard_list_complex_brace_pattern(self):
54 """shard_list handles complex brace patterns."""
55 source = URLSource("s3://bucket/{train,test}-{00..01}.tar")
56 assert source.shard_list == [
57 "s3://bucket/train-00.tar",
58 "s3://bucket/train-01.tar",
59 "s3://bucket/test-00.tar",
60 "s3://bucket/test-01.tar",
61 ]
62
63 def test_shards_yields_streams(self, tmp_path):
64 """shards property yields (url, stream) pairs."""
65 # Create test tar file
66 tar_path = tmp_path / "test.tar"
67 create_test_tar(tar_path, [{"name": "test", "value": 42}])
68
69 source = URLSource(str(tar_path))
70 shards = list(source.shards)
71
72 assert len(shards) == 1
73 url, stream = shards[0]
74 assert url == str(tar_path)
75 assert hasattr(stream, "read")
76
77 def test_open_shard(self, tmp_path):
78 """open_shard opens a specific shard."""
79 tar_path = tmp_path / "test.tar"
80 create_test_tar(tar_path, [{"name": "test", "value": 42}])
81
82 source = URLSource(str(tar_path))
83 stream = source.open_shard(str(tar_path))
84
85 assert hasattr(stream, "read")
86
87 def test_open_shard_not_found(self, tmp_path):
88 """open_shard raises KeyError for unknown shard."""
89 tar_path = tmp_path / "test.tar"
90 create_test_tar(tar_path, [{"name": "test", "value": 42}])
91
92 source = URLSource(str(tar_path))
93
94 with pytest.raises(KeyError, match="Shard not found"):
95 source.open_shard("nonexistent.tar")
96
97 def test_dataset_integration(self, tmp_path):
98 """URLSource works with Dataset."""
99 tar_path = tmp_path / "test.tar"
100 create_test_tar(
101 tar_path,
102 [
103 {"name": "sample1", "value": 1},
104 {"name": "sample2", "value": 2},
105 ],
106 )
107
108 source = URLSource(str(tar_path))
109 ds = atdata.Dataset[SourceTestSample](source)
110
111 samples = list(ds.ordered())
112 assert len(samples) == 2
113 assert samples[0].name == "sample1"
114 assert samples[1].value == 2
115
116
117class TestS3Source:
118 """Tests for S3Source."""
119
120 def test_conforms_to_protocol(self):
121 """S3Source should satisfy DataSource protocol."""
122 source = S3Source(bucket="test", keys=["data.tar"])
123 assert isinstance(source, DataSource)
124
125 def test_shard_list(self):
126 """shard_list returns S3 URIs."""
127 source = S3Source(bucket="my-bucket", keys=["a.tar", "b.tar"])
128 assert source.shard_list == [
129 "s3://my-bucket/a.tar",
130 "s3://my-bucket/b.tar",
131 ]
132
133 def test_from_urls(self):
134 """from_urls parses S3 URLs correctly."""
135 source = S3Source.from_urls(
136 [
137 "s3://bucket/path/a.tar",
138 "s3://bucket/path/b.tar",
139 ]
140 )
141
142 assert source.bucket == "bucket"
143 assert source.keys == ["path/a.tar", "path/b.tar"]
144
145 def test_from_urls_with_credentials(self):
146 """from_urls passes credentials through."""
147 source = S3Source.from_urls(
148 ["s3://bucket/data.tar"],
149 endpoint="https://r2.example.com",
150 access_key="AKID",
151 secret_key="SECRET",
152 )
153
154 assert source.endpoint == "https://r2.example.com"
155 assert source.access_key == "AKID"
156 assert source.secret_key == "SECRET"
157
158 def test_from_urls_empty(self):
159 """from_urls raises on empty list."""
160 with pytest.raises(ValueError, match="cannot be empty"):
161 S3Source.from_urls([])
162
163 def test_from_urls_invalid_scheme(self):
164 """from_urls raises on non-s3 URLs."""
165 with pytest.raises(ValueError, match="Not an S3 URL"):
166 S3Source.from_urls(["https://example.com/data.tar"])
167
168 def test_from_urls_multiple_buckets(self):
169 """from_urls raises when URLs span buckets."""
170 with pytest.raises(ValueError, match="same bucket"):
171 S3Source.from_urls(
172 [
173 "s3://bucket-a/data.tar",
174 "s3://bucket-b/data.tar",
175 ]
176 )
177
178 def test_from_credentials(self):
179 """from_credentials creates source from dict."""
180 creds = {
181 "AWS_ACCESS_KEY_ID": "AKID",
182 "AWS_SECRET_ACCESS_KEY": "SECRET",
183 "AWS_ENDPOINT": "https://r2.example.com",
184 }
185
186 source = S3Source.from_credentials(creds, "bucket", ["data.tar"])
187
188 assert source.bucket == "bucket"
189 assert source.keys == ["data.tar"]
190 assert source.endpoint == "https://r2.example.com"
191 assert source.access_key == "AKID"
192 assert source.secret_key == "SECRET"
193
194 def test_shards_uses_boto3(self):
195 """shards() uses boto3 client to fetch objects."""
196 mock_body = MagicMock()
197 mock_body.read.return_value = b"tar data"
198
199 with patch("boto3.client") as mock_boto:
200 mock_client = Mock()
201 mock_client.get_object.return_value = {"Body": mock_body}
202 mock_boto.return_value = mock_client
203
204 source = S3Source(
205 bucket="test-bucket",
206 keys=["data.tar"],
207 access_key="AKID",
208 secret_key="SECRET",
209 )
210
211 shards = list(source.shards)
212
213 assert len(shards) == 1
214 uri, stream = shards[0]
215 assert uri == "s3://test-bucket/data.tar"
216 assert stream == mock_body
217
218 mock_client.get_object.assert_called_once_with(
219 Bucket="test-bucket",
220 Key="data.tar",
221 )
222
223 def test_open_shard_uses_boto3(self):
224 """open_shard() uses boto3 client to fetch specific object."""
225 mock_body = MagicMock()
226
227 with patch("boto3.client") as mock_boto:
228 mock_client = Mock()
229 mock_client.get_object.return_value = {"Body": mock_body}
230 mock_boto.return_value = mock_client
231
232 source = S3Source(
233 bucket="test-bucket",
234 keys=["a.tar", "b.tar"],
235 access_key="AKID",
236 secret_key="SECRET",
237 )
238
239 stream = source.open_shard("s3://test-bucket/b.tar")
240
241 assert stream == mock_body
242 mock_client.get_object.assert_called_once_with(
243 Bucket="test-bucket",
244 Key="b.tar",
245 )
246
247 def test_open_shard_not_found(self):
248 """open_shard raises KeyError for unknown shard."""
249 source = S3Source(bucket="bucket", keys=["a.tar"])
250
251 with pytest.raises(KeyError, match="Shard not found"):
252 source.open_shard("s3://bucket/unknown.tar")
253
254 def test_client_uses_endpoint(self):
255 """Client is created with custom endpoint."""
256 with patch("boto3.client") as mock_boto:
257 mock_boto.return_value = Mock()
258
259 source = S3Source(
260 bucket="bucket",
261 keys=["data.tar"],
262 endpoint="https://custom.endpoint.com",
263 access_key="AKID",
264 secret_key="SECRET",
265 )
266
267 # Trigger client creation
268 source._get_client()
269
270 mock_boto.assert_called_once_with(
271 "s3",
272 endpoint_url="https://custom.endpoint.com",
273 aws_access_key_id="AKID",
274 aws_secret_access_key="SECRET",
275 )
276
277 def test_client_caching(self):
278 """Client is cached after first creation."""
279 with patch("boto3.client") as mock_boto:
280 mock_client = Mock()
281 mock_boto.return_value = mock_client
282
283 source = S3Source(
284 bucket="bucket",
285 keys=["data.tar"],
286 access_key="AKID",
287 secret_key="SECRET",
288 )
289
290 # Call twice
291 client1 = source._get_client()
292 client2 = source._get_client()
293
294 assert client1 is client2
295 assert mock_boto.call_count == 1
296
297
298class TestBlobSource:
299 """Tests for BlobSource (ATProto PDS blob storage)."""
300
301 def test_conforms_to_protocol(self):
302 """BlobSource should satisfy DataSource protocol."""
303 source = BlobSource(blob_refs=[{"did": "did:plc:abc", "cid": "bafyrei123"}])
304 assert isinstance(source, DataSource)
305
306 def test_list_shards(self):
307 """list_shards returns AT URIs."""
308 source = BlobSource(
309 blob_refs=[
310 {"did": "did:plc:abc", "cid": "bafyrei111"},
311 {"did": "did:plc:abc", "cid": "bafyrei222"},
312 ]
313 )
314 assert source.list_shards() == [
315 "at://did:plc:abc/blob/bafyrei111",
316 "at://did:plc:abc/blob/bafyrei222",
317 ]
318
319 def test_from_refs_simple_format(self):
320 """from_refs accepts simple {did, cid} format."""
321 source = BlobSource.from_refs(
322 [
323 {"did": "did:plc:abc", "cid": "bafyrei123"},
324 ]
325 )
326 assert len(source.blob_refs) == 1
327 assert source.blob_refs[0]["did"] == "did:plc:abc"
328 assert source.blob_refs[0]["cid"] == "bafyrei123"
329
330 def test_from_refs_with_endpoint(self):
331 """from_refs accepts pds_endpoint parameter."""
332 source = BlobSource.from_refs(
333 [{"did": "did:plc:abc", "cid": "bafyrei123"}],
334 pds_endpoint="https://pds.example.com",
335 )
336 assert source.pds_endpoint == "https://pds.example.com"
337
338 def test_from_refs_empty(self):
339 """from_refs raises on empty list."""
340 with pytest.raises(ValueError, match="cannot be empty"):
341 BlobSource.from_refs([])
342
343 def test_from_refs_invalid_format(self):
344 """from_refs raises on invalid blob reference format."""
345 with pytest.raises(ValueError, match="Invalid blob reference format"):
346 BlobSource.from_refs([{"invalid": "data"}])
347
348 def test_from_refs_atproto_format_without_did(self):
349 """from_refs raises helpful error for ATProto format without DID."""
350 with pytest.raises(ValueError, match="requires 'did' field"):
351 BlobSource.from_refs([{"ref": {"$link": "bafyrei123"}}])
352
353 def test_resolve_pds_endpoint_uses_cache(self):
354 """PDS endpoint resolution is cached."""
355 source = BlobSource(blob_refs=[{"did": "did:plc:abc", "cid": "cid"}])
356
357 # Pre-populate cache
358 source._endpoint_cache["did:plc:abc"] = "https://cached.pds.com"
359
360 endpoint = source._resolve_pds_endpoint("did:plc:abc")
361 assert endpoint == "https://cached.pds.com"
362
363 def test_resolve_pds_endpoint_uses_provided_endpoint(self):
364 """Provided pds_endpoint is used instead of resolution."""
365 source = BlobSource(
366 blob_refs=[{"did": "did:plc:abc", "cid": "cid"}],
367 pds_endpoint="https://my.pds.com",
368 )
369
370 endpoint = source._resolve_pds_endpoint("did:plc:abc")
371 assert endpoint == "https://my.pds.com"
372
373 def test_get_blob_url(self):
374 """_get_blob_url constructs correct URL."""
375 source = BlobSource(
376 blob_refs=[{"did": "did:plc:abc", "cid": "bafyrei123"}],
377 pds_endpoint="https://pds.example.com",
378 )
379
380 url = source._get_blob_url("did:plc:abc", "bafyrei123")
381 assert (
382 url
383 == "https://pds.example.com/xrpc/com.atproto.sync.getBlob?did=did:plc:abc&cid=bafyrei123"
384 )
385
386 def test_shards_fetches_blobs(self):
387 """shards property fetches blobs via HTTP."""
388 mock_response = Mock()
389 mock_response.raw = Mock()
390 mock_response.raise_for_status = Mock()
391
392 with patch("requests.get", return_value=mock_response) as mock_get:
393 source = BlobSource(
394 blob_refs=[{"did": "did:plc:abc", "cid": "bafyrei123"}],
395 pds_endpoint="https://pds.example.com",
396 )
397
398 shards = list(source.shards)
399
400 assert len(shards) == 1
401 shard_id, stream = shards[0]
402 assert shard_id == "at://did:plc:abc/blob/bafyrei123"
403 assert stream is mock_response.raw
404
405 mock_get.assert_called_once_with(
406 "https://pds.example.com/xrpc/com.atproto.sync.getBlob?did=did:plc:abc&cid=bafyrei123",
407 stream=True,
408 timeout=60,
409 )
410
411 def test_open_shard_fetches_single_blob(self):
412 """open_shard fetches a specific blob."""
413 mock_response = Mock()
414 mock_response.raw = Mock()
415 mock_response.raise_for_status = Mock()
416
417 with patch("requests.get", return_value=mock_response) as mock_get:
418 source = BlobSource(
419 blob_refs=[
420 {"did": "did:plc:abc", "cid": "bafyrei111"},
421 {"did": "did:plc:abc", "cid": "bafyrei222"},
422 ],
423 pds_endpoint="https://pds.example.com",
424 )
425
426 stream = source.open_shard("at://did:plc:abc/blob/bafyrei222")
427
428 assert stream is mock_response.raw
429 mock_get.assert_called_once()
430 call_args = mock_get.call_args
431 assert "bafyrei222" in call_args[0][0]
432
433 def test_open_shard_not_found(self):
434 """open_shard raises KeyError for unknown shard."""
435 source = BlobSource(blob_refs=[{"did": "did:plc:abc", "cid": "bafyrei123"}])
436
437 with pytest.raises(KeyError, match="Shard not found"):
438 source.open_shard("at://did:plc:abc/blob/unknown")
439
440 def test_open_shard_invalid_format(self):
441 """open_shard raises ValueError for invalid shard ID format."""
442 # Test that we properly validate the AT URI format
443 # by checking the error message when we pass an invalid format
444 # that isn't in the list but would fail format check
445 source = BlobSource(
446 blob_refs=[{"did": "did:plc:abc", "cid": "bafyrei123"}],
447 )
448
449 # A non-AT URI should raise KeyError (not in list)
450 with pytest.raises(KeyError, match="Shard not found"):
451 source.open_shard("not-an-at-uri")
452
453 # An AT URI with wrong format should also raise KeyError (not in list)
454 with pytest.raises(KeyError, match="Shard not found"):
455 source.open_shard("at://did:plc:abc/wrong/format")
456
457
458class TestDatasetWithDataSource:
459 """Integration tests for Dataset with different DataSource types."""
460
461 def test_dataset_accepts_url_source(self, tmp_path):
462 """Dataset can be created with URLSource."""
463 tar_path = tmp_path / "test.tar"
464 create_test_tar(tar_path, [{"name": "test", "value": 42}])
465
466 source = URLSource(str(tar_path))
467 ds = atdata.Dataset[SourceTestSample](source)
468
469 assert ds.source is source
470 assert ds.shard_list == [str(tar_path)]
471
472 def test_dataset_accepts_string_url(self, tmp_path):
473 """Dataset auto-wraps string URLs in URLSource."""
474 tar_path = tmp_path / "test.tar"
475 create_test_tar(tar_path, [{"name": "test", "value": 42}])
476
477 ds = atdata.Dataset[SourceTestSample](str(tar_path))
478
479 assert isinstance(ds.source, URLSource)
480 assert ds.url == str(tar_path)
481
482 def test_dataset_backward_compat_url_kwarg(self, tmp_path):
483 """Dataset accepts url= keyword for backward compatibility."""
484 tar_path = tmp_path / "test.tar"
485 create_test_tar(tar_path, [{"name": "test", "value": 42}])
486
487 ds = atdata.Dataset[SourceTestSample](url=str(tar_path))
488
489 assert isinstance(ds.source, URLSource)
490 assert ds.url == str(tar_path)
491
492 def test_dataset_source_property(self, tmp_path):
493 """Dataset.source property returns the underlying DataSource."""
494 tar_path = tmp_path / "test.tar"
495 create_test_tar(tar_path, [{"name": "test", "value": 42}])
496
497 source = URLSource(str(tar_path))
498 ds = atdata.Dataset[SourceTestSample](source)
499
500 assert ds.source is source
501
502 def test_dataset_multiple_shards(self, tmp_path):
503 """Dataset works with multi-shard sources."""
504 # Create two shards
505 for i in range(2):
506 tar_path = tmp_path / f"data-{i:06d}.tar"
507 create_test_tar(tar_path, [{"name": f"shard{i}", "value": i}])
508
509 pattern = str(tmp_path / "data-{000000..000001}.tar")
510 ds = atdata.Dataset[SourceTestSample](pattern)
511
512 samples = list(ds.ordered())
513 assert len(samples) == 2
514 names = {s.name for s in samples}
515 assert names == {"shard0", "shard1"}