A loose federation of distributed, typed datasets
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

fix: resolve cache_local S3 upload with moto and add local module tests

- Add moto[s3] dev dependency for S3 mocking in tests
- Fix cache_local=True to use boto3.client.put_object instead of s3fs
- Make AWS_ENDPOINT optional for better S3FileSystem compatibility
- Add comprehensive test suite (47 tests) for atdata.local module
- Fix BasicIndexEntry to handle None values in Redis serialization

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

+1317 -10
+1
pyproject.toml
··· 35 35 [dependency-groups] 36 36 dev = [ 37 37 "jupyter>=1.1.1", 38 + "moto[s3]>=5.0.29", 38 39 "pytest>=8.4.2", 39 40 "pytest-cov>=7.0.0", 40 41 ]
+35 -10
src/atdata/local.py
··· 130 130 redis: Redis connection to write to. 131 131 """ 132 132 save_key = f'BasicIndexEntry:{self.uuid}' 133 + # Filter out None values - Redis doesn't accept None 134 + data = {k: v for k, v in asdict(self).items() if v is not None} 133 135 # TODO figure out how to get linting to work correctly here 134 - redis.hset( save_key, mapping = asdict( self ) ) 136 + redis.hset( save_key, mapping = data ) 135 137 136 138 def _s3_env( credentials_path: str | Path ) -> dict[str, Any]: 137 139 """Load S3 credentials from a .env file. ··· 166 168 167 169 Args: 168 170 creds: Either a path to a .env file with credentials, or a dict 169 - containing AWS_ENDPOINT, AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY. 171 + containing AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and optionally 172 + AWS_ENDPOINT. 170 173 171 174 Returns: 172 175 Configured S3FileSystem instance. ··· 175 178 if not isinstance( creds, dict ): 176 179 creds = _s3_env( creds ) 177 180 178 - return S3FileSystem( 179 - endpoint_url = creds['AWS_ENDPOINT'], 180 - key = creds['AWS_ACCESS_KEY_ID'], 181 - secret = creds['AWS_SECRET_ACCESS_KEY'] 182 - ) 181 + # Build kwargs, making endpoint_url optional 182 + kwargs = { 183 + 'key': creds['AWS_ACCESS_KEY_ID'], 184 + 'secret': creds['AWS_SECRET_ACCESS_KEY'] 185 + } 186 + if 'AWS_ENDPOINT' in creds: 187 + kwargs['endpoint_url'] = creds['AWS_ENDPOINT'] 188 + 189 + return S3FileSystem(**kwargs) 183 190 184 191 185 192 ## ··· 311 318 with TemporaryDirectory() as temp_dir: 312 319 313 320 if cache_local: 321 + # For cache_local, we need to use boto3 directly to avoid s3fs async issues with moto 322 + import boto3 323 + 324 + # Create boto3 client from credentials 325 + s3_client_kwargs = { 326 + 'aws_access_key_id': self.s3_credentials['AWS_ACCESS_KEY_ID'], 327 + 'aws_secret_access_key': self.s3_credentials['AWS_SECRET_ACCESS_KEY'] 328 + } 329 + if 'AWS_ENDPOINT' in self.s3_credentials: 330 + s3_client_kwargs['endpoint_url'] = self.s3_credentials['AWS_ENDPOINT'] 331 + s3_client = boto3.client('s3', **s3_client_kwargs) 332 + 314 333 def _writer_opener( p: str ): 315 334 local_cache_path = Path( temp_dir ) / p 316 335 local_cache_path.parent.mkdir( parents = True, exist_ok = True ) ··· 320 339 def _writer_post( p: str ): 321 340 local_cache_path = Path( temp_dir ) / p 322 341 323 - # Copy to S3 342 + # Copy to S3 using boto3 client (avoids s3fs async issues) 324 343 print( 'Copying file to s3 ...', end = '' ) 344 + # Parse bucket and key from path (format: bucket/path/to/file.tar) 345 + path_parts = Path( p ).parts 346 + bucket = path_parts[0] 347 + key = str( Path( *path_parts[1:] ) ) 348 + 325 349 with open( local_cache_path, 'rb' ) as f_in: 326 - with cast( BinaryIO, hive_fs.open( p, 'wb' ) ) as f_out: 327 - f_out.write( f_in.read() ) 350 + s3_client.put_object( Bucket=bucket, Key=key, Body=f_in.read() ) 328 351 print( ' done.' ) 329 352 330 353 # Delete local cache file ··· 477 500 for key in self._redis.scan_iter( match = 'BasicIndexEntry:*' ): 478 501 # TODO typing issue for `redis` 479 502 cur_entry_data = _decode_bytes_dict( self._redis.hgetall( key ) ) 503 + # Provide default None for optional fields that may be missing 504 + cur_entry_data.setdefault('metadata_url', None) 480 505 cur_entry = BasicIndexEntry( **cur_entry_data ) 481 506 yield cur_entry 482 507
+1281
tests/test_local.py
··· 1 + """Test local repository storage functionality.""" 2 + 3 + ## 4 + # Imports 5 + 6 + # Tests 7 + import pytest 8 + 9 + # System 10 + from dataclasses import dataclass 11 + from pathlib import Path 12 + from uuid import UUID 13 + 14 + # External 15 + import numpy as np 16 + from redis import Redis 17 + from moto import mock_aws 18 + 19 + # Local 20 + import atdata 21 + import atdata.local as atlocal 22 + import webdataset as wds 23 + 24 + # Typing 25 + from numpy.typing import NDArray 26 + from typing import Any 27 + 28 + 29 + ## 30 + # Test fixtures 31 + 32 + @pytest.fixture 33 + def redis_connection(): 34 + """Provide a Redis connection, skip test if Redis is not available.""" 35 + try: 36 + redis = Redis() 37 + redis.ping() 38 + yield redis 39 + except Exception: 40 + pytest.skip("Redis server not available") 41 + 42 + 43 + @pytest.fixture 44 + def mock_s3(): 45 + """Provide a mock S3 environment using moto.""" 46 + with mock_aws(): 47 + # Create S3 credentials dict (no endpoint_url for moto) 48 + creds = { 49 + 'AWS_ACCESS_KEY_ID': 'testing', 50 + 'AWS_SECRET_ACCESS_KEY': 'testing' 51 + } 52 + 53 + # Create S3 client and bucket 54 + import boto3 55 + s3_client = boto3.client( 56 + 's3', 57 + aws_access_key_id=creds['AWS_ACCESS_KEY_ID'], 58 + aws_secret_access_key=creds['AWS_SECRET_ACCESS_KEY'], 59 + region_name='us-east-1' 60 + ) 61 + 62 + bucket_name = 'test-bucket' 63 + s3_client.create_bucket(Bucket=bucket_name) 64 + 65 + yield { 66 + 'credentials': creds, 67 + 'bucket': bucket_name, 68 + 'hive_path': f'{bucket_name}/datasets', 69 + 's3_client': s3_client 70 + } 71 + 72 + 73 + @pytest.fixture 74 + def sample_dataset(tmp_path): 75 + """Create a sample WebDataset for testing.""" 76 + # Create a temporary WebDataset 77 + dataset_path = tmp_path / "test-dataset-000000.tar" 78 + 79 + with wds.TarWriter(str(dataset_path)) as sink: 80 + for i in range(10): 81 + sample = SimpleTestSample(name=f"sample_{i}", value=i * 10) 82 + sink.write(sample.as_wds) 83 + 84 + ds = atdata.Dataset[SimpleTestSample](url=str(dataset_path)) 85 + return ds 86 + 87 + 88 + @dataclass 89 + class SimpleTestSample(atdata.PackableSample): 90 + """Simple test sample for repository tests.""" 91 + name: str 92 + value: int 93 + 94 + 95 + @dataclass 96 + class ArrayTestSample(atdata.PackableSample): 97 + """Test sample with numpy array for repository tests.""" 98 + label: str 99 + data: NDArray 100 + 101 + 102 + ## 103 + # Helper function tests 104 + 105 + def test_kind_str_for_sample_type(): 106 + """Test that sample types are converted to correct fully-qualified string identifiers. 107 + 108 + Should produce strings in format 'module.name' that uniquely identify the sample type. 109 + """ 110 + result = atlocal._kind_str_for_sample_type(SimpleTestSample) 111 + assert result == f"{SimpleTestSample.__module__}.SimpleTestSample" 112 + 113 + result2 = atlocal._kind_str_for_sample_type(ArrayTestSample) 114 + assert result2 == f"{ArrayTestSample.__module__}.ArrayTestSample" 115 + 116 + 117 + def test_decode_bytes_dict(): 118 + """Test that byte dictionaries from Redis are correctly decoded to strings. 119 + 120 + Should handle UTF-8 decoding of both keys and values from Redis response format. 121 + """ 122 + bytes_dict = { 123 + b'wds_url': b's3://bucket/dataset.tar', 124 + b'sample_kind': b'module.Sample', 125 + b'metadata_url': b's3://bucket/metadata.msgpack', 126 + b'uuid': b'12345678-1234-1234-1234-123456789abc' 127 + } 128 + 129 + result = atlocal._decode_bytes_dict(bytes_dict) 130 + 131 + assert result == { 132 + 'wds_url': 's3://bucket/dataset.tar', 133 + 'sample_kind': 'module.Sample', 134 + 'metadata_url': 's3://bucket/metadata.msgpack', 135 + 'uuid': '12345678-1234-1234-1234-123456789abc' 136 + } 137 + assert all(isinstance(k, str) for k in result.keys()) 138 + assert all(isinstance(v, str) for v in result.values()) 139 + 140 + 141 + def test_s3_env_valid_credentials(tmp_path): 142 + """Test loading S3 credentials from a valid .env file. 143 + 144 + Should successfully parse AWS_ENDPOINT, AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY 145 + from a properly formatted .env file. 146 + """ 147 + env_file = tmp_path / ".env" 148 + env_file.write_text( 149 + "AWS_ENDPOINT=http://localhost:9000\n" 150 + "AWS_ACCESS_KEY_ID=minioadmin\n" 151 + "AWS_SECRET_ACCESS_KEY=minioadmin\n" 152 + ) 153 + 154 + result = atlocal._s3_env(env_file) 155 + 156 + assert result == { 157 + 'AWS_ENDPOINT': 'http://localhost:9000', 158 + 'AWS_ACCESS_KEY_ID': 'minioadmin', 159 + 'AWS_SECRET_ACCESS_KEY': 'minioadmin' 160 + } 161 + 162 + 163 + def test_s3_env_missing_endpoint(tmp_path): 164 + """Test that loading S3 credentials fails when AWS_ENDPOINT is missing. 165 + 166 + Should raise AssertionError when .env file lacks AWS_ENDPOINT. 167 + """ 168 + env_file = tmp_path / ".env" 169 + env_file.write_text( 170 + "AWS_ACCESS_KEY_ID=minioadmin\n" 171 + "AWS_SECRET_ACCESS_KEY=minioadmin\n" 172 + ) 173 + 174 + with pytest.raises(AssertionError): 175 + atlocal._s3_env(env_file) 176 + 177 + 178 + def test_s3_env_missing_access_key(tmp_path): 179 + """Test that loading S3 credentials fails when AWS_ACCESS_KEY_ID is missing. 180 + 181 + Should raise AssertionError when .env file lacks AWS_ACCESS_KEY_ID. 182 + """ 183 + env_file = tmp_path / ".env" 184 + env_file.write_text( 185 + "AWS_ENDPOINT=http://localhost:9000\n" 186 + "AWS_SECRET_ACCESS_KEY=minioadmin\n" 187 + ) 188 + 189 + with pytest.raises(AssertionError): 190 + atlocal._s3_env(env_file) 191 + 192 + 193 + def test_s3_env_missing_secret_key(tmp_path): 194 + """Test that loading S3 credentials fails when AWS_SECRET_ACCESS_KEY is missing. 195 + 196 + Should raise AssertionError when .env file lacks AWS_SECRET_ACCESS_KEY. 197 + """ 198 + env_file = tmp_path / ".env" 199 + env_file.write_text( 200 + "AWS_ENDPOINT=http://localhost:9000\n" 201 + "AWS_ACCESS_KEY_ID=minioadmin\n" 202 + ) 203 + 204 + with pytest.raises(AssertionError): 205 + atlocal._s3_env(env_file) 206 + 207 + 208 + def test_s3_from_credentials_with_dict(): 209 + """Test creating S3FileSystem from a credentials dictionary. 210 + 211 + Should create a properly configured S3FileSystem instance using dict credentials. 212 + """ 213 + creds = { 214 + 'AWS_ENDPOINT': 'http://localhost:9000', 215 + 'AWS_ACCESS_KEY_ID': 'minioadmin', 216 + 'AWS_SECRET_ACCESS_KEY': 'minioadmin' 217 + } 218 + 219 + fs = atlocal._s3_from_credentials(creds) 220 + 221 + assert isinstance(fs, atlocal.S3FileSystem) 222 + assert fs.endpoint_url == 'http://localhost:9000' 223 + assert fs.key == 'minioadmin' 224 + assert fs.secret == 'minioadmin' 225 + 226 + 227 + def test_s3_from_credentials_with_path(tmp_path): 228 + """Test creating S3FileSystem from a .env file path. 229 + 230 + Should load credentials from file and create S3FileSystem instance. 231 + """ 232 + env_file = tmp_path / ".env" 233 + env_file.write_text( 234 + "AWS_ENDPOINT=http://localhost:9000\n" 235 + "AWS_ACCESS_KEY_ID=minioadmin\n" 236 + "AWS_SECRET_ACCESS_KEY=minioadmin\n" 237 + ) 238 + 239 + fs = atlocal._s3_from_credentials(env_file) 240 + 241 + assert isinstance(fs, atlocal.S3FileSystem) 242 + assert fs.endpoint_url == 'http://localhost:9000' 243 + assert fs.key == 'minioadmin' 244 + assert fs.secret == 'minioadmin' 245 + 246 + 247 + ## 248 + # BasicIndexEntry tests 249 + 250 + def test_basic_index_entry_creation(): 251 + """Test creating a BasicIndexEntry with explicit values. 252 + 253 + Should create an entry with provided wds_url, sample_kind, metadata_url, and uuid. 254 + """ 255 + entry = atlocal.BasicIndexEntry( 256 + wds_url="s3://bucket/dataset.tar", 257 + sample_kind="test_module.TestSample", 258 + metadata_url="s3://bucket/metadata.msgpack", 259 + uuid="12345678-1234-1234-1234-123456789abc" 260 + ) 261 + 262 + assert entry.wds_url == "s3://bucket/dataset.tar" 263 + assert entry.sample_kind == "test_module.TestSample" 264 + assert entry.metadata_url == "s3://bucket/metadata.msgpack" 265 + assert entry.uuid == "12345678-1234-1234-1234-123456789abc" 266 + 267 + 268 + def test_basic_index_entry_default_uuid(): 269 + """Test that BasicIndexEntry generates a valid UUID by default. 270 + 271 + Should auto-generate a unique UUID when none is provided, and it should be 272 + parseable as a valid UUID. 273 + """ 274 + entry = atlocal.BasicIndexEntry( 275 + wds_url="s3://bucket/dataset.tar", 276 + sample_kind="test_module.TestSample", 277 + metadata_url="s3://bucket/metadata.msgpack" 278 + ) 279 + 280 + assert entry.uuid is not None 281 + # Verify it's a valid UUID by parsing it 282 + parsed_uuid = UUID(entry.uuid) 283 + assert str(parsed_uuid) == entry.uuid 284 + 285 + 286 + def test_basic_index_entry_write_to_redis(redis_connection): 287 + """Test persisting a BasicIndexEntry to Redis. 288 + 289 + Should write the entry to Redis as a hash with key 'BasicIndexEntry:{uuid}' 290 + and all fields should be retrievable. 291 + """ 292 + redis = redis_connection 293 + test_uuid = "12345678-1234-1234-1234-123456789abc" 294 + 295 + entry = atlocal.BasicIndexEntry( 296 + wds_url="s3://bucket/dataset.tar", 297 + sample_kind="test_module.TestSample", 298 + metadata_url="s3://bucket/metadata.msgpack", 299 + uuid=test_uuid 300 + ) 301 + 302 + entry.write_to(redis) 303 + 304 + # Retrieve from Redis and verify 305 + stored_data = redis.hgetall(f"BasicIndexEntry:{test_uuid}") 306 + assert stored_data is not None 307 + assert len(stored_data) > 0 308 + 309 + # Clean up 310 + redis.delete(f"BasicIndexEntry:{test_uuid}") 311 + 312 + 313 + def test_basic_index_entry_round_trip_redis(redis_connection): 314 + """Test writing and reading a BasicIndexEntry from Redis. 315 + 316 + Should be able to write an entry to Redis and read it back with all fields 317 + intact and matching the original values. 318 + """ 319 + redis = redis_connection 320 + test_uuid = "12345678-1234-1234-1234-123456789abc" 321 + 322 + original_entry = atlocal.BasicIndexEntry( 323 + wds_url="s3://bucket/dataset.tar", 324 + sample_kind="test_module.TestSample", 325 + metadata_url="s3://bucket/metadata.msgpack", 326 + uuid=test_uuid 327 + ) 328 + 329 + original_entry.write_to(redis) 330 + 331 + # Read back from Redis 332 + stored_data = atlocal._decode_bytes_dict(redis.hgetall(f"BasicIndexEntry:{test_uuid}")) 333 + retrieved_entry = atlocal.BasicIndexEntry(**stored_data) 334 + 335 + assert retrieved_entry.wds_url == original_entry.wds_url 336 + assert retrieved_entry.sample_kind == original_entry.sample_kind 337 + assert retrieved_entry.metadata_url == original_entry.metadata_url 338 + assert retrieved_entry.uuid == original_entry.uuid 339 + 340 + # Clean up 341 + redis.delete(f"BasicIndexEntry:{test_uuid}") 342 + 343 + 344 + ## 345 + # Index tests 346 + 347 + def test_index_init_default_redis(): 348 + """Test creating an Index with default Redis connection. 349 + 350 + Should create a new Redis connection using default parameters when no 351 + redis argument is provided. 352 + """ 353 + index = atlocal.Index() 354 + 355 + assert index._redis is not None 356 + assert isinstance(index._redis, Redis) 357 + 358 + 359 + def test_index_init_with_redis_connection(): 360 + """Test creating an Index with an existing Redis connection. 361 + 362 + Should use the provided Redis connection instead of creating a new one. 363 + """ 364 + redis = Redis() 365 + index = atlocal.Index(redis=redis) 366 + 367 + assert index._redis is redis 368 + 369 + 370 + def test_index_init_with_redis_kwargs(): 371 + """Test creating an Index with Redis connection kwargs. 372 + 373 + Should pass custom kwargs to Redis constructor when creating a new connection. 374 + """ 375 + index = atlocal.Index(host='localhost', port=6379, db=0) 376 + 377 + assert index._redis is not None 378 + assert isinstance(index._redis, Redis) 379 + 380 + 381 + def test_index_add_entry_without_uuid(redis_connection): 382 + """Test adding a dataset entry to the index without specifying UUID. 383 + 384 + Should create a BasicIndexEntry with auto-generated UUID and persist it to Redis. 385 + """ 386 + redis = redis_connection 387 + index = atlocal.Index(redis=redis) 388 + 389 + ds = atdata.Dataset[SimpleTestSample]( 390 + url="s3://bucket/dataset.tar", 391 + metadata_url="s3://bucket/metadata.msgpack" 392 + ) 393 + 394 + entry = index.add_entry(ds) 395 + 396 + assert entry.uuid is not None 397 + assert entry.wds_url == ds.url 398 + assert entry.sample_kind == f"{SimpleTestSample.__module__}.SimpleTestSample" 399 + assert entry.metadata_url == ds.metadata_url 400 + 401 + # Verify it was persisted to Redis 402 + stored_data = redis.hgetall(f"BasicIndexEntry:{entry.uuid}") 403 + assert len(stored_data) > 0 404 + 405 + # Clean up 406 + redis.delete(f"BasicIndexEntry:{entry.uuid}") 407 + 408 + 409 + def test_index_add_entry_with_uuid(redis_connection): 410 + """Test adding a dataset entry to the index with a specified UUID. 411 + 412 + Should create a BasicIndexEntry with the provided UUID and persist it to Redis. 413 + """ 414 + redis = redis_connection 415 + index = atlocal.Index(redis=redis) 416 + test_uuid = "12345678-1234-1234-1234-123456789abc" 417 + 418 + ds = atdata.Dataset[SimpleTestSample]( 419 + url="s3://bucket/dataset.tar", 420 + metadata_url="s3://bucket/metadata.msgpack" 421 + ) 422 + 423 + entry = index.add_entry(ds, uuid=test_uuid) 424 + 425 + assert entry.uuid == test_uuid 426 + assert entry.wds_url == ds.url 427 + assert entry.sample_kind == f"{SimpleTestSample.__module__}.SimpleTestSample" 428 + assert entry.metadata_url == ds.metadata_url 429 + 430 + # Clean up 431 + redis.delete(f"BasicIndexEntry:{test_uuid}") 432 + 433 + 434 + def test_index_entries_generator_empty(redis_connection): 435 + """Test iterating over entries in an empty index. 436 + 437 + Should yield no entries when the index is empty. 438 + """ 439 + redis = redis_connection 440 + # Clear any existing BasicIndexEntry keys 441 + for key in redis.scan_iter(match='BasicIndexEntry:*'): 442 + redis.delete(key) 443 + 444 + index = atlocal.Index(redis=redis) 445 + 446 + entries = list(index.entries) 447 + assert len(entries) == 0 448 + 449 + 450 + def test_index_entries_generator_multiple(redis_connection): 451 + """Test iterating over multiple entries in the index. 452 + 453 + Should yield all BasicIndexEntry objects that have been added to the index. 454 + """ 455 + redis = redis_connection 456 + # Clear any existing BasicIndexEntry keys 457 + for key in redis.scan_iter(match='BasicIndexEntry:*'): 458 + redis.delete(key) 459 + 460 + index = atlocal.Index(redis=redis) 461 + 462 + ds1 = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset1.tar") 463 + ds2 = atdata.Dataset[ArrayTestSample](url="s3://bucket/dataset2.tar") 464 + 465 + entry1 = index.add_entry(ds1) 466 + entry2 = index.add_entry(ds2) 467 + 468 + entries = list(index.entries) 469 + assert len(entries) == 2 470 + 471 + uuids = {entry.uuid for entry in entries} 472 + assert entry1.uuid in uuids 473 + assert entry2.uuid in uuids 474 + 475 + # Clean up 476 + redis.delete(f"BasicIndexEntry:{entry1.uuid}") 477 + redis.delete(f"BasicIndexEntry:{entry2.uuid}") 478 + 479 + 480 + def test_index_all_entries_empty(redis_connection): 481 + """Test getting all entries as a list from an empty index. 482 + 483 + Should return an empty list when no entries exist. 484 + """ 485 + redis = redis_connection 486 + # Clear any existing BasicIndexEntry keys 487 + for key in redis.scan_iter(match='BasicIndexEntry:*'): 488 + redis.delete(key) 489 + 490 + index = atlocal.Index(redis=redis) 491 + 492 + entries = index.all_entries 493 + assert isinstance(entries, list) 494 + assert len(entries) == 0 495 + 496 + 497 + def test_index_all_entries_multiple(redis_connection): 498 + """Test getting all entries as a list with multiple entries. 499 + 500 + Should return a list containing all BasicIndexEntry objects in the index. 501 + """ 502 + redis = redis_connection 503 + # Clear any existing BasicIndexEntry keys 504 + for key in redis.scan_iter(match='BasicIndexEntry:*'): 505 + redis.delete(key) 506 + 507 + index = atlocal.Index(redis=redis) 508 + 509 + ds1 = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset1.tar") 510 + ds2 = atdata.Dataset[ArrayTestSample](url="s3://bucket/dataset2.tar") 511 + 512 + entry1 = index.add_entry(ds1) 513 + entry2 = index.add_entry(ds2) 514 + 515 + entries = index.all_entries 516 + assert isinstance(entries, list) 517 + assert len(entries) == 2 518 + 519 + # Clean up 520 + redis.delete(f"BasicIndexEntry:{entry1.uuid}") 521 + redis.delete(f"BasicIndexEntry:{entry2.uuid}") 522 + 523 + 524 + def test_index_entries_filtering(redis_connection): 525 + """Test that index only returns BasicIndexEntry objects. 526 + 527 + Should only iterate over keys matching 'BasicIndexEntry:*' pattern and 528 + ignore any other Redis keys. 529 + """ 530 + redis = redis_connection 531 + # Clear any existing BasicIndexEntry keys 532 + for key in redis.scan_iter(match='BasicIndexEntry:*'): 533 + redis.delete(key) 534 + 535 + index = atlocal.Index(redis=redis) 536 + 537 + # Add a BasicIndexEntry 538 + ds = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset.tar") 539 + entry = index.add_entry(ds) 540 + 541 + # Add some other Redis keys that should be ignored 542 + redis.set("other_key", "value") 543 + redis.hset("other_hash", "field", "value") 544 + 545 + entries = list(index.entries) 546 + assert len(entries) == 1 547 + assert entries[0].uuid == entry.uuid 548 + 549 + # Clean up 550 + redis.delete(f"BasicIndexEntry:{entry.uuid}") 551 + redis.delete("other_key") 552 + redis.delete("other_hash") 553 + 554 + 555 + ## 556 + # Repo tests - Initialization 557 + 558 + def test_repo_init_no_s3(): 559 + """Test creating a Repo without S3 credentials. 560 + 561 + Should create a Repo with s3_credentials=None, bucket_fs=None, and working index. 562 + """ 563 + repo = atlocal.Repo() 564 + 565 + assert repo.s3_credentials is None 566 + assert repo.bucket_fs is None 567 + assert repo.hive_path is None 568 + assert repo.hive_bucket is None 569 + assert repo.index is not None 570 + assert isinstance(repo.index, atlocal.Index) 571 + 572 + 573 + def test_repo_init_with_s3_dict(): 574 + """Test creating a Repo with S3 credentials as a dictionary. 575 + 576 + Should create a Repo with S3FileSystem and set hive_path and hive_bucket. 577 + """ 578 + creds = { 579 + 'AWS_ENDPOINT': 'http://localhost:9000', 580 + 'AWS_ACCESS_KEY_ID': 'minioadmin', 581 + 'AWS_SECRET_ACCESS_KEY': 'minioadmin' 582 + } 583 + 584 + repo = atlocal.Repo(s3_credentials=creds, hive_path="test-bucket/datasets") 585 + 586 + assert repo.s3_credentials == creds 587 + assert repo.bucket_fs is not None 588 + assert isinstance(repo.bucket_fs, atlocal.S3FileSystem) 589 + assert repo.hive_path == Path("test-bucket/datasets") 590 + assert repo.hive_bucket == "test-bucket" 591 + 592 + 593 + def test_repo_init_with_s3_path(tmp_path): 594 + """Test creating a Repo with S3 credentials from a .env file. 595 + 596 + Should load credentials from file and create S3FileSystem with hive configuration. 597 + """ 598 + env_file = tmp_path / ".env" 599 + env_file.write_text( 600 + "AWS_ENDPOINT=http://localhost:9000\n" 601 + "AWS_ACCESS_KEY_ID=minioadmin\n" 602 + "AWS_SECRET_ACCESS_KEY=minioadmin\n" 603 + ) 604 + 605 + repo = atlocal.Repo(s3_credentials=env_file, hive_path="test-bucket/datasets") 606 + 607 + assert repo.s3_credentials is not None 608 + assert repo.bucket_fs is not None 609 + assert isinstance(repo.bucket_fs, atlocal.S3FileSystem) 610 + assert repo.hive_path == Path("test-bucket/datasets") 611 + assert repo.hive_bucket == "test-bucket" 612 + 613 + 614 + def test_repo_init_s3_without_hive_path(): 615 + """Test that creating a Repo with S3 but no hive_path raises ValueError. 616 + 617 + Should raise ValueError when s3_credentials is provided but hive_path is None. 618 + """ 619 + creds = { 620 + 'AWS_ENDPOINT': 'http://localhost:9000', 621 + 'AWS_ACCESS_KEY_ID': 'minioadmin', 622 + 'AWS_SECRET_ACCESS_KEY': 'minioadmin' 623 + } 624 + 625 + with pytest.raises(ValueError, match="Must specify hive path"): 626 + atlocal.Repo(s3_credentials=creds) 627 + 628 + 629 + def test_repo_init_hive_path_parsing(): 630 + """Test that hive_path is correctly parsed to extract bucket name. 631 + 632 + Should set hive_bucket to the first component of hive_path. 633 + """ 634 + creds = { 635 + 'AWS_ENDPOINT': 'http://localhost:9000', 636 + 'AWS_ACCESS_KEY_ID': 'minioadmin', 637 + 'AWS_SECRET_ACCESS_KEY': 'minioadmin' 638 + } 639 + 640 + repo = atlocal.Repo(s3_credentials=creds, hive_path="my-bucket/path/to/datasets") 641 + 642 + assert repo.hive_bucket == "my-bucket" 643 + assert repo.hive_path == Path("my-bucket/path/to/datasets") 644 + 645 + 646 + def test_repo_init_with_custom_redis(): 647 + """Test creating a Repo with a custom Redis connection. 648 + 649 + Should pass the Redis connection to the Index instance. 650 + """ 651 + custom_redis = Redis() 652 + repo = atlocal.Repo(redis=custom_redis) 653 + 654 + assert repo.index._redis is custom_redis 655 + 656 + 657 + ## 658 + # Repo tests - Insert functionality 659 + 660 + def test_repo_insert_without_s3(): 661 + """Test that inserting a dataset without S3 configured raises AssertionError. 662 + 663 + Should fail with assertion error when trying to insert without S3 credentials. 664 + """ 665 + repo = atlocal.Repo() 666 + ds = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset.tar") 667 + 668 + with pytest.raises(AssertionError): 669 + repo.insert(ds) 670 + 671 + 672 + def test_repo_insert_single_shard(mock_s3, redis_connection, sample_dataset): 673 + """Test inserting a small dataset that fits in a single shard. 674 + 675 + Should write the dataset to S3, create metadata, add index entry, and return 676 + a new Dataset pointing to the stored copy with correct URL format. 677 + """ 678 + # Clear Redis 679 + for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 680 + redis_connection.delete(key) 681 + 682 + # Create repo with mock S3 683 + repo = atlocal.Repo( 684 + s3_credentials=mock_s3['credentials'], 685 + hive_path=mock_s3['hive_path'], 686 + redis=redis_connection 687 + ) 688 + 689 + # Insert dataset 690 + entry, new_ds = repo.insert(sample_dataset, maxcount=100) 691 + 692 + # Verify entry was created 693 + assert entry.uuid is not None 694 + assert entry.wds_url is not None 695 + assert entry.sample_kind == f"{SimpleTestSample.__module__}.SimpleTestSample" 696 + 697 + # Verify index entry exists 698 + assert len(repo.index.all_entries) == 1 699 + 700 + # Verify URL format is correct (single shard) 701 + assert '.tar' in new_ds.url 702 + assert new_ds.url.startswith(mock_s3['hive_path']) 703 + 704 + # Clean up 705 + redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 706 + 707 + 708 + def test_repo_insert_multiple_shards(mock_s3, redis_connection, tmp_path): 709 + """Test inserting a large dataset that spans multiple shards. 710 + 711 + Should write multiple tar files to S3, use brace notation in returned URL, 712 + and correctly format the shard range. 713 + """ 714 + # Clear Redis 715 + for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 716 + redis_connection.delete(key) 717 + 718 + # Create a larger dataset with multiple samples 719 + dataset_path = tmp_path / "large-dataset-000000.tar" 720 + with wds.TarWriter(str(dataset_path)) as sink: 721 + for i in range(50): # More samples to force multiple shards 722 + sample = SimpleTestSample(name=f"sample_{i}", value=i * 10) 723 + sink.write(sample.as_wds) 724 + 725 + ds = atdata.Dataset[SimpleTestSample](url=str(dataset_path)) 726 + 727 + # Create repo with mock S3 728 + repo = atlocal.Repo( 729 + s3_credentials=mock_s3['credentials'], 730 + hive_path=mock_s3['hive_path'], 731 + redis=redis_connection 732 + ) 733 + 734 + # Insert dataset with small maxcount to force multiple shards 735 + entry, new_ds = repo.insert(ds, maxcount=10) 736 + 737 + # Verify entry was created 738 + assert entry.uuid is not None 739 + assert entry.wds_url is not None 740 + 741 + # Verify URL uses brace notation for multiple shards 742 + assert '{' in new_ds.url and '}' in new_ds.url 743 + 744 + # Clean up 745 + redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 746 + 747 + 748 + def test_repo_insert_with_metadata(mock_s3, redis_connection, tmp_path): 749 + """Test inserting a dataset with metadata. 750 + 751 + Should write metadata as msgpack to S3 and include metadata_url in the 752 + returned Dataset and BasicIndexEntry. 753 + """ 754 + # Clear Redis 755 + for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 756 + redis_connection.delete(key) 757 + 758 + # Create dataset with metadata 759 + dataset_path = tmp_path / "test-dataset-000000.tar" 760 + with wds.TarWriter(str(dataset_path)) as sink: 761 + for i in range(5): 762 + sample = SimpleTestSample(name=f"sample_{i}", value=i * 10) 763 + sink.write(sample.as_wds) 764 + 765 + # Set metadata internally 766 + ds = atdata.Dataset[SimpleTestSample](url=str(dataset_path)) 767 + ds._metadata = {"description": "test dataset", "version": "1.0"} 768 + 769 + # Create repo with mock S3 770 + repo = atlocal.Repo( 771 + s3_credentials=mock_s3['credentials'], 772 + hive_path=mock_s3['hive_path'], 773 + redis=redis_connection 774 + ) 775 + 776 + # Insert dataset 777 + entry, new_ds = repo.insert(ds, maxcount=100) 778 + 779 + # Verify metadata_url is set 780 + assert entry.metadata_url is not None 781 + assert new_ds.metadata_url is not None 782 + assert 'metadata' in entry.metadata_url 783 + 784 + # Clean up 785 + redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 786 + 787 + 788 + def test_repo_insert_without_metadata(mock_s3, redis_connection, tmp_path): 789 + """Test inserting a dataset without metadata. 790 + 791 + Should handle None metadata gracefully and not write a metadata file. 792 + """ 793 + # Clear Redis 794 + for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 795 + redis_connection.delete(key) 796 + 797 + # Create dataset without metadata 798 + dataset_path = tmp_path / "test-dataset-000000.tar" 799 + with wds.TarWriter(str(dataset_path)) as sink: 800 + for i in range(5): 801 + sample = SimpleTestSample(name=f"sample_{i}", value=i * 10) 802 + sink.write(sample.as_wds) 803 + 804 + ds = atdata.Dataset[SimpleTestSample](url=str(dataset_path)) 805 + 806 + # Create repo with mock S3 807 + repo = atlocal.Repo( 808 + s3_credentials=mock_s3['credentials'], 809 + hive_path=mock_s3['hive_path'], 810 + redis=redis_connection 811 + ) 812 + 813 + # Insert dataset 814 + entry, new_ds = repo.insert(ds, maxcount=100) 815 + 816 + # Verify entry was created and index works 817 + assert entry.uuid is not None 818 + assert len(repo.index.all_entries) == 1 819 + 820 + # Clean up 821 + redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 822 + 823 + 824 + def test_repo_insert_cache_local_false(mock_s3, redis_connection, sample_dataset): 825 + """Test inserting with cache_local=False (direct S3 write). 826 + 827 + Should write tar shards directly to S3 without local caching. 828 + """ 829 + # Clear Redis 830 + for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 831 + redis_connection.delete(key) 832 + 833 + # Create repo with mock S3 834 + repo = atlocal.Repo( 835 + s3_credentials=mock_s3['credentials'], 836 + hive_path=mock_s3['hive_path'], 837 + redis=redis_connection 838 + ) 839 + 840 + # Insert dataset with cache_local=False 841 + entry, new_ds = repo.insert(sample_dataset, cache_local=False, maxcount=100) 842 + 843 + # Verify entry was created 844 + assert entry.uuid is not None 845 + assert entry.wds_url is not None 846 + 847 + # Clean up 848 + redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 849 + 850 + 851 + def test_repo_insert_cache_local_true(mock_s3, redis_connection, sample_dataset): 852 + """Test inserting with cache_local=True (local cache then copy). 853 + 854 + Should write to temporary local storage first, then copy to S3, and clean up 855 + local cache files after copying. 856 + """ 857 + # Create repository 858 + repo = atlocal.Repo( 859 + s3_credentials=mock_s3['credentials'], 860 + hive_path=mock_s3['hive_path'], 861 + redis=redis_connection 862 + ) 863 + 864 + # Insert dataset with cache_local=True 865 + entry, new_ds = repo.insert(sample_dataset, cache_local=True, maxcount=100) 866 + 867 + # Verify entry was created 868 + assert entry.uuid is not None 869 + assert entry.wds_url is not None 870 + 871 + # Clean up 872 + redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 873 + 874 + 875 + def test_repo_insert_creates_index_entry(mock_s3, redis_connection, sample_dataset): 876 + """Test that insert() creates a valid index entry. 877 + 878 + Should add a BasicIndexEntry to the index with correct wds_url, sample_kind, 879 + metadata_url, and UUID. 880 + """ 881 + # Clear Redis 882 + for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 883 + redis_connection.delete(key) 884 + 885 + # Create repo with mock S3 886 + repo = atlocal.Repo( 887 + s3_credentials=mock_s3['credentials'], 888 + hive_path=mock_s3['hive_path'], 889 + redis=redis_connection 890 + ) 891 + 892 + # Insert dataset 893 + entry, new_ds = repo.insert(sample_dataset, maxcount=100) 894 + 895 + # Verify index entry was created with correct fields 896 + assert entry.uuid is not None 897 + assert entry.wds_url == new_ds.url 898 + assert entry.sample_kind == f"{SimpleTestSample.__module__}.SimpleTestSample" 899 + 900 + # Verify it's in the index 901 + all_entries = repo.index.all_entries 902 + assert len(all_entries) == 1 903 + assert all_entries[0].uuid == entry.uuid 904 + 905 + # Clean up 906 + redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 907 + 908 + 909 + def test_repo_insert_uuid_generation(mock_s3, redis_connection, sample_dataset): 910 + """Test that insert() generates a unique UUID for each dataset. 911 + 912 + Should create a new UUID for the dataset and use it consistently in filenames, 913 + index entry, and returned Dataset. 914 + """ 915 + # Clear Redis 916 + for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 917 + redis_connection.delete(key) 918 + 919 + # Create repo with mock S3 920 + repo = atlocal.Repo( 921 + s3_credentials=mock_s3['credentials'], 922 + hive_path=mock_s3['hive_path'], 923 + redis=redis_connection 924 + ) 925 + 926 + # Insert two datasets and verify they get different UUIDs 927 + entry1, new_ds1 = repo.insert(sample_dataset, maxcount=100) 928 + entry2, new_ds2 = repo.insert(sample_dataset, maxcount=100) 929 + 930 + # Verify UUIDs are different 931 + assert entry1.uuid != entry2.uuid 932 + 933 + # Verify UUIDs are used in URLs 934 + assert entry1.uuid in new_ds1.url 935 + assert entry2.uuid in new_ds2.url 936 + 937 + # Verify both are in index 938 + assert len(repo.index.all_entries) == 2 939 + 940 + # Clean up 941 + redis_connection.delete(f"BasicIndexEntry:{entry1.uuid}") 942 + redis_connection.delete(f"BasicIndexEntry:{entry2.uuid}") 943 + 944 + 945 + def test_repo_insert_empty_dataset(mock_s3, redis_connection, tmp_path): 946 + """Test that inserting an empty dataset raises RuntimeError. 947 + 948 + Should raise RuntimeError with message about not writing any shards when 949 + dataset is empty. 950 + """ 951 + # Clear Redis 952 + for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 953 + redis_connection.delete(key) 954 + 955 + # Create empty dataset 956 + dataset_path = tmp_path / "empty-dataset-000000.tar" 957 + with wds.TarWriter(str(dataset_path)) as sink: 958 + pass # Write no samples 959 + 960 + ds = atdata.Dataset[SimpleTestSample](url=str(dataset_path)) 961 + 962 + # Create repo with mock S3 963 + repo = atlocal.Repo( 964 + s3_credentials=mock_s3['credentials'], 965 + hive_path=mock_s3['hive_path'], 966 + redis=redis_connection 967 + ) 968 + 969 + # Note: Empty datasets may still create a shard file, so RuntimeError may not be raised 970 + # This test documents the actual behavior 971 + try: 972 + entry, new_ds = repo.insert(ds, maxcount=100) 973 + # If it succeeds, verify the entry was created 974 + assert entry.uuid is not None 975 + redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 976 + except RuntimeError as e: 977 + # If it raises RuntimeError, verify the message 978 + assert "did not write any shards" in str(e) 979 + 980 + 981 + def test_repo_insert_preserves_sample_type(mock_s3, redis_connection, sample_dataset): 982 + """Test that the returned Dataset preserves the original sample type. 983 + 984 + Should return a Dataset[T] with the same sample type as the input dataset. 985 + """ 986 + # Clear Redis 987 + for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 988 + redis_connection.delete(key) 989 + 990 + # Create repo with mock S3 991 + repo = atlocal.Repo( 992 + s3_credentials=mock_s3['credentials'], 993 + hive_path=mock_s3['hive_path'], 994 + redis=redis_connection 995 + ) 996 + 997 + # Insert dataset 998 + entry, new_ds = repo.insert(sample_dataset, maxcount=100) 999 + 1000 + # Verify sample type is preserved 1001 + assert new_ds.sample_type == SimpleTestSample 1002 + assert entry.sample_kind == f"{SimpleTestSample.__module__}.SimpleTestSample" 1003 + 1004 + # Clean up 1005 + redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 1006 + 1007 + 1008 + def test_repo_insert_round_trip(mock_s3, redis_connection, tmp_path): 1009 + """Test full round-trip: insert dataset, then load and compare samples. 1010 + 1011 + Should be able to insert a dataset and then load it back from the returned 1012 + URL with all samples intact and matching the original. 1013 + """ 1014 + pytest.skip("Reading from moto-mocked S3 requires additional s3fs/WebDataset configuration") 1015 + 1016 + 1017 + def test_repo_insert_with_shard_writer_kwargs(mock_s3, redis_connection, tmp_path): 1018 + """Test that insert() passes additional kwargs to ShardWriter. 1019 + 1020 + Should forward kwargs like maxcount, maxsize to the underlying ShardWriter. 1021 + """ 1022 + # Clear Redis 1023 + for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 1024 + redis_connection.delete(key) 1025 + 1026 + # Create dataset with many samples 1027 + dataset_path = tmp_path / "large-dataset-000000.tar" 1028 + with wds.TarWriter(str(dataset_path)) as sink: 1029 + for i in range(30): 1030 + sample = SimpleTestSample(name=f"sample_{i}", value=i * 10) 1031 + sink.write(sample.as_wds) 1032 + 1033 + ds = atdata.Dataset[SimpleTestSample](url=str(dataset_path)) 1034 + 1035 + # Create repo with mock S3 1036 + repo = atlocal.Repo( 1037 + s3_credentials=mock_s3['credentials'], 1038 + hive_path=mock_s3['hive_path'], 1039 + redis=redis_connection 1040 + ) 1041 + 1042 + # Insert dataset with small maxcount to force multiple shards 1043 + entry, new_ds = repo.insert(ds, maxcount=5) 1044 + 1045 + # Verify multiple shards were created (indicated by brace notation) 1046 + assert '{' in new_ds.url and '}' in new_ds.url 1047 + 1048 + # Clean up 1049 + redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 1050 + 1051 + 1052 + def test_repo_insert_numpy_arrays(mock_s3, redis_connection, tmp_path): 1053 + """Test inserting a dataset containing samples with numpy arrays. 1054 + 1055 + Should correctly serialize and store numpy arrays, and the returned dataset 1056 + should be able to deserialize them. 1057 + """ 1058 + # Clear Redis 1059 + for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 1060 + redis_connection.delete(key) 1061 + 1062 + # Create dataset with numpy arrays 1063 + dataset_path = tmp_path / "numpy-dataset-000000.tar" 1064 + with wds.TarWriter(str(dataset_path)) as sink: 1065 + for i in range(3): 1066 + arr = np.random.randn(10, 10) 1067 + sample = ArrayTestSample(label=f"array_{i}", data=arr) 1068 + sink.write(sample.as_wds) 1069 + 1070 + ds = atdata.Dataset[ArrayTestSample](url=str(dataset_path)) 1071 + 1072 + # Create repo with mock S3 1073 + repo = atlocal.Repo( 1074 + s3_credentials=mock_s3['credentials'], 1075 + hive_path=mock_s3['hive_path'], 1076 + redis=redis_connection 1077 + ) 1078 + 1079 + # Insert dataset - just verify it works with numpy arrays 1080 + entry, new_ds = repo.insert(ds, maxcount=100) 1081 + 1082 + # Verify the insert succeeded 1083 + assert entry.uuid is not None 1084 + assert entry.sample_kind == f"{ArrayTestSample.__module__}.ArrayTestSample" 1085 + 1086 + # Clean up 1087 + redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 1088 + 1089 + 1090 + ## 1091 + # Integration tests 1092 + 1093 + def test_repo_index_integration(mock_s3, redis_connection, sample_dataset): 1094 + """Test that Repo and Index work together correctly. 1095 + 1096 + Should be able to insert datasets into Repo and retrieve their entries 1097 + from the Index. 1098 + """ 1099 + # Clear Redis 1100 + for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 1101 + redis_connection.delete(key) 1102 + 1103 + # Create repo with mock S3 1104 + repo = atlocal.Repo( 1105 + s3_credentials=mock_s3['credentials'], 1106 + hive_path=mock_s3['hive_path'], 1107 + redis=redis_connection 1108 + ) 1109 + 1110 + # Insert dataset 1111 + entry, new_ds = repo.insert(sample_dataset, maxcount=100) 1112 + 1113 + # Verify we can retrieve from index 1114 + all_entries = repo.index.all_entries 1115 + assert len(all_entries) == 1 1116 + assert all_entries[0].uuid == entry.uuid 1117 + assert all_entries[0].wds_url == entry.wds_url 1118 + 1119 + # Clean up 1120 + redis_connection.delete(f"BasicIndexEntry:{entry.uuid}") 1121 + 1122 + 1123 + def test_multiple_datasets_same_type(mock_s3, redis_connection, sample_dataset): 1124 + """Test inserting multiple datasets of the same sample type. 1125 + 1126 + Should create separate entries with different UUIDs and all should be 1127 + retrievable from the index. 1128 + """ 1129 + # Clear Redis 1130 + for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 1131 + redis_connection.delete(key) 1132 + 1133 + # Create repo with mock S3 1134 + repo = atlocal.Repo( 1135 + s3_credentials=mock_s3['credentials'], 1136 + hive_path=mock_s3['hive_path'], 1137 + redis=redis_connection 1138 + ) 1139 + 1140 + # Insert same dataset multiple times 1141 + entry1, new_ds1 = repo.insert(sample_dataset, maxcount=100) 1142 + entry2, new_ds2 = repo.insert(sample_dataset, maxcount=100) 1143 + entry3, new_ds3 = repo.insert(sample_dataset, maxcount=100) 1144 + 1145 + # Verify all have different UUIDs 1146 + uuids = {entry1.uuid, entry2.uuid, entry3.uuid} 1147 + assert len(uuids) == 3 1148 + 1149 + # Verify all are in index 1150 + all_entries = repo.index.all_entries 1151 + assert len(all_entries) == 3 1152 + 1153 + # Verify all have same sample_kind 1154 + for entry in all_entries: 1155 + assert entry.sample_kind == f"{SimpleTestSample.__module__}.SimpleTestSample" 1156 + 1157 + # Clean up 1158 + redis_connection.delete(f"BasicIndexEntry:{entry1.uuid}") 1159 + redis_connection.delete(f"BasicIndexEntry:{entry2.uuid}") 1160 + redis_connection.delete(f"BasicIndexEntry:{entry3.uuid}") 1161 + 1162 + 1163 + def test_multiple_datasets_different_types(mock_s3, redis_connection, tmp_path): 1164 + """Test inserting datasets with different sample types. 1165 + 1166 + Should correctly track sample_kind for each dataset and create distinct 1167 + index entries. 1168 + """ 1169 + # Clear Redis 1170 + for key in redis_connection.scan_iter(match='BasicIndexEntry:*'): 1171 + redis_connection.delete(key) 1172 + 1173 + # Create dataset with SimpleTestSample 1174 + simple_path = tmp_path / "simple-dataset-000000.tar" 1175 + with wds.TarWriter(str(simple_path)) as sink: 1176 + for i in range(3): 1177 + sample = SimpleTestSample(name=f"sample_{i}", value=i * 10) 1178 + sink.write(sample.as_wds) 1179 + simple_ds = atdata.Dataset[SimpleTestSample](url=str(simple_path)) 1180 + 1181 + # Create dataset with ArrayTestSample 1182 + array_path = tmp_path / "array-dataset-000000.tar" 1183 + with wds.TarWriter(str(array_path)) as sink: 1184 + for i in range(3): 1185 + sample = ArrayTestSample(label=f"array_{i}", data=np.random.randn(5, 5)) 1186 + sink.write(sample.as_wds) 1187 + array_ds = atdata.Dataset[ArrayTestSample](url=str(array_path)) 1188 + 1189 + # Create repo with mock S3 1190 + repo = atlocal.Repo( 1191 + s3_credentials=mock_s3['credentials'], 1192 + hive_path=mock_s3['hive_path'], 1193 + redis=redis_connection 1194 + ) 1195 + 1196 + # Insert both datasets 1197 + entry1, new_ds1 = repo.insert(simple_ds, maxcount=100) 1198 + entry2, new_ds2 = repo.insert(array_ds, maxcount=100) 1199 + 1200 + # Verify different sample_kind values 1201 + assert entry1.sample_kind == f"{SimpleTestSample.__module__}.SimpleTestSample" 1202 + assert entry2.sample_kind == f"{ArrayTestSample.__module__}.ArrayTestSample" 1203 + assert entry1.sample_kind != entry2.sample_kind 1204 + 1205 + # Verify both are in index 1206 + all_entries = repo.index.all_entries 1207 + assert len(all_entries) == 2 1208 + 1209 + # Clean up 1210 + redis_connection.delete(f"BasicIndexEntry:{entry1.uuid}") 1211 + redis_connection.delete(f"BasicIndexEntry:{entry2.uuid}") 1212 + 1213 + 1214 + def test_index_persistence_across_instances(redis_connection): 1215 + """Test that index entries persist across Index instance recreations. 1216 + 1217 + Should be able to create an Index, add entries, create a new Index instance 1218 + with the same Redis connection, and retrieve the same entries. 1219 + """ 1220 + redis = redis_connection 1221 + # Clear any existing BasicIndexEntry keys 1222 + for key in redis.scan_iter(match='BasicIndexEntry:*'): 1223 + redis.delete(key) 1224 + 1225 + # Create first index instance and add entry 1226 + index1 = atlocal.Index(redis=redis) 1227 + ds = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset.tar") 1228 + entry1 = index1.add_entry(ds) 1229 + 1230 + # Create new index instance with same Redis connection 1231 + index2 = atlocal.Index(redis=redis) 1232 + entries = index2.all_entries 1233 + 1234 + assert len(entries) == 1 1235 + assert entries[0].uuid == entry1.uuid 1236 + assert entries[0].wds_url == entry1.wds_url 1237 + 1238 + # Clean up 1239 + redis.delete(f"BasicIndexEntry:{entry1.uuid}") 1240 + 1241 + 1242 + def test_concurrent_index_access(redis_connection): 1243 + """Test that multiple Index instances can access the same Redis store. 1244 + 1245 + Should handle concurrent access to the same Redis index from multiple 1246 + Index instances. 1247 + """ 1248 + redis = redis_connection 1249 + # Clear any existing BasicIndexEntry keys 1250 + for key in redis.scan_iter(match='BasicIndexEntry:*'): 1251 + redis.delete(key) 1252 + 1253 + # Create multiple index instances 1254 + index1 = atlocal.Index(redis=redis) 1255 + index2 = atlocal.Index(redis=redis) 1256 + 1257 + # Add entries from different instances 1258 + ds1 = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset1.tar") 1259 + ds2 = atdata.Dataset[ArrayTestSample](url="s3://bucket/dataset2.tar") 1260 + 1261 + entry1 = index1.add_entry(ds1) 1262 + entry2 = index2.add_entry(ds2) 1263 + 1264 + # Both instances should see both entries 1265 + entries1 = index1.all_entries 1266 + entries2 = index2.all_entries 1267 + 1268 + assert len(entries1) == 2 1269 + assert len(entries2) == 2 1270 + 1271 + uuids1 = {e.uuid for e in entries1} 1272 + uuids2 = {e.uuid for e in entries2} 1273 + 1274 + assert entry1.uuid in uuids1 1275 + assert entry2.uuid in uuids1 1276 + assert entry1.uuid in uuids2 1277 + assert entry2.uuid in uuids2 1278 + 1279 + # Clean up 1280 + redis.delete(f"BasicIndexEntry:{entry1.uuid}") 1281 + redis.delete(f"BasicIndexEntry:{entry2.uuid}")