A loose federation of distributed, typed datasets
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 515 lines 18 kB view raw
1"""Tests for data source implementations.""" 2 3from pathlib import Path 4from unittest.mock import Mock, patch, MagicMock 5 6import pytest 7import webdataset as wds 8 9import atdata 10from atdata._sources import URLSource, S3Source, BlobSource 11from atdata._protocols import DataSource 12 13 14# Test sample type 15@atdata.packable 16class SourceTestSample: 17 """Simple sample for testing data sources.""" 18 19 name: str 20 value: int 21 22 23def create_test_tar(path: Path, samples: list[dict]) -> None: 24 """Create a test tar file with msgpack samples.""" 25 with wds.writer.TarWriter(str(path)) as sink: 26 for i, data in enumerate(samples): 27 sample = SourceTestSample(**data) 28 sink.write(sample.as_wds) 29 30 31class TestURLSource: 32 """Tests for URLSource.""" 33 34 def test_conforms_to_protocol(self): 35 """URLSource should satisfy DataSource protocol.""" 36 source = URLSource("http://example.com/data.tar") 37 assert isinstance(source, DataSource) 38 39 def test_shard_list_single_url(self): 40 """shard_list returns single URL unchanged.""" 41 source = URLSource("http://example.com/data.tar") 42 assert source.shard_list == ["http://example.com/data.tar"] 43 44 def test_shard_list_brace_expansion(self): 45 """shard_list expands brace patterns.""" 46 source = URLSource("data-{000..002}.tar") 47 assert source.shard_list == [ 48 "data-000.tar", 49 "data-001.tar", 50 "data-002.tar", 51 ] 52 53 def test_shard_list_complex_brace_pattern(self): 54 """shard_list handles complex brace patterns.""" 55 source = URLSource("s3://bucket/{train,test}-{00..01}.tar") 56 assert source.shard_list == [ 57 "s3://bucket/train-00.tar", 58 "s3://bucket/train-01.tar", 59 "s3://bucket/test-00.tar", 60 "s3://bucket/test-01.tar", 61 ] 62 63 def test_shards_yields_streams(self, tmp_path): 64 """shards property yields (url, stream) pairs.""" 65 # Create test tar file 66 tar_path = tmp_path / "test.tar" 67 create_test_tar(tar_path, [{"name": "test", "value": 42}]) 68 69 source = URLSource(str(tar_path)) 70 shards = list(source.shards) 71 72 assert len(shards) == 1 73 url, stream = shards[0] 74 assert url == str(tar_path) 75 assert hasattr(stream, "read") 76 77 def test_open_shard(self, tmp_path): 78 """open_shard opens a specific shard.""" 79 tar_path = tmp_path / "test.tar" 80 create_test_tar(tar_path, [{"name": "test", "value": 42}]) 81 82 source = URLSource(str(tar_path)) 83 stream = source.open_shard(str(tar_path)) 84 85 assert hasattr(stream, "read") 86 87 def test_open_shard_not_found(self, tmp_path): 88 """open_shard raises KeyError for unknown shard.""" 89 tar_path = tmp_path / "test.tar" 90 create_test_tar(tar_path, [{"name": "test", "value": 42}]) 91 92 source = URLSource(str(tar_path)) 93 94 with pytest.raises(KeyError, match="Shard not found"): 95 source.open_shard("nonexistent.tar") 96 97 def test_dataset_integration(self, tmp_path): 98 """URLSource works with Dataset.""" 99 tar_path = tmp_path / "test.tar" 100 create_test_tar( 101 tar_path, 102 [ 103 {"name": "sample1", "value": 1}, 104 {"name": "sample2", "value": 2}, 105 ], 106 ) 107 108 source = URLSource(str(tar_path)) 109 ds = atdata.Dataset[SourceTestSample](source) 110 111 samples = list(ds.ordered()) 112 assert len(samples) == 2 113 assert samples[0].name == "sample1" 114 assert samples[1].value == 2 115 116 117class TestS3Source: 118 """Tests for S3Source.""" 119 120 def test_conforms_to_protocol(self): 121 """S3Source should satisfy DataSource protocol.""" 122 source = S3Source(bucket="test", keys=["data.tar"]) 123 assert isinstance(source, DataSource) 124 125 def test_shard_list(self): 126 """shard_list returns S3 URIs.""" 127 source = S3Source(bucket="my-bucket", keys=["a.tar", "b.tar"]) 128 assert source.shard_list == [ 129 "s3://my-bucket/a.tar", 130 "s3://my-bucket/b.tar", 131 ] 132 133 def test_from_urls(self): 134 """from_urls parses S3 URLs correctly.""" 135 source = S3Source.from_urls( 136 [ 137 "s3://bucket/path/a.tar", 138 "s3://bucket/path/b.tar", 139 ] 140 ) 141 142 assert source.bucket == "bucket" 143 assert source.keys == ["path/a.tar", "path/b.tar"] 144 145 def test_from_urls_with_credentials(self): 146 """from_urls passes credentials through.""" 147 source = S3Source.from_urls( 148 ["s3://bucket/data.tar"], 149 endpoint="https://r2.example.com", 150 access_key="AKID", 151 secret_key="SECRET", 152 ) 153 154 assert source.endpoint == "https://r2.example.com" 155 assert source.access_key == "AKID" 156 assert source.secret_key == "SECRET" 157 158 def test_from_urls_empty(self): 159 """from_urls raises on empty list.""" 160 with pytest.raises(ValueError, match="cannot be empty"): 161 S3Source.from_urls([]) 162 163 def test_from_urls_invalid_scheme(self): 164 """from_urls raises on non-s3 URLs.""" 165 with pytest.raises(ValueError, match="Not an S3 URL"): 166 S3Source.from_urls(["https://example.com/data.tar"]) 167 168 def test_from_urls_multiple_buckets(self): 169 """from_urls raises when URLs span buckets.""" 170 with pytest.raises(ValueError, match="same bucket"): 171 S3Source.from_urls( 172 [ 173 "s3://bucket-a/data.tar", 174 "s3://bucket-b/data.tar", 175 ] 176 ) 177 178 def test_from_credentials(self): 179 """from_credentials creates source from dict.""" 180 creds = { 181 "AWS_ACCESS_KEY_ID": "AKID", 182 "AWS_SECRET_ACCESS_KEY": "SECRET", 183 "AWS_ENDPOINT": "https://r2.example.com", 184 } 185 186 source = S3Source.from_credentials(creds, "bucket", ["data.tar"]) 187 188 assert source.bucket == "bucket" 189 assert source.keys == ["data.tar"] 190 assert source.endpoint == "https://r2.example.com" 191 assert source.access_key == "AKID" 192 assert source.secret_key == "SECRET" 193 194 def test_shards_uses_boto3(self): 195 """shards() uses boto3 client to fetch objects.""" 196 mock_body = MagicMock() 197 mock_body.read.return_value = b"tar data" 198 199 with patch("boto3.client") as mock_boto: 200 mock_client = Mock() 201 mock_client.get_object.return_value = {"Body": mock_body} 202 mock_boto.return_value = mock_client 203 204 source = S3Source( 205 bucket="test-bucket", 206 keys=["data.tar"], 207 access_key="AKID", 208 secret_key="SECRET", 209 ) 210 211 shards = list(source.shards) 212 213 assert len(shards) == 1 214 uri, stream = shards[0] 215 assert uri == "s3://test-bucket/data.tar" 216 assert stream == mock_body 217 218 mock_client.get_object.assert_called_once_with( 219 Bucket="test-bucket", 220 Key="data.tar", 221 ) 222 223 def test_open_shard_uses_boto3(self): 224 """open_shard() uses boto3 client to fetch specific object.""" 225 mock_body = MagicMock() 226 227 with patch("boto3.client") as mock_boto: 228 mock_client = Mock() 229 mock_client.get_object.return_value = {"Body": mock_body} 230 mock_boto.return_value = mock_client 231 232 source = S3Source( 233 bucket="test-bucket", 234 keys=["a.tar", "b.tar"], 235 access_key="AKID", 236 secret_key="SECRET", 237 ) 238 239 stream = source.open_shard("s3://test-bucket/b.tar") 240 241 assert stream == mock_body 242 mock_client.get_object.assert_called_once_with( 243 Bucket="test-bucket", 244 Key="b.tar", 245 ) 246 247 def test_open_shard_not_found(self): 248 """open_shard raises KeyError for unknown shard.""" 249 source = S3Source(bucket="bucket", keys=["a.tar"]) 250 251 with pytest.raises(KeyError, match="Shard not found"): 252 source.open_shard("s3://bucket/unknown.tar") 253 254 def test_client_uses_endpoint(self): 255 """Client is created with custom endpoint.""" 256 with patch("boto3.client") as mock_boto: 257 mock_boto.return_value = Mock() 258 259 source = S3Source( 260 bucket="bucket", 261 keys=["data.tar"], 262 endpoint="https://custom.endpoint.com", 263 access_key="AKID", 264 secret_key="SECRET", 265 ) 266 267 # Trigger client creation 268 source._get_client() 269 270 mock_boto.assert_called_once_with( 271 "s3", 272 endpoint_url="https://custom.endpoint.com", 273 aws_access_key_id="AKID", 274 aws_secret_access_key="SECRET", 275 ) 276 277 def test_client_caching(self): 278 """Client is cached after first creation.""" 279 with patch("boto3.client") as mock_boto: 280 mock_client = Mock() 281 mock_boto.return_value = mock_client 282 283 source = S3Source( 284 bucket="bucket", 285 keys=["data.tar"], 286 access_key="AKID", 287 secret_key="SECRET", 288 ) 289 290 # Call twice 291 client1 = source._get_client() 292 client2 = source._get_client() 293 294 assert client1 is client2 295 assert mock_boto.call_count == 1 296 297 298class TestBlobSource: 299 """Tests for BlobSource (ATProto PDS blob storage).""" 300 301 def test_conforms_to_protocol(self): 302 """BlobSource should satisfy DataSource protocol.""" 303 source = BlobSource(blob_refs=[{"did": "did:plc:abc", "cid": "bafyrei123"}]) 304 assert isinstance(source, DataSource) 305 306 def test_list_shards(self): 307 """list_shards returns AT URIs.""" 308 source = BlobSource( 309 blob_refs=[ 310 {"did": "did:plc:abc", "cid": "bafyrei111"}, 311 {"did": "did:plc:abc", "cid": "bafyrei222"}, 312 ] 313 ) 314 assert source.list_shards() == [ 315 "at://did:plc:abc/blob/bafyrei111", 316 "at://did:plc:abc/blob/bafyrei222", 317 ] 318 319 def test_from_refs_simple_format(self): 320 """from_refs accepts simple {did, cid} format.""" 321 source = BlobSource.from_refs( 322 [ 323 {"did": "did:plc:abc", "cid": "bafyrei123"}, 324 ] 325 ) 326 assert len(source.blob_refs) == 1 327 assert source.blob_refs[0]["did"] == "did:plc:abc" 328 assert source.blob_refs[0]["cid"] == "bafyrei123" 329 330 def test_from_refs_with_endpoint(self): 331 """from_refs accepts pds_endpoint parameter.""" 332 source = BlobSource.from_refs( 333 [{"did": "did:plc:abc", "cid": "bafyrei123"}], 334 pds_endpoint="https://pds.example.com", 335 ) 336 assert source.pds_endpoint == "https://pds.example.com" 337 338 def test_from_refs_empty(self): 339 """from_refs raises on empty list.""" 340 with pytest.raises(ValueError, match="cannot be empty"): 341 BlobSource.from_refs([]) 342 343 def test_from_refs_invalid_format(self): 344 """from_refs raises on invalid blob reference format.""" 345 with pytest.raises(ValueError, match="Invalid blob reference format"): 346 BlobSource.from_refs([{"invalid": "data"}]) 347 348 def test_from_refs_atproto_format_without_did(self): 349 """from_refs raises helpful error for ATProto format without DID.""" 350 with pytest.raises(ValueError, match="requires 'did' field"): 351 BlobSource.from_refs([{"ref": {"$link": "bafyrei123"}}]) 352 353 def test_resolve_pds_endpoint_uses_cache(self): 354 """PDS endpoint resolution is cached.""" 355 source = BlobSource(blob_refs=[{"did": "did:plc:abc", "cid": "cid"}]) 356 357 # Pre-populate cache 358 source._endpoint_cache["did:plc:abc"] = "https://cached.pds.com" 359 360 endpoint = source._resolve_pds_endpoint("did:plc:abc") 361 assert endpoint == "https://cached.pds.com" 362 363 def test_resolve_pds_endpoint_uses_provided_endpoint(self): 364 """Provided pds_endpoint is used instead of resolution.""" 365 source = BlobSource( 366 blob_refs=[{"did": "did:plc:abc", "cid": "cid"}], 367 pds_endpoint="https://my.pds.com", 368 ) 369 370 endpoint = source._resolve_pds_endpoint("did:plc:abc") 371 assert endpoint == "https://my.pds.com" 372 373 def test_get_blob_url(self): 374 """_get_blob_url constructs correct URL.""" 375 source = BlobSource( 376 blob_refs=[{"did": "did:plc:abc", "cid": "bafyrei123"}], 377 pds_endpoint="https://pds.example.com", 378 ) 379 380 url = source._get_blob_url("did:plc:abc", "bafyrei123") 381 assert ( 382 url 383 == "https://pds.example.com/xrpc/com.atproto.sync.getBlob?did=did:plc:abc&cid=bafyrei123" 384 ) 385 386 def test_shards_fetches_blobs(self): 387 """shards property fetches blobs via HTTP.""" 388 mock_response = Mock() 389 mock_response.raw = Mock() 390 mock_response.raise_for_status = Mock() 391 392 with patch("requests.get", return_value=mock_response) as mock_get: 393 source = BlobSource( 394 blob_refs=[{"did": "did:plc:abc", "cid": "bafyrei123"}], 395 pds_endpoint="https://pds.example.com", 396 ) 397 398 shards = list(source.shards) 399 400 assert len(shards) == 1 401 shard_id, stream = shards[0] 402 assert shard_id == "at://did:plc:abc/blob/bafyrei123" 403 assert stream is mock_response.raw 404 405 mock_get.assert_called_once_with( 406 "https://pds.example.com/xrpc/com.atproto.sync.getBlob?did=did:plc:abc&cid=bafyrei123", 407 stream=True, 408 timeout=60, 409 ) 410 411 def test_open_shard_fetches_single_blob(self): 412 """open_shard fetches a specific blob.""" 413 mock_response = Mock() 414 mock_response.raw = Mock() 415 mock_response.raise_for_status = Mock() 416 417 with patch("requests.get", return_value=mock_response) as mock_get: 418 source = BlobSource( 419 blob_refs=[ 420 {"did": "did:plc:abc", "cid": "bafyrei111"}, 421 {"did": "did:plc:abc", "cid": "bafyrei222"}, 422 ], 423 pds_endpoint="https://pds.example.com", 424 ) 425 426 stream = source.open_shard("at://did:plc:abc/blob/bafyrei222") 427 428 assert stream is mock_response.raw 429 mock_get.assert_called_once() 430 call_args = mock_get.call_args 431 assert "bafyrei222" in call_args[0][0] 432 433 def test_open_shard_not_found(self): 434 """open_shard raises KeyError for unknown shard.""" 435 source = BlobSource(blob_refs=[{"did": "did:plc:abc", "cid": "bafyrei123"}]) 436 437 with pytest.raises(KeyError, match="Shard not found"): 438 source.open_shard("at://did:plc:abc/blob/unknown") 439 440 def test_open_shard_invalid_format(self): 441 """open_shard raises ValueError for invalid shard ID format.""" 442 # Test that we properly validate the AT URI format 443 # by checking the error message when we pass an invalid format 444 # that isn't in the list but would fail format check 445 source = BlobSource( 446 blob_refs=[{"did": "did:plc:abc", "cid": "bafyrei123"}], 447 ) 448 449 # A non-AT URI should raise KeyError (not in list) 450 with pytest.raises(KeyError, match="Shard not found"): 451 source.open_shard("not-an-at-uri") 452 453 # An AT URI with wrong format should also raise KeyError (not in list) 454 with pytest.raises(KeyError, match="Shard not found"): 455 source.open_shard("at://did:plc:abc/wrong/format") 456 457 458class TestDatasetWithDataSource: 459 """Integration tests for Dataset with different DataSource types.""" 460 461 def test_dataset_accepts_url_source(self, tmp_path): 462 """Dataset can be created with URLSource.""" 463 tar_path = tmp_path / "test.tar" 464 create_test_tar(tar_path, [{"name": "test", "value": 42}]) 465 466 source = URLSource(str(tar_path)) 467 ds = atdata.Dataset[SourceTestSample](source) 468 469 assert ds.source is source 470 assert ds.shard_list == [str(tar_path)] 471 472 def test_dataset_accepts_string_url(self, tmp_path): 473 """Dataset auto-wraps string URLs in URLSource.""" 474 tar_path = tmp_path / "test.tar" 475 create_test_tar(tar_path, [{"name": "test", "value": 42}]) 476 477 ds = atdata.Dataset[SourceTestSample](str(tar_path)) 478 479 assert isinstance(ds.source, URLSource) 480 assert ds.url == str(tar_path) 481 482 def test_dataset_backward_compat_url_kwarg(self, tmp_path): 483 """Dataset accepts url= keyword for backward compatibility.""" 484 tar_path = tmp_path / "test.tar" 485 create_test_tar(tar_path, [{"name": "test", "value": 42}]) 486 487 ds = atdata.Dataset[SourceTestSample](url=str(tar_path)) 488 489 assert isinstance(ds.source, URLSource) 490 assert ds.url == str(tar_path) 491 492 def test_dataset_source_property(self, tmp_path): 493 """Dataset.source property returns the underlying DataSource.""" 494 tar_path = tmp_path / "test.tar" 495 create_test_tar(tar_path, [{"name": "test", "value": 42}]) 496 497 source = URLSource(str(tar_path)) 498 ds = atdata.Dataset[SourceTestSample](source) 499 500 assert ds.source is source 501 502 def test_dataset_multiple_shards(self, tmp_path): 503 """Dataset works with multi-shard sources.""" 504 # Create two shards 505 for i in range(2): 506 tar_path = tmp_path / f"data-{i:06d}.tar" 507 create_test_tar(tar_path, [{"name": f"shard{i}", "value": i}]) 508 509 pattern = str(tmp_path / "data-{000000..000001}.tar") 510 ds = atdata.Dataset[SourceTestSample](pattern) 511 512 samples = list(ds.ordered()) 513 assert len(samples) == 2 514 names = {s.name for s in samples} 515 assert names == {"shard0", "shard1"}