A loose federation of distributed, typed datasets
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

ci: modernize GitHub Actions with caching, concurrency, lint checks, and trusted publishing

- Add uv caching (enable-cache: true) for faster CI runs
- Add concurrency control to cancel in-progress runs on new commits
- Add ruff lint job (check + format) targeting src/ and tests/
- Switch to --locked flag for reproducible dependency resolution
- Add fail-fast: false to test matrix to see all failures
- Enable Codecov coverage upload
- Switch PyPI publishing to trusted publishing (OIDC)
- Split publish workflow into build and publish jobs with artifacts
- Fix all ruff lint issues (unused imports, undefined names)
- Format entire codebase with ruff format (42 files)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

+2078 -1214
.chainlink/issues.db

This is a binary file and will not be displayed.

+40 -21
.github/workflows/uv-publish-pypi.yml
··· 1 - # 2 - 3 1 name: Build and upload package to PyPI 4 2 5 3 on: ··· 9 7 10 8 permissions: 11 9 contents: read 10 + id-token: write 12 11 13 - jobs: 12 + concurrency: 13 + group: ${{ github.workflow }}-${{ github.ref }} 14 + cancel-in-progress: true 14 15 15 - uv-build-release-pypi-publish: 16 - name: "Build release distribution and publish to PyPI" 16 + jobs: 17 + build: 18 + name: Build release distribution 17 19 runs-on: ubuntu-latest 18 - environment: 19 - name: pypi 20 - 20 + 21 21 steps: 22 22 - uses: actions/checkout@v5 23 - 24 - - name: "Set up Python" 23 + 24 + - name: Set up Python 25 25 uses: actions/setup-python@v5 26 26 with: 27 27 python-version-file: "pyproject.toml" 28 28 29 29 - name: Install uv 30 30 uses: astral-sh/setup-uv@v6 31 - 31 + with: 32 + enable-cache: true 33 + 32 34 - name: Install project 33 - run: uv sync --all-extras --dev 34 - # TODO Better to use --locked for author control over versions? 35 - # run: uv sync --locked --all-extras --dev 36 - 35 + run: uv sync --locked --all-extras --dev 36 + 37 37 - name: Build release distributions 38 38 run: uv build 39 - 40 - - name: Publish to PyPI 41 - env: 42 - UV_PUBLISH_TOKEN: ${{ secrets.UV_PUBLISH_TOKEN }} 43 - run: uv publish 39 + 40 + - name: Upload dist artifacts 41 + uses: actions/upload-artifact@v4 42 + with: 43 + name: dist 44 + path: dist/ 44 45 46 + publish: 47 + name: Publish to PyPI 48 + runs-on: ubuntu-latest 49 + needs: build 50 + environment: 51 + name: pypi 52 + url: https://pypi.org/project/atdata/ 45 53 46 - ## 54 + steps: 55 + - name: Download dist artifacts 56 + uses: actions/download-artifact@v4 57 + with: 58 + name: dist 59 + path: dist/ 60 + 61 + - name: Install uv 62 + uses: astral-sh/setup-uv@v6 63 + 64 + - name: Publish to PyPI 65 + run: uv publish --trusted-publishing always dist/*
+40 -19
.github/workflows/uv-test.yml
··· 1 - # 2 - 3 1 name: Run tests with `uv` 4 2 5 3 on: ··· 11 9 branches: 12 10 - main 13 11 12 + permissions: 13 + contents: read 14 + 15 + concurrency: 16 + group: ${{ github.workflow }}-${{ github.ref }} 17 + cancel-in-progress: true 18 + 14 19 jobs: 15 - uv-test: 16 - name: Run tests 20 + lint: 21 + name: Lint 22 + runs-on: ubuntu-latest 23 + steps: 24 + - uses: actions/checkout@v5 25 + 26 + - name: Install uv 27 + uses: astral-sh/setup-uv@v6 28 + with: 29 + enable-cache: true 30 + 31 + - name: Install the project 32 + run: uv sync --locked --dev 33 + 34 + - name: Run ruff check 35 + run: uv run ruff check src/ tests/ 36 + 37 + - name: Run ruff format check 38 + run: uv run ruff format --check src/ tests/ 39 + 40 + test: 41 + name: Test (py${{ matrix.python-version }}, redis${{ matrix.redis-version }}) 17 42 runs-on: ubuntu-latest 18 43 environment: 19 44 name: test 20 45 strategy: 46 + fail-fast: false 21 47 matrix: 22 - python-version: [3.12, 3.13, 3.14] 48 + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] 23 49 redis-version: [6, 7] 24 50 25 51 steps: 26 52 - uses: actions/checkout@v5 27 53 28 - - name: "Set up Python" 54 + - name: Set up Python 29 55 uses: actions/setup-python@v5 30 56 with: 31 57 python-version: ${{ matrix.python-version }} 32 - # python-version-file: "pyproject.toml" 33 58 34 59 - name: Install uv 35 60 uses: astral-sh/setup-uv@v6 61 + with: 62 + enable-cache: true 36 63 37 64 - name: Install the project 38 - run: uv sync --all-extras --dev 39 - # TODO Better to use --locked for author control over versions? 40 - # run: uv sync --locked --all-extras --dev 65 + run: uv sync --locked --all-extras --dev 41 66 42 67 - name: Start Redis 43 68 uses: supercharge/redis-github-action@1.8.1 ··· 47 72 - name: Run tests with coverage 48 73 run: uv run pytest --cov=atdata --cov-report=xml --cov-report=term 49 74 50 - # - name: Upload coverage to Codecov 51 - # uses: codecov/codecov-action@v5 52 - # with: 53 - # # file: ./coverage.xml # Claude hallucination -- fascinating! 54 - # fail_ci_if_error: false 55 - # token: ${{ secrets.CODECOV_TOKEN }} 56 - 57 - 58 - # 75 + - name: Upload coverage to Codecov 76 + uses: codecov/codecov-action@v5 77 + with: 78 + fail_ci_if_error: false 79 + token: ${{ secrets.CODECOV_TOKEN }}
+1
CHANGELOG.md
··· 25 25 - **Comprehensive integration test suite**: 593 tests covering E2E flows, error handling, edge cases 26 26 27 27 ### Changed 28 + - Review GitHub workflows and recommend CI improvements (#405) 28 29 - Fix type signatures for Dataset.ordered and Dataset.shuffled (GH#28) (#404) 29 30 - Investigate quartodoc Example section rendering - missing CSS classes on pre/code tags (#401) 30 31 - Update all docstrings from Example: to Examples: format (#403)
+1 -1
src/atdata/__init__.py
··· 88 88 from . import atmosphere as atmosphere 89 89 90 90 # CLI entry point 91 - from .cli import main as main 91 + from .cli import main as main
+6 -2
src/atdata/_cid.py
··· 64 64 # Build raw CID bytes: 65 65 # CIDv1 = version(1) + codec(dag-cbor) + multihash 66 66 # Multihash = code(sha256) + size(32) + digest 67 - raw_cid_bytes = bytes([CID_VERSION_1, CODEC_DAG_CBOR, HASH_SHA256, SHA256_SIZE]) + sha256_hash 67 + raw_cid_bytes = ( 68 + bytes([CID_VERSION_1, CODEC_DAG_CBOR, HASH_SHA256, SHA256_SIZE]) + sha256_hash 69 + ) 68 70 69 71 # Encode to base32 multibase string 70 72 return libipld.encode_cid(raw_cid_bytes) ··· 87 89 >>> cid = generate_cid_from_bytes(cbor_bytes) 88 90 """ 89 91 sha256_hash = hashlib.sha256(data_bytes).digest() 90 - raw_cid_bytes = bytes([CID_VERSION_1, CODEC_DAG_CBOR, HASH_SHA256, SHA256_SIZE]) + sha256_hash 92 + raw_cid_bytes = ( 93 + bytes([CID_VERSION_1, CODEC_DAG_CBOR, HASH_SHA256, SHA256_SIZE]) + sha256_hash 94 + ) 91 95 return libipld.encode_cid(raw_cid_bytes) 92 96 93 97
+7 -5
src/atdata/_helpers.py
··· 22 22 23 23 ## 24 24 25 - def array_to_bytes( x: np.ndarray ) -> bytes: 25 + 26 + def array_to_bytes(x: np.ndarray) -> bytes: 26 27 """Convert a numpy array to bytes for msgpack serialization. 27 28 28 29 Uses numpy's native ``save()`` format to preserve array dtype and shape. ··· 37 38 Uses ``allow_pickle=True`` to support object dtypes. 38 39 """ 39 40 np_bytes = BytesIO() 40 - np.save( np_bytes, x, allow_pickle = True ) 41 + np.save(np_bytes, x, allow_pickle=True) 41 42 return np_bytes.getvalue() 42 43 43 - def bytes_to_array( b: bytes ) -> np.ndarray: 44 + 45 + def bytes_to_array(b: bytes) -> np.ndarray: 44 46 """Convert serialized bytes back to a numpy array. 45 47 46 48 Reverses the serialization performed by ``array_to_bytes()``. ··· 54 56 Note: 55 57 Uses ``allow_pickle=True`` to support object dtypes. 56 58 """ 57 - np_bytes = BytesIO( b ) 58 - return np.load( np_bytes, allow_pickle = True ) 59 + np_bytes = BytesIO(b) 60 + return np.load(np_bytes, allow_pickle=True)
+8 -4
src/atdata/_hf_api.py
··· 46 46 47 47 if TYPE_CHECKING: 48 48 from ._protocols import AbstractIndex 49 - from .local import S3DataStore 50 49 51 50 ## 52 51 # Type variables ··· 77 76 >>> for split_name, dataset in ds_dict.items(): 78 77 ... print(f"{split_name}: {len(dataset.shard_list)} shards") 79 78 """ 79 + 80 80 # TODO The above has a line for "Parameters:" that should be "Type Parameters:"; this is a temporary fix for `quartodoc` auto-generation bugs. 81 81 82 82 def __init__( ··· 464 464 data_urls = entry.data_urls 465 465 466 466 # Check if index has a data store 467 - if hasattr(index, 'data_store') and index.data_store is not None: 467 + if hasattr(index, "data_store") and index.data_store is not None: 468 468 store = index.data_store 469 469 470 470 # Import here to avoid circular imports at module level ··· 638 638 source, schema_ref = _resolve_indexed_path(path, index) 639 639 640 640 # Resolve sample_type from schema if not provided 641 - resolved_type: Type = sample_type if sample_type is not None else index.decode_schema(schema_ref) 641 + resolved_type: Type = ( 642 + sample_type if sample_type is not None else index.decode_schema(schema_ref) 643 + ) 642 644 643 645 # Create dataset from the resolved source (includes credentials if S3) 644 646 ds = Dataset[resolved_type](source) ··· 647 649 # Indexed datasets are single-split by default 648 650 return ds 649 651 650 - return DatasetDict({"train": ds}, sample_type=resolved_type, streaming=streaming) 652 + return DatasetDict( 653 + {"train": ds}, sample_type=resolved_type, streaming=streaming 654 + ) 651 655 652 656 # Use DictSample as default when no type specified 653 657 resolved_type = sample_type if sample_type is not None else DictSample
-1
src/atdata/_protocols.py
··· 32 32 from typing import ( 33 33 IO, 34 34 Any, 35 - ClassVar, 36 35 Iterator, 37 36 Optional, 38 37 Protocol,
+6 -2
src/atdata/_schema_codec.py
··· 203 203 namespace={ 204 204 "__post_init__": lambda self: PackableSample.__post_init__(self), 205 205 "__schema_version__": version, 206 - "__schema_ref__": schema.get("$ref", None), # Store original ref if available 206 + "__schema_ref__": schema.get( 207 + "$ref", None 208 + ), # Store original ref if available 207 209 }, 208 210 ) 209 211 ··· 239 241 240 242 if kind == "primitive": 241 243 primitive = field_type.get("primitive", "str") 242 - py_type = primitive # str, int, float, bool, bytes are all valid Python type names 244 + py_type = ( 245 + primitive # str, int, float, bool, bytes are all valid Python type names 246 + ) 243 247 elif kind == "ndarray": 244 248 py_type = "NDArray[Any]" 245 249 elif kind == "array":
+8 -3
src/atdata/_sources.py
··· 167 167 client_kwargs["region_name"] = self.region 168 168 elif not self.endpoint: 169 169 # Default region for AWS S3 170 - client_kwargs["region_name"] = os.environ.get("AWS_DEFAULT_REGION", "us-east-1") 170 + client_kwargs["region_name"] = os.environ.get( 171 + "AWS_DEFAULT_REGION", "us-east-1" 172 + ) 171 173 172 174 self._client = boto3.client("s3", **client_kwargs) 173 175 return self._client ··· 219 221 if not shard_id.startswith(f"s3://{self.bucket}/"): 220 222 raise KeyError(f"Shard not in this bucket: {shard_id}") 221 223 222 - key = shard_id[len(f"s3://{self.bucket}/"):] 224 + key = shard_id[len(f"s3://{self.bucket}/") :] 223 225 client = self._get_client() 224 226 response = client.get_object(Bucket=self.bucket, Key=key) 225 227 return response["Body"] ··· 355 357 356 358 blob_refs: list[dict[str, str]] 357 359 pds_endpoint: str | None = None 358 - _endpoint_cache: dict[str, str] = field(default_factory=dict, repr=False, compare=False) 360 + _endpoint_cache: dict[str, str] = field( 361 + default_factory=dict, repr=False, compare=False 362 + ) 359 363 360 364 def _resolve_pds_endpoint(self, did: str) -> str: 361 365 """Resolve PDS endpoint for a DID, with caching.""" ··· 447 451 url = self._get_blob_url(did, cid) 448 452 449 453 import requests 454 + 450 455 response = requests.get(url, stream=True, timeout=60) 451 456 response.raise_for_status() 452 457 return response.raw
+13 -4
src/atdata/_stub_manager.py
··· 153 153 """Alias for _module_filename for backwards compatibility.""" 154 154 return self._module_filename(name, version) 155 155 156 - def _module_path(self, name: str, version: str, authority: str = DEFAULT_AUTHORITY) -> Path: 156 + def _module_path( 157 + self, name: str, version: str, authority: str = DEFAULT_AUTHORITY 158 + ) -> Path: 157 159 """Get full path to module file for a schema. 158 160 159 161 Args: ··· 166 168 """ 167 169 return self._stub_dir / authority / self._module_filename(name, version) 168 170 169 - def _stub_path(self, name: str, version: str, authority: str = DEFAULT_AUTHORITY) -> Path: 171 + def _stub_path( 172 + self, name: str, version: str, authority: str = DEFAULT_AUTHORITY 173 + ) -> Path: 170 174 """Alias for _module_path for backwards compatibility.""" 171 175 return self._module_path(name, version, authority) 172 176 ··· 207 211 authority_dir.mkdir(parents=True, exist_ok=True) 208 212 init_path = authority_dir / "__init__.py" 209 213 if not init_path.exists(): 210 - init_path.write_text(f'"""Auto-generated schema modules for {authority}."""\n') 214 + init_path.write_text( 215 + f'"""Auto-generated schema modules for {authority}."""\n' 216 + ) 211 217 212 218 def _write_module_atomic(self, path: Path, content: str, authority: str) -> None: 213 219 """Write module file atomically using temp file and rename. ··· 355 361 356 362 return cls 357 363 358 - def _import_class_from_module(self, module_path: Path, class_name: str) -> Optional[Type]: 364 + def _import_class_from_module( 365 + self, module_path: Path, class_name: str 366 + ) -> Optional[Type]: 359 367 """Import a class from a generated module file. 360 368 361 369 Uses importlib to dynamically load the module and extract the class. ··· 395 403 def _print_ide_hint(self) -> None: 396 404 """Print a one-time hint about IDE configuration.""" 397 405 import sys as _sys 406 + 398 407 print( 399 408 f"\n[atdata] Generated schema module in: {self._stub_dir}\n" 400 409 f"[atdata] For IDE support, add this path to your type checker:\n"
+19 -5
src/atdata/_type_utils.py
··· 9 9 10 10 # Mapping from numpy dtype strings to schema dtype names 11 11 NUMPY_DTYPE_MAP = { 12 - "float16": "float16", "float32": "float32", "float64": "float64", 13 - "int8": "int8", "int16": "int16", "int32": "int32", "int64": "int64", 14 - "uint8": "uint8", "uint16": "uint16", "uint32": "uint32", "uint64": "uint64", 15 - "bool": "bool", "complex64": "complex64", "complex128": "complex128", 12 + "float16": "float16", 13 + "float32": "float32", 14 + "float64": "float64", 15 + "int8": "int8", 16 + "int16": "int16", 17 + "int32": "int32", 18 + "int64": "int64", 19 + "uint8": "uint8", 20 + "uint16": "uint16", 21 + "uint32": "uint32", 22 + "uint64": "uint64", 23 + "bool": "bool", 24 + "complex64": "complex64", 25 + "complex128": "complex128", 16 26 } 17 27 18 28 # Mapping from Python primitive types to schema type names 19 29 PRIMITIVE_TYPE_MAP = { 20 - str: "str", int: "int", float: "float", bool: "bool", bytes: "bytes", 30 + str: "str", 31 + int: "int", 32 + float: "float", 33 + bool: "bool", 34 + bytes: "bytes", 21 35 } 22 36 23 37
+1
src/atdata/atmosphere/__init__.py
··· 84 84 def metadata(self) -> Optional[dict]: 85 85 """Metadata from the record, if any.""" 86 86 import msgpack 87 + 87 88 metadata_bytes = self._record.get("metadata") 88 89 if metadata_bytes is None: 89 90 return None
+3 -1
src/atdata/atmosphere/_types.py
··· 56 56 57 57 parts = uri[5:].split("/") 58 58 if len(parts) < 3: 59 - raise ValueError(f"Invalid AT URI: expected authority/collection/rkey: {uri}") 59 + raise ValueError( 60 + f"Invalid AT URI: expected authority/collection/rkey: {uri}" 61 + ) 60 62 61 63 return cls( 62 64 authority=parts[0],
+6 -1
src/atdata/atmosphere/client.py
··· 18 18 if _atproto_client_class is None: 19 19 try: 20 20 from atproto import Client 21 + 21 22 _atproto_client_class = Client 22 23 except ImportError as e: 23 24 raise ImportError( ··· 325 326 # Convert to dict format suitable for embedding in records 326 327 return { 327 328 "$type": "blob", 328 - "ref": {"$link": blob_ref.ref.link if hasattr(blob_ref.ref, "link") else str(blob_ref.ref)}, 329 + "ref": { 330 + "$link": blob_ref.ref.link 331 + if hasattr(blob_ref.ref, "link") 332 + else str(blob_ref.ref) 333 + }, 329 334 "mimeType": blob_ref.mime_type, 330 335 "size": blob_ref.size, 331 336 }
+1
src/atdata/atmosphere/lens.py
··· 21 21 22 22 # Import for type checking only 23 23 from typing import TYPE_CHECKING 24 + 24 25 if TYPE_CHECKING: 25 26 from ..lens import Lens 26 27
+2 -2
src/atdata/atmosphere/records.py
··· 19 19 20 20 # Import for type checking only to avoid circular imports 21 21 from typing import TYPE_CHECKING 22 + 22 23 if TYPE_CHECKING: 23 24 from ..dataset import PackableSample, Dataset 24 25 ··· 394 395 return storage.get("blobs", []) 395 396 elif "storageExternal" in storage_type: 396 397 raise ValueError( 397 - "Dataset uses external URL storage, not blobs. " 398 - "Use get_urls() instead." 398 + "Dataset uses external URL storage, not blobs. Use get_urls() instead." 399 399 ) 400 400 else: 401 401 raise ValueError(f"Unknown storage type: {storage_type}")
+12 -4
src/atdata/atmosphere/schema.py
··· 17 17 LEXICON_NAMESPACE, 18 18 ) 19 19 from .._type_utils import ( 20 - numpy_dtype_to_string, 21 20 unwrap_optional, 22 21 is_ndarray_type, 23 22 extract_ndarray_dtype, ··· 25 24 26 25 # Import for type checking only to avoid circular imports 27 26 from typing import TYPE_CHECKING 27 + 28 28 if TYPE_CHECKING: 29 29 from ..dataset import PackableSample 30 30 ··· 88 88 TypeError: If a field type is not supported. 89 89 """ 90 90 if not is_dataclass(sample_type): 91 - raise ValueError(f"{sample_type.__name__} must be a dataclass (use @packable)") 91 + raise ValueError( 92 + f"{sample_type.__name__} must be a dataclass (use @packable)" 93 + ) 92 94 93 95 # Build the schema record 94 96 schema_record = self._build_schema_record( ··· 153 155 return FieldType(kind="primitive", primitive="bytes") 154 156 155 157 if is_ndarray_type(python_type): 156 - return FieldType(kind="ndarray", dtype=extract_ndarray_dtype(python_type), shape=None) 158 + return FieldType( 159 + kind="ndarray", dtype=extract_ndarray_dtype(python_type), shape=None 160 + ) 157 161 158 162 origin = get_origin(python_type) 159 163 if origin is list: 160 164 args = get_args(python_type) 161 - items = self._python_type_to_field_type(args[0]) if args else FieldType(kind="primitive", primitive="str") 165 + items = ( 166 + self._python_type_to_field_type(args[0]) 167 + if args 168 + else FieldType(kind="primitive", primitive="str") 169 + ) 162 170 return FieldType(kind="array", items=items) 163 171 164 172 if is_dataclass(python_type):
+1 -1
src/atdata/atmosphere/store.py
··· 20 20 21 21 from __future__ import annotations 22 22 23 - import io 24 23 import tempfile 25 24 from dataclasses import dataclass 26 25 from typing import TYPE_CHECKING, Any ··· 29 28 30 29 if TYPE_CHECKING: 31 30 from ..dataset import Dataset 31 + from .._sources import BlobSource 32 32 from .client import AtmosphereClient 33 33 34 34
+12 -3
src/atdata/cli/__init__.py
··· 42 42 formatter_class=argparse.RawDescriptionHelpFormatter, 43 43 ) 44 44 parser.add_argument( 45 - "--version", "-v", 45 + "--version", 46 + "-v", 46 47 action="store_true", 47 48 help="Show version information", 48 49 ) ··· 83 84 help="MinIO console port (default: 9001)", 84 85 ) 85 86 up_parser.add_argument( 86 - "--detach", "-d", 87 + "--detach", 88 + "-d", 87 89 action="store_true", 88 90 default=True, 89 91 help="Run containers in detached mode (default: True)", ··· 95 97 help="Stop local development containers", 96 98 ) 97 99 down_parser.add_argument( 98 - "--volumes", "-v", 100 + "--volumes", 101 + "-v", 99 102 action="store_true", 100 103 help="Also remove volumes (deletes all data)", 101 104 ) ··· 165 168 """Show version information.""" 166 169 try: 167 170 from atdata import __version__ 171 + 168 172 version = __version__ 169 173 except ImportError: 170 174 # Fallback to package metadata 171 175 from importlib.metadata import version as pkg_version 176 + 172 177 version = pkg_version("atdata") 173 178 174 179 print(f"atdata {version}") ··· 183 188 ) -> int: 184 189 """Start local development infrastructure.""" 185 190 from .local import local_up 191 + 186 192 return local_up( 187 193 redis_port=redis_port, 188 194 minio_port=minio_port, ··· 194 200 def _cmd_local_down(remove_volumes: bool) -> int: 195 201 """Stop local development infrastructure.""" 196 202 from .local import local_down 203 + 197 204 return local_down(remove_volumes=remove_volumes) 198 205 199 206 200 207 def _cmd_local_status() -> int: 201 208 """Show status of local infrastructure.""" 202 209 from .local import local_status 210 + 203 211 return local_status() 204 212 205 213 206 214 def _cmd_diagnose(host: str, port: int) -> int: 207 215 """Diagnose Redis configuration.""" 208 216 from .diagnose import diagnose_redis 217 + 209 218 return diagnose_redis(host=host, port=port) 210 219 211 220
+12 -8
src/atdata/cli/diagnose.py
··· 5 5 """ 6 6 7 7 import sys 8 - from typing import Any 9 8 10 9 11 10 def _print_status(label: str, ok: bool, detail: str = "") -> None: ··· 41 40 # Try to connect 42 41 try: 43 42 from redis import Redis 43 + 44 44 redis = Redis(host=host, port=port, socket_connect_timeout=5) 45 45 redis.ping() 46 46 _print_status("Connection", True, "connected") ··· 70 70 _print_status( 71 71 "AOF Persistence", 72 72 aof_ok, 73 - "enabled" if aof_ok else "DISABLED - data may be lost on restart!" 73 + "enabled" if aof_ok else "DISABLED - data may be lost on restart!", 74 74 ) 75 75 if not aof_ok: 76 76 issues_found = True ··· 85 85 _print_status( 86 86 "RDB Persistence", 87 87 rdb_ok, 88 - f"configured ({save_config})" if rdb_ok else "DISABLED" 88 + f"configured ({save_config})" if rdb_ok else "DISABLED", 89 89 ) 90 90 # RDB disabled is only a warning if AOF is enabled 91 91 except Exception as e: ··· 95 95 try: 96 96 policy = redis.config_get("maxmemory-policy").get("maxmemory-policy", "unknown") 97 97 # Safe policies that won't evict index data 98 - safe_policies = {"noeviction", "volatile-lru", "volatile-lfu", "volatile-ttl", "volatile-random"} 98 + safe_policies = { 99 + "noeviction", 100 + "volatile-lru", 101 + "volatile-lfu", 102 + "volatile-ttl", 103 + "volatile-random", 104 + } 99 105 policy_ok = policy in safe_policies 100 106 101 107 if policy_ok: ··· 104 110 _print_status( 105 111 "Memory Policy", 106 112 False, 107 - f"{policy} - may evict index data! Use 'noeviction' or 'volatile-*'" 113 + f"{policy} - may evict index data! Use 'noeviction' or 'volatile-*'", 108 114 ) 109 115 issues_found = True 110 116 except Exception as e: ··· 141 147 for key in redis.scan_iter(match="LocalSchema:*", count=100): 142 148 schema_count += 1 143 149 _print_status( 144 - "atdata Keys", 145 - True, 146 - f"{dataset_count} datasets, {schema_count} schemas" 150 + "atdata Keys", True, f"{dataset_count} datasets, {schema_count} schemas" 147 151 ) 148 152 except Exception as e: 149 153 _print_status("atdata Keys", False, f"check failed: {e}")
+4 -1
src/atdata/cli/local.py
··· 144 144 elif shutil.which("docker-compose"): 145 145 base_cmd = ["docker-compose"] 146 146 else: 147 - raise RuntimeError("Neither 'docker compose' nor 'docker-compose' available") 147 + raise RuntimeError( 148 + "Neither 'docker compose' nor 'docker-compose' available" 149 + ) 148 150 else: 149 151 raise RuntimeError("Docker not found") 150 152 ··· 195 197 196 198 # Wait a moment for containers to be healthy 197 199 import time 200 + 198 201 time.sleep(2) 199 202 200 203 # Show status
+191 -172
src/atdata/dataset.py
··· 41 41 ) 42 42 from abc import ABC 43 43 44 - from ._sources import URLSource, S3Source 44 + from ._sources import URLSource 45 45 from ._protocols import DataSource 46 46 47 47 from tqdm import tqdm ··· 65 65 TypeAlias, 66 66 dataclass_transform, 67 67 overload, 68 - Literal, 69 68 ) 70 69 from numpy.typing import NDArray 71 70 ··· 85 84 WDSRawBatch: TypeAlias = Dict[str, Any] 86 85 87 86 SampleExportRow: TypeAlias = Dict[str, Any] 88 - SampleExportMap: TypeAlias = Callable[['PackableSample'], SampleExportRow] 87 + SampleExportMap: TypeAlias = Callable[["PackableSample"], SampleExportRow] 89 88 90 89 91 90 ## 92 91 # Main base classes 93 92 94 - DT = TypeVar( 'DT' ) 93 + DT = TypeVar("DT") 95 94 96 95 97 - def _make_packable( x ): 96 + def _make_packable(x): 98 97 """Convert numpy arrays to bytes; pass through other values unchanged.""" 99 - if isinstance( x, np.ndarray ): 100 - return eh.array_to_bytes( x ) 98 + if isinstance(x, np.ndarray): 99 + return eh.array_to_bytes(x) 101 100 return x 102 101 103 102 104 - def _is_possibly_ndarray_type( t ): 103 + def _is_possibly_ndarray_type(t): 105 104 """Return True if type annotation is NDArray or Optional[NDArray].""" 106 105 if t == NDArray: 107 106 return True 108 - if isinstance( t, types.UnionType ): 109 - return any( x == NDArray for x in t.__args__ ) 107 + if isinstance(t, types.UnionType): 108 + return any(x == NDArray for x in t.__args__) 110 109 return False 110 + 111 111 112 112 class DictSample: 113 113 """Dynamic sample type providing dict-like access to raw msgpack data. ··· 141 141 converted to numpy arrays when accessed through a typed sample class. 142 142 """ 143 143 144 - __slots__ = ('_data',) 144 + __slots__ = ("_data",) 145 145 146 146 def __init__(self, _data: dict[str, Any] | None = None, **kwargs: Any) -> None: 147 147 """Create a DictSample from a dictionary or keyword arguments. ··· 151 151 **kwargs: Field values if _data is not provided. 152 152 """ 153 153 if _data is not None: 154 - object.__setattr__(self, '_data', _data) 154 + object.__setattr__(self, "_data", _data) 155 155 else: 156 - object.__setattr__(self, '_data', kwargs) 156 + object.__setattr__(self, "_data", kwargs) 157 157 158 158 @classmethod 159 - def from_data(cls, data: dict[str, Any]) -> 'DictSample': 159 + def from_data(cls, data: dict[str, Any]) -> "DictSample": 160 160 """Create a DictSample from unpacked msgpack data. 161 161 162 162 Args: ··· 168 168 return cls(_data=data) 169 169 170 170 @classmethod 171 - def from_bytes(cls, bs: bytes) -> 'DictSample': 171 + def from_bytes(cls, bs: bytes) -> "DictSample": 172 172 """Create a DictSample from raw msgpack bytes. 173 173 174 174 Args: ··· 192 192 AttributeError: If the field doesn't exist. 193 193 """ 194 194 # Avoid infinite recursion for _data lookup 195 - if name == '_data': 195 + if name == "_data": 196 196 raise AttributeError(name) 197 197 try: 198 198 return self._data[name] ··· 258 258 return msgpack.packb(self._data) 259 259 260 260 @property 261 - def as_wds(self) -> 'WDSRawSample': 261 + def as_wds(self) -> "WDSRawSample": 262 262 """Pack this sample's data for writing to WebDataset. 263 263 264 264 Returns: 265 265 A dictionary with ``__key__`` and ``msgpack`` fields. 266 266 """ 267 267 return { 268 - '__key__': str(uuid.uuid1(0, 0)), 269 - 'msgpack': self.packed, 268 + "__key__": str(uuid.uuid1(0, 0)), 269 + "msgpack": self.packed, 270 270 } 271 271 272 272 def __repr__(self) -> str: 273 - fields = ', '.join(f'{k}=...' for k in self._data.keys()) 274 - return f'DictSample({fields})' 273 + fields = ", ".join(f"{k}=..." for k in self._data.keys()) 274 + return f"DictSample({fields})" 275 275 276 276 277 277 @dataclass 278 - class PackableSample( ABC ): 278 + class PackableSample(ABC): 279 279 """Base class for samples that can be serialized with msgpack. 280 280 281 281 This abstract base class provides automatic serialization/deserialization ··· 298 298 >>> restored = MyData.from_bytes(packed) # Deserialize 299 299 """ 300 300 301 - def _ensure_good( self ): 301 + def _ensure_good(self): 302 302 """Convert bytes to NDArray for fields annotated as NDArray or NDArray | None.""" 303 303 304 304 # Auto-convert known types when annotated 305 305 # for var_name, var_type in vars( self.__class__ )['__annotations__'].items(): 306 - for field in dataclasses.fields( self ): 306 + for field in dataclasses.fields(self): 307 307 var_name = field.name 308 308 var_type = field.type 309 309 310 310 # Annotation for this variable is to be an NDArray 311 - if _is_possibly_ndarray_type( var_type ): 311 + if _is_possibly_ndarray_type(var_type): 312 312 # ... so, we'll always auto-convert to numpy 313 313 314 - var_cur_value = getattr( self, var_name ) 314 + var_cur_value = getattr(self, var_name) 315 315 316 316 # Execute the appropriate conversion for intermediate data 317 317 # based on what is provided 318 318 319 - if isinstance( var_cur_value, np.ndarray ): 319 + if isinstance(var_cur_value, np.ndarray): 320 320 # Already the correct type, no conversion needed 321 321 continue 322 322 323 - elif isinstance( var_cur_value, bytes ): 323 + elif isinstance(var_cur_value, bytes): 324 324 # Design note: bytes in NDArray-typed fields are always interpreted 325 325 # as serialized arrays. This means raw bytes fields must not be 326 326 # annotated as NDArray. 327 - setattr( self, var_name, eh.bytes_to_array( var_cur_value ) ) 327 + setattr(self, var_name, eh.bytes_to_array(var_cur_value)) 328 328 329 - def __post_init__( self ): 329 + def __post_init__(self): 330 330 self._ensure_good() 331 331 332 332 ## 333 333 334 334 @classmethod 335 - def from_data( cls, data: WDSRawSample ) -> Self: 335 + def from_data(cls, data: WDSRawSample) -> Self: 336 336 """Create a sample instance from unpacked msgpack data. 337 337 338 338 Args: ··· 341 341 Returns: 342 342 New instance with NDArray fields auto-converted from bytes. 343 343 """ 344 - return cls( **data ) 345 - 344 + return cls(**data) 345 + 346 346 @classmethod 347 - def from_bytes( cls, bs: bytes ) -> Self: 347 + def from_bytes(cls, bs: bytes) -> Self: 348 348 """Create a sample instance from raw msgpack bytes. 349 349 350 350 Args: ··· 353 353 Returns: 354 354 A new instance of this sample class deserialized from the bytes. 355 355 """ 356 - return cls.from_data( ormsgpack.unpackb( bs ) ) 356 + return cls.from_data(ormsgpack.unpackb(bs)) 357 357 358 358 @property 359 - def packed( self ) -> bytes: 359 + def packed(self) -> bytes: 360 360 """Pack this sample's data into msgpack bytes. 361 361 362 362 NDArray fields are automatically converted to bytes before packing. ··· 371 371 372 372 # Make sure that all of our (possibly unpackable) data is in a packable 373 373 # format 374 - o = { 375 - k: _make_packable( v ) 376 - for k, v in vars( self ).items() 377 - } 374 + o = {k: _make_packable(v) for k, v in vars(self).items()} 378 375 379 - ret = msgpack.packb( o ) 376 + ret = msgpack.packb(o) 380 377 381 378 if ret is None: 382 - raise RuntimeError( f'Failed to pack sample to bytes: {o}' ) 379 + raise RuntimeError(f"Failed to pack sample to bytes: {o}") 383 380 384 381 return ret 385 - 382 + 386 383 @property 387 - def as_wds( self ) -> WDSRawSample: 384 + def as_wds(self) -> WDSRawSample: 388 385 """Pack this sample's data for writing to WebDataset. 389 386 390 387 Returns: ··· 397 394 """ 398 395 return { 399 396 # Generates a UUID that is timelike-sortable 400 - '__key__': str( uuid.uuid1( 0, 0 ) ), 401 - 'msgpack': self.packed, 397 + "__key__": str(uuid.uuid1(0, 0)), 398 + "msgpack": self.packed, 402 399 } 403 400 404 - def _batch_aggregate( xs: Sequence ): 401 + 402 + def _batch_aggregate(xs: Sequence): 405 403 """Stack arrays into numpy array with batch dim; otherwise return list.""" 406 404 if not xs: 407 405 return [] 408 - if isinstance( xs[0], np.ndarray ): 409 - return np.array( list( xs ) ) 410 - return list( xs ) 406 + if isinstance(xs[0], np.ndarray): 407 + return np.array(list(xs)) 408 + return list(xs) 409 + 411 410 412 - class SampleBatch( Generic[DT] ): 411 + class SampleBatch(Generic[DT]): 413 412 """A batch of samples with automatic attribute aggregation. 414 413 415 414 This class wraps a sequence of samples and provides magic ``__getattr__`` ··· 437 436 subscripted syntax ``SampleBatch[MyType](samples)`` rather than 438 437 calling the constructor directly with an unsubscripted class. 439 438 """ 439 + 440 440 # Design note: The docstring uses "Parameters:" for type parameters because 441 441 # quartodoc doesn't yet support "Type Parameters:" sections in generated docs. 442 442 443 - def __init__( self, samples: Sequence[DT] ): 443 + def __init__(self, samples: Sequence[DT]): 444 444 """Create a batch from a sequence of samples. 445 445 446 446 Args: ··· 448 448 Each sample must be an instance of a type derived from 449 449 ``PackableSample``. 450 450 """ 451 - self.samples = list( samples ) 451 + self.samples = list(samples) 452 452 self._aggregate_cache = dict() 453 453 self._sample_type_cache: Type | None = None 454 454 455 455 @property 456 - def sample_type( self ) -> Type: 456 + def sample_type(self) -> Type: 457 457 """The type of each sample in this batch. 458 458 459 459 Returns: 460 460 The type parameter ``DT`` used when creating this ``SampleBatch[DT]``. 461 461 """ 462 462 if self._sample_type_cache is None: 463 - self._sample_type_cache = typing.get_args( self.__orig_class__)[0] 463 + self._sample_type_cache = typing.get_args(self.__orig_class__)[0] 464 464 assert self._sample_type_cache is not None 465 465 return self._sample_type_cache 466 466 467 - def __getattr__( self, name ): 467 + def __getattr__(self, name): 468 468 """Aggregate an attribute across all samples in the batch. 469 469 470 470 This magic method enables attribute-style access to aggregated sample ··· 481 481 AttributeError: If the attribute doesn't exist on the sample type. 482 482 """ 483 483 # Aggregate named params of sample type 484 - if name in vars( self.sample_type )['__annotations__']: 484 + if name in vars(self.sample_type)["__annotations__"]: 485 485 if name not in self._aggregate_cache: 486 486 self._aggregate_cache[name] = _batch_aggregate( 487 - [ getattr( x, name ) 488 - for x in self.samples ] 487 + [getattr(x, name) for x in self.samples] 489 488 ) 490 489 491 490 return self._aggregate_cache[name] 492 491 493 - raise AttributeError( f'No sample attribute named {name}' ) 492 + raise AttributeError(f"No sample attribute named {name}") 494 493 495 494 496 - ST = TypeVar( 'ST', bound = PackableSample ) 497 - RT = TypeVar( 'RT', bound = PackableSample ) 495 + ST = TypeVar("ST", bound=PackableSample) 496 + RT = TypeVar("RT", bound=PackableSample) 498 497 499 498 500 499 class _ShardListStage(wds.utils.PipelineStage): ··· 532 531 yield sample 533 532 534 533 535 - class Dataset( Generic[ST] ): 534 + class Dataset(Generic[ST]): 536 535 """A typed dataset built on WebDataset with lens transformations. 537 536 538 537 This class wraps WebDataset tar archives and provides type-safe iteration ··· 566 565 subscripted syntax ``Dataset[MyType](url)`` rather than calling the 567 566 constructor directly with an unsubscripted class. 568 567 """ 568 + 569 569 # Design note: The docstring uses "Parameters:" for type parameters because 570 570 # quartodoc doesn't yet support "Type Parameters:" sections in generated docs. 571 571 572 572 @property 573 - def sample_type( self ) -> Type: 573 + def sample_type(self) -> Type: 574 574 """The type of each returned sample from this dataset's iterator. 575 575 576 576 Returns: 577 577 The type parameter ``ST`` used when creating this ``Dataset[ST]``. 578 578 """ 579 579 if self._sample_type_cache is None: 580 - self._sample_type_cache = typing.get_args( self.__orig_class__ )[0] 580 + self._sample_type_cache = typing.get_args(self.__orig_class__)[0] 581 581 assert self._sample_type_cache is not None 582 582 return self._sample_type_cache 583 + 583 584 @property 584 - def batch_type( self ) -> Type: 585 + def batch_type(self) -> Type: 585 586 """The type of batches produced by this dataset. 586 587 587 588 Returns: ··· 589 590 """ 590 591 return SampleBatch[self.sample_type] 591 592 592 - def __init__( self, 593 - source: DataSource | str | None = None, 594 - metadata_url: str | None = None, 595 - *, 596 - url: str | None = None, 597 - ) -> None: 593 + def __init__( 594 + self, 595 + source: DataSource | str | None = None, 596 + metadata_url: str | None = None, 597 + *, 598 + url: str | None = None, 599 + ) -> None: 598 600 """Create a dataset from a DataSource or URL. 599 601 600 602 Args: ··· 642 644 """The underlying data source for this dataset.""" 643 645 return self._source 644 646 645 - def as_type( self, other: Type[RT] ) -> 'Dataset[RT]': 647 + def as_type(self, other: Type[RT]) -> "Dataset[RT]": 646 648 """View this dataset through a different sample type using a registered lens. 647 649 648 650 Args: ··· 658 660 ValueError: If no registered lens exists between the current 659 661 sample type and the target type. 660 662 """ 661 - ret = Dataset[other]( self._source ) 663 + ret = Dataset[other](self._source) 662 664 # Get the singleton lens registry 663 665 lenses = LensNetwork() 664 - ret._output_lens = lenses.transform( self.sample_type, ret.sample_type ) 666 + ret._output_lens = lenses.transform(self.sample_type, ret.sample_type) 665 667 return ret 666 668 667 669 @property ··· 695 697 Use :meth:`list_shards` instead. 696 698 """ 697 699 import warnings 700 + 698 701 warnings.warn( 699 702 "shard_list is deprecated, use list_shards() instead", 700 703 DeprecationWarning, ··· 703 706 return self.list_shards() 704 707 705 708 @property 706 - def metadata( self ) -> dict[str, Any] | None: 709 + def metadata(self) -> dict[str, Any] | None: 707 710 """Fetch and cache metadata from metadata_url. 708 711 709 712 Returns: ··· 716 719 return None 717 720 718 721 if self._metadata is None: 719 - with requests.get( self.metadata_url, stream = True ) as response: 722 + with requests.get(self.metadata_url, stream=True) as response: 720 723 response.raise_for_status() 721 - self._metadata = msgpack.unpackb( response.content, raw = False ) 722 - 724 + self._metadata = msgpack.unpackb(response.content, raw=False) 725 + 723 726 # Use our cached values 724 727 return self._metadata 725 - 728 + 726 729 @overload 727 - def ordered( self, 728 - batch_size: None = None, 729 - ) -> Iterable[ST]: ... 730 + def ordered( 731 + self, 732 + batch_size: None = None, 733 + ) -> Iterable[ST]: ... 730 734 731 735 @overload 732 - def ordered( self, 733 - batch_size: int, 734 - ) -> Iterable[SampleBatch[ST]]: ... 736 + def ordered( 737 + self, 738 + batch_size: int, 739 + ) -> Iterable[SampleBatch[ST]]: ... 735 740 736 - def ordered( self, 737 - batch_size: int | None = None, 738 - ) -> Iterable[ST] | Iterable[SampleBatch[ST]]: 741 + def ordered( 742 + self, 743 + batch_size: int | None = None, 744 + ) -> Iterable[ST] | Iterable[SampleBatch[ST]]: 739 745 """Iterate over the dataset in order. 740 746 741 747 Args: ··· 762 768 _StreamOpenerStage(self._source), 763 769 wds.tariterators.tar_file_expander, 764 770 wds.tariterators.group_by_keys, 765 - wds.filters.map( self.wrap ), 771 + wds.filters.map(self.wrap), 766 772 ) 767 773 768 774 return wds.pipeline.DataPipeline( ··· 771 777 _StreamOpenerStage(self._source), 772 778 wds.tariterators.tar_file_expander, 773 779 wds.tariterators.group_by_keys, 774 - wds.filters.batched( batch_size ), 775 - wds.filters.map( self.wrap_batch ), 780 + wds.filters.batched(batch_size), 781 + wds.filters.map(self.wrap_batch), 776 782 ) 777 783 778 784 @overload 779 - def shuffled( self, 780 - buffer_shards: int = 100, 781 - buffer_samples: int = 10_000, 782 - batch_size: None = None, 783 - ) -> Iterable[ST]: ... 785 + def shuffled( 786 + self, 787 + buffer_shards: int = 100, 788 + buffer_samples: int = 10_000, 789 + batch_size: None = None, 790 + ) -> Iterable[ST]: ... 784 791 785 792 @overload 786 - def shuffled( self, 787 - buffer_shards: int = 100, 788 - buffer_samples: int = 10_000, 789 - *, 790 - batch_size: int, 791 - ) -> Iterable[SampleBatch[ST]]: ... 793 + def shuffled( 794 + self, 795 + buffer_shards: int = 100, 796 + buffer_samples: int = 10_000, 797 + *, 798 + batch_size: int, 799 + ) -> Iterable[SampleBatch[ST]]: ... 792 800 793 - def shuffled( self, 794 - buffer_shards: int = 100, 795 - buffer_samples: int = 10_000, 796 - batch_size: int | None = None, 797 - ) -> Iterable[ST] | Iterable[SampleBatch[ST]]: 801 + def shuffled( 802 + self, 803 + buffer_shards: int = 100, 804 + buffer_samples: int = 10_000, 805 + batch_size: int | None = None, 806 + ) -> Iterable[ST] | Iterable[SampleBatch[ST]]: 798 807 """Iterate over the dataset in random order. 799 808 800 809 Args: ··· 823 832 if batch_size is None: 824 833 return wds.pipeline.DataPipeline( 825 834 _ShardListStage(self._source), 826 - wds.filters.shuffle( buffer_shards ), 835 + wds.filters.shuffle(buffer_shards), 827 836 wds.shardlists.split_by_worker, 828 837 _StreamOpenerStage(self._source), 829 838 wds.tariterators.tar_file_expander, 830 839 wds.tariterators.group_by_keys, 831 - wds.filters.shuffle( buffer_samples ), 832 - wds.filters.map( self.wrap ), 840 + wds.filters.shuffle(buffer_samples), 841 + wds.filters.map(self.wrap), 833 842 ) 834 843 835 844 return wds.pipeline.DataPipeline( 836 845 _ShardListStage(self._source), 837 - wds.filters.shuffle( buffer_shards ), 846 + wds.filters.shuffle(buffer_shards), 838 847 wds.shardlists.split_by_worker, 839 848 _StreamOpenerStage(self._source), 840 849 wds.tariterators.tar_file_expander, 841 850 wds.tariterators.group_by_keys, 842 - wds.filters.shuffle( buffer_samples ), 843 - wds.filters.batched( batch_size ), 844 - wds.filters.map( self.wrap_batch ), 851 + wds.filters.shuffle(buffer_samples), 852 + wds.filters.batched(batch_size), 853 + wds.filters.map(self.wrap_batch), 845 854 ) 846 - 855 + 847 856 # Design note: Uses pandas for parquet export. Could be replaced with 848 857 # direct fastparquet calls to reduce dependencies if needed. 849 - def to_parquet( self, path: Pathlike, 850 - sample_map: Optional[SampleExportMap] = None, 851 - maxcount: Optional[int] = None, 852 - **kwargs, 853 - ): 858 + def to_parquet( 859 + self, 860 + path: Pathlike, 861 + sample_map: Optional[SampleExportMap] = None, 862 + maxcount: Optional[int] = None, 863 + **kwargs, 864 + ): 854 865 """Export dataset contents to parquet format. 855 866 856 867 Converts all samples to a pandas DataFrame and saves to parquet file(s). ··· 890 901 ## 891 902 892 903 # Normalize args 893 - path = Path( path ) 904 + path = Path(path) 894 905 if sample_map is None: 895 906 sample_map = asdict 896 - 897 - verbose = kwargs.get( 'verbose', False ) 898 907 899 - it = self.ordered( batch_size = None ) 908 + verbose = kwargs.get("verbose", False) 909 + 910 + it = self.ordered(batch_size=None) 900 911 if verbose: 901 - it = tqdm( it ) 912 + it = tqdm(it) 902 913 903 914 # 904 915 905 916 if maxcount is None: 906 917 # Load and save full dataset 907 - df = pd.DataFrame( [ sample_map( x ) 908 - for x in self.ordered( batch_size = None ) ] ) 909 - df.to_parquet( path, **kwargs ) 910 - 918 + df = pd.DataFrame([sample_map(x) for x in self.ordered(batch_size=None)]) 919 + df.to_parquet(path, **kwargs) 920 + 911 921 else: 912 922 # Load and save dataset in segments of size `maxcount` 913 923 914 924 cur_segment = 0 915 925 cur_buffer = [] 916 - path_template = (path.parent / f'{path.stem}-{{:06d}}{path.suffix}').as_posix() 926 + path_template = ( 927 + path.parent / f"{path.stem}-{{:06d}}{path.suffix}" 928 + ).as_posix() 917 929 918 - for x in self.ordered( batch_size = None ): 919 - cur_buffer.append( sample_map( x ) ) 930 + for x in self.ordered(batch_size=None): 931 + cur_buffer.append(sample_map(x)) 920 932 921 - if len( cur_buffer ) >= maxcount: 933 + if len(cur_buffer) >= maxcount: 922 934 # Write current segment 923 - cur_path = path_template.format( cur_segment ) 924 - df = pd.DataFrame( cur_buffer ) 925 - df.to_parquet( cur_path, **kwargs ) 935 + cur_path = path_template.format(cur_segment) 936 + df = pd.DataFrame(cur_buffer) 937 + df.to_parquet(cur_path, **kwargs) 926 938 927 939 cur_segment += 1 928 940 cur_buffer = [] 929 - 930 - if len( cur_buffer ) > 0: 941 + 942 + if len(cur_buffer) > 0: 931 943 # Write one last segment with remainder 932 - cur_path = path_template.format( cur_segment ) 933 - df = pd.DataFrame( cur_buffer ) 934 - df.to_parquet( cur_path, **kwargs ) 944 + cur_path = path_template.format(cur_segment) 945 + df = pd.DataFrame(cur_buffer) 946 + df.to_parquet(cur_path, **kwargs) 935 947 936 - def wrap( self, sample: WDSRawSample ) -> ST: 948 + def wrap(self, sample: WDSRawSample) -> ST: 937 949 """Wrap a raw msgpack sample into the appropriate dataset-specific type. 938 950 939 951 Args: ··· 944 956 A deserialized sample of type ``ST``, optionally transformed through 945 957 a lens if ``as_type()`` was called. 946 958 """ 947 - if 'msgpack' not in sample: 948 - raise ValueError(f"Sample missing 'msgpack' key, got keys: {list(sample.keys())}") 949 - if not isinstance(sample['msgpack'], bytes): 950 - raise ValueError(f"Expected sample['msgpack'] to be bytes, got {type(sample['msgpack']).__name__}") 959 + if "msgpack" not in sample: 960 + raise ValueError( 961 + f"Sample missing 'msgpack' key, got keys: {list(sample.keys())}" 962 + ) 963 + if not isinstance(sample["msgpack"], bytes): 964 + raise ValueError( 965 + f"Expected sample['msgpack'] to be bytes, got {type(sample['msgpack']).__name__}" 966 + ) 951 967 952 968 if self._output_lens is None: 953 - return self.sample_type.from_bytes( sample['msgpack'] ) 969 + return self.sample_type.from_bytes(sample["msgpack"]) 954 970 955 - source_sample = self._output_lens.source_type.from_bytes( sample['msgpack'] ) 956 - return self._output_lens( source_sample ) 971 + source_sample = self._output_lens.source_type.from_bytes(sample["msgpack"]) 972 + return self._output_lens(source_sample) 957 973 958 - def wrap_batch( self, batch: WDSRawBatch ) -> SampleBatch[ST]: 974 + def wrap_batch(self, batch: WDSRawBatch) -> SampleBatch[ST]: 959 975 """Wrap a batch of raw msgpack samples into a typed SampleBatch. 960 976 961 977 Args: ··· 971 987 aggregates them into a batch. 972 988 """ 973 989 974 - if 'msgpack' not in batch: 975 - raise ValueError(f"Batch missing 'msgpack' key, got keys: {list(batch.keys())}") 990 + if "msgpack" not in batch: 991 + raise ValueError( 992 + f"Batch missing 'msgpack' key, got keys: {list(batch.keys())}" 993 + ) 976 994 977 995 if self._output_lens is None: 978 - batch_unpacked = [ self.sample_type.from_bytes( bs ) 979 - for bs in batch['msgpack'] ] 980 - return SampleBatch[self.sample_type]( batch_unpacked ) 996 + batch_unpacked = [ 997 + self.sample_type.from_bytes(bs) for bs in batch["msgpack"] 998 + ] 999 + return SampleBatch[self.sample_type](batch_unpacked) 981 1000 982 - batch_source = [ self._output_lens.source_type.from_bytes( bs ) 983 - for bs in batch['msgpack'] ] 984 - batch_view = [ self._output_lens( s ) 985 - for s in batch_source ] 986 - return SampleBatch[self.sample_type]( batch_view ) 1001 + batch_source = [ 1002 + self._output_lens.source_type.from_bytes(bs) for bs in batch["msgpack"] 1003 + ] 1004 + batch_view = [self._output_lens(s) for s in batch_source] 1005 + return SampleBatch[self.sample_type](batch_view) 987 1006 988 1007 989 - _T = TypeVar('_T') 1008 + _T = TypeVar("_T") 990 1009 991 1010 992 1011 @dataclass_transform() 993 - def packable( cls: type[_T] ) -> type[_T]: 1012 + def packable(cls: type[_T]) -> type[_T]: 994 1013 """Decorator to convert a regular class into a ``PackableSample``. 995 1014 996 1015 This decorator transforms a class into a dataclass that inherits from ··· 1029 1048 class_annotations = cls.__annotations__ 1030 1049 1031 1050 # Add in dataclass niceness to original class 1032 - as_dataclass = dataclass( cls ) 1051 + as_dataclass = dataclass(cls) 1033 1052 1034 1053 # This triggers a bunch of behind-the-scenes stuff for the newly annotated class 1035 1054 @dataclass 1036 - class as_packable( as_dataclass, PackableSample ): 1037 - def __post_init__( self ): 1038 - return PackableSample.__post_init__( self ) 1039 - 1055 + class as_packable(as_dataclass, PackableSample): 1056 + def __post_init__(self): 1057 + return PackableSample.__post_init__(self) 1058 + 1040 1059 # Restore original class identity for better repr/debugging 1041 1060 as_packable.__name__ = class_name 1042 1061 as_packable.__qualname__ = class_name ··· 1047 1066 1048 1067 # Fix qualnames of dataclass-generated methods so they don't show 1049 1068 # 'packable.<locals>.as_packable' in help() and IDE hints 1050 - old_qualname_prefix = 'packable.<locals>.as_packable' 1051 - for attr_name in ('__init__', '__repr__', '__eq__', '__post_init__'): 1069 + old_qualname_prefix = "packable.<locals>.as_packable" 1070 + for attr_name in ("__init__", "__repr__", "__eq__", "__post_init__"): 1052 1071 attr = getattr(as_packable, attr_name, None) 1053 - if attr is not None and hasattr(attr, '__qualname__'): 1072 + if attr is not None and hasattr(attr, "__qualname__"): 1054 1073 if attr.__qualname__.startswith(old_qualname_prefix): 1055 1074 attr.__qualname__ = attr.__qualname__.replace( 1056 1075 old_qualname_prefix, class_name, 1 ··· 1066 1085 1067 1086 ## 1068 1087 1069 - return as_packable 1088 + return as_packable
+35 -32
src/atdata/lens.py
··· 54 54 Optional, 55 55 Generic, 56 56 # 57 - TYPE_CHECKING 57 + TYPE_CHECKING, 58 58 ) 59 59 60 60 if TYPE_CHECKING: ··· 66 66 ## 67 67 # Typing helpers 68 68 69 - DatasetType: TypeAlias = Type['PackableSample'] 69 + DatasetType: TypeAlias = Type["PackableSample"] 70 70 LensSignature: TypeAlias = Tuple[DatasetType, DatasetType] 71 71 72 - S = TypeVar( 'S', bound = Packable ) 73 - V = TypeVar( 'V', bound = Packable ) 72 + S = TypeVar("S", bound=Packable) 73 + V = TypeVar("V", bound=Packable) 74 74 type LensGetter[S, V] = Callable[[S], V] 75 75 type LensPutter[S, V] = Callable[[V, S], S] 76 76 ··· 78 78 ## 79 79 # Shortcut decorators 80 80 81 - class Lens( Generic[S, V] ): 81 + 82 + class Lens(Generic[S, V]): 82 83 """A bidirectional transformation between two sample types. 83 84 84 85 A lens provides a way to view and update data of type ``S`` (source) as if ··· 99 100 ... def name_lens_put(view: NameOnly, source: FullData) -> FullData: 100 101 ... return FullData(name=view.name, age=source.age) 101 102 """ 103 + 102 104 # TODO The above has a line for "Parameters:" that should be "Type Parameters:"; this is a temporary fix for `quartodoc` auto-generation bugs. 103 105 104 - def __init__( self, get: LensGetter[S, V], 105 - put: Optional[LensPutter[S, V]] = None 106 - ) -> None: 106 + def __init__( 107 + self, get: LensGetter[S, V], put: Optional[LensPutter[S, V]] = None 108 + ) -> None: 107 109 """Initialize a lens with a getter and optional putter function. 108 110 109 111 Args: ··· 122 124 123 125 # Check argument validity 124 126 125 - sig = inspect.signature( get ) 126 - input_types = list( sig.parameters.values() ) 127 + sig = inspect.signature(get) 128 + input_types = list(sig.parameters.values()) 127 129 if len(input_types) != 1: 128 130 raise ValueError( 129 131 f"Lens getter must have exactly one parameter, got {len(input_types)}: " ··· 131 133 ) 132 134 133 135 # Update function details for this object as returned by annotation 134 - functools.update_wrapper( self, get ) 136 + functools.update_wrapper(self, get) 135 137 136 138 self.source_type: Type[Packable] = input_types[0].annotation 137 139 self.view_type: Type[Packable] = sig.return_annotation ··· 142 144 # Determine and store the putter 143 145 if put is None: 144 146 # Trivial putter does not update the source 145 - def _trivial_put( v: V, s: S ) -> S: 147 + def _trivial_put(v: V, s: S) -> S: 146 148 return s 149 + 147 150 put = _trivial_put 148 151 self._putter = put 149 - 152 + 150 153 # 151 154 152 - def putter( self, put: LensPutter[S, V] ) -> LensPutter[S, V]: 155 + def putter(self, put: LensPutter[S, V]) -> LensPutter[S, V]: 153 156 """Decorator to register a putter function for this lens. 154 157 155 158 Args: ··· 167 170 ## 168 171 self._putter = put 169 172 return put 170 - 173 + 171 174 # Methods to actually execute transformations 172 175 173 - def put( self, v: V, s: S ) -> S: 176 + def put(self, v: V, s: S) -> S: 174 177 """Update the source based on a modified view. 175 178 176 179 Args: ··· 180 183 Returns: 181 184 An updated source of type ``S`` that reflects changes from the view. 182 185 """ 183 - return self._putter( v, s ) 186 + return self._putter(v, s) 184 187 185 - def get( self, s: S ) -> V: 188 + def get(self, s: S) -> V: 186 189 """Transform the source into the view type. 187 190 188 191 Args: ··· 191 194 Returns: 192 195 A view of the source as type ``V``. 193 196 """ 194 - return self( s ) 197 + return self(s) 195 198 196 - def __call__( self, s: S ) -> V: 199 + def __call__(self, s: S) -> V: 197 200 """Apply the lens transformation (same as ``get()``).""" 198 - return self._getter( s ) 201 + return self._getter(s) 199 202 200 203 201 - def lens( f: LensGetter[S, V] ) -> Lens[S, V]: 204 + def lens(f: LensGetter[S, V]) -> Lens[S, V]: 202 205 """Decorator to create and register a lens transformation. 203 206 204 207 This decorator converts a getter function into a ``Lens`` object and ··· 221 224 ... def extract_name_put(view: NameOnly, source: FullData) -> FullData: 222 225 ... return FullData(name=view.name, age=source.age) 223 226 """ 224 - ret = Lens[S, V]( f ) 225 - _network.register( ret ) 227 + ret = Lens[S, V](f) 228 + _network.register(ret) 226 229 return ret 227 230 228 231 ··· 251 254 252 255 def __init__(self): 253 256 """Initialize the lens registry (only on first instantiation).""" 254 - if not hasattr(self, '_initialized'): # Check if already initialized 257 + if not hasattr(self, "_initialized"): # Check if already initialized 255 258 self._registry: Dict[LensSignature, Lens] = dict() 256 259 self._initialized = True 257 - 258 - def register( self, _lens: Lens ): 260 + 261 + def register(self, _lens: Lens): 259 262 """Register a lens as the canonical transformation between two types. 260 263 261 264 Args: ··· 267 270 overwritten. 268 271 """ 269 272 self._registry[_lens.source_type, _lens.view_type] = _lens 270 - 271 - def transform( self, source: DatasetType, view: DatasetType ) -> Lens: 273 + 274 + def transform(self, source: DatasetType, view: DatasetType) -> Lens: 272 275 """Look up the lens transformation between two sample types. 273 276 274 277 Args: ··· 285 288 Currently only supports direct transformations. Compositional 286 289 transformations (chaining multiple lenses) are not yet implemented. 287 290 """ 288 - ret = self._registry.get( (source, view), None ) 291 + ret = self._registry.get((source, view), None) 289 292 if ret is None: 290 - raise ValueError( f'No registered lens from source {source} to view {view}' ) 293 + raise ValueError(f"No registered lens from source {source} to view {view}") 291 294 292 295 return ret 293 296 294 297 295 298 # Global singleton registry instance 296 - _network = LensNetwork() 299 + _network = LensNetwork()
+151 -128
src/atdata/local.py
··· 24 24 ) 25 25 from atdata._cid import generate_cid 26 26 from atdata._type_utils import ( 27 - numpy_dtype_to_string, 28 27 PRIMITIVE_TYPE_MAP, 29 28 unwrap_optional, 30 29 is_ndarray_type, 31 30 extract_ndarray_dtype, 32 31 ) 33 - from atdata._protocols import IndexEntry, AbstractDataStore, Packable 32 + from atdata._protocols import AbstractDataStore, Packable 34 33 35 34 from pathlib import Path 36 35 from uuid import uuid4 ··· 57 56 Generator, 58 57 Iterator, 59 58 BinaryIO, 60 - Union, 61 59 Optional, 62 60 Literal, 63 61 cast, ··· 70 68 import json 71 69 import warnings 72 70 73 - T = TypeVar( 'T', bound = PackableSample ) 71 + T = TypeVar("T", bound=PackableSample) 74 72 75 73 # Redis key prefixes for index entries and schemas 76 74 REDIS_KEY_DATASET_ENTRY = "LocalDatasetEntry" ··· 355 353 ## 356 354 # Helpers 357 355 358 - def _kind_str_for_sample_type( st: Type[Packable] ) -> str: 356 + 357 + def _kind_str_for_sample_type(st: Type[Packable]) -> str: 359 358 """Return fully-qualified 'module.name' string for a sample type.""" 360 - return f'{st.__module__}.{st.__name__}' 359 + return f"{st.__module__}.{st.__name__}" 361 360 362 361 363 362 def _create_s3_write_callbacks( ··· 385 384 import boto3 386 385 387 386 s3_client_kwargs = { 388 - 'aws_access_key_id': credentials['AWS_ACCESS_KEY_ID'], 389 - 'aws_secret_access_key': credentials['AWS_SECRET_ACCESS_KEY'] 387 + "aws_access_key_id": credentials["AWS_ACCESS_KEY_ID"], 388 + "aws_secret_access_key": credentials["AWS_SECRET_ACCESS_KEY"], 390 389 } 391 - if 'AWS_ENDPOINT' in credentials: 392 - s3_client_kwargs['endpoint_url'] = credentials['AWS_ENDPOINT'] 393 - s3_client = boto3.client('s3', **s3_client_kwargs) 390 + if "AWS_ENDPOINT" in credentials: 391 + s3_client_kwargs["endpoint_url"] = credentials["AWS_ENDPOINT"] 392 + s3_client = boto3.client("s3", **s3_client_kwargs) 394 393 395 394 def _writer_opener(p: str): 396 395 local_path = Path(temp_dir) / p 397 396 local_path.parent.mkdir(parents=True, exist_ok=True) 398 - return open(local_path, 'wb') 397 + return open(local_path, "wb") 399 398 400 399 def _writer_post(p: str): 401 400 local_path = Path(temp_dir) / p ··· 403 402 bucket = path_parts[0] 404 403 key = str(Path(*path_parts[1:])) 405 404 406 - with open(local_path, 'rb') as f_in: 405 + with open(local_path, "rb") as f_in: 407 406 s3_client.put_object(Bucket=bucket, Key=key, Body=f_in.read()) 408 407 409 408 local_path.unlink() ··· 417 416 assert fs is not None, "S3FileSystem required when cache_local=False" 418 417 419 418 def _direct_opener(s: str): 420 - return cast(BinaryIO, fs.open(f's3://{s}', 'wb')) 419 + return cast(BinaryIO, fs.open(f"s3://{s}", "wb")) 421 420 422 421 def _direct_post(s: str): 423 422 if add_s3_prefix: ··· 426 425 written_shards.append(s) 427 426 428 427 return _direct_opener, _direct_post 428 + 429 429 430 430 ## 431 431 # Schema helpers ··· 452 452 and legacy format: 'local://schemas/{module.Class}@{version}' 453 453 """ 454 454 if ref.startswith(_ATDATA_URI_PREFIX): 455 - path = ref[len(_ATDATA_URI_PREFIX):] 455 + path = ref[len(_ATDATA_URI_PREFIX) :] 456 456 elif ref.startswith(_LEGACY_URI_PREFIX): 457 - path = ref[len(_LEGACY_URI_PREFIX):] 457 + path = ref[len(_LEGACY_URI_PREFIX) :] 458 458 else: 459 459 raise ValueError(f"Invalid schema reference: {ref}") 460 460 ··· 485 485 def _python_type_to_field_type(python_type: Any) -> dict: 486 486 """Convert Python type annotation to schema field type dict.""" 487 487 if python_type in PRIMITIVE_TYPE_MAP: 488 - return {"$type": "local#primitive", "primitive": PRIMITIVE_TYPE_MAP[python_type]} 488 + return { 489 + "$type": "local#primitive", 490 + "primitive": PRIMITIVE_TYPE_MAP[python_type], 491 + } 489 492 490 493 if is_ndarray_type(python_type): 491 494 return {"$type": "local#ndarray", "dtype": extract_ndarray_dtype(python_type)} ··· 493 496 origin = get_origin(python_type) 494 497 if origin is list: 495 498 args = get_args(python_type) 496 - items = _python_type_to_field_type(args[0]) if args else {"$type": "local#primitive", "primitive": "str"} 499 + items = ( 500 + _python_type_to_field_type(args[0]) 501 + if args 502 + else {"$type": "local#primitive", "primitive": "str"} 503 + ) 497 504 return {"$type": "local#array", "items": items} 498 505 499 506 if is_dataclass(python_type): ··· 541 548 field_type, is_optional = unwrap_optional(field_type) 542 549 field_type_dict = _python_type_to_field_type(field_type) 543 550 544 - field_defs.append({ 545 - "name": f.name, 546 - "fieldType": field_type_dict, 547 - "optional": is_optional, 548 - }) 551 + field_defs.append( 552 + { 553 + "name": f.name, 554 + "fieldType": field_type_dict, 555 + "optional": is_optional, 556 + } 557 + ) 549 558 550 559 return { 551 560 "name": sample_type.__name__, ··· 558 567 559 568 ## 560 569 # Redis object model 570 + 561 571 562 572 @dataclass 563 573 class LocalDatasetEntry: ··· 577 587 data_urls: WebDataset URLs for the data. 578 588 metadata: Arbitrary metadata dictionary, or None if not set. 579 589 """ 590 + 580 591 ## 581 592 582 593 name: str ··· 638 649 Args: 639 650 redis: Redis connection to write to. 640 651 """ 641 - save_key = f'{REDIS_KEY_DATASET_ENTRY}:{self.cid}' 652 + save_key = f"{REDIS_KEY_DATASET_ENTRY}:{self.cid}" 642 653 data = { 643 - 'name': self.name, 644 - 'schema_ref': self.schema_ref, 645 - 'data_urls': msgpack.packb(self.data_urls), # Serialize list 646 - 'cid': self.cid, 654 + "name": self.name, 655 + "schema_ref": self.schema_ref, 656 + "data_urls": msgpack.packb(self.data_urls), # Serialize list 657 + "cid": self.cid, 647 658 } 648 659 if self.metadata is not None: 649 - data['metadata'] = msgpack.packb(self.metadata) 660 + data["metadata"] = msgpack.packb(self.metadata) 650 661 if self._legacy_uuid is not None: 651 - data['legacy_uuid'] = self._legacy_uuid 662 + data["legacy_uuid"] = self._legacy_uuid 652 663 653 664 redis.hset(save_key, mapping=data) # type: ignore[arg-type] 654 665 ··· 666 677 Raises: 667 678 KeyError: If entry not found. 668 679 """ 669 - save_key = f'{REDIS_KEY_DATASET_ENTRY}:{cid}' 680 + save_key = f"{REDIS_KEY_DATASET_ENTRY}:{cid}" 670 681 raw_data = redis.hgetall(save_key) 671 682 if not raw_data: 672 683 raise KeyError(f"{REDIS_KEY_DATASET_ENTRY} not found: {cid}") 673 684 674 685 # Decode string fields, keep binary fields as bytes for msgpack 675 686 raw_data_typed = cast(dict[bytes, bytes], raw_data) 676 - name = raw_data_typed[b'name'].decode('utf-8') 677 - schema_ref = raw_data_typed[b'schema_ref'].decode('utf-8') 678 - cid_value = raw_data_typed.get(b'cid', b'').decode('utf-8') or None 679 - legacy_uuid = raw_data_typed.get(b'legacy_uuid', b'').decode('utf-8') or None 687 + name = raw_data_typed[b"name"].decode("utf-8") 688 + schema_ref = raw_data_typed[b"schema_ref"].decode("utf-8") 689 + cid_value = raw_data_typed.get(b"cid", b"").decode("utf-8") or None 690 + legacy_uuid = raw_data_typed.get(b"legacy_uuid", b"").decode("utf-8") or None 680 691 681 692 # Deserialize msgpack fields (stored as raw bytes) 682 - data_urls = msgpack.unpackb(raw_data_typed[b'data_urls']) 693 + data_urls = msgpack.unpackb(raw_data_typed[b"data_urls"]) 683 694 metadata = None 684 - if b'metadata' in raw_data_typed: 685 - metadata = msgpack.unpackb(raw_data_typed[b'metadata']) 695 + if b"metadata" in raw_data_typed: 696 + metadata = msgpack.unpackb(raw_data_typed[b"metadata"]) 686 697 687 698 return cls( 688 699 name=name, ··· 697 708 # Backwards compatibility alias 698 709 BasicIndexEntry = LocalDatasetEntry 699 710 700 - def _s3_env( credentials_path: str | Path ) -> dict[str, Any]: 711 + 712 + def _s3_env(credentials_path: str | Path) -> dict[str, Any]: 701 713 """Load S3 credentials from .env file. 702 714 703 715 Args: ··· 710 722 Raises: 711 723 ValueError: If any required key is missing from the .env file. 712 724 """ 713 - credentials_path = Path( credentials_path ) 714 - env_values = dotenv_values( credentials_path ) 725 + credentials_path = Path(credentials_path) 726 + env_values = dotenv_values(credentials_path) 715 727 716 - required_keys = ('AWS_ENDPOINT', 'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY') 728 + required_keys = ("AWS_ENDPOINT", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY") 717 729 missing = [k for k in required_keys if k not in env_values] 718 730 if missing: 719 - raise ValueError(f"Missing required keys in {credentials_path}: {', '.join(missing)}") 731 + raise ValueError( 732 + f"Missing required keys in {credentials_path}: {', '.join(missing)}" 733 + ) 720 734 721 735 return {k: env_values[k] for k in required_keys} 722 736 723 - def _s3_from_credentials( creds: str | Path | dict ) -> S3FileSystem: 737 + 738 + def _s3_from_credentials(creds: str | Path | dict) -> S3FileSystem: 724 739 """Create S3FileSystem from credentials dict or .env file path.""" 725 - if not isinstance( creds, dict ): 726 - creds = _s3_env( creds ) 740 + if not isinstance(creds, dict): 741 + creds = _s3_env(creds) 727 742 728 743 # Build kwargs, making endpoint_url optional 729 744 kwargs = { 730 - 'key': creds['AWS_ACCESS_KEY_ID'], 731 - 'secret': creds['AWS_SECRET_ACCESS_KEY'] 745 + "key": creds["AWS_ACCESS_KEY_ID"], 746 + "secret": creds["AWS_SECRET_ACCESS_KEY"], 732 747 } 733 - if 'AWS_ENDPOINT' in creds: 734 - kwargs['endpoint_url'] = creds['AWS_ENDPOINT'] 748 + if "AWS_ENDPOINT" in creds: 749 + kwargs["endpoint_url"] = creds["AWS_ENDPOINT"] 735 750 736 751 return S3FileSystem(**kwargs) 737 752 738 753 739 754 ## 740 755 # Classes 756 + 741 757 742 758 class Repo: 743 759 """Repository for storing and managing atdata datasets. ··· 795 811 796 812 if s3_credentials is None: 797 813 self.s3_credentials = None 798 - elif isinstance( s3_credentials, dict ): 814 + elif isinstance(s3_credentials, dict): 799 815 self.s3_credentials = s3_credentials 800 816 else: 801 - self.s3_credentials = _s3_env( s3_credentials ) 817 + self.s3_credentials = _s3_env(s3_credentials) 802 818 803 819 if self.s3_credentials is None: 804 820 self.bucket_fs = None 805 821 else: 806 - self.bucket_fs = _s3_from_credentials( self.s3_credentials ) 822 + self.bucket_fs = _s3_from_credentials(self.s3_credentials) 807 823 808 824 if self.bucket_fs is not None: 809 825 if hive_path is None: 810 - raise ValueError( 'Must specify hive path within bucket' ) 811 - self.hive_path = Path( hive_path ) 826 + raise ValueError("Must specify hive path within bucket") 827 + self.hive_path = Path(hive_path) 812 828 self.hive_bucket = self.hive_path.parts[0] 813 829 else: 814 830 self.hive_path = None ··· 816 832 817 833 # 818 834 819 - self.index = Index( redis = redis ) 835 + self.index = Index(redis=redis) 820 836 821 837 ## 822 838 823 - def insert(self, 824 - ds: Dataset[T], 825 - *, 826 - name: str, 827 - cache_local: bool = False, 828 - schema_ref: str | None = None, 829 - **kwargs 830 - ) -> tuple[LocalDatasetEntry, Dataset[T]]: 839 + def insert( 840 + self, 841 + ds: Dataset[T], 842 + *, 843 + name: str, 844 + cache_local: bool = False, 845 + schema_ref: str | None = None, 846 + **kwargs, 847 + ) -> tuple[LocalDatasetEntry, Dataset[T]]: 831 848 """Insert a dataset into the repository. 832 849 833 850 Writes the dataset to S3 as WebDataset tar files, stores metadata, ··· 851 868 RuntimeError: If no shards were written. 852 869 """ 853 870 if self.s3_credentials is None: 854 - raise ValueError("S3 credentials required for insert(). Initialize Repo with s3_credentials.") 871 + raise ValueError( 872 + "S3 credentials required for insert(). Initialize Repo with s3_credentials." 873 + ) 855 874 if self.hive_bucket is None or self.hive_path is None: 856 - raise ValueError("hive_path required for insert(). Initialize Repo with hive_path.") 875 + raise ValueError( 876 + "hive_path required for insert(). Initialize Repo with hive_path." 877 + ) 857 878 858 - new_uuid = str( uuid4() ) 879 + new_uuid = str(uuid4()) 859 880 860 - hive_fs = _s3_from_credentials( self.s3_credentials ) 881 + hive_fs = _s3_from_credentials(self.s3_credentials) 861 882 862 883 # Write metadata 863 884 metadata_path = ( 864 - self.hive_path 865 - / 'metadata' 866 - / f'atdata-metadata--{new_uuid}.msgpack' 885 + self.hive_path / "metadata" / f"atdata-metadata--{new_uuid}.msgpack" 867 886 ) 868 887 # Note: S3 doesn't need directories created beforehand - s3fs handles this 869 888 870 889 if ds.metadata is not None: 871 890 # Use s3:// prefix to ensure s3fs treats this as an S3 path 872 - with cast( BinaryIO, hive_fs.open( f's3://{metadata_path.as_posix()}', 'wb' ) ) as f: 873 - meta_packed = msgpack.packb( ds.metadata ) 891 + with cast( 892 + BinaryIO, hive_fs.open(f"s3://{metadata_path.as_posix()}", "wb") 893 + ) as f: 894 + meta_packed = msgpack.packb(ds.metadata) 874 895 assert meta_packed is not None 875 - f.write( cast( bytes, meta_packed ) ) 876 - 896 + f.write(cast(bytes, meta_packed)) 877 897 878 898 # Write data 879 - shard_pattern = ( 880 - self.hive_path 881 - / f'atdata--{new_uuid}--%06d.tar' 882 - ).as_posix() 899 + shard_pattern = (self.hive_path / f"atdata--{new_uuid}--%06d.tar").as_posix() 883 900 884 901 written_shards: list[str] = [] 885 902 with TemporaryDirectory() as temp_dir: ··· 902 919 sink.write(sample.as_wds) 903 920 904 921 # Make a new Dataset object for the written dataset copy 905 - if len( written_shards ) == 0: 906 - raise RuntimeError( 'Cannot form new dataset entry -- did not write any shards' ) 907 - 908 - elif len( written_shards ) < 2: 922 + if len(written_shards) == 0: 923 + raise RuntimeError( 924 + "Cannot form new dataset entry -- did not write any shards" 925 + ) 926 + 927 + elif len(written_shards) < 2: 909 928 new_dataset_url = ( 910 - self.hive_path 911 - / ( Path( written_shards[0] ).name ) 929 + self.hive_path / (Path(written_shards[0]).name) 912 930 ).as_posix() 913 931 914 932 else: 915 933 shard_s3_format = ( 916 - ( 917 - self.hive_path 918 - / f'atdata--{new_uuid}' 919 - ).as_posix() 920 - ) + '--{shard_id}.tar' 921 - shard_id_braced = '{' + f'{0:06d}..{len( written_shards ) - 1:06d}' + '}' 922 - new_dataset_url = shard_s3_format.format( shard_id = shard_id_braced ) 934 + (self.hive_path / f"atdata--{new_uuid}").as_posix() 935 + ) + "--{shard_id}.tar" 936 + shard_id_braced = "{" + f"{0:06d}..{len(written_shards) - 1:06d}" + "}" 937 + new_dataset_url = shard_s3_format.format(shard_id=shard_id_braced) 923 938 924 939 new_dataset = Dataset[ds.sample_type]( 925 940 url=new_dataset_url, ··· 993 1008 # Providing stub_dir implies auto_stubs=True 994 1009 if auto_stubs or stub_dir is not None: 995 1010 from ._stub_manager import StubManager 1011 + 996 1012 self._stub_manager: StubManager | None = StubManager(stub_dir=stub_dir) 997 1013 else: 998 1014 self._stub_manager = None ··· 1130 1146 Yields: 1131 1147 LocalDatasetEntry objects from the index. 1132 1148 """ 1133 - prefix = f'{REDIS_KEY_DATASET_ENTRY}:' 1134 - for key in self._redis.scan_iter(match=f'{prefix}*'): 1135 - key_str = key.decode('utf-8') if isinstance(key, bytes) else key 1136 - cid = key_str[len(prefix):] 1149 + prefix = f"{REDIS_KEY_DATASET_ENTRY}:" 1150 + for key in self._redis.scan_iter(match=f"{prefix}*"): 1151 + key_str = key.decode("utf-8") if isinstance(key, bytes) else key 1152 + cid = key_str[len(prefix) :] 1137 1153 yield LocalDatasetEntry.from_redis(self._redis, cid) 1138 1154 1139 - def add_entry(self, 1140 - ds: Dataset, 1141 - *, 1142 - name: str, 1143 - schema_ref: str | None = None, 1144 - metadata: dict | None = None, 1145 - ) -> LocalDatasetEntry: 1155 + def add_entry( 1156 + self, 1157 + ds: Dataset, 1158 + *, 1159 + name: str, 1160 + schema_ref: str | None = None, 1161 + metadata: dict | None = None, 1162 + ) -> LocalDatasetEntry: 1146 1163 """Add a dataset to the index. 1147 1164 1148 1165 Creates a LocalDatasetEntry for the dataset and persists it to Redis. ··· 1158 1175 """ 1159 1176 ## 1160 1177 if schema_ref is None: 1161 - schema_ref = f"local://schemas/{_kind_str_for_sample_type(ds.sample_type)}@1.0.0" 1178 + schema_ref = ( 1179 + f"local://schemas/{_kind_str_for_sample_type(ds.sample_type)}@1.0.0" 1180 + ) 1162 1181 1163 1182 # Normalize URL to list 1164 1183 data_urls = [ds.url] ··· 1237 1256 Returns: 1238 1257 IndexEntry for the inserted dataset. 1239 1258 """ 1240 - metadata = kwargs.get('metadata') 1259 + metadata = kwargs.get("metadata") 1241 1260 1242 1261 if self._data_store is not None: 1243 1262 # Write shards to data store, then index the new URLs 1244 - prefix = kwargs.get('prefix', name) 1245 - cache_local = kwargs.get('cache_local', False) 1263 + prefix = kwargs.get("prefix", name) 1264 + cache_local = kwargs.get("cache_local", False) 1246 1265 1247 1266 written_urls = self._data_store.write_shards( 1248 1267 ds, ··· 1306 1325 latest_version: tuple[int, int, int] | None = None 1307 1326 latest_version_str: str | None = None 1308 1327 1309 - prefix = f'{REDIS_KEY_SCHEMA}:' 1310 - for key in self._redis.scan_iter(match=f'{prefix}*'): 1311 - key_str = key.decode('utf-8') if isinstance(key, bytes) else key 1312 - schema_id = key_str[len(prefix):] 1328 + prefix = f"{REDIS_KEY_SCHEMA}:" 1329 + for key in self._redis.scan_iter(match=f"{prefix}*"): 1330 + key_str = key.decode("utf-8") if isinstance(key, bytes) else key 1331 + schema_id = key_str[len(prefix) :] 1313 1332 1314 1333 if "@" not in schema_id: 1315 1334 continue ··· 1361 1380 # This catches non-packable types early with a clear error message 1362 1381 try: 1363 1382 # Check protocol compliance by verifying required methods exist 1364 - if not (hasattr(sample_type, 'from_data') and 1365 - hasattr(sample_type, 'from_bytes') and 1366 - callable(getattr(sample_type, 'from_data', None)) and 1367 - callable(getattr(sample_type, 'from_bytes', None))): 1383 + if not ( 1384 + hasattr(sample_type, "from_data") 1385 + and hasattr(sample_type, "from_bytes") 1386 + and callable(getattr(sample_type, "from_data", None)) 1387 + and callable(getattr(sample_type, "from_bytes", None)) 1388 + ): 1368 1389 raise TypeError( 1369 1390 f"{sample_type.__name__} does not satisfy the Packable protocol. " 1370 1391 "Use @packable decorator or inherit from PackableSample." ··· 1422 1443 raise KeyError(f"Schema not found: {ref}") 1423 1444 1424 1445 if isinstance(schema_json, bytes): 1425 - schema_json = schema_json.decode('utf-8') 1446 + schema_json = schema_json.decode("utf-8") 1426 1447 1427 1448 schema = json.loads(schema_json) 1428 - schema['$ref'] = _make_schema_ref(name, version) 1449 + schema["$ref"] = _make_schema_ref(name, version) 1429 1450 1430 1451 # Auto-generate stub if enabled 1431 1452 if self._stub_manager is not None: ··· 1460 1481 Yields: 1461 1482 LocalSchemaRecord for each schema. 1462 1483 """ 1463 - prefix = f'{REDIS_KEY_SCHEMA}:' 1464 - for key in self._redis.scan_iter(match=f'{prefix}*'): 1465 - key_str = key.decode('utf-8') if isinstance(key, bytes) else key 1484 + prefix = f"{REDIS_KEY_SCHEMA}:" 1485 + for key in self._redis.scan_iter(match=f"{prefix}*"): 1486 + key_str = key.decode("utf-8") if isinstance(key, bytes) else key 1466 1487 # Extract name@version from key 1467 - schema_id = key_str[len(prefix):] 1488 + schema_id = key_str[len(prefix) :] 1468 1489 1469 1490 schema_json = self._redis.get(key) 1470 1491 if schema_json is None: 1471 1492 continue 1472 1493 1473 1494 if isinstance(schema_json, bytes): 1474 - schema_json = schema_json.decode('utf-8') 1495 + schema_json = schema_json.decode("utf-8") 1475 1496 1476 1497 schema = json.loads(schema_json) 1477 1498 # Handle legacy keys that have module.Class format 1478 1499 if "." in schema_id.split("@")[0]: 1479 1500 name = schema_id.split("@")[0].rsplit(".", 1)[1] 1480 1501 version = schema_id.split("@")[1] 1481 - schema['$ref'] = _make_schema_ref(name, version) 1502 + schema["$ref"] = _make_schema_ref(name, version) 1482 1503 else: 1483 1504 # schema_id is already "name@version" 1484 1505 name, version = schema_id.rsplit("@", 1) 1485 - schema['$ref'] = _make_schema_ref(name, version) 1506 + schema["$ref"] = _make_schema_ref(name, version) 1486 1507 yield LocalSchemaRecord.from_dict(schema) 1487 1508 1488 1509 def list_schemas(self) -> list[dict]: ··· 1526 1547 1527 1548 # Fall back to dynamic type generation 1528 1549 from atdata._schema_codec import schema_to_type 1550 + 1529 1551 return schema_to_type(schema_dict) 1530 1552 1531 1553 def decode_schema_as(self, ref: str, type_hint: type[T]) -> type[T]: ··· 1557 1579 stub matches the schema to avoid runtime surprises. 1558 1580 """ 1559 1581 from typing import cast 1582 + 1560 1583 return cast(type[T], self.decode_schema(ref)) 1561 1584 1562 1585 def clear_stubs(self) -> int: ··· 1677 1700 HTTPS URL if custom endpoint is configured, otherwise unchanged. 1678 1701 Example: 's3://bucket/path' -> 'https://endpoint.com/bucket/path' 1679 1702 """ 1680 - endpoint = self.credentials.get('AWS_ENDPOINT') 1681 - if endpoint and url.startswith('s3://'): 1703 + endpoint = self.credentials.get("AWS_ENDPOINT") 1704 + if endpoint and url.startswith("s3://"): 1682 1705 # s3://bucket/path -> https://endpoint/bucket/path 1683 1706 path = url[5:] # Remove 's3://' prefix 1684 - endpoint = endpoint.rstrip('/') 1707 + endpoint = endpoint.rstrip("/") 1685 1708 return f"{endpoint}/{path}" 1686 1709 return url 1687 1710 ··· 1694 1717 return True 1695 1718 1696 1719 1697 - # 1720 + #
+15 -10
tests/conftest.py
··· 3 3 This module provides shared fixtures and sample types for the test suite. 4 4 """ 5 5 6 - import pytest 7 - from redis import Redis 8 - from typing import Optional 6 + from pathlib import Path 7 + from typing import Optional, TypeVar 9 8 10 9 import numpy as np 10 + import pytest 11 + import webdataset as wds 11 12 from numpy.typing import NDArray 13 + from redis import Redis 12 14 13 15 import atdata 14 16 ··· 41 43 42 44 Fields: name (str), value (int) 43 45 """ 46 + 44 47 name: str 45 48 value: int 46 49 ··· 51 54 52 55 Fields: data (NDArray), label (str) 53 56 """ 57 + 54 58 data: NDArray 55 59 label: str 56 60 ··· 61 65 62 66 Fields: required (str), optional_int (int|None), optional_array (NDArray|None) 63 67 """ 68 + 64 69 required: str 65 70 optional_int: Optional[int] = None 66 71 optional_array: Optional[NDArray] = None ··· 72 77 73 78 Fields: str_field, int_field, float_field, bool_field, bytes_field 74 79 """ 80 + 75 81 str_field: str 76 82 int_field: int 77 83 float_field: float ··· 85 91 86 92 Fields: tags (list[str]), scores (list[float]) 87 93 """ 94 + 88 95 tags: list[str] 89 96 scores: list[float] 90 97 ··· 95 102 96 103 Fields: id (int), content (str), score (float) 97 104 """ 105 + 98 106 id: int 99 107 content: str 100 108 score: float ··· 108 116 # Import and use these instead of duplicating TarWriter boilerplate. 109 117 # 110 118 # ============================================================================= 111 - 112 - import webdataset as wds 113 - from pathlib import Path 114 - from typing import Type, TypeVar 115 119 116 120 ST = TypeVar("ST") 117 121 ··· 150 154 """ 151 155 tar_path = tmp_path / f"{name}-000000.tar" 152 156 samples = [ 153 - SharedBasicSample(name=f"sample_{i}", value=i * 10) 154 - for i in range(num_samples) 157 + SharedBasicSample(name=f"sample_{i}", value=i * 10) for i in range(num_samples) 155 158 ] 156 159 create_tar_with_samples(tar_path, samples) 157 160 return atdata.Dataset[SharedBasicSample](url=str(tar_path)) ··· 190 193 # Fixtures 191 194 # ============================================================================= 192 195 196 + 193 197 @pytest.fixture 194 198 def redis_connection(): 195 199 """Provide a Redis connection, skip test if Redis is not available.""" ··· 208 212 Clears LocalDatasetEntry, BasicIndexEntry (legacy), and LocalSchema keys 209 213 before and after each test to ensure test isolation. 210 214 """ 215 + 211 216 def _clear_all(): 212 - for pattern in ('LocalDatasetEntry:*', 'BasicIndexEntry:*', 'LocalSchema:*'): 217 + for pattern in ("LocalDatasetEntry:*", "BasicIndexEntry:*", "LocalSchema:*"): 213 218 for key in redis_connection.scan_iter(match=pattern): 214 219 redis_connection.delete(key) 215 220
+154 -46
tests/test_atmosphere.py
··· 44 44 # Test Fixtures 45 45 # ============================================================================= 46 46 47 + 47 48 @pytest.fixture 48 49 def mock_atproto_client(): 49 50 """Create a mock atproto SDK client.""" ··· 75 76 @atdata.packable 76 77 class BasicSample: 77 78 """Simple sample type for testing.""" 79 + 78 80 name: str 79 81 value: int 80 82 ··· 82 84 @atdata.packable 83 85 class NumpySample: 84 86 """Sample type with NDArray field.""" 87 + 85 88 data: NDArray 86 89 label: str 87 90 ··· 89 92 @atdata.packable 90 93 class OptionalSample: 91 94 """Sample type with optional fields.""" 95 + 92 96 required_field: str 93 97 optional_field: Optional[int] 94 98 optional_array: Optional[NDArray] ··· 97 101 @atdata.packable 98 102 class AllTypesSample: 99 103 """Sample type with all primitive types.""" 104 + 100 105 str_field: str 101 106 int_field: int 102 107 float_field: float ··· 107 112 # ============================================================================= 108 113 # Tests for _types.py - AtUri 109 114 # ============================================================================= 115 + 110 116 111 117 class TestAtUri: 112 118 """Tests for AtUri parsing and formatting.""" ··· 161 167 # ============================================================================= 162 168 # Tests for _types.py - FieldType 163 169 # ============================================================================= 170 + 164 171 165 172 class TestFieldType: 166 173 """Tests for FieldType dataclass.""" ··· 203 210 # Tests for _types.py - FieldDef 204 211 # ============================================================================= 205 212 213 + 206 214 class TestFieldDef: 207 215 """Tests for FieldDef dataclass.""" 208 216 ··· 242 250 # ============================================================================= 243 251 # Tests for _types.py - SchemaRecord 244 252 # ============================================================================= 253 + 245 254 246 255 class TestSchemaRecord: 247 256 """Tests for SchemaRecord dataclass and to_record().""" ··· 318 327 # Check primitive field 319 328 prim_field = record["fields"][0] 320 329 assert prim_field["name"] == "primitive_field" 321 - assert prim_field["fieldType"]["$type"] == f"{LEXICON_NAMESPACE}.schemaType#primitive" 330 + assert ( 331 + prim_field["fieldType"]["$type"] 332 + == f"{LEXICON_NAMESPACE}.schemaType#primitive" 333 + ) 322 334 assert prim_field["fieldType"]["primitive"] == "int" 323 335 assert prim_field["optional"] is False 324 336 325 337 # Check ndarray field 326 338 arr_field = record["fields"][1] 327 339 assert arr_field["name"] == "array_field" 328 - assert arr_field["fieldType"]["$type"] == f"{LEXICON_NAMESPACE}.schemaType#ndarray" 340 + assert ( 341 + arr_field["fieldType"]["$type"] == f"{LEXICON_NAMESPACE}.schemaType#ndarray" 342 + ) 329 343 assert arr_field["fieldType"]["dtype"] == "float32" 330 344 assert arr_field["optional"] is True 331 345 ··· 333 347 # ============================================================================= 334 348 # Tests for _types.py - StorageLocation 335 349 # ============================================================================= 350 + 336 351 337 352 class TestStorageLocation: 338 353 """Tests for StorageLocation dataclass.""" ··· 364 379 # Tests for _types.py - DatasetRecord 365 380 # ============================================================================= 366 381 382 + 367 383 class TestDatasetRecord: 368 384 """Tests for DatasetRecord dataclass and to_record().""" 369 385 ··· 382 398 383 399 assert record["$type"] == f"{LEXICON_NAMESPACE}.record" 384 400 assert record["name"] == "TestDataset" 385 - assert record["schemaRef"] == "at://did:plc:abc/ac.foundation.dataset.sampleSchema/xyz" 401 + assert ( 402 + record["schemaRef"] 403 + == "at://did:plc:abc/ac.foundation.dataset.sampleSchema/xyz" 404 + ) 386 405 assert record["storage"]["$type"] == f"{LEXICON_NAMESPACE}.storageExternal" 387 406 assert record["storage"]["urls"] == ["s3://bucket/data.tar"] 388 407 ··· 438 457 # Tests for _types.py - LensRecord 439 458 # ============================================================================= 440 459 460 + 441 461 class TestLensRecord: 442 462 """Tests for LensRecord dataclass and to_record().""" 443 463 ··· 500 520 # Tests for client.py - AtmosphereClient 501 521 # ============================================================================= 502 522 523 + 503 524 class TestAtmosphereClient: 504 525 """Tests for AtmosphereClient.""" 505 526 ··· 539 560 assert client.is_authenticated 540 561 assert client.did == "did:plc:test123456789" 541 562 assert client.handle == "test.bsky.social" 542 - mock_atproto_client.login.assert_called_once_with("test.bsky.social", "password123") 563 + mock_atproto_client.login.assert_called_once_with( 564 + "test.bsky.social", "password123" 565 + ) 543 566 544 567 def test_login_with_session(self, mock_atproto_client): 545 568 """Login with exported session string.""" ··· 548 571 client.login_with_session("test-session-string") 549 572 550 573 assert client.is_authenticated 551 - mock_atproto_client.login.assert_called_once_with(session_string="test-session-string") 574 + mock_atproto_client.login.assert_called_once_with( 575 + session_string="test-session-string" 576 + ) 552 577 553 578 def test_export_session(self, authenticated_client, mock_atproto_client): 554 579 """Export session string.""" ··· 625 650 626 651 assert record["field"] == "value" 627 652 628 - def test_get_record_with_aturi_object(self, authenticated_client, mock_atproto_client): 653 + def test_get_record_with_aturi_object( 654 + self, authenticated_client, mock_atproto_client 655 + ): 629 656 """Get a record using AtUri object.""" 630 657 mock_response = Mock() 631 658 mock_response.value = {"$type": "test", "data": 123} ··· 653 680 mock_response.blob = mock_blob_ref 654 681 mock_atproto_client.upload_blob.return_value = mock_response 655 682 656 - result = authenticated_client.upload_blob(b"test data", mime_type="application/x-tar") 683 + result = authenticated_client.upload_blob( 684 + b"test data", mime_type="application/x-tar" 685 + ) 657 686 658 687 assert result["$type"] == "blob" 659 688 assert result["ref"]["$link"] == "bafkreitest123" ··· 673 702 mock_did_response = Mock() 674 703 mock_did_response.json.return_value = { 675 704 "service": [ 676 - {"type": "AtprotoPersonalDataServer", "serviceEndpoint": "https://pds.example.com"} 705 + { 706 + "type": "AtprotoPersonalDataServer", 707 + "serviceEndpoint": "https://pds.example.com", 708 + } 677 709 ] 678 710 } 679 711 mock_did_response.raise_for_status = Mock() ··· 692 724 def test_get_blob_pds_not_found(self, authenticated_client): 693 725 """Get blob raises when PDS cannot be resolved.""" 694 726 import requests as req_module 727 + 695 728 with patch("requests.get") as mock_get: 696 729 mock_get.side_effect = req_module.RequestException("Network error") 697 730 ··· 704 737 mock_response = Mock() 705 738 mock_response.json.return_value = { 706 739 "service": [ 707 - {"type": "AtprotoPersonalDataServer", "serviceEndpoint": "https://pds.example.com"} 740 + { 741 + "type": "AtprotoPersonalDataServer", 742 + "serviceEndpoint": "https://pds.example.com", 743 + } 708 744 ] 709 745 } 710 746 mock_response.raise_for_status = Mock() ··· 712 748 713 749 url = authenticated_client.get_blob_url("did:plc:abc", "bafkreitest") 714 750 715 - assert url == "https://pds.example.com/xrpc/com.atproto.sync.getBlob?did=did:plc:abc&cid=bafkreitest" 751 + assert ( 752 + url 753 + == "https://pds.example.com/xrpc/com.atproto.sync.getBlob?did=did:plc:abc&cid=bafkreitest" 754 + ) 716 755 717 756 def test_get_blob_url_pds_not_found(self, authenticated_client): 718 757 """Get blob URL raises when PDS cannot be resolved.""" 719 758 import requests as req_module 759 + 720 760 with patch("requests.get") as mock_get: 721 761 mock_get.side_effect = req_module.RequestException("Network error") 722 762 ··· 763 803 # Tests for schema.py - SchemaPublisher 764 804 # ============================================================================= 765 805 806 + 766 807 class TestSchemaPublisher: 767 808 """Tests for SchemaPublisher.""" 768 809 769 810 def test_publish_basic_sample(self, authenticated_client, mock_atproto_client): 770 811 """Publish a basic sample type schema.""" 771 812 mock_response = Mock() 772 - mock_response.uri = f"at://did:plc:test123456789/{LEXICON_NAMESPACE}.sampleSchema/abc" 813 + mock_response.uri = ( 814 + f"at://did:plc:test123456789/{LEXICON_NAMESPACE}.sampleSchema/abc" 815 + ) 773 816 mock_atproto_client.com.atproto.repo.create_record.return_value = mock_response 774 817 775 818 publisher = SchemaPublisher(authenticated_client) ··· 833 876 assert required["optional"] is False 834 877 assert optional["optional"] is True 835 878 836 - def test_publish_all_primitive_types(self, authenticated_client, mock_atproto_client): 879 + def test_publish_all_primitive_types( 880 + self, authenticated_client, mock_atproto_client 881 + ): 837 882 """Publish sample with all primitive types.""" 838 883 mock_response = Mock() 839 884 mock_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/abc" ··· 918 963 # Tests for records.py - DatasetPublisher 919 964 # ============================================================================= 920 965 966 + 921 967 class TestDatasetPublisher: 922 968 """Tests for DatasetPublisher.""" 923 969 ··· 950 996 """Publish dataset with auto schema publishing.""" 951 997 # Mock for schema creation 952 998 schema_response = Mock() 953 - schema_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/schema123" 999 + schema_response.uri = ( 1000 + f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/schema123" 1001 + ) 954 1002 955 1003 # Mock for dataset creation 956 1004 dataset_response = Mock() 957 - dataset_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.record/dataset456" 1005 + dataset_response.uri = ( 1006 + f"at://did:plc:test/{LEXICON_NAMESPACE}.record/dataset456" 1007 + ) 958 1008 959 1009 mock_atproto_client.com.atproto.repo.create_record.side_effect = [ 960 1010 schema_response, ··· 977 1027 # Should have called create_record twice (schema + dataset) 978 1028 assert mock_atproto_client.com.atproto.repo.create_record.call_count == 2 979 1029 980 - def test_publish_explicit_schema_uri(self, authenticated_client, mock_atproto_client): 1030 + def test_publish_explicit_schema_uri( 1031 + self, authenticated_client, mock_atproto_client 1032 + ): 981 1033 """Publish dataset with explicit schema URI (no auto publish).""" 982 1034 mock_response = Mock() 983 1035 mock_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.record/abc" ··· 1026 1078 1027 1079 # Mock create_record response 1028 1080 mock_create_response = Mock() 1029 - mock_create_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.record/blobds" 1030 - mock_atproto_client.com.atproto.repo.create_record.return_value = mock_create_response 1081 + mock_create_response.uri = ( 1082 + f"at://did:plc:test/{LEXICON_NAMESPACE}.record/blobds" 1083 + ) 1084 + mock_atproto_client.com.atproto.repo.create_record.return_value = ( 1085 + mock_create_response 1086 + ) 1031 1087 1032 1088 publisher = DatasetPublisher(authenticated_client) 1033 1089 uri = publisher.publish_with_blobs( ··· 1050 1106 assert record["name"] == "BlobStoredDataset" 1051 1107 assert "storageBlobs" in record["storage"]["$type"] 1052 1108 1053 - def test_publish_with_blobs_with_metadata(self, authenticated_client, mock_atproto_client): 1109 + def test_publish_with_blobs_with_metadata( 1110 + self, authenticated_client, mock_atproto_client 1111 + ): 1054 1112 """Publish with blobs includes metadata when provided.""" 1055 1113 mock_blob_ref = Mock() 1056 1114 mock_blob_ref.ref = Mock(link="bafkreiblob456") ··· 1062 1120 mock_atproto_client.upload_blob.return_value = mock_upload_response 1063 1121 1064 1122 mock_create_response = Mock() 1065 - mock_create_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.record/metads" 1066 - mock_atproto_client.com.atproto.repo.create_record.return_value = mock_create_response 1123 + mock_create_response.uri = ( 1124 + f"at://did:plc:test/{LEXICON_NAMESPACE}.record/metads" 1125 + ) 1126 + mock_atproto_client.com.atproto.repo.create_record.return_value = ( 1127 + mock_create_response 1128 + ) 1067 1129 1068 1130 publisher = DatasetPublisher(authenticated_client) 1069 1131 publisher.publish_with_blobs( ··· 1123 1185 "schemaRef": "at://schema", 1124 1186 "storage": { 1125 1187 "$type": f"{LEXICON_NAMESPACE}.storageExternal", 1126 - "urls": ["s3://bucket/data-{000000..000009}.tar", "s3://bucket/extra.tar"], 1188 + "urls": [ 1189 + "s3://bucket/data-{000000..000009}.tar", 1190 + "s3://bucket/extra.tar", 1191 + ], 1127 1192 }, 1128 1193 } 1129 1194 mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response ··· 1134 1199 assert len(urls) == 2 1135 1200 assert "data-{000000..000009}.tar" in urls[0] 1136 1201 1137 - def test_get_urls_blob_storage_error(self, authenticated_client, mock_atproto_client): 1202 + def test_get_urls_blob_storage_error( 1203 + self, authenticated_client, mock_atproto_client 1204 + ): 1138 1205 """Get URLs raises for blob storage datasets.""" 1139 1206 mock_response = Mock() 1140 1207 mock_response.value = { ··· 1170 1237 mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 1171 1238 1172 1239 loader = DatasetLoader(authenticated_client) 1173 - metadata = loader.get_metadata(f"at://did:plc:abc/{LEXICON_NAMESPACE}.record/xyz") 1240 + metadata = loader.get_metadata( 1241 + f"at://did:plc:abc/{LEXICON_NAMESPACE}.record/xyz" 1242 + ) 1174 1243 1175 1244 assert metadata["split"] == "train" 1176 1245 assert metadata["samples"] == 10000 ··· 1187 1256 mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 1188 1257 1189 1258 loader = DatasetLoader(authenticated_client) 1190 - metadata = loader.get_metadata(f"at://did:plc:abc/{LEXICON_NAMESPACE}.record/xyz") 1259 + metadata = loader.get_metadata( 1260 + f"at://did:plc:abc/{LEXICON_NAMESPACE}.record/xyz" 1261 + ) 1191 1262 1192 1263 assert metadata is None 1193 1264 ··· 1221 1292 mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 1222 1293 1223 1294 loader = DatasetLoader(authenticated_client) 1224 - storage_type = loader.get_storage_type(f"at://did:plc:abc/{LEXICON_NAMESPACE}.record/xyz") 1295 + storage_type = loader.get_storage_type( 1296 + f"at://did:plc:abc/{LEXICON_NAMESPACE}.record/xyz" 1297 + ) 1225 1298 1226 1299 assert storage_type == "external" 1227 1300 ··· 1240 1313 mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 1241 1314 1242 1315 loader = DatasetLoader(authenticated_client) 1243 - storage_type = loader.get_storage_type(f"at://did:plc:abc/{LEXICON_NAMESPACE}.record/xyz") 1316 + storage_type = loader.get_storage_type( 1317 + f"at://did:plc:abc/{LEXICON_NAMESPACE}.record/xyz" 1318 + ) 1244 1319 1245 1320 assert storage_type == "blobs" 1246 1321 ··· 1265 1340 def test_get_blobs(self, authenticated_client, mock_atproto_client): 1266 1341 """Get blobs returns blob references from storage.""" 1267 1342 blob_refs = [ 1268 - {"ref": {"$link": "bafkreitest1"}, "mimeType": "application/x-tar", "size": 1024}, 1269 - {"ref": {"$link": "bafkreitest2"}, "mimeType": "application/x-tar", "size": 2048}, 1343 + { 1344 + "ref": {"$link": "bafkreitest1"}, 1345 + "mimeType": "application/x-tar", 1346 + "size": 1024, 1347 + }, 1348 + { 1349 + "ref": {"$link": "bafkreitest2"}, 1350 + "mimeType": "application/x-tar", 1351 + "size": 2048, 1352 + }, 1270 1353 ] 1271 1354 mock_response = Mock() 1272 1355 mock_response.value = { ··· 1287 1370 assert blobs[0]["ref"]["$link"] == "bafkreitest1" 1288 1371 assert blobs[1]["ref"]["$link"] == "bafkreitest2" 1289 1372 1290 - def test_get_blobs_external_storage_error(self, authenticated_client, mock_atproto_client): 1373 + def test_get_blobs_external_storage_error( 1374 + self, authenticated_client, mock_atproto_client 1375 + ): 1291 1376 """Get blobs raises for external URL storage datasets.""" 1292 1377 mock_response = Mock() 1293 1378 mock_response.value = { ··· 1306 1391 with pytest.raises(ValueError, match="external URL storage"): 1307 1392 loader.get_blobs(f"at://did:plc:abc/{LEXICON_NAMESPACE}.record/xyz") 1308 1393 1309 - def test_get_blobs_unknown_storage_error(self, authenticated_client, mock_atproto_client): 1394 + def test_get_blobs_unknown_storage_error( 1395 + self, authenticated_client, mock_atproto_client 1396 + ): 1310 1397 """Get blobs raises for unknown storage type.""" 1311 1398 mock_response = Mock() 1312 1399 mock_response.value = { ··· 1346 1433 mock_did_response = Mock() 1347 1434 mock_did_response.json.return_value = { 1348 1435 "service": [ 1349 - {"type": "AtprotoPersonalDataServer", "serviceEndpoint": "https://pds.example.com"} 1436 + { 1437 + "type": "AtprotoPersonalDataServer", 1438 + "serviceEndpoint": "https://pds.example.com", 1439 + } 1350 1440 ] 1351 1441 } 1352 1442 mock_did_response.raise_for_status = Mock() 1353 1443 mock_get.return_value = mock_did_response 1354 1444 1355 1445 loader = DatasetLoader(authenticated_client) 1356 - urls = loader.get_blob_urls(f"at://did:plc:abc123/{LEXICON_NAMESPACE}.record/xyz") 1446 + urls = loader.get_blob_urls( 1447 + f"at://did:plc:abc123/{LEXICON_NAMESPACE}.record/xyz" 1448 + ) 1357 1449 1358 1450 assert len(urls) == 2 1359 1451 assert "bafkreitest1" in urls[0] 1360 1452 assert "bafkreitest2" in urls[1] 1361 1453 assert "did:plc:abc123" in urls[0] 1362 1454 1363 - def test_get_urls_unknown_storage_error(self, authenticated_client, mock_atproto_client): 1455 + def test_get_urls_unknown_storage_error( 1456 + self, authenticated_client, mock_atproto_client 1457 + ): 1364 1458 """Get URLs raises for unknown storage type.""" 1365 1459 mock_response = Mock() 1366 1460 mock_response.value = { ··· 1382 1476 # ============================================================================= 1383 1477 # Tests for lens.py - LensPublisher 1384 1478 # ============================================================================= 1479 + 1385 1480 1386 1481 class TestLensPublisher: 1387 1482 """Tests for LensPublisher.""" ··· 1508 1603 1509 1604 assert len(lenses) == 1 1510 1605 1511 - def test_find_by_schemas_source_only(self, authenticated_client, mock_atproto_client): 1606 + def test_find_by_schemas_source_only( 1607 + self, authenticated_client, mock_atproto_client 1608 + ): 1512 1609 """Find lenses by source schema only.""" 1513 1610 mock_records = [ 1514 - Mock(value={"sourceSchema": "at://schema/a", "targetSchema": "at://schema/b"}), 1515 - Mock(value={"sourceSchema": "at://schema/a", "targetSchema": "at://schema/c"}), 1516 - Mock(value={"sourceSchema": "at://schema/x", "targetSchema": "at://schema/y"}), 1611 + Mock( 1612 + value={"sourceSchema": "at://schema/a", "targetSchema": "at://schema/b"} 1613 + ), 1614 + Mock( 1615 + value={"sourceSchema": "at://schema/a", "targetSchema": "at://schema/c"} 1616 + ), 1617 + Mock( 1618 + value={"sourceSchema": "at://schema/x", "targetSchema": "at://schema/y"} 1619 + ), 1517 1620 ] 1518 1621 1519 1622 mock_response = Mock() ··· 1529 1632 def test_find_by_schemas_both(self, authenticated_client, mock_atproto_client): 1530 1633 """Find lenses by both source and target schema.""" 1531 1634 mock_records = [ 1532 - Mock(value={"sourceSchema": "at://schema/a", "targetSchema": "at://schema/b"}), 1533 - Mock(value={"sourceSchema": "at://schema/a", "targetSchema": "at://schema/c"}), 1635 + Mock( 1636 + value={"sourceSchema": "at://schema/a", "targetSchema": "at://schema/b"} 1637 + ), 1638 + Mock( 1639 + value={"sourceSchema": "at://schema/a", "targetSchema": "at://schema/c"} 1640 + ), 1534 1641 ] 1535 1642 1536 1643 mock_response = Mock() ··· 1691 1798 # AtmosphereIndex Tests 1692 1799 # ============================================================================= 1693 1800 1801 + 1694 1802 class TestAtmosphereIndexEntry: 1695 1803 """Tests for AtmosphereIndexEntry wrapper.""" 1696 1804 ··· 1738 1846 """Index has all AbstractIndex protocol methods.""" 1739 1847 index = AtmosphereIndex(authenticated_client) 1740 1848 1741 - assert hasattr(index, 'insert_dataset') 1742 - assert hasattr(index, 'get_dataset') 1743 - assert hasattr(index, 'list_datasets') 1744 - assert hasattr(index, 'publish_schema') 1745 - assert hasattr(index, 'get_schema') 1746 - assert hasattr(index, 'list_schemas') 1747 - assert hasattr(index, 'decode_schema') 1849 + assert hasattr(index, "insert_dataset") 1850 + assert hasattr(index, "get_dataset") 1851 + assert hasattr(index, "list_datasets") 1852 + assert hasattr(index, "publish_schema") 1853 + assert hasattr(index, "get_schema") 1854 + assert hasattr(index, "list_schemas") 1855 + assert hasattr(index, "decode_schema") 1748 1856 1749 1857 def test_publish_schema(self, authenticated_client, mock_atproto_client): 1750 1858 """publish_schema delegates to SchemaPublisher."""
+9 -6
tests/test_cid.py
··· 176 176 parsed = parse_cid(cid) 177 177 assert parsed["hash"]["digest"] == expected_digest 178 178 179 - @pytest.mark.parametrize("malformed_cid", [ 180 - "", # empty 181 - "invalid", # not a CID 182 - "bafy123", # truncated CID 183 - "Qm123", # v0 prefix but invalid 184 - ]) 179 + @pytest.mark.parametrize( 180 + "malformed_cid", 181 + [ 182 + "", # empty 183 + "invalid", # not a CID 184 + "bafy123", # truncated CID 185 + "Qm123", # v0 prefix but invalid 186 + ], 187 + ) 185 188 def test_parse_cid_malformed_raises_valueerror(self, malformed_cid): 186 189 """Malformed CID strings raise ValueError.""" 187 190 with pytest.raises(ValueError, match="Failed to decode CID"):
+223 -186
tests/test_dataset.py
··· 27 27 ## 28 28 # Sample test cases 29 29 30 + 30 31 @dataclass 31 - class BasicTestSample( atdata.PackableSample ): 32 + class BasicTestSample(atdata.PackableSample): 32 33 name: str 33 34 position: int 34 35 value: float 36 + 35 37 36 38 @dataclass 37 - class NumpyTestSample( atdata.PackableSample ): 39 + class NumpyTestSample(atdata.PackableSample): 38 40 label: int 39 41 image: NDArray 42 + 40 43 41 44 @atdata.packable 42 45 class BasicTestSampleDecorated: ··· 44 47 position: int 45 48 value: float 46 49 50 + 47 51 @atdata.packable 48 52 class NumpyTestSampleDecorated: 49 53 label: int 50 54 image: NDArray 55 + 51 56 52 57 @atdata.packable 53 58 class NumpyOptionalSampleDecorated: 54 59 label: int 55 60 image: NDArray 56 61 embeddings: NDArray | None = None 62 + 57 63 58 64 test_cases = [ 59 65 { 60 - 'SampleType': BasicTestSample, 61 - 'sample_data': { 62 - 'name': 'Hello, world!', 63 - 'position': 42, 64 - 'value': 1024.768, 66 + "SampleType": BasicTestSample, 67 + "sample_data": { 68 + "name": "Hello, world!", 69 + "position": 42, 70 + "value": 1024.768, 65 71 }, 66 - 'sample_wds_stem': 'basic_test', 67 - 'test_parquet': True, 72 + "sample_wds_stem": "basic_test", 73 + "test_parquet": True, 68 74 }, 69 75 { 70 - 'SampleType': NumpyTestSample, 71 - 'sample_data': 72 - { 73 - 'label': 9_001, 74 - 'image': np.random.randn( 1024, 1024 ), 76 + "SampleType": NumpyTestSample, 77 + "sample_data": { 78 + "label": 9_001, 79 + "image": np.random.randn(1024, 1024), 75 80 }, 76 - 'sample_wds_stem': 'numpy_test', 77 - 'test_parquet': False, 81 + "sample_wds_stem": "numpy_test", 82 + "test_parquet": False, 78 83 }, 79 84 { 80 - 'SampleType': BasicTestSampleDecorated, 81 - 'sample_data': { 82 - 'name': 'Hello, world!', 83 - 'position': 42, 84 - 'value': 1024.768, 85 + "SampleType": BasicTestSampleDecorated, 86 + "sample_data": { 87 + "name": "Hello, world!", 88 + "position": 42, 89 + "value": 1024.768, 85 90 }, 86 - 'sample_wds_stem': 'basic_test_decorated', 87 - 'test_parquet': True, 91 + "sample_wds_stem": "basic_test_decorated", 92 + "test_parquet": True, 88 93 }, 89 94 { 90 - 'SampleType': NumpyTestSampleDecorated, 91 - 'sample_data': 92 - { 93 - 'label': 9_001, 94 - 'image': np.random.randn( 1024, 1024 ), 95 + "SampleType": NumpyTestSampleDecorated, 96 + "sample_data": { 97 + "label": 9_001, 98 + "image": np.random.randn(1024, 1024), 95 99 }, 96 - 'sample_wds_stem': 'numpy_test_decorated', 97 - 'test_parquet': False, 100 + "sample_wds_stem": "numpy_test_decorated", 101 + "test_parquet": False, 98 102 }, 99 103 { 100 - 'SampleType': NumpyOptionalSampleDecorated, 101 - 'sample_data': 102 - { 103 - 'label': 9_001, 104 - 'image': np.random.randn( 1024, 1024 ), 105 - 'embeddings': np.random.randn( 512 ), 104 + "SampleType": NumpyOptionalSampleDecorated, 105 + "sample_data": { 106 + "label": 9_001, 107 + "image": np.random.randn(1024, 1024), 108 + "embeddings": np.random.randn(512), 106 109 }, 107 - 'sample_wds_stem': 'numpy_optional_decorated', 108 - 'test_parquet': False, 110 + "sample_wds_stem": "numpy_optional_decorated", 111 + "test_parquet": False, 109 112 }, 110 113 { 111 - 'SampleType': NumpyOptionalSampleDecorated, 112 - 'sample_data': 113 - { 114 - 'label': 9_001, 115 - 'image': np.random.randn( 1024, 1024 ), 116 - 'embeddings': None, 114 + "SampleType": NumpyOptionalSampleDecorated, 115 + "sample_data": { 116 + "label": 9_001, 117 + "image": np.random.randn(1024, 1024), 118 + "embeddings": None, 117 119 }, 118 - 'sample_wds_stem': 'numpy_optional_decorated_none', 119 - 'test_parquet': False, 120 + "sample_wds_stem": "numpy_optional_decorated_none", 121 + "test_parquet": False, 120 122 }, 121 123 ] 122 124 123 125 124 126 ## Tests 125 127 128 + 126 129 @pytest.mark.parametrize( 127 - ('SampleType', 'sample_data'), 128 - [ (case['SampleType'], case['sample_data']) 129 - for case in test_cases ] 130 + ("SampleType", "sample_data"), 131 + [(case["SampleType"], case["sample_data"]) for case in test_cases], 130 132 ) 131 133 def test_create_sample( 132 - SampleType: Type[atdata.PackableSample], 133 - sample_data: atds.WDSRawSample, 134 - ): 134 + SampleType: Type[atdata.PackableSample], 135 + sample_data: atds.WDSRawSample, 136 + ): 135 137 """Test our ability to create samples from semi-structured data""" 136 138 137 - sample = SampleType.from_data( sample_data ) 138 - assert isinstance( sample, SampleType ), \ 139 - f'Did not properly form sample for test type {SampleType}' 139 + sample = SampleType.from_data(sample_data) 140 + assert isinstance(sample, SampleType), ( 141 + f"Did not properly form sample for test type {SampleType}" 142 + ) 140 143 141 144 for k, v in sample_data.items(): 142 145 cur_assertion: bool 143 - if isinstance( v, np.ndarray ): 144 - cur_assertion = np.all( getattr( sample, k ) == v ) 146 + if isinstance(v, np.ndarray): 147 + cur_assertion = np.all(getattr(sample, k) == v) 145 148 else: 146 - cur_assertion = getattr( sample, k ) == v 147 - assert cur_assertion, \ 148 - f'Did not properly incorporate property {k} of test type {SampleType}' 149 + cur_assertion = getattr(sample, k) == v 150 + assert cur_assertion, ( 151 + f"Did not properly incorporate property {k} of test type {SampleType}" 152 + ) 149 153 150 154 151 155 @pytest.mark.parametrize( 152 - ('SampleType', 'sample_data', 'sample_wds_stem'), 153 - [ (case['SampleType'], case['sample_data'], case['sample_wds_stem']) 154 - for case in test_cases ] 156 + ("SampleType", "sample_data", "sample_wds_stem"), 157 + [ 158 + (case["SampleType"], case["sample_data"], case["sample_wds_stem"]) 159 + for case in test_cases 160 + ], 155 161 ) 156 162 def test_wds( 157 - SampleType: Type[atdata.PackableSample], 158 - sample_data: atds.WDSRawSample, 159 - sample_wds_stem: str, 160 - tmp_path 161 - ): 163 + SampleType: Type[atdata.PackableSample], 164 + sample_data: atds.WDSRawSample, 165 + sample_wds_stem: str, 166 + tmp_path, 167 + ): 162 168 """Test our ability to write samples as `WebDatasets` to disk""" 163 169 164 170 ## Testing hyperparameters ··· 170 176 171 177 ## Write sharded dataset 172 178 173 - file_pattern = ( 174 - tmp_path 175 - / (f'{sample_wds_stem}' + '-{shard_id}.tar') 176 - ).as_posix() 177 - file_wds_pattern = file_pattern.format( shard_id = '%06d' ) 179 + file_pattern = (tmp_path / (f"{sample_wds_stem}" + "-{shard_id}.tar")).as_posix() 180 + file_wds_pattern = file_pattern.format(shard_id="%06d") 178 181 179 182 with wds.writer.ShardWriter( 180 - pattern = file_wds_pattern, 181 - maxcount = shard_maxcount, 183 + pattern=file_wds_pattern, 184 + maxcount=shard_maxcount, 182 185 ) as sink: 183 - 184 - for i_sample in range( n_copies ): 185 - new_sample = SampleType.from_data( sample_data ) 186 - assert isinstance( new_sample, SampleType ), \ 187 - f'Did not properly form sample for test type {SampleType}' 186 + for i_sample in range(n_copies): 187 + new_sample = SampleType.from_data(sample_data) 188 + assert isinstance(new_sample, SampleType), ( 189 + f"Did not properly form sample for test type {SampleType}" 190 + ) 188 191 189 - sink.write( new_sample.as_wds ) 190 - 192 + sink.write(new_sample.as_wds) 191 193 192 194 ## Ordered 193 195 194 196 # Read first shard, no batches 195 197 196 - first_filename = file_pattern.format( shard_id = f'{0:06d}' ) 197 - dataset = atdata.Dataset[SampleType]( first_filename ) 198 + first_filename = file_pattern.format(shard_id=f"{0:06d}") 199 + dataset = atdata.Dataset[SampleType](first_filename) 198 200 199 201 iterations_run = 0 200 - for i_iterate, cur_sample in enumerate( dataset.ordered( batch_size = None ) ): 202 + for i_iterate, cur_sample in enumerate(dataset.ordered(batch_size=None)): 203 + assert isinstance(cur_sample, SampleType), ( 204 + f"Single sample for {SampleType} written to `wds` is of wrong type" 205 + ) 201 206 202 - assert isinstance( cur_sample, SampleType ), \ 203 - f'Single sample for {SampleType} written to `wds` is of wrong type' 204 - 205 207 # Check sample values 206 - 208 + 207 209 for k, v in sample_data.items(): 208 - if isinstance( v, np.ndarray ): 209 - is_correct = np.all( getattr( cur_sample, k ) == v ) 210 + if isinstance(v, np.ndarray): 211 + is_correct = np.all(getattr(cur_sample, k) == v) 210 212 else: 211 - is_correct = getattr( cur_sample, k ) == v 212 - assert is_correct, \ 213 - f'{SampleType}: Incorrect sample value found for {k} - {type( getattr( cur_sample, k ) )}' 213 + is_correct = getattr(cur_sample, k) == v 214 + assert is_correct, ( 215 + f"{SampleType}: Incorrect sample value found for {k} - {type(getattr(cur_sample, k))}" 216 + ) 214 217 215 218 iterations_run += 1 216 219 if iterations_run >= n_iterate: 217 220 break 218 221 219 - assert iterations_run == n_iterate, \ 222 + assert iterations_run == n_iterate, ( 220 223 f"Only found {iterations_run} samples, not {n_iterate}" 224 + ) 221 225 222 226 # Read all shards, batches 223 227 224 - start_id = f'{0:06d}' 225 - end_id = f'{9:06d}' 226 - first_filename = file_pattern.format( shard_id = '{' + start_id + '..' + end_id + '}' ) 227 - dataset = atdata.Dataset[SampleType]( first_filename ) 228 + start_id = f"{0:06d}" 229 + end_id = f"{9:06d}" 230 + first_filename = file_pattern.format(shard_id="{" + start_id + ".." + end_id + "}") 231 + dataset = atdata.Dataset[SampleType](first_filename) 228 232 229 233 iterations_run = 0 230 - for i_iterate, cur_batch in enumerate( dataset.ordered( batch_size = batch_size ) ): 231 - 232 - assert isinstance( cur_batch, atdata.SampleBatch ), \ 233 - f'{SampleType}: Batch sample is not correctly a batch' 234 - 235 - assert cur_batch.sample_type == SampleType, \ 236 - f'{SampleType}: Batch `sample_type` is incorrect type' 237 - 234 + for i_iterate, cur_batch in enumerate(dataset.ordered(batch_size=batch_size)): 235 + assert isinstance(cur_batch, atdata.SampleBatch), ( 236 + f"{SampleType}: Batch sample is not correctly a batch" 237 + ) 238 + 239 + assert cur_batch.sample_type == SampleType, ( 240 + f"{SampleType}: Batch `sample_type` is incorrect type" 241 + ) 242 + 238 243 if i_iterate == 0: 239 - cur_n = len( cur_batch.samples ) 240 - assert cur_n == batch_size, \ 241 - f'{SampleType}: Batch has {cur_n} samples, not {batch_size}' 242 - 243 - assert isinstance( cur_batch.samples[0], SampleType ), \ 244 - f'{SampleType}: Batch sample of wrong type ({type( cur_batch.samples[0])})' 245 - 244 + cur_n = len(cur_batch.samples) 245 + assert cur_n == batch_size, ( 246 + f"{SampleType}: Batch has {cur_n} samples, not {batch_size}" 247 + ) 248 + 249 + assert isinstance(cur_batch.samples[0], SampleType), ( 250 + f"{SampleType}: Batch sample of wrong type ({type(cur_batch.samples[0])})" 251 + ) 252 + 246 253 # Check batch values 247 254 for k, v in sample_data.items(): 248 - cur_batch_data = getattr( cur_batch, k ) 255 + cur_batch_data = getattr(cur_batch, k) 256 + 257 + if isinstance(v, np.ndarray): 258 + assert isinstance(cur_batch_data, np.ndarray), ( 259 + f"{SampleType}: `NDArray` not carried through to batch" 260 + ) 249 261 250 - if isinstance( v, np.ndarray ): 251 - assert isinstance( cur_batch_data, np.ndarray ), \ 252 - f'{SampleType}: `NDArray` not carried through to batch' 253 - 254 - is_correct = all( 255 - [ np.all( cur_batch_data[i] == v ) 256 - for i in range( cur_batch_data.shape[0] ) ] 262 + is_correct = all( 263 + [ 264 + np.all(cur_batch_data[i] == v) 265 + for i in range(cur_batch_data.shape[0]) 266 + ] 257 267 ) 258 268 259 269 else: 260 - is_correct = all( 261 - [ cur_batch_data[i] == v 262 - for i in range( len( cur_batch_data ) ) ] 270 + is_correct = all( 271 + [cur_batch_data[i] == v for i in range(len(cur_batch_data))] 263 272 ) 264 273 265 - assert is_correct, \ 266 - f'{SampleType}: Incorrect sample value found for {k}' 274 + assert is_correct, f"{SampleType}: Incorrect sample value found for {k}" 267 275 268 276 iterations_run += 1 269 277 if iterations_run >= n_iterate: 270 278 break 271 279 272 - assert iterations_run == n_iterate, \ 280 + assert iterations_run == n_iterate, ( 273 281 f"Only found {iterations_run} samples, not {n_iterate}" 274 - 282 + ) 275 283 276 284 ## Shuffled 277 285 278 286 # Read first shard, no batches 279 287 280 - first_filename = file_pattern.format( shard_id = f'{0:06d}' ) 281 - dataset = atdata.Dataset[SampleType]( first_filename ) 288 + first_filename = file_pattern.format(shard_id=f"{0:06d}") 289 + dataset = atdata.Dataset[SampleType](first_filename) 282 290 283 291 iterations_run = 0 284 - for i_iterate, cur_sample in enumerate( dataset.shuffled( batch_size = None ) ): 285 - 286 - assert isinstance( cur_sample, SampleType ), \ 287 - f'Single sample for {SampleType} written to `wds` is of wrong type' 288 - 292 + for i_iterate, cur_sample in enumerate(dataset.shuffled(batch_size=None)): 293 + assert isinstance(cur_sample, SampleType), ( 294 + f"Single sample for {SampleType} written to `wds` is of wrong type" 295 + ) 296 + 289 297 iterations_run += 1 290 298 if iterations_run >= n_iterate: 291 299 break 292 300 293 - assert iterations_run == n_iterate, \ 301 + assert iterations_run == n_iterate, ( 294 302 f"Only found {iterations_run} samples, not {n_iterate}" 303 + ) 295 304 296 305 # Read all shards, batches 297 306 298 - start_id = f'{0:06d}' 299 - end_id = f'{9:06d}' 300 - first_filename = file_pattern.format( shard_id = '{' + start_id + '..' + end_id + '}' ) 301 - dataset = atdata.Dataset[SampleType]( first_filename ) 307 + start_id = f"{0:06d}" 308 + end_id = f"{9:06d}" 309 + first_filename = file_pattern.format(shard_id="{" + start_id + ".." + end_id + "}") 310 + dataset = atdata.Dataset[SampleType](first_filename) 302 311 303 312 iterations_run = 0 304 - for i_iterate, cur_sample in enumerate( dataset.shuffled( batch_size = batch_size ) ): 305 - 306 - assert isinstance( cur_sample, atdata.SampleBatch ), \ 307 - f'{SampleType}: Batch sample is not correctly a batch' 308 - 309 - assert cur_sample.sample_type == SampleType, \ 310 - f'{SampleType}: Batch `sample_type` is incorrect type' 311 - 313 + for i_iterate, cur_sample in enumerate(dataset.shuffled(batch_size=batch_size)): 314 + assert isinstance(cur_sample, atdata.SampleBatch), ( 315 + f"{SampleType}: Batch sample is not correctly a batch" 316 + ) 317 + 318 + assert cur_sample.sample_type == SampleType, ( 319 + f"{SampleType}: Batch `sample_type` is incorrect type" 320 + ) 321 + 312 322 if i_iterate == 0: 313 - cur_n = len( cur_sample.samples ) 314 - assert cur_n == batch_size, \ 315 - f'{SampleType}: Batch has {cur_n} samples, not {batch_size}' 316 - 317 - assert isinstance( cur_sample.samples[0], SampleType ), \ 318 - f'{SampleType}: Batch sample of wrong type ({type( cur_sample.samples[0])})' 319 - 323 + cur_n = len(cur_sample.samples) 324 + assert cur_n == batch_size, ( 325 + f"{SampleType}: Batch has {cur_n} samples, not {batch_size}" 326 + ) 327 + 328 + assert isinstance(cur_sample.samples[0], SampleType), ( 329 + f"{SampleType}: Batch sample of wrong type ({type(cur_sample.samples[0])})" 330 + ) 331 + 320 332 iterations_run += 1 321 333 if iterations_run >= n_iterate: 322 334 break 323 335 324 - assert iterations_run == n_iterate, \ 336 + assert iterations_run == n_iterate, ( 325 337 f"Only found {iterations_run} samples, not {n_iterate}" 338 + ) 339 + 326 340 327 341 # 342 + 328 343 329 344 @pytest.mark.parametrize( 330 - ('SampleType', 'sample_data', 'sample_wds_stem', 'test_parquet'), 331 - [ ( 332 - case['SampleType'], 333 - case['sample_data'], 334 - case['sample_wds_stem'], 335 - case['test_parquet'] 336 - ) 337 - for case in test_cases ] 345 + ("SampleType", "sample_data", "sample_wds_stem", "test_parquet"), 346 + [ 347 + ( 348 + case["SampleType"], 349 + case["sample_data"], 350 + case["sample_wds_stem"], 351 + case["test_parquet"], 352 + ) 353 + for case in test_cases 354 + ], 338 355 ) 339 356 def test_parquet_export( 340 - SampleType: Type[atdata.PackableSample], 341 - sample_data: atds.WDSRawSample, 342 - sample_wds_stem: str, 343 - test_parquet: bool, 344 - tmp_path 345 - ): 357 + SampleType: Type[atdata.PackableSample], 358 + sample_data: atds.WDSRawSample, 359 + sample_wds_stem: str, 360 + test_parquet: bool, 361 + tmp_path, 362 + ): 346 363 """Test our ability to export a dataset to `parquet` format""" 347 364 348 365 # Skip irrelevant test cases ··· 356 373 357 374 ## Start out by writing tar dataset 358 375 359 - wds_filename = (tmp_path / f'{sample_wds_stem}.tar').as_posix() 360 - with wds.writer.TarWriter( wds_filename ) as sink: 361 - for _ in range( n_copies_dataset ): 362 - new_sample = SampleType.from_data( sample_data ) 363 - sink.write( new_sample.as_wds ) 364 - 376 + wds_filename = (tmp_path / f"{sample_wds_stem}.tar").as_posix() 377 + with wds.writer.TarWriter(wds_filename) as sink: 378 + for _ in range(n_copies_dataset): 379 + new_sample = SampleType.from_data(sample_data) 380 + sink.write(new_sample.as_wds) 381 + 365 382 ## Now export to `parquet` 366 383 367 - dataset = atdata.Dataset[SampleType]( wds_filename ) 368 - parquet_filename = tmp_path / f'{sample_wds_stem}.parquet' 369 - dataset.to_parquet( parquet_filename ) 384 + dataset = atdata.Dataset[SampleType](wds_filename) 385 + parquet_filename = tmp_path / f"{sample_wds_stem}.parquet" 386 + dataset.to_parquet(parquet_filename) 370 387 371 - parquet_filename = tmp_path / f'{sample_wds_stem}-segments.parquet' 372 - dataset.to_parquet( parquet_filename, maxcount = n_per_file ) 388 + parquet_filename = tmp_path / f"{sample_wds_stem}-segments.parquet" 389 + dataset.to_parquet(parquet_filename, maxcount=n_per_file) 373 390 374 391 375 392 ## ··· 384 401 385 402 def test_sample_batch_attribute_error(): 386 403 """Test SampleBatch raises AttributeError for non-existent attributes.""" 404 + 387 405 @atdata.packable 388 406 class SimpleSample: 389 407 name: str ··· 398 416 399 417 def test_sample_batch_type_property(): 400 418 """Test SampleBatch.sample_type property.""" 419 + 401 420 @atdata.packable 402 421 class TypedSample: 403 422 data: str ··· 410 429 411 430 def test_dataset_batch_type_property(tmp_path): 412 431 """Test Dataset.batch_type property.""" 432 + 413 433 @atdata.packable 414 434 class BatchTypeSample: 415 435 value: int ··· 429 449 430 450 def test_dataset_shard_list_property(tmp_path): 431 451 """Test Dataset.shard_list property returns list of shard URLs.""" 452 + 432 453 @atdata.packable 433 454 class ShardListSample: 434 455 value: int ··· 474 495 475 496 with patch("atdata.dataset.requests.get", return_value=mock_response) as mock_get: 476 497 dataset = atdata.Dataset[MetadataSample]( 477 - wds_filename, 478 - metadata_url="http://example.com/metadata.msgpack" 498 + wds_filename, metadata_url="http://example.com/metadata.msgpack" 479 499 ) 480 500 481 501 # First call should fetch 482 502 metadata = dataset.metadata 483 503 assert metadata == mock_metadata 484 - mock_get.assert_called_once_with("http://example.com/metadata.msgpack", stream=True) 504 + mock_get.assert_called_once_with( 505 + "http://example.com/metadata.msgpack", stream=True 506 + ) 485 507 486 508 # Second call should use cache 487 509 metadata2 = dataset.metadata ··· 491 513 492 514 def test_dataset_metadata_property_none(tmp_path): 493 515 """Test Dataset.metadata returns None when no metadata_url is set.""" 516 + 494 517 @atdata.packable 495 518 class NoMetadataSample: 496 519 value: int ··· 506 529 507 530 def test_parquet_export_with_remainder(tmp_path): 508 531 """Test parquet export with maxcount that doesn't divide evenly.""" 532 + 509 533 @atdata.packable 510 534 class RemainderSample: 511 535 name: str ··· 527 551 528 552 # Should have created 3 segment files 529 553 import pandas as pd 554 + 530 555 segment_files = list(tmp_path.glob("remainder_output-*.parquet")) 531 556 assert len(segment_files) == 3 532 557 ··· 586 611 587 612 def test_from_bytes_invalid_msgpack(): 588 613 """Test from_bytes raises on invalid msgpack data.""" 614 + 589 615 @atdata.packable 590 616 class SimpleSample: 591 617 value: int ··· 596 622 597 623 def test_from_bytes_missing_field(): 598 624 """Test from_bytes raises when required field is missing.""" 625 + 599 626 @atdata.packable 600 627 class RequiredFieldSample: 601 628 name: str 602 629 count: int 603 630 604 631 import ormsgpack 632 + 605 633 # Only provide 'name', missing 'count' 606 634 incomplete_data = ormsgpack.packb({"name": "test"}) 607 635 ··· 611 639 612 640 def test_wrap_missing_msgpack_key(tmp_path): 613 641 """Test wrap raises ValueError on sample missing msgpack key.""" 642 + 614 643 @atdata.packable 615 644 class WrapTestSample: 616 645 value: int ··· 629 658 630 659 def test_wrap_wrong_msgpack_type(tmp_path): 631 660 """Test wrap raises ValueError when msgpack value is not bytes.""" 661 + 632 662 @atdata.packable 633 663 class WrapTypeSample: 634 664 value: int ··· 647 677 648 678 def test_wrap_corrupted_msgpack(tmp_path): 649 679 """Test wrap raises on corrupted msgpack bytes.""" 680 + 650 681 @atdata.packable 651 682 class CorruptedSample: 652 683 value: int ··· 665 696 666 697 def test_dataset_nonexistent_file(): 667 698 """Test Dataset raises on nonexistent tar file during iteration.""" 699 + 668 700 @atdata.packable 669 701 class NonexistentSample: 670 702 value: int ··· 681 713 682 714 def test_dataset_invalid_batch_size(tmp_path): 683 715 """Test Dataset raises on invalid batch_size values.""" 716 + 684 717 @atdata.packable 685 718 class BatchSizeSample: 686 719 value: int ··· 798 831 799 832 def test_dictsample_dataset_iteration(tmp_path): 800 833 """Test Dataset[DictSample] can iterate over data.""" 834 + 801 835 # Create typed sample data 802 836 @atdata.packable 803 837 class SourceSample: ··· 824 858 825 859 def test_dictsample_to_typed_via_as_type(tmp_path): 826 860 """Test converting DictSample dataset to typed via as_type.""" 861 + 827 862 @atdata.packable 828 863 class TypedSample: 829 864 text: str ··· 854 889 855 890 def test_packable_auto_registers_dictsample_lens(): 856 891 """Test @packable decorator auto-registers lens from DictSample.""" 892 + 857 893 @atdata.packable 858 894 class AutoLensSample: 859 895 name: str ··· 874 910 875 911 def test_dictsample_batched_iteration(tmp_path): 876 912 """Test Dataset[DictSample] works with batched iteration.""" 913 + 877 914 @atdata.packable 878 915 class BatchSource: 879 916 text: str ··· 899 936 assert batch_count == 3 # 10 samples / 4 per batch = 2 full + 1 partial 900 937 901 938 902 - ## 939 + ##
+22 -16
tests/test_helpers.py
··· 9 9 class TestArraySerialization: 10 10 """Test array_to_bytes and bytes_to_array round-trip serialization.""" 11 11 12 - @pytest.mark.parametrize("dtype", [ 13 - np.float32, 14 - np.float64, 15 - np.int32, 16 - np.int64, 17 - np.uint8, 18 - np.bool_, 19 - np.complex64, 20 - ]) 12 + @pytest.mark.parametrize( 13 + "dtype", 14 + [ 15 + np.float32, 16 + np.float64, 17 + np.int32, 18 + np.int64, 19 + np.uint8, 20 + np.bool_, 21 + np.complex64, 22 + ], 23 + ) 21 24 def test_dtype_preservation(self, dtype): 22 25 """Verify dtype is preserved through serialization.""" 23 26 original = np.array([1, 2, 3], dtype=dtype) ··· 27 30 assert restored.dtype == original.dtype 28 31 np.testing.assert_array_equal(restored, original) 29 32 30 - @pytest.mark.parametrize("shape", [ 31 - (10,), 32 - (3, 4), 33 - (2, 3, 4), 34 - (1, 1, 1, 1), 35 - ]) 33 + @pytest.mark.parametrize( 34 + "shape", 35 + [ 36 + (10,), 37 + (3, 4), 38 + (2, 3, 4), 39 + (1, 1, 1, 1), 40 + ], 41 + ) 36 42 def test_shape_preservation(self, shape): 37 43 """Verify shape is preserved through serialization.""" 38 44 original = np.random.rand(*shape).astype(np.float32) ··· 73 79 original = np.random.rand(10, 10).astype(np.float32) 74 80 non_contiguous = original[::2, ::2] # Strided view 75 81 76 - assert not non_contiguous.flags['C_CONTIGUOUS'] 82 + assert not non_contiguous.flags["C_CONTIGUOUS"] 77 83 78 84 serialized = array_to_bytes(non_contiguous) 79 85 restored = bytes_to_array(serialized)
+4 -1
tests/test_hf_api.py
··· 855 855 mock_index = Mock() 856 856 mock_index.data_store = mock_store 857 857 mock_entry = Mock() 858 - mock_entry.data_urls = ["s3://my-bucket/train-000.tar", "s3://my-bucket/train-001.tar"] 858 + mock_entry.data_urls = [ 859 + "s3://my-bucket/train-000.tar", 860 + "s3://my-bucket/train-001.tar", 861 + ] 859 862 mock_entry.schema_ref = "local://schemas/test@1.0.0" 860 863 mock_index.get_dataset.return_value = mock_entry 861 864
+20 -5
tests/test_integration.py
··· 16 16 @atdata.packable 17 17 class IntegrationTestSample: 18 18 """Sample type for integration tests.""" 19 + 19 20 name: str 20 21 value: int 21 22 ··· 39 40 "name": "test_integration.IntegrationTestSample", 40 41 "version": "1.0.0", 41 42 "fields": [ 42 - {"name": "name", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 43 - {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 43 + { 44 + "name": "name", 45 + "fieldType": {"$type": "local#primitive", "primitive": "str"}, 46 + "optional": False, 47 + }, 48 + { 49 + "name": "value", 50 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 51 + "optional": False, 52 + }, 44 53 ], 45 54 } 46 55 ··· 86 95 "version": "2.1.0", 87 96 "description": "A sample with specific version", 88 97 "fields": [ 89 - {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 98 + { 99 + "name": "value", 100 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 101 + "optional": False, 102 + }, 90 103 ], 91 104 } 92 105 ··· 97 110 98 111 with patch("atdata.atmosphere.DatasetPublisher") as MockPublisher: 99 112 mock_publisher = MockPublisher.return_value 100 - mock_publisher.publish_with_urls.return_value = Mock(__str__=lambda s: "at://result") 113 + mock_publisher.publish_with_urls.return_value = Mock( 114 + __str__=lambda s: "at://result" 115 + ) 101 116 102 117 promote_to_atmosphere(local_entry, mock_local_index, mock_client) 103 118 ··· 173 188 "value": { 174 189 "name": "test_integration.IntegrationTestSample", 175 190 "version": "1.0.0", # Different version 176 - } 191 + }, 177 192 } 178 193 ] 179 194
+69 -25
tests/test_integration_atmosphere.py
··· 33 33 @atdata.packable 34 34 class AtmoSample: 35 35 """Sample for atmosphere tests.""" 36 + 36 37 name: str 37 38 value: int 38 39 ··· 40 41 @atdata.packable 41 42 class AtmoNDArraySample: 42 43 """Sample with NDArray for atmosphere tests.""" 44 + 43 45 label: str 44 46 data: NDArray 45 47 ··· 84 86 """Full workflow: login → publish schema → publish dataset.""" 85 87 # Setup mock responses 86 88 schema_response = Mock() 87 - schema_response.uri = f"at://did:plc:integration123/{LEXICON_NAMESPACE}.sampleSchema/schema123" 89 + schema_response.uri = ( 90 + f"at://did:plc:integration123/{LEXICON_NAMESPACE}.sampleSchema/schema123" 91 + ) 88 92 89 93 dataset_response = Mock() 90 - dataset_response.uri = f"at://did:plc:integration123/{LEXICON_NAMESPACE}.dataset/dataset456" 94 + dataset_response.uri = ( 95 + f"at://did:plc:integration123/{LEXICON_NAMESPACE}.dataset/dataset456" 96 + ) 91 97 92 98 mock_atproto_client.com.atproto.repo.create_record.side_effect = [ 93 99 schema_response, ··· 135 141 client.login_with_session("saved-session-string") 136 142 137 143 assert client.is_authenticated 138 - mock_atproto_client.login.assert_called_with(session_string="saved-session-string") 144 + mock_atproto_client.login.assert_called_with( 145 + session_string="saved-session-string" 146 + ) 139 147 140 148 def test_session_round_trip(self, mock_atproto_client): 141 149 """Export then import session should maintain auth.""" ··· 188 196 "name": "FoundSchema", 189 197 "version": "2.0.0", 190 198 "fields": [ 191 - {"name": "field1", "fieldType": {"$type": f"{LEXICON_NAMESPACE}.schemaType#primitive", "primitive": "str"}, "optional": False} 192 - ] 199 + { 200 + "name": "field1", 201 + "fieldType": { 202 + "$type": f"{LEXICON_NAMESPACE}.schemaType#primitive", 203 + "primitive": "str", 204 + }, 205 + "optional": False, 206 + } 207 + ], 193 208 } 194 209 mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 195 210 ··· 203 218 class TestAtmosphereIndex: 204 219 """Tests for AtmosphereIndex AbstractIndex compliance.""" 205 220 206 - def test_index_list_datasets_yields_entries(self, authenticated_client, mock_atproto_client): 221 + def test_index_list_datasets_yields_entries( 222 + self, authenticated_client, mock_atproto_client 223 + ): 207 224 """list_datasets should yield AtmosphereIndexEntry objects.""" 208 225 mock_record = Mock() 209 226 mock_record.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.dataset/d1" 210 227 mock_record.value = { 211 228 "name": "listed-dataset", 212 229 "schemaRef": "at://schema", 213 - "storage": {"$type": f"{LEXICON_NAMESPACE}.storageExternal", "urls": ["s3://data.tar"]}, 230 + "storage": { 231 + "$type": f"{LEXICON_NAMESPACE}.storageExternal", 232 + "urls": ["s3://data.tar"], 233 + }, 214 234 } 215 235 216 236 mock_response = Mock() ··· 229 249 record = { 230 250 "name": "test-dataset", 231 251 "schemaRef": "at://did:plc:schema/schema/key", 232 - "storage": {"$type": f"{LEXICON_NAMESPACE}.storageExternal", "urls": ["s3://data.tar"]}, 252 + "storage": { 253 + "$type": f"{LEXICON_NAMESPACE}.storageExternal", 254 + "urls": ["s3://data.tar"], 255 + }, 233 256 } 234 257 235 258 entry = AtmosphereIndexEntry("at://test/dataset/key", record) ··· 247 270 record = { 248 271 "name": "meta-dataset", 249 272 "schemaRef": "at://schema", 250 - "storage": {"$type": f"{LEXICON_NAMESPACE}.storageExternal", "urls": ["s3://data.tar"]}, 273 + "storage": { 274 + "$type": f"{LEXICON_NAMESPACE}.storageExternal", 275 + "urls": ["s3://data.tar"], 276 + }, 251 277 "metadata": packed_meta, 252 278 } 253 279 ··· 261 287 record = { 262 288 "name": "no-meta", 263 289 "schemaRef": "at://schema", 264 - "storage": {"$type": f"{LEXICON_NAMESPACE}.storageExternal", "urls": ["s3://data.tar"]}, 290 + "storage": { 291 + "$type": f"{LEXICON_NAMESPACE}.storageExternal", 292 + "urls": ["s3://data.tar"], 293 + }, 265 294 } 266 295 267 296 entry = AtmosphereIndexEntry("at://test/dataset/key", record) ··· 300 329 "schemaRef": "at://schema", 301 330 "storage": { 302 331 "$type": f"{LEXICON_NAMESPACE}.storageExternal", 303 - "urls": ["https://cdn.example.com/data-000.tar", "https://cdn.example.com/data-001.tar"], 332 + "urls": [ 333 + "https://cdn.example.com/data-000.tar", 334 + "https://cdn.example.com/data-001.tar", 335 + ], 304 336 }, 305 337 } 306 338 ··· 339 371 def test_publish_ndarray_schema(self, authenticated_client, mock_atproto_client): 340 372 """Schema with NDArray field should publish correctly.""" 341 373 mock_response = Mock() 342 - mock_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/ndarray" 374 + mock_response.uri = ( 375 + f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/ndarray" 376 + ) 343 377 mock_atproto_client.com.atproto.repo.create_record.return_value = mock_response 344 378 345 379 publisher = SchemaPublisher(authenticated_client) ··· 438 472 with pytest.raises(ValueError, match="Not authenticated"): 439 473 store.write_shards(mock_ds, prefix="test") 440 474 441 - def test_write_shards_uploads_blobs(self, authenticated_client, mock_atproto_client, tmp_path): 475 + def test_write_shards_uploads_blobs( 476 + self, authenticated_client, mock_atproto_client, tmp_path 477 + ): 442 478 """write_shards uploads each shard as a blob.""" 443 479 from atdata.atmosphere import PDSBlobStore 444 480 import webdataset as wds ··· 452 488 ds = atdata.Dataset[AtmoSample](str(tar_path)) 453 489 454 490 # Mock upload_blob to return a blob reference 455 - authenticated_client.upload_blob = Mock(return_value={ 456 - "$type": "blob", 457 - "ref": {"$link": "bafyrei123abc"}, 458 - "mimeType": "application/x-tar", 459 - "size": 1024, 460 - }) 491 + authenticated_client.upload_blob = Mock( 492 + return_value={ 493 + "$type": "blob", 494 + "ref": {"$link": "bafyrei123abc"}, 495 + "mimeType": "application/x-tar", 496 + "size": 1024, 497 + } 498 + ) 461 499 462 500 store = PDSBlobStore(client=authenticated_client) 463 501 urls = store.write_shards(ds, prefix="test/v1", maxcount=100) ··· 473 511 # First arg should be bytes (tar data) 474 512 assert isinstance(call_args.args[0], bytes) 475 513 476 - def test_read_url_transforms_at_uri(self, authenticated_client, mock_atproto_client): 514 + def test_read_url_transforms_at_uri( 515 + self, authenticated_client, mock_atproto_client 516 + ): 477 517 """read_url transforms AT URIs to HTTP URLs.""" 478 518 from atdata.atmosphere import PDSBlobStore 479 519 ··· 486 526 487 527 assert "https://pds.example.com" in url 488 528 assert "bafyrei123" in url 489 - authenticated_client.get_blob_url.assert_called_once_with("did:plc:abc", "bafyrei123") 529 + authenticated_client.get_blob_url.assert_called_once_with( 530 + "did:plc:abc", "bafyrei123" 531 + ) 490 532 491 533 def test_read_url_passes_non_at_uri(self, authenticated_client): 492 534 """read_url returns non-AT URIs unchanged.""" ··· 512 554 from atdata._sources import BlobSource 513 555 514 556 store = PDSBlobStore(client=authenticated_client) 515 - source = store.create_source([ 516 - "at://did:plc:abc/blob/bafyrei111", 517 - "at://did:plc:abc/blob/bafyrei222", 518 - ]) 557 + source = store.create_source( 558 + [ 559 + "at://did:plc:abc/blob/bafyrei111", 560 + "at://did:plc:abc/blob/bafyrei222", 561 + ] 562 + ) 519 563 520 564 assert isinstance(source, BlobSource) 521 565 assert len(source.blob_refs) == 2
+22 -4
tests/test_integration_atmosphere_live.py
··· 52 52 """Skip test if credentials not available.""" 53 53 handle, password = get_test_credentials() 54 54 if not handle or not password: 55 - pytest.skip("Live test credentials not configured (set ATDATA_TEST_HANDLE and ATDATA_TEST_APP_PASSWORD)") 55 + pytest.skip( 56 + "Live test credentials not configured (set ATDATA_TEST_HANDLE and ATDATA_TEST_APP_PASSWORD)" 57 + ) 56 58 57 59 58 60 ## ··· 62 64 @atdata.packable 63 65 class LiveTestSample: 64 66 """Simple sample for live tests.""" 67 + 65 68 name: str 66 69 value: int 67 70 ··· 69 72 @atdata.packable 70 73 class LiveTestArraySample: 71 74 """Sample with NDArray for live tests.""" 75 + 72 76 label: str 73 77 data: NDArray 74 78 ··· 215 219 216 220 def test_publish_schema(self, live_client, unique_name): 217 221 """Should publish a schema to ATProto.""" 222 + 218 223 # Create a unique sample type for this test 219 224 @atdata.packable 220 225 class UniqueTestSample: ··· 241 246 242 247 def test_publish_and_retrieve_schema(self, live_client, unique_name): 243 248 """Should publish then retrieve a schema by URI.""" 249 + 244 250 @atdata.packable 245 251 class RetrievableTestSample: 246 252 field1: str ··· 265 271 266 272 def test_schema_with_ndarray_field(self, live_client, unique_name): 267 273 """Should publish schema with NDArray field type.""" 274 + 268 275 @atdata.packable 269 276 class ArrayTestSample: 270 277 label: str ··· 296 303 297 304 def test_publish_dataset_with_urls(self, live_client, unique_name): 298 305 """Should publish a dataset record with external URLs.""" 306 + 299 307 # First publish a schema 300 308 @atdata.packable 301 309 class DatasetTestSample: ··· 327 335 328 336 def test_publish_and_retrieve_dataset(self, live_client, unique_name): 329 337 """Should publish then retrieve a dataset.""" 338 + 330 339 @atdata.packable 331 340 class RetrievableDatasetSample: 332 341 value: int ··· 360 369 assert dataset["name"] == unique_name 361 370 assert dataset["description"] == "Retrievable test dataset" 362 371 363 - def test_to_dataset_with_fake_urls_fails_on_iteration(self, live_client, unique_name): 372 + def test_to_dataset_with_fake_urls_fails_on_iteration( 373 + self, live_client, unique_name 374 + ): 364 375 """Attempting to iterate a dataset with fake URLs should fail. 365 376 366 377 This test documents a known limitation: we can publish and retrieve ··· 369 380 1. Real external URLs (e.g., S3 with test data) 370 381 2. ATProto blob storage support (not yet implemented) 371 382 """ 383 + 372 384 @atdata.packable 373 385 class IterationTestSample: 374 386 value: int ··· 560 572 561 573 def test_index_publish_schema(self, live_index, unique_name): 562 574 """Should publish schema via AtmosphereIndex.""" 575 + 563 576 @atdata.packable 564 577 class IndexTestSample: 565 578 data: str ··· 573 586 574 587 def test_index_get_schema(self, live_index, unique_name): 575 588 """Should retrieve schema via AtmosphereIndex.""" 589 + 576 590 @atdata.packable 577 591 class GetSchemaTestSample: 578 592 field: int ··· 599 613 """Should raise on getting non-existent record.""" 600 614 loader = SchemaLoader(live_client) 601 615 602 - fake_uri = f"at://{live_client.did}/{LEXICON_NAMESPACE}.sampleSchema/nonexistent12345" 616 + fake_uri = ( 617 + f"at://{live_client.did}/{LEXICON_NAMESPACE}.sampleSchema/nonexistent12345" 618 + ) 603 619 604 620 with pytest.raises(Exception): 605 621 loader.get(fake_uri) ··· 629 645 schemas_deleted = cleanup_test_schemas(live_client) 630 646 datasets_deleted = cleanup_test_datasets(live_client) 631 647 632 - print(f"\nCleanup: deleted {schemas_deleted} schemas, {datasets_deleted} datasets") 648 + print( 649 + f"\nCleanup: deleted {schemas_deleted} schemas, {datasets_deleted} datasets" 650 + ) 633 651 634 652 # Just verify cleanup ran without error 635 653 assert True
+27 -14
tests/test_integration_cross_backend.py
··· 29 29 @atdata.packable 30 30 class CrossBackendSample: 31 31 """Sample for cross-backend tests.""" 32 + 32 33 name: str 33 34 value: int 34 35 ··· 36 37 @atdata.packable 37 38 class CrossBackendArraySample: 38 39 """Sample with NDArray for cross-backend tests.""" 40 + 39 41 label: str 40 42 data: NDArray 41 43 ··· 116 118 117 119 assert isinstance(entry, IndexEntry) 118 120 assert entry.name == "atmo-dataset" 119 - assert entry.schema_ref == "at://did:plc:test/ac.foundation.dataset.sampleSchema/abc" 121 + assert ( 122 + entry.schema_ref 123 + == "at://did:plc:test/ac.foundation.dataset.sampleSchema/abc" 124 + ) 120 125 assert entry.data_urls == ["s3://bucket/atmo.tar"] 121 126 assert entry.metadata is None 122 127 123 128 def test_entries_work_with_common_function(self): 124 129 """Both entry types should work with functions accepting IndexEntry.""" 130 + 125 131 def process_entry(entry: IndexEntry) -> dict: 126 132 return { 127 133 "name": entry.name, ··· 294 300 assert schema["version"] == "2.0.0" 295 301 assert len(schema["fields"]) == 2 296 302 297 - def test_atmosphere_index_get_schema( 298 - self, atmosphere_index, mock_atproto_client 299 - ): 303 + def test_atmosphere_index_get_schema(self, atmosphere_index, mock_atproto_client): 300 304 """AtmosphereIndex should retrieve schemas.""" 301 305 mock_response = Mock() 302 306 mock_response.value = { 303 307 "$type": f"{LEXICON_NAMESPACE}.sampleSchema", 304 308 "name": "RetrievedSchema", 305 309 "version": "1.0.0", 306 - "fields": [{"name": "field1", "fieldType": {"$type": f"{LEXICON_NAMESPACE}.schemaType#primitive", "primitive": "str"}, "optional": False}], 310 + "fields": [ 311 + { 312 + "name": "field1", 313 + "fieldType": { 314 + "$type": f"{LEXICON_NAMESPACE}.schemaType#primitive", 315 + "primitive": "str", 316 + }, 317 + "optional": False, 318 + } 319 + ], 307 320 } 308 321 mock_atproto_client.com.atproto.repo.get_record.return_value = mock_response 309 322 ··· 332 345 333 346 def test_ndarray_schema_field_structure(self, local_index): 334 347 """NDArray fields should be represented consistently.""" 335 - schema_ref = local_index.publish_schema(CrossBackendArraySample, version="1.0.0") 348 + schema_ref = local_index.publish_schema( 349 + CrossBackendArraySample, version="1.0.0" 350 + ) 336 351 schema = local_index.get_schema(schema_ref) 337 352 338 353 field_names = {f["name"] for f in schema["fields"]} ··· 343 358 data_field = next(f for f in schema["fields"] if f["name"] == "data") 344 359 field_type = data_field["fieldType"] 345 360 # Field type should indicate it's an ndarray 346 - assert "ndarray" in field_type.get("$type", "").lower() or \ 347 - field_type.get("primitive") == "ndarray" 361 + assert ( 362 + "ndarray" in field_type.get("$type", "").lower() 363 + or field_type.get("primitive") == "ndarray" 364 + ) 348 365 349 366 350 367 class TestCrossBackendSchemaResolution: ··· 357 374 assert schema_ref.startswith("atdata://local/sampleSchema/") 358 375 assert "CrossBackendSample" in schema_ref 359 376 360 - def test_atmosphere_schema_ref_format( 361 - self, atmosphere_index, mock_atproto_client 362 - ): 377 + def test_atmosphere_schema_ref_format(self, atmosphere_index, mock_atproto_client): 363 378 """Atmosphere schema refs should use at:// URI scheme.""" 364 379 mock_response = Mock() 365 380 mock_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/abc" ··· 483 498 count = self.count_datasets(local_index) 484 499 assert count >= 3 485 500 486 - def test_count_works_with_atmosphere( 487 - self, atmosphere_index, mock_atproto_client 488 - ): 501 + def test_count_works_with_atmosphere(self, atmosphere_index, mock_atproto_client): 489 502 """Dataset count function should work with AtmosphereIndex.""" 490 503 mock_records = [] 491 504 for i in range(5):
+331 -96
tests/test_integration_dynamic_types.py
··· 16 16 import webdataset as wds 17 17 18 18 import atdata 19 - from atdata._schema_codec import schema_to_type, generate_stub, clear_type_cache, get_cached_types 19 + from atdata._schema_codec import ( 20 + schema_to_type, 21 + generate_stub, 22 + clear_type_cache, 23 + get_cached_types, 24 + ) 20 25 import atdata.local as atlocal 21 26 22 27 ··· 27 32 @dataclass 28 33 class SimpleSample(atdata.PackableSample): 29 34 """Simple sample for testing.""" 35 + 30 36 name: str 31 37 value: int 32 38 score: float ··· 35 41 @dataclass 36 42 class ArraySample(atdata.PackableSample): 37 43 """Sample with NDArray field.""" 44 + 38 45 label: str 39 46 image: NDArray 40 47 ··· 42 49 @dataclass 43 50 class OptionalSample(atdata.PackableSample): 44 51 """Sample with optional fields.""" 52 + 45 53 name: str 46 54 value: int 47 55 extra: str | None = None ··· 51 59 @dataclass 52 60 class ListSample(atdata.PackableSample): 53 61 """Sample with list fields.""" 62 + 54 63 tags: list[str] 55 64 scores: list[float] 56 65 ··· 80 89 "name": "SimpleSample", 81 90 "version": "1.0.0", 82 91 "fields": [ 83 - {"name": "name", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 84 - {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 85 - {"name": "score", "fieldType": {"$type": "local#primitive", "primitive": "float"}, "optional": False}, 86 - ] 92 + { 93 + "name": "name", 94 + "fieldType": {"$type": "local#primitive", "primitive": "str"}, 95 + "optional": False, 96 + }, 97 + { 98 + "name": "value", 99 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 100 + "optional": False, 101 + }, 102 + { 103 + "name": "score", 104 + "fieldType": {"$type": "local#primitive", "primitive": "float"}, 105 + "optional": False, 106 + }, 107 + ], 87 108 } 88 109 89 110 SampleType = schema_to_type(schema) ··· 100 121 "name": "ArraySample", 101 122 "version": "1.0.0", 102 123 "fields": [ 103 - {"name": "label", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 104 - {"name": "image", "fieldType": {"$type": "local#ndarray", "dtype": "float32"}, "optional": False}, 105 - ] 124 + { 125 + "name": "label", 126 + "fieldType": {"$type": "local#primitive", "primitive": "str"}, 127 + "optional": False, 128 + }, 129 + { 130 + "name": "image", 131 + "fieldType": {"$type": "local#ndarray", "dtype": "float32"}, 132 + "optional": False, 133 + }, 134 + ], 106 135 } 107 136 108 137 SampleType = schema_to_type(schema) ··· 119 148 "name": "OptionalSample", 120 149 "version": "1.0.0", 121 150 "fields": [ 122 - {"name": "name", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 123 - {"name": "extra", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": True}, 124 - ] 151 + { 152 + "name": "name", 153 + "fieldType": {"$type": "local#primitive", "primitive": "str"}, 154 + "optional": False, 155 + }, 156 + { 157 + "name": "extra", 158 + "fieldType": {"$type": "local#primitive", "primitive": "str"}, 159 + "optional": True, 160 + }, 161 + ], 125 162 } 126 163 127 164 SampleType = schema_to_type(schema) ··· 141 178 "name": "ListSample", 142 179 "version": "1.0.0", 143 180 "fields": [ 144 - {"name": "tags", "fieldType": {"$type": "local#array", "items": {"$type": "local#primitive", "primitive": "str"}}, "optional": False}, 145 - {"name": "scores", "fieldType": {"$type": "local#array", "items": {"$type": "local#primitive", "primitive": "float"}}, "optional": False}, 146 - ] 181 + { 182 + "name": "tags", 183 + "fieldType": { 184 + "$type": "local#array", 185 + "items": {"$type": "local#primitive", "primitive": "str"}, 186 + }, 187 + "optional": False, 188 + }, 189 + { 190 + "name": "scores", 191 + "fieldType": { 192 + "$type": "local#array", 193 + "items": {"$type": "local#primitive", "primitive": "float"}, 194 + }, 195 + "optional": False, 196 + }, 197 + ], 147 198 } 148 199 149 200 SampleType = schema_to_type(schema) ··· 158 209 "name": "AllPrimitives", 159 210 "version": "1.0.0", 160 211 "fields": [ 161 - {"name": "s", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 162 - {"name": "i", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 163 - {"name": "f", "fieldType": {"$type": "local#primitive", "primitive": "float"}, "optional": False}, 164 - {"name": "b", "fieldType": {"$type": "local#primitive", "primitive": "bool"}, "optional": False}, 165 - {"name": "raw", "fieldType": {"$type": "local#primitive", "primitive": "bytes"}, "optional": False}, 166 - ] 212 + { 213 + "name": "s", 214 + "fieldType": {"$type": "local#primitive", "primitive": "str"}, 215 + "optional": False, 216 + }, 217 + { 218 + "name": "i", 219 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 220 + "optional": False, 221 + }, 222 + { 223 + "name": "f", 224 + "fieldType": {"$type": "local#primitive", "primitive": "float"}, 225 + "optional": False, 226 + }, 227 + { 228 + "name": "b", 229 + "fieldType": {"$type": "local#primitive", "primitive": "bool"}, 230 + "optional": False, 231 + }, 232 + { 233 + "name": "raw", 234 + "fieldType": {"$type": "local#primitive", "primitive": "bytes"}, 235 + "optional": False, 236 + }, 237 + ], 167 238 } 168 239 169 240 SampleType = schema_to_type(schema) ··· 197 268 "name": "SimpleSample", 198 269 "version": "1.0.0", 199 270 "fields": [ 200 - {"name": "name", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 201 - {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 202 - {"name": "score", "fieldType": {"$type": "local#primitive", "primitive": "float"}, "optional": False}, 203 - ] 271 + { 272 + "name": "name", 273 + "fieldType": {"$type": "local#primitive", "primitive": "str"}, 274 + "optional": False, 275 + }, 276 + { 277 + "name": "value", 278 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 279 + "optional": False, 280 + }, 281 + { 282 + "name": "score", 283 + "fieldType": {"$type": "local#primitive", "primitive": "float"}, 284 + "optional": False, 285 + }, 286 + ], 204 287 } 205 288 206 289 DynamicType = schema_to_type(schema) ··· 231 314 "name": "ArraySample", 232 315 "version": "1.0.0", 233 316 "fields": [ 234 - {"name": "label", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 235 - {"name": "image", "fieldType": {"$type": "local#ndarray", "dtype": "float32"}, "optional": False}, 236 - ] 317 + { 318 + "name": "label", 319 + "fieldType": {"$type": "local#primitive", "primitive": "str"}, 320 + "optional": False, 321 + }, 322 + { 323 + "name": "image", 324 + "fieldType": {"$type": "local#ndarray", "dtype": "float32"}, 325 + "optional": False, 326 + }, 327 + ], 237 328 } 238 329 239 330 DynamicType = schema_to_type(schema) ··· 258 349 "name": "SimpleSample", 259 350 "version": "1.0.0", 260 351 "fields": [ 261 - {"name": "name", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 262 - {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 263 - {"name": "score", "fieldType": {"$type": "local#primitive", "primitive": "float"}, "optional": False}, 264 - ] 352 + { 353 + "name": "name", 354 + "fieldType": {"$type": "local#primitive", "primitive": "str"}, 355 + "optional": False, 356 + }, 357 + { 358 + "name": "value", 359 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 360 + "optional": False, 361 + }, 362 + { 363 + "name": "score", 364 + "fieldType": {"$type": "local#primitive", "primitive": "float"}, 365 + "optional": False, 366 + }, 367 + ], 265 368 } 266 369 267 370 DynamicType = schema_to_type(schema) ··· 284 387 "name": "CachedSample", 285 388 "version": "1.0.0", 286 389 "fields": [ 287 - {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 288 - ] 390 + { 391 + "name": "value", 392 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 393 + "optional": False, 394 + }, 395 + ], 289 396 } 290 397 291 398 Type1 = schema_to_type(schema) ··· 299 406 "name": "VersionedSample", 300 407 "version": "1.0.0", 301 408 "fields": [ 302 - {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 303 - ] 409 + { 410 + "name": "value", 411 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 412 + "optional": False, 413 + }, 414 + ], 304 415 } 305 416 schema2 = { 306 417 "name": "VersionedSample", 307 418 "version": "2.0.0", 308 419 "fields": [ 309 - {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 310 - ] 420 + { 421 + "name": "value", 422 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 423 + "optional": False, 424 + }, 425 + ], 311 426 } 312 427 313 428 Type1 = schema_to_type(schema1) ··· 322 437 "name": "FieldSample", 323 438 "version": "1.0.0", 324 439 "fields": [ 325 - {"name": "a", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 326 - ] 440 + { 441 + "name": "a", 442 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 443 + "optional": False, 444 + }, 445 + ], 327 446 } 328 447 schema2 = { 329 448 "name": "FieldSample", 330 449 "version": "1.0.0", 331 450 "fields": [ 332 - {"name": "b", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 333 - ] 451 + { 452 + "name": "b", 453 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 454 + "optional": False, 455 + }, 456 + ], 334 457 } 335 458 336 459 Type1 = schema_to_type(schema1) ··· 344 467 "name": "NoCacheSample", 345 468 "version": "1.0.0", 346 469 "fields": [ 347 - {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 348 - ] 470 + { 471 + "name": "value", 472 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 473 + "optional": False, 474 + }, 475 + ], 349 476 } 350 477 351 478 Type1 = schema_to_type(schema, use_cache=False) ··· 360 487 "name": "ClearableSample", 361 488 "version": "1.0.0", 362 489 "fields": [ 363 - {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 364 - ] 490 + { 491 + "name": "value", 492 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 493 + "optional": False, 494 + }, 495 + ], 365 496 } 366 497 367 498 Type1 = schema_to_type(schema) ··· 377 508 "name": "TrackedSample", 378 509 "version": "1.0.0", 379 510 "fields": [ 380 - {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 381 - ] 511 + { 512 + "name": "value", 513 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 514 + "optional": False, 515 + }, 516 + ], 382 517 } 383 518 384 519 schema_to_type(schema) ··· 450 585 schema = { 451 586 "version": "1.0.0", 452 587 "fields": [ 453 - {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 454 - ] 588 + { 589 + "name": "value", 590 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 591 + "optional": False, 592 + }, 593 + ], 455 594 } 456 595 457 596 with pytest.raises(ValueError, match="must have a 'name'"): ··· 459 598 460 599 def test_schema_without_fields_raises(self): 461 600 """Schema without fields should raise ValueError.""" 462 - schema = { 463 - "name": "EmptySample", 464 - "version": "1.0.0", 465 - "fields": [] 466 - } 601 + schema = {"name": "EmptySample", "version": "1.0.0", "fields": []} 467 602 468 603 with pytest.raises(ValueError, match="must have at least one field"): 469 604 schema_to_type(schema) ··· 474 609 "name": "BadFieldSample", 475 610 "version": "1.0.0", 476 611 "fields": [ 477 - {"fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 478 - ] 612 + { 613 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 614 + "optional": False, 615 + }, 616 + ], 479 617 } 480 618 481 619 # Raises KeyError during cache key generation or ValueError during field processing ··· 488 626 "name": "UnknownPrimitive", 489 627 "version": "1.0.0", 490 628 "fields": [ 491 - {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "complex128"}, "optional": False}, 492 - ] 629 + { 630 + "name": "value", 631 + "fieldType": { 632 + "$type": "local#primitive", 633 + "primitive": "complex128", 634 + }, 635 + "optional": False, 636 + }, 637 + ], 493 638 } 494 639 495 640 with pytest.raises(ValueError, match="Unknown primitive type"): ··· 501 646 "name": "UnknownType", 502 647 "version": "1.0.0", 503 648 "fields": [ 504 - {"name": "value", "fieldType": {"$type": "local#custom"}, "optional": False}, 505 - ] 649 + { 650 + "name": "value", 651 + "fieldType": {"$type": "local#custom"}, 652 + "optional": False, 653 + }, 654 + ], 506 655 } 507 656 508 657 with pytest.raises(ValueError, match="Unknown field type kind"): ··· 518 667 "name": "OptionalArraySample", 519 668 "version": "1.0.0", 520 669 "fields": [ 521 - {"name": "name", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 522 - {"name": "embedding", "fieldType": {"$type": "local#ndarray", "dtype": "float32"}, "optional": True}, 523 - ] 670 + { 671 + "name": "name", 672 + "fieldType": {"$type": "local#primitive", "primitive": "str"}, 673 + "optional": False, 674 + }, 675 + { 676 + "name": "embedding", 677 + "fieldType": {"$type": "local#ndarray", "dtype": "float32"}, 678 + "optional": True, 679 + }, 680 + ], 524 681 } 525 682 526 683 DynamicType = schema_to_type(schema) ··· 530 687 with wds.writer.TarWriter(str(tar_path)) as sink: 531 688 for i in range(6): 532 689 if i % 2 == 0: 533 - sample = OptionalSample(name=f"s_{i}", value=i, embedding=np.zeros(4, dtype=np.float32)) 690 + sample = OptionalSample( 691 + name=f"s_{i}", value=i, embedding=np.zeros(4, dtype=np.float32) 692 + ) 534 693 else: 535 694 sample = OptionalSample(name=f"s_{i}", value=i, embedding=None) 536 695 sink.write(sample.as_wds) ··· 550 709 "name": "NestedListSample", 551 710 "version": "1.0.0", 552 711 "fields": [ 553 - {"name": "matrix", "fieldType": { 554 - "$type": "local#array", 555 - "items": {"$type": "local#array", "items": {"$type": "local#primitive", "primitive": "int"}} 556 - }, "optional": False}, 557 - ] 712 + { 713 + "name": "matrix", 714 + "fieldType": { 715 + "$type": "local#array", 716 + "items": { 717 + "$type": "local#array", 718 + "items": {"$type": "local#primitive", "primitive": "int"}, 719 + }, 720 + }, 721 + "optional": False, 722 + }, 723 + ], 558 724 } 559 725 560 726 DynamicType = schema_to_type(schema) ··· 591 757 "name": "SimpleSample", 592 758 "version": "1.0.0", 593 759 "fields": [ 594 - {"name": "name", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 595 - {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 596 - ] 760 + { 761 + "name": "name", 762 + "fieldType": {"$type": "local#primitive", "primitive": "str"}, 763 + "optional": False, 764 + }, 765 + { 766 + "name": "value", 767 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 768 + "optional": False, 769 + }, 770 + ], 597 771 } 598 772 599 773 stub = generate_stub(schema) ··· 609 783 "name": "ArraySample", 610 784 "version": "1.0.0", 611 785 "fields": [ 612 - {"name": "image", "fieldType": {"$type": "local#ndarray", "dtype": "float32"}, "optional": False}, 613 - ] 786 + { 787 + "name": "image", 788 + "fieldType": {"$type": "local#ndarray", "dtype": "float32"}, 789 + "optional": False, 790 + }, 791 + ], 614 792 } 615 793 616 794 stub = generate_stub(schema) ··· 624 802 "name": "OptionalSample", 625 803 "version": "1.0.0", 626 804 "fields": [ 627 - {"name": "name", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 628 - {"name": "extra", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": True}, 629 - ] 805 + { 806 + "name": "name", 807 + "fieldType": {"$type": "local#primitive", "primitive": "str"}, 808 + "optional": False, 809 + }, 810 + { 811 + "name": "extra", 812 + "fieldType": {"$type": "local#primitive", "primitive": "str"}, 813 + "optional": True, 814 + }, 815 + ], 630 816 } 631 817 632 818 stub = generate_stub(schema) ··· 641 827 "name": "ListSample", 642 828 "version": "1.0.0", 643 829 "fields": [ 644 - {"name": "tags", "fieldType": {"$type": "local#array", "items": {"$type": "local#primitive", "primitive": "str"}}, "optional": False}, 645 - ] 830 + { 831 + "name": "tags", 832 + "fieldType": { 833 + "$type": "local#array", 834 + "items": {"$type": "local#primitive", "primitive": "str"}, 835 + }, 836 + "optional": False, 837 + }, 838 + ], 646 839 } 647 840 648 841 stub = generate_stub(schema) ··· 655 848 "name": "MySample", 656 849 "version": "2.1.0", 657 850 "fields": [ 658 - {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 659 - ] 851 + { 852 + "name": "value", 853 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 854 + "optional": False, 855 + }, 856 + ], 660 857 } 661 858 662 859 stub = generate_stub(schema) ··· 671 868 "name": "ImportSample", 672 869 "version": "1.0.0", 673 870 "fields": [ 674 - {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 675 - ] 871 + { 872 + "name": "value", 873 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 874 + "optional": False, 875 + }, 876 + ], 676 877 } 677 878 678 879 stub = generate_stub(schema) ··· 686 887 "name": "AllPrimitives", 687 888 "version": "1.0.0", 688 889 "fields": [ 689 - {"name": "s", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 690 - {"name": "i", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 691 - {"name": "f", "fieldType": {"$type": "local#primitive", "primitive": "float"}, "optional": False}, 692 - {"name": "b", "fieldType": {"$type": "local#primitive", "primitive": "bool"}, "optional": False}, 693 - {"name": "raw", "fieldType": {"$type": "local#primitive", "primitive": "bytes"}, "optional": False}, 694 - ] 890 + { 891 + "name": "s", 892 + "fieldType": {"$type": "local#primitive", "primitive": "str"}, 893 + "optional": False, 894 + }, 895 + { 896 + "name": "i", 897 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 898 + "optional": False, 899 + }, 900 + { 901 + "name": "f", 902 + "fieldType": {"$type": "local#primitive", "primitive": "float"}, 903 + "optional": False, 904 + }, 905 + { 906 + "name": "b", 907 + "fieldType": {"$type": "local#primitive", "primitive": "bool"}, 908 + "optional": False, 909 + }, 910 + { 911 + "name": "raw", 912 + "fieldType": {"$type": "local#primitive", "primitive": "bytes"}, 913 + "optional": False, 914 + }, 915 + ], 695 916 } 696 917 697 918 stub = generate_stub(schema) ··· 708 929 "name": "NestedSample", 709 930 "version": "1.0.0", 710 931 "fields": [ 711 - {"name": "matrix", "fieldType": { 712 - "$type": "local#array", 713 - "items": {"$type": "local#array", "items": {"$type": "local#primitive", "primitive": "int"}} 714 - }, "optional": False}, 715 - ] 932 + { 933 + "name": "matrix", 934 + "fieldType": { 935 + "$type": "local#array", 936 + "items": { 937 + "$type": "local#array", 938 + "items": {"$type": "local#primitive", "primitive": "int"}, 939 + }, 940 + }, 941 + "optional": False, 942 + }, 943 + ], 716 944 } 717 945 718 946 stub = generate_stub(schema) ··· 725 953 "name": "RefSample", 726 954 "version": "1.0.0", 727 955 "fields": [ 728 - {"name": "nested", "fieldType": {"$type": "local#ref", "ref": "local://schemas/Other@1.0.0"}, "optional": False}, 729 - ] 956 + { 957 + "name": "nested", 958 + "fieldType": { 959 + "$type": "local#ref", 960 + "ref": "local://schemas/Other@1.0.0", 961 + }, 962 + "optional": False, 963 + }, 964 + ], 730 965 } 731 966 732 967 stub = generate_stub(schema)
+20 -7
tests/test_integration_e2e.py
··· 27 27 @atdata.packable 28 28 class SimpleSample: 29 29 """Basic sample with primitive types only.""" 30 + 30 31 name: str 31 32 value: int 32 33 score: float ··· 36 37 @atdata.packable 37 38 class NDArraySample: 38 39 """Sample with multiple NDArray fields of different shapes.""" 40 + 39 41 label: int 40 42 image: NDArray 41 43 features: NDArray ··· 44 46 @atdata.packable 45 47 class OptionalNDArraySample: 46 48 """Sample with optional NDArray fields.""" 49 + 47 50 label: int 48 51 image: NDArray 49 52 embeddings: NDArray | None = None ··· 52 55 @atdata.packable 53 56 class BytesSample: 54 57 """Sample with bytes field.""" 58 + 55 59 name: str 56 60 raw_data: bytes 57 61 ··· 59 63 @atdata.packable 60 64 class ListSample: 61 65 """Sample with list fields.""" 66 + 62 67 tags: list[str] 63 68 scores: list[float] 64 69 ids: list[int] ··· 67 72 @dataclass 68 73 class InheritanceSample(atdata.PackableSample): 69 74 """Sample using inheritance syntax instead of decorator.""" 75 + 70 76 title: str 71 77 count: int 72 78 measurements: NDArray ··· 101 107 ] 102 108 103 109 104 - def create_optional_samples(n: int, include_optional: bool) -> list[OptionalNDArraySample]: 110 + def create_optional_samples( 111 + n: int, include_optional: bool 112 + ) -> list[OptionalNDArraySample]: 105 113 """Create samples with or without optional embeddings.""" 106 114 return [ 107 115 OptionalNDArraySample( 108 116 label=i, 109 117 image=np.random.randn(32, 32).astype(np.float32), 110 - embeddings=np.random.randn(64).astype(np.float32) if include_optional else None, 118 + embeddings=np.random.randn(64).astype(np.float32) 119 + if include_optional 120 + else None, 111 121 ) 112 122 for i in range(n) 113 123 ] ··· 134 144 sink.write(sample.as_wds) 135 145 136 146 n_shards = (len(samples) + samples_per_shard - 1) // samples_per_shard 137 - brace_pattern = ( 138 - base_path / f"shard-{{000000..{n_shards - 1:06d}}}.tar" 139 - ).as_posix() 147 + brace_pattern = (base_path / f"shard-{{000000..{n_shards - 1:06d}}}.tar").as_posix() 140 148 return brace_pattern, n_shards 141 149 142 150 ··· 241 249 for original, loaded_sample in zip(samples, loaded): 242 250 assert loaded_sample.label == original.label 243 251 np.testing.assert_array_almost_equal(loaded_sample.image, original.image) 244 - np.testing.assert_array_almost_equal(loaded_sample.features, original.features) 252 + np.testing.assert_array_almost_equal( 253 + loaded_sample.features, original.features 254 + ) 245 255 246 256 def test_ndarray_batch_stacking(self, tmp_path): 247 257 """NDArray fields should stack into batch dimension.""" ··· 309 319 310 320 def test_mixed_dtypes(self, tmp_path): 311 321 """Various numpy dtypes should serialize correctly.""" 322 + 312 323 @atdata.packable 313 324 class MultiDtypeSample: 314 325 f32: NDArray ··· 657 668 658 669 # At least two passes should differ (very high probability with 100 samples) 659 670 # Note: This could theoretically fail, but probability is astronomically low 660 - assert passes[0] != passes[1] or passes[1] != passes[2] or passes[0] != passes[2] 671 + assert ( 672 + passes[0] != passes[1] or passes[1] != passes[2] or passes[0] != passes[2] 673 + ) 661 674 662 675 def test_batch_size_one(self, tmp_path): 663 676 """batch_size=1 should return single-element batches."""
+7 -4
tests/test_integration_edge_cases.py
··· 9 9 - All primitive type variations 10 10 """ 11 11 12 - from pathlib import Path 13 - 14 12 import numpy as np 15 13 from numpy.typing import NDArray 16 14 ··· 28 26 @atdata.packable 29 27 class EmptyCompatSample: 30 28 """Sample type for empty dataset tests.""" 29 + 31 30 id: int 32 31 33 32 34 33 @atdata.packable 35 34 class AllPrimitivesSample: 36 35 """Sample with all primitive types.""" 36 + 37 37 str_field: str 38 38 int_field: int 39 39 float_field: float ··· 44 44 @atdata.packable 45 45 class OptionalFieldsSample: 46 46 """Sample with optional fields.""" 47 + 47 48 required_str: str 48 49 optional_str: str | None 49 50 optional_int: int | None ··· 54 55 @atdata.packable 55 56 class ListFieldsSample: 56 57 """Sample with list fields.""" 58 + 57 59 str_list: list[str] 58 60 int_list: list[int] 59 61 float_list: list[float] ··· 63 65 @atdata.packable 64 66 class UnicodeSample: 65 67 """Sample with unicode content.""" 68 + 66 69 text: str 67 70 label: str 68 71 ··· 70 73 @atdata.packable 71 74 class NDArraySample: 72 75 """Sample with NDArray field.""" 76 + 73 77 label: str 74 78 data: NDArray 75 79 ··· 368 372 tar_path = tmp_path / "emoji-000000.tar" 369 373 370 374 sample = UnicodeSample( 371 - text="Hello World! Have a great day!", 372 - label="with-emoji" 375 + text="Hello World! Have a great day!", label="with-emoji" 373 376 ) 374 377 create_tar_with_samples(tar_path, [sample]) 375 378
+15 -5
tests/test_integration_error_handling.py
··· 28 28 @atdata.packable 29 29 class ErrorTestSample: 30 30 """Sample for error handling tests.""" 31 + 31 32 name: str 32 33 value: int 33 34 ··· 110 111 """Tar with invalid msgpack should raise on iteration.""" 111 112 tar_path = tmp_path / "corrupted-000000.tar" 112 113 113 - import io 114 - 115 114 # Create tar with invalid msgpack data 116 115 with tarfile.open(tar_path, "w") as tar: 117 116 # Add a valid key file ··· 146 145 info = tarfile.TarInfo(name="test.txt") 147 146 info.size = len(data) 148 147 import io 148 + 149 149 tar.addfile(info, fileobj=io.BytesIO(data)) 150 150 151 151 # Truncate the file ··· 183 183 from redis import Redis, ConnectionError 184 184 185 185 # Create index with invalid Redis connection 186 - bad_redis = Redis(host="nonexistent.invalid.host", port=9999, socket_timeout=0.1) 186 + bad_redis = Redis( 187 + host="nonexistent.invalid.host", port=9999, socket_timeout=0.1 188 + ) 187 189 188 190 index = LocalIndex(redis=bad_redis) 189 191 ··· 227 229 assert not client.is_authenticated 228 230 229 231 from atdata.atmosphere import SchemaPublisher 232 + 230 233 publisher = SchemaPublisher(client) 231 234 232 235 with pytest.raises(ValueError, match="authenticated"): ··· 264 267 client._session = {"did": "did:plc:test123"} # Mark as authenticated 265 268 266 269 from atdata.atmosphere import SchemaPublisher 270 + 267 271 publisher = SchemaPublisher(client) 268 272 269 273 # Should propagate the API error ··· 339 343 client = AtmosphereClient(_client=mock_client) 340 344 341 345 from atdata.atmosphere import SchemaPublisher 346 + 342 347 publisher = SchemaPublisher(client) 343 348 344 349 try: ··· 374 379 # Now use a good file - should still work 375 380 good_tar = tmp_path / "good-000000.tar" 376 381 import webdataset as wds 382 + 377 383 with wds.writer.TarWriter(str(good_tar)) as writer: 378 384 sample = ErrorTestSample(name="good", value=42) 379 385 writer.write(sample.as_wds) ··· 421 427 """Special characters in version should be handled.""" 422 428 index = LocalIndex(redis=clean_redis) 423 429 424 - schema_ref = index.publish_schema(ErrorTestSample, version="1.0.0-beta+build.123") 430 + schema_ref = index.publish_schema( 431 + ErrorTestSample, version="1.0.0-beta+build.123" 432 + ) 425 433 schema = index.get_schema(schema_ref) 426 434 427 435 assert schema["version"] == "1.0.0-beta+build.123" ··· 615 623 # Mock the client after source creation 616 624 with patch.object(source, "_get_client") as mock_get_client: 617 625 mock_client = Mock() 618 - mock_client.get_object.side_effect = ConnectTimeoutError(endpoint_url="s3://test") 626 + mock_client.get_object.side_effect = ConnectTimeoutError( 627 + endpoint_url="s3://test" 628 + ) 619 629 mock_get_client.return_value = mock_client 620 630 621 631 # Use full S3 URI as returned by shard_list
+11 -1
tests/test_integration_lens.py
··· 26 26 @atdata.packable 27 27 class FullRecord: 28 28 """Complete record with all fields.""" 29 + 29 30 id: int 30 31 name: str 31 32 email: str ··· 37 38 @atdata.packable 38 39 class ProfileView: 39 40 """View with profile information only.""" 41 + 40 42 name: str 41 43 email: str 42 44 age: int ··· 45 47 @atdata.packable 46 48 class NameView: 47 49 """Minimal view with just name.""" 50 + 48 51 name: str 49 52 50 53 51 54 @atdata.packable 52 55 class ScoredRecord: 53 56 """Record with score and embedding.""" 57 + 54 58 id: int 55 59 score: float 56 60 embedding: NDArray ··· 59 63 @atdata.packable 60 64 class OptionalFieldSample: 61 65 """Sample with optional fields.""" 66 + 62 67 name: str 63 68 value: int 64 69 extra: str | None = None ··· 68 73 @atdata.packable 69 74 class OptionalView: 70 75 """View of optional sample.""" 76 + 71 77 name: str 72 78 extra: str | None = None 73 79 ··· 148 154 149 155 150 156 @optional_to_view.putter 151 - def optional_to_view_put(view: OptionalView, source: OptionalFieldSample) -> OptionalFieldSample: 157 + def optional_to_view_put( 158 + view: OptionalView, source: OptionalFieldSample 159 + ) -> OptionalFieldSample: 152 160 """Update optional sample from view.""" 153 161 return OptionalFieldSample( 154 162 name=view.name, ··· 486 494 487 495 def test_unregistered_lens_raises(self): 488 496 """Querying unregistered lens should raise ValueError.""" 497 + 489 498 @atdata.packable 490 499 class UnknownSource: 491 500 x: int ··· 536 545 537 546 def test_ndarray_transformation_lens(self): 538 547 """Lens that transforms NDArray values.""" 548 + 539 549 @atdata.packable 540 550 class RawData: 541 551 values: NDArray
+59 -49
tests/test_integration_local.py
··· 28 28 @dataclass 29 29 class WorkflowSample(atdata.PackableSample): 30 30 """Sample for workflow tests.""" 31 + 31 32 name: str 32 33 value: int 33 34 score: float ··· 36 37 @dataclass 37 38 class ArrayWorkflowSample(atdata.PackableSample): 38 39 """Sample with array for workflow tests.""" 40 + 39 41 label: str 40 42 data: NDArray 41 43 ··· 43 45 @dataclass 44 46 class MetadataSample(atdata.PackableSample): 45 47 """Sample for metadata workflow tests.""" 48 + 46 49 id: int 47 50 content: str 48 51 ··· 61 64 """ 62 65 with mock_aws(): 63 66 import boto3 64 - creds = { 65 - 'AWS_ACCESS_KEY_ID': 'testing', 66 - 'AWS_SECRET_ACCESS_KEY': 'testing' 67 - } 67 + 68 + creds = {"AWS_ACCESS_KEY_ID": "testing", "AWS_SECRET_ACCESS_KEY": "testing"} 68 69 s3_client = boto3.client( 69 - 's3', 70 - aws_access_key_id=creds['AWS_ACCESS_KEY_ID'], 71 - aws_secret_access_key=creds['AWS_SECRET_ACCESS_KEY'], 72 - region_name='us-east-1' 70 + "s3", 71 + aws_access_key_id=creds["AWS_ACCESS_KEY_ID"], 72 + aws_secret_access_key=creds["AWS_SECRET_ACCESS_KEY"], 73 + region_name="us-east-1", 73 74 ) 74 - bucket_name = 'integration-test-bucket' 75 + bucket_name = "integration-test-bucket" 75 76 s3_client.create_bucket(Bucket=bucket_name) 76 77 yield { 77 - 'credentials': creds, 78 - 'bucket': bucket_name, 79 - 'hive_path': f'{bucket_name}/datasets', 80 - 's3_client': s3_client 78 + "credentials": creds, 79 + "bucket": bucket_name, 80 + "hive_path": f"{bucket_name}/datasets", 81 + "s3_client": s3_client, 81 82 } 82 83 83 84 ··· 124 125 """Full workflow: init repo → publish schema → insert → query entry.""" 125 126 # Initialize repo 126 127 repo = atlocal.Repo( 127 - s3_credentials=mock_s3['credentials'], 128 - hive_path=mock_s3['hive_path'], 129 - redis=clean_redis 128 + s3_credentials=mock_s3["credentials"], 129 + hive_path=mock_s3["hive_path"], 130 + redis=clean_redis, 130 131 ) 131 132 132 133 # Publish schema first ··· 154 155 def test_multiple_datasets_same_schema(self, mock_s3, clean_redis, tmp_path): 155 156 """Insert multiple datasets with same schema type.""" 156 157 repo = atlocal.Repo( 157 - s3_credentials=mock_s3['credentials'], 158 - hive_path=mock_s3['hive_path'], 159 - redis=clean_redis 158 + s3_credentials=mock_s3["credentials"], 159 + hive_path=mock_s3["hive_path"], 160 + redis=clean_redis, 160 161 ) 161 162 162 163 # Create multiple datasets ··· 187 188 def test_different_schema_types(self, mock_s3, clean_redis, tmp_path): 188 189 """Insert datasets with different schema types.""" 189 190 repo = atlocal.Repo( 190 - s3_credentials=mock_s3['credentials'], 191 - hive_path=mock_s3['hive_path'], 192 - redis=clean_redis 191 + s3_credentials=mock_s3["credentials"], 192 + hive_path=mock_s3["hive_path"], 193 + redis=clean_redis, 193 194 ) 194 195 195 196 # Different sample types ··· 429 430 # Should be a generator 430 431 entries = index.entries 431 432 import types 433 + 432 434 assert isinstance(entries, types.GeneratorType) 433 435 434 436 # Can iterate partially ··· 449 451 def test_metadata_preserved_through_insert(self, mock_s3, clean_redis, tmp_path): 450 452 """Metadata should be preserved when inserting dataset.""" 451 453 repo = atlocal.Repo( 452 - s3_credentials=mock_s3['credentials'], 453 - hive_path=mock_s3['hive_path'], 454 - redis=clean_redis 454 + s3_credentials=mock_s3["credentials"], 455 + hive_path=mock_s3["hive_path"], 456 + redis=clean_redis, 455 457 ) 456 458 457 459 ds = create_workflow_dataset(tmp_path, n_samples=5) ··· 518 520 @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") 519 521 @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") 520 522 @pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") 521 - def test_cache_local_true_produces_valid_entry(self, mock_s3, clean_redis, tmp_path): 523 + def test_cache_local_true_produces_valid_entry( 524 + self, mock_s3, clean_redis, tmp_path 525 + ): 522 526 """cache_local=True should produce valid index entry.""" 523 527 repo = atlocal.Repo( 524 - s3_credentials=mock_s3['credentials'], 525 - hive_path=mock_s3['hive_path'], 526 - redis=clean_redis 528 + s3_credentials=mock_s3["credentials"], 529 + hive_path=mock_s3["hive_path"], 530 + redis=clean_redis, 527 531 ) 528 532 529 533 ds = create_workflow_dataset(tmp_path, n_samples=10) ··· 536 540 @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") 537 541 @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") 538 542 @pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") 539 - def test_cache_local_false_produces_valid_entry(self, mock_s3, clean_redis, tmp_path): 543 + def test_cache_local_false_produces_valid_entry( 544 + self, mock_s3, clean_redis, tmp_path 545 + ): 540 546 """cache_local=False should produce valid index entry.""" 541 547 repo = atlocal.Repo( 542 - s3_credentials=mock_s3['credentials'], 543 - hive_path=mock_s3['hive_path'], 544 - redis=clean_redis 548 + s3_credentials=mock_s3["credentials"], 549 + hive_path=mock_s3["hive_path"], 550 + redis=clean_redis, 545 551 ) 546 552 547 553 ds = create_workflow_dataset(tmp_path, n_samples=10) ··· 557 563 def test_both_modes_produce_same_structure(self, mock_s3, clean_redis, tmp_path): 558 564 """Both cache modes should produce entries with same structure.""" 559 565 repo = atlocal.Repo( 560 - s3_credentials=mock_s3['credentials'], 561 - hive_path=mock_s3['hive_path'], 562 - redis=clean_redis 566 + s3_credentials=mock_s3["credentials"], 567 + hive_path=mock_s3["hive_path"], 568 + redis=clean_redis, 563 569 ) 564 570 565 571 ds1 = create_workflow_dataset(tmp_path / "cached", n_samples=10) 566 572 ds2 = create_workflow_dataset(tmp_path / "direct", n_samples=10) 567 573 568 574 entry1, _ = repo.insert(ds1, name="cached-mode", cache_local=True, maxcount=100) 569 - entry2, _ = repo.insert(ds2, name="direct-mode", cache_local=False, maxcount=100) 575 + entry2, _ = repo.insert( 576 + ds2, name="direct-mode", cache_local=False, maxcount=100 577 + ) 570 578 571 579 # Both should have valid structure 572 580 assert entry1.schema_ref == entry2.schema_ref # Same type ··· 598 606 ) 599 607 600 608 # Required properties 601 - assert hasattr(entry, 'name') 602 - assert hasattr(entry, 'schema_ref') 603 - assert hasattr(entry, 'data_urls') 604 - assert hasattr(entry, 'metadata') 605 - assert hasattr(entry, 'cid') 609 + assert hasattr(entry, "name") 610 + assert hasattr(entry, "schema_ref") 611 + assert hasattr(entry, "data_urls") 612 + assert hasattr(entry, "metadata") 613 + assert hasattr(entry, "cid") 606 614 607 615 # Values accessible 608 616 assert entry.name == "props-test" ··· 629 637 @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") 630 638 @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") 631 639 @pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") 632 - def test_large_dataset_creates_multiple_shards(self, mock_s3, clean_redis, tmp_path): 640 + def test_large_dataset_creates_multiple_shards( 641 + self, mock_s3, clean_redis, tmp_path 642 + ): 633 643 """Large dataset should create multiple shard files.""" 634 644 repo = atlocal.Repo( 635 - s3_credentials=mock_s3['credentials'], 636 - hive_path=mock_s3['hive_path'], 637 - redis=clean_redis 645 + s3_credentials=mock_s3["credentials"], 646 + hive_path=mock_s3["hive_path"], 647 + redis=clean_redis, 638 648 ) 639 649 640 650 # Create dataset with many samples ··· 662 672 def test_single_shard_no_brace_notation(self, mock_s3, clean_redis, tmp_path): 663 673 """Small dataset should result in single shard without brace notation.""" 664 674 repo = atlocal.Repo( 665 - s3_credentials=mock_s3['credentials'], 666 - hive_path=mock_s3['hive_path'], 667 - redis=clean_redis 675 + s3_credentials=mock_s3["credentials"], 676 + hive_path=mock_s3["hive_path"], 677 + redis=clean_redis, 668 678 ) 669 679 670 680 ds = create_workflow_dataset(tmp_path, n_samples=5)
+59 -19
tests/test_integration_promotion.py
··· 28 28 @atdata.packable 29 29 class PromotionSample: 30 30 """Sample for promotion tests.""" 31 + 31 32 name: str 32 33 value: int 33 34 ··· 35 36 @atdata.packable 36 37 class PromotionArraySample: 37 38 """Sample with NDArray for promotion tests.""" 39 + 38 40 label: str 39 41 features: NDArray 40 42 ··· 110 112 111 113 # Setup mock responses for atmosphere operations 112 114 schema_response = Mock() 113 - schema_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/promoted-schema" 115 + schema_response.uri = ( 116 + f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/promoted-schema" 117 + ) 114 118 115 119 dataset_response = Mock() 116 - dataset_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.dataset/promoted-dataset" 120 + dataset_response.uri = ( 121 + f"at://did:plc:test/{LEXICON_NAMESPACE}.dataset/promoted-dataset" 122 + ) 117 123 118 124 mock_atproto_client.com.atproto.repo.create_record.side_effect = [ 119 125 schema_response, ··· 124 130 mock_list_response = Mock() 125 131 mock_list_response.records = [] 126 132 mock_list_response.cursor = None 127 - mock_atproto_client.com.atproto.repo.list_records.return_value = mock_list_response 133 + mock_atproto_client.com.atproto.repo.list_records.return_value = ( 134 + mock_list_response 135 + ) 128 136 129 137 # Promote 130 138 result_uri = promote_to_atmosphere( ··· 146 154 mock_list_response = Mock() 147 155 mock_list_response.records = [] 148 156 mock_list_response.cursor = None 149 - mock_atproto_client.com.atproto.repo.list_records.return_value = mock_list_response 157 + mock_atproto_client.com.atproto.repo.list_records.return_value = ( 158 + mock_list_response 159 + ) 150 160 151 161 schema_response = Mock() 152 162 schema_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/s1" ··· 177 187 mock_list_response = Mock() 178 188 mock_list_response.records = [] 179 189 mock_list_response.cursor = None 180 - mock_atproto_client.com.atproto.repo.list_records.return_value = mock_list_response 190 + mock_atproto_client.com.atproto.repo.list_records.return_value = ( 191 + mock_list_response 192 + ) 181 193 182 194 schema_response = Mock() 183 195 schema_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/s1" ··· 215 227 216 228 # Patch _find_existing_schema to return an existing schema URI 217 229 with patch("atdata.promote._find_existing_schema") as mock_find: 218 - mock_find.return_value = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/existing" 230 + mock_find.return_value = ( 231 + f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/existing" 232 + ) 219 233 220 234 # Only dataset should be created (schema exists) 221 235 dataset_response = Mock() 222 236 dataset_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.dataset/d1" 223 - mock_atproto_client.com.atproto.repo.create_record.return_value = dataset_response 237 + mock_atproto_client.com.atproto.repo.create_record.return_value = ( 238 + dataset_response 239 + ) 224 240 225 241 promote_to_atmosphere(local_entry, local_index, authenticated_client) 226 242 ··· 228 244 assert mock_atproto_client.com.atproto.repo.create_record.call_count == 1 229 245 230 246 # Verify it was the dataset call 231 - call_kwargs = mock_atproto_client.com.atproto.repo.create_record.call_args.kwargs 247 + call_kwargs = ( 248 + mock_atproto_client.com.atproto.repo.create_record.call_args.kwargs 249 + ) 232 250 assert "dataset" in call_kwargs["data"]["collection"] 233 251 234 252 def test_creates_schema_when_not_found( ··· 241 259 mock_list_response = Mock() 242 260 mock_list_response.records = [] 243 261 mock_list_response.cursor = None 244 - mock_atproto_client.com.atproto.repo.list_records.return_value = mock_list_response 262 + mock_atproto_client.com.atproto.repo.list_records.return_value = ( 263 + mock_list_response 264 + ) 245 265 246 266 # Both schema and dataset should be created 247 267 schema_response = Mock() ··· 277 297 mock_list_response = Mock() 278 298 mock_list_response.records = [existing_schema] 279 299 mock_list_response.cursor = None 280 - mock_atproto_client.com.atproto.repo.list_records.return_value = mock_list_response 300 + mock_atproto_client.com.atproto.repo.list_records.return_value = ( 301 + mock_list_response 302 + ) 281 303 282 304 # Both should be created (version mismatch) 283 305 schema_response = Mock() 284 - schema_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/v1new" 306 + schema_response.uri = ( 307 + f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/v1new" 308 + ) 285 309 286 310 dataset_response = Mock() 287 311 dataset_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.dataset/d1" ··· 314 338 mock_list_response = Mock() 315 339 mock_list_response.records = [] 316 340 mock_list_response.cursor = None 317 - mock_atproto_client.com.atproto.repo.list_records.return_value = mock_list_response 341 + mock_atproto_client.com.atproto.repo.list_records.return_value = ( 342 + mock_list_response 343 + ) 318 344 319 345 schema_response = Mock() 320 346 schema_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/s1" ··· 337 363 # The metadata should be in the record (may be msgpack encoded) 338 364 assert "metadata" in record 339 365 340 - def test_none_metadata_handled(self, clean_redis, authenticated_client, mock_atproto_client): 366 + def test_none_metadata_handled( 367 + self, clean_redis, authenticated_client, mock_atproto_client 368 + ): 341 369 """Entry without metadata should promote successfully.""" 342 370 index = LocalIndex(redis=clean_redis) 343 371 schema_ref = index.publish_schema(PromotionSample, version="1.0.0") ··· 354 382 mock_list_response = Mock() 355 383 mock_list_response.records = [] 356 384 mock_list_response.cursor = None 357 - mock_atproto_client.com.atproto.repo.list_records.return_value = mock_list_response 385 + mock_atproto_client.com.atproto.repo.list_records.return_value = ( 386 + mock_list_response 387 + ) 358 388 359 389 schema_response = Mock() 360 390 schema_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/s1" ··· 423 453 dataset_responses[2], # Third dataset 424 454 ] 425 455 426 - with patch("atdata.promote._find_existing_schema", side_effect=mock_find_existing): 456 + with patch( 457 + "atdata.promote._find_existing_schema", side_effect=mock_find_existing 458 + ): 427 459 # Promote all three 428 460 for i, entry in enumerate(entries): 429 461 promote_to_atmosphere(entry, index, authenticated_client) ··· 462 494 mock_list_response = Mock() 463 495 mock_list_response.records = [] 464 496 mock_list_response.cursor = None 465 - mock_atproto_client.com.atproto.repo.list_records.return_value = mock_list_response 497 + mock_atproto_client.com.atproto.repo.list_records.return_value = ( 498 + mock_list_response 499 + ) 466 500 467 501 schema_response = Mock() 468 502 schema_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/s1" ··· 543 577 mock_list_response = Mock() 544 578 mock_list_response.records = [] 545 579 mock_list_response.cursor = None 546 - mock_atproto_client.com.atproto.repo.list_records.return_value = mock_list_response 580 + mock_atproto_client.com.atproto.repo.list_records.return_value = ( 581 + mock_list_response 582 + ) 547 583 548 584 schema_response = Mock() 549 585 schema_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/s1" ··· 578 614 mock_list_response = Mock() 579 615 mock_list_response.records = [] 580 616 mock_list_response.cursor = None 581 - mock_atproto_client.com.atproto.repo.list_records.return_value = mock_list_response 617 + mock_atproto_client.com.atproto.repo.list_records.return_value = ( 618 + mock_list_response 619 + ) 582 620 583 621 schema_response = Mock() 584 622 schema_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/s1" ··· 616 654 mock_list_response = Mock() 617 655 mock_list_response.records = [] 618 656 mock_list_response.cursor = None 619 - mock_atproto_client.com.atproto.repo.list_records.return_value = mock_list_response 657 + mock_atproto_client.com.atproto.repo.list_records.return_value = ( 658 + mock_list_response 659 + ) 620 660 621 661 schema_response = Mock() 622 662 schema_response.uri = f"at://did:plc:test/{LEXICON_NAMESPACE}.sampleSchema/s1"
+73 -76
tests/test_lens.py
··· 16 16 ## 17 17 # Tests 18 18 19 + 19 20 def test_lens(): 20 21 """Test a lens between sample types""" 21 22 ··· 31 32 class View: 32 33 name: str 33 34 height: float 34 - 35 + 35 36 @atdata.lens 36 - def polite( s: Source ) -> View: 37 + def polite(s: Source) -> View: 37 38 return View( 38 - name = s.name, 39 - height = s.height, 39 + name=s.name, 40 + height=s.height, 40 41 ) 41 - 42 + 42 43 @polite.putter 43 - def polite_update( v: View, s: Source ) -> Source: 44 + def polite_update(v: View, s: Source) -> Source: 44 45 return Source( 45 - name = v.name, 46 - height = v.height, 46 + name=v.name, 47 + height=v.height, 47 48 # 48 - age = s.age, 49 + age=s.age, 49 50 ) 50 - 51 + 51 52 # Test with an example sample 52 53 53 54 test_source = Source( 54 - name = 'Hello World', 55 - age = 42, 56 - height = 182.9, 55 + name="Hello World", 56 + age=42, 57 + height=182.9, 57 58 ) 58 59 correct_view = View( 59 - name = test_source.name, 60 - height = test_source.height, 60 + name=test_source.name, 61 + height=test_source.height, 61 62 ) 62 63 63 - test_view = polite( test_source ) 64 - assert test_view == correct_view, \ 65 - f'Incorrect lens behavior: {test_view}, and not {correct_view}' 64 + test_view = polite(test_source) 65 + assert test_view == correct_view, ( 66 + f"Incorrect lens behavior: {test_view}, and not {correct_view}" 67 + ) 66 68 67 69 # This lens should be well-behaved 68 70 69 71 update_view = View( 70 - name = 'Now Taller', 71 - height = 192.9, 72 + name="Now Taller", 73 + height=192.9, 72 74 ) 73 75 74 - x = polite( polite.put( update_view, test_source ) ) 75 - assert x == update_view, \ 76 - f'Violation of GetPut: {x} =/= {update_view}' 77 - 78 - y = polite.put( polite( test_source ), test_source ) 79 - assert y == test_source, \ 80 - f'Violation of PutGet: {y} =/= {test_source}' 76 + x = polite(polite.put(update_view, test_source)) 77 + assert x == update_view, f"Violation of GetPut: {x} =/= {update_view}" 78 + 79 + y = polite.put(polite(test_source), test_source) 80 + assert y == test_source, f"Violation of PutGet: {y} =/= {test_source}" 81 81 82 82 # PutPut law: put(v2, put(v1, s)) = put(v2, s) 83 83 another_view = View( 84 - name = 'Different Name', 85 - height = 165.0, 84 + name="Different Name", 85 + height=165.0, 86 86 ) 87 - z1 = polite.put( another_view, polite.put( update_view, test_source ) ) 88 - z2 = polite.put( another_view, test_source ) 89 - assert z1 == z2, \ 90 - f'Violation of PutPut: {z1} =/= {z2}' 87 + z1 = polite.put(another_view, polite.put(update_view, test_source)) 88 + z2 = polite.put(another_view, test_source) 89 + assert z1 == z2, f"Violation of PutPut: {z1} =/= {z2}" 91 90 92 - def test_conversion( tmp_path ): 91 + 92 + def test_conversion(tmp_path): 93 93 """Test automatic interconversion between sample types""" 94 94 95 95 @dataclass 96 - class Source( atdata.PackableSample ): 96 + class Source(atdata.PackableSample): 97 97 name: str 98 98 height: float 99 99 favorite_pizza: str 100 100 favorite_image: NDArray 101 - 101 + 102 102 @dataclass 103 - class View( atdata.PackableSample ): 103 + class View(atdata.PackableSample): 104 104 name: str 105 105 favorite_pizza: str 106 106 favorite_image: NDArray 107 - 107 + 108 108 @atdata.lens 109 - def polite( s: Source ) -> View: 109 + def polite(s: Source) -> View: 110 110 return View( 111 - name = s.name, 112 - favorite_pizza = s.favorite_pizza, 113 - favorite_image = s.favorite_image, 111 + name=s.name, 112 + favorite_pizza=s.favorite_pizza, 113 + favorite_image=s.favorite_image, 114 114 ) 115 - 115 + 116 116 # Map a test sample through the view 117 117 test_source = Source( 118 - name = 'Larry', 119 - height = 42., 120 - favorite_pizza = 'pineapple', 121 - favorite_image = np.random.randn( 224, 224 ) 118 + name="Larry", 119 + height=42.0, 120 + favorite_pizza="pineapple", 121 + favorite_image=np.random.randn(224, 224), 122 122 ) 123 - test_view = polite( test_source ) 123 + test_view = polite(test_source) 124 124 125 125 # Create a test dataset 126 126 127 127 k_test = 100 128 - test_filename = ( 129 - tmp_path 130 - / 'test-source.tar' 131 - ).as_posix() 128 + test_filename = (tmp_path / "test-source.tar").as_posix() 132 129 133 - with wds.writer.TarWriter( test_filename ) as dest: 134 - for i in range( k_test ): 130 + with wds.writer.TarWriter(test_filename) as dest: 131 + for i in range(k_test): 135 132 # Create a new copied sample 136 133 cur_sample = Source( 137 - name = test_source.name, 138 - height = test_source.height, 139 - favorite_pizza = test_source.favorite_pizza, 140 - favorite_image = test_source.favorite_image, 134 + name=test_source.name, 135 + height=test_source.height, 136 + favorite_pizza=test_source.favorite_pizza, 137 + favorite_image=test_source.favorite_image, 141 138 ) 142 - dest.write( cur_sample.as_wds ) 143 - 139 + dest.write(cur_sample.as_wds) 140 + 144 141 # Try reading the test dataset 145 142 146 - ds = ( 147 - atdata.Dataset[Source]( test_filename ) 148 - .as_type( View ) 149 - ) 143 + ds = atdata.Dataset[Source](test_filename).as_type(View) 150 144 151 - assert ds.sample_type == View, \ 152 - 'Auto-mapped' 145 + assert ds.sample_type == View, "Auto-mapped" 153 146 154 147 sample: View | None = None 155 - for sample in ds.ordered( batch_size = None ): 148 + for sample in ds.ordered(batch_size=None): 156 149 # Load only the first sample 157 150 break 158 151 159 - assert sample is not None, \ 160 - 'Did not load any samples from `Source` dataset' 152 + assert sample is not None, "Did not load any samples from `Source` dataset" 161 153 162 - assert sample.name == test_view.name, \ 163 - f'Divergence on auto-mapped dataset: `name` should be {test_view.name}, but is {sample.name}' 164 - assert sample.favorite_pizza == test_view.favorite_pizza, \ 165 - f'Divergence on auto-mapped dataset: `favorite_pizza` should be {test_view.favorite_pizza}, but is {sample.favorite_pizza}' 166 - assert np.all( sample.favorite_image == test_view.favorite_image ), \ 167 - 'Divergence on auto-mapped dataset: `favorite_image`' 154 + assert sample.name == test_view.name, ( 155 + f"Divergence on auto-mapped dataset: `name` should be {test_view.name}, but is {sample.name}" 156 + ) 157 + assert sample.favorite_pizza == test_view.favorite_pizza, ( 158 + f"Divergence on auto-mapped dataset: `favorite_pizza` should be {test_view.favorite_pizza}, but is {sample.favorite_pizza}" 159 + ) 160 + assert np.all(sample.favorite_image == test_view.favorite_image), ( 161 + "Divergence on auto-mapped dataset: `favorite_image`" 162 + ) 168 163 169 164 170 165 ## ··· 173 168 174 169 def test_lens_get_method(): 175 170 """Test calling lens.get() explicitly instead of lens().""" 171 + 176 172 @atdata.packable 177 173 class GetSource: 178 174 value: int ··· 197 193 198 194 def test_lens_trivial_putter(): 199 195 """Test lens without explicit putter uses trivial putter.""" 196 + 200 197 @atdata.packable 201 198 class TrivialSource: 202 199 a: int ··· 237 234 network.transform(UnregisteredSource, UnregisteredView) 238 235 239 236 240 - ## 237 + ##
+279 -181
tests/test_local.py
··· 26 26 ## 27 27 # Test fixtures (redis_connection and clean_redis are in conftest.py) 28 28 29 + 29 30 @pytest.fixture 30 31 def mock_s3(): 31 32 """Provide a mock S3 environment using moto. ··· 36 37 """ 37 38 with mock_aws(): 38 39 # Create S3 credentials dict (no endpoint_url for moto) 39 - creds = { 40 - 'AWS_ACCESS_KEY_ID': 'testing', 41 - 'AWS_SECRET_ACCESS_KEY': 'testing' 42 - } 40 + creds = {"AWS_ACCESS_KEY_ID": "testing", "AWS_SECRET_ACCESS_KEY": "testing"} 43 41 44 42 # Create S3 client and bucket 45 43 import boto3 44 + 46 45 s3_client = boto3.client( 47 - 's3', 48 - aws_access_key_id=creds['AWS_ACCESS_KEY_ID'], 49 - aws_secret_access_key=creds['AWS_SECRET_ACCESS_KEY'], 50 - region_name='us-east-1' 46 + "s3", 47 + aws_access_key_id=creds["AWS_ACCESS_KEY_ID"], 48 + aws_secret_access_key=creds["AWS_SECRET_ACCESS_KEY"], 49 + region_name="us-east-1", 51 50 ) 52 51 53 - bucket_name = 'test-bucket' 52 + bucket_name = "test-bucket" 54 53 s3_client.create_bucket(Bucket=bucket_name) 55 54 56 55 yield { 57 - 'credentials': creds, 58 - 'bucket': bucket_name, 59 - 'hive_path': f'{bucket_name}/datasets', 60 - 's3_client': s3_client 56 + "credentials": creds, 57 + "bucket": bucket_name, 58 + "hive_path": f"{bucket_name}/datasets", 59 + "s3_client": s3_client, 61 60 } 62 61 63 62 ··· 83 82 Note: This matches SharedBasicSample in conftest.py but is kept local 84 83 because tests verify class name behavior. 85 84 """ 85 + 86 86 name: str 87 87 value: int 88 88 ··· 93 93 94 94 Note: Similar to SharedNumpySample but kept local for test isolation. 95 95 """ 96 + 96 97 label: str 97 98 data: NDArray 98 99 99 100 100 - def make_simple_dataset(tmp_path: Path, num_samples: int = 10, name: str = "test") -> atdata.Dataset: 101 + def make_simple_dataset( 102 + tmp_path: Path, num_samples: int = 10, name: str = "test" 103 + ) -> atdata.Dataset: 101 104 """Create a SimpleTestSample dataset for testing.""" 102 105 dataset_path = tmp_path / f"{name}-dataset-000000.tar" 103 106 with wds.writer.TarWriter(str(dataset_path)) as sink: ··· 107 110 return atdata.Dataset[SimpleTestSample](url=str(dataset_path)) 108 111 109 112 110 - def make_array_dataset(tmp_path: Path, num_samples: int = 3, array_shape: tuple = (10, 10)) -> atdata.Dataset: 113 + def make_array_dataset( 114 + tmp_path: Path, num_samples: int = 3, array_shape: tuple = (10, 10) 115 + ) -> atdata.Dataset: 111 116 """Create an ArrayTestSample dataset for testing.""" 112 117 dataset_path = tmp_path / "array-dataset-000000.tar" 113 118 with wds.writer.TarWriter(str(dataset_path)) as sink: ··· 120 125 121 126 ## 122 127 # Helper function tests 128 + 123 129 124 130 def test_kind_str_for_sample_type(): 125 131 """Test that sample types are converted to correct fully-qualified string identifiers. ··· 149 155 result = atlocal._s3_env(env_file) 150 156 151 157 assert result == { 152 - 'AWS_ENDPOINT': 'http://localhost:9000', 153 - 'AWS_ACCESS_KEY_ID': 'minioadmin', 154 - 'AWS_SECRET_ACCESS_KEY': 'minioadmin' 158 + "AWS_ENDPOINT": "http://localhost:9000", 159 + "AWS_ACCESS_KEY_ID": "minioadmin", 160 + "AWS_SECRET_ACCESS_KEY": "minioadmin", 155 161 } 156 162 157 163 158 - @pytest.mark.parametrize("missing_field,env_content", [ 159 - ("AWS_ENDPOINT", "AWS_ACCESS_KEY_ID=minioadmin\nAWS_SECRET_ACCESS_KEY=minioadmin\n"), 160 - ("AWS_ACCESS_KEY_ID", "AWS_ENDPOINT=http://localhost:9000\nAWS_SECRET_ACCESS_KEY=minioadmin\n"), 161 - ("AWS_SECRET_ACCESS_KEY", "AWS_ENDPOINT=http://localhost:9000\nAWS_ACCESS_KEY_ID=minioadmin\n"), 162 - ]) 164 + @pytest.mark.parametrize( 165 + "missing_field,env_content", 166 + [ 167 + ( 168 + "AWS_ENDPOINT", 169 + "AWS_ACCESS_KEY_ID=minioadmin\nAWS_SECRET_ACCESS_KEY=minioadmin\n", 170 + ), 171 + ( 172 + "AWS_ACCESS_KEY_ID", 173 + "AWS_ENDPOINT=http://localhost:9000\nAWS_SECRET_ACCESS_KEY=minioadmin\n", 174 + ), 175 + ( 176 + "AWS_SECRET_ACCESS_KEY", 177 + "AWS_ENDPOINT=http://localhost:9000\nAWS_ACCESS_KEY_ID=minioadmin\n", 178 + ), 179 + ], 180 + ) 163 181 def test_s3_env_missing_required_field(tmp_path, missing_field, env_content): 164 182 """Test that loading S3 credentials fails when a required field is missing. 165 183 ··· 179 197 Should create a properly configured S3FileSystem instance using dict credentials. 180 198 """ 181 199 creds = { 182 - 'AWS_ENDPOINT': 'http://localhost:9000', 183 - 'AWS_ACCESS_KEY_ID': 'minioadmin', 184 - 'AWS_SECRET_ACCESS_KEY': 'minioadmin' 200 + "AWS_ENDPOINT": "http://localhost:9000", 201 + "AWS_ACCESS_KEY_ID": "minioadmin", 202 + "AWS_SECRET_ACCESS_KEY": "minioadmin", 185 203 } 186 204 187 205 fs = atlocal._s3_from_credentials(creds) 188 206 189 207 assert isinstance(fs, atlocal.S3FileSystem) 190 - assert fs.endpoint_url == 'http://localhost:9000' 191 - assert fs.key == 'minioadmin' 192 - assert fs.secret == 'minioadmin' 208 + assert fs.endpoint_url == "http://localhost:9000" 209 + assert fs.key == "minioadmin" 210 + assert fs.secret == "minioadmin" 193 211 194 212 195 213 def test_s3_from_credentials_with_path(tmp_path): ··· 207 225 fs = atlocal._s3_from_credentials(env_file) 208 226 209 227 assert isinstance(fs, atlocal.S3FileSystem) 210 - assert fs.endpoint_url == 'http://localhost:9000' 211 - assert fs.key == 'minioadmin' 212 - assert fs.secret == 'minioadmin' 228 + assert fs.endpoint_url == "http://localhost:9000" 229 + assert fs.key == "minioadmin" 230 + assert fs.secret == "minioadmin" 213 231 214 232 215 233 ## 216 234 # LocalDatasetEntry tests 235 + 217 236 218 237 def test_local_dataset_entry_creation(): 219 238 """Test creating a LocalDatasetEntry with explicit values. ··· 315 334 original_entry.write_to(clean_redis) 316 335 317 336 # Read back from Redis 318 - retrieved_entry = atlocal.LocalDatasetEntry.from_redis(clean_redis, original_entry.cid) 337 + retrieved_entry = atlocal.LocalDatasetEntry.from_redis( 338 + clean_redis, original_entry.cid 339 + ) 319 340 320 341 assert retrieved_entry.name == original_entry.name 321 342 assert retrieved_entry.schema_ref == original_entry.schema_ref ··· 356 377 index = atlocal.Index() 357 378 358 379 # Check protocol methods exist 359 - assert hasattr(index, 'insert_dataset') 360 - assert hasattr(index, 'get_dataset') 361 - assert hasattr(index, 'list_datasets') 362 - assert hasattr(index, 'publish_schema') 363 - assert hasattr(index, 'get_schema') 364 - assert hasattr(index, 'list_schemas') 365 - assert hasattr(index, 'decode_schema') 380 + assert hasattr(index, "insert_dataset") 381 + assert hasattr(index, "get_dataset") 382 + assert hasattr(index, "list_datasets") 383 + assert hasattr(index, "publish_schema") 384 + assert hasattr(index, "get_schema") 385 + assert hasattr(index, "list_schemas") 386 + assert hasattr(index, "decode_schema") 366 387 367 388 # Check they are callable 368 389 assert callable(index.insert_dataset) ··· 373 394 ## 374 395 # Index tests 375 396 397 + 376 398 def test_index_init_default_redis(): 377 399 """Test creating an Index with default Redis connection. 378 400 ··· 401 423 402 424 Should pass custom kwargs to Redis constructor when creating a new connection. 403 425 """ 404 - index = atlocal.Index(host='localhost', port=6379, db=0) 426 + index = atlocal.Index(host="localhost", port=6379, db=0) 405 427 406 428 assert index._redis is not None 407 429 assert isinstance(index._redis, Redis) ··· 415 437 index = atlocal.Index(redis=clean_redis) 416 438 417 439 ds = atdata.Dataset[SimpleTestSample]( 418 - url="s3://bucket/dataset.tar", 419 - metadata_url="s3://bucket/metadata.msgpack" 440 + url="s3://bucket/dataset.tar", metadata_url="s3://bucket/metadata.msgpack" 420 441 ) 421 442 422 443 entry = index.add_entry(ds, name="test-dataset") ··· 442 463 ds = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset.tar") 443 464 444 465 entry = index.add_entry( 445 - ds, 446 - name="test-dataset", 447 - schema_ref="local://schemas/custom.Schema@2.0.0" 466 + ds, name="test-dataset", schema_ref="local://schemas/custom.Schema@2.0.0" 448 467 ) 449 468 450 469 assert entry.schema_ref == "local://schemas/custom.Schema@2.0.0" ··· 460 479 ds = atdata.Dataset[SimpleTestSample](url="s3://bucket/dataset.tar") 461 480 462 481 entry = index.add_entry( 463 - ds, 464 - name="test-dataset", 465 - metadata={"version": "1.0", "author": "test"} 482 + ds, name="test-dataset", metadata={"version": "1.0", "author": "test"} 466 483 ) 467 484 468 485 assert entry.metadata == {"version": "1.0", "author": "test"} ··· 593 610 ## 594 611 # AbstractIndex protocol method tests 595 612 613 + 596 614 def test_index_insert_dataset(clean_redis): 597 615 """Test insert_dataset protocol method.""" 598 616 index = atlocal.Index(redis=clean_redis) ··· 667 685 Should create a Repo with S3FileSystem and set hive_path and hive_bucket. 668 686 """ 669 687 creds = { 670 - 'AWS_ENDPOINT': 'http://localhost:9000', 671 - 'AWS_ACCESS_KEY_ID': 'minioadmin', 672 - 'AWS_SECRET_ACCESS_KEY': 'minioadmin' 688 + "AWS_ENDPOINT": "http://localhost:9000", 689 + "AWS_ACCESS_KEY_ID": "minioadmin", 690 + "AWS_SECRET_ACCESS_KEY": "minioadmin", 673 691 } 674 692 675 693 repo = atlocal.Repo(s3_credentials=creds, hive_path="test-bucket/datasets") ··· 710 728 Should raise ValueError when s3_credentials is provided but hive_path is None. 711 729 """ 712 730 creds = { 713 - 'AWS_ENDPOINT': 'http://localhost:9000', 714 - 'AWS_ACCESS_KEY_ID': 'minioadmin', 715 - 'AWS_SECRET_ACCESS_KEY': 'minioadmin' 731 + "AWS_ENDPOINT": "http://localhost:9000", 732 + "AWS_ACCESS_KEY_ID": "minioadmin", 733 + "AWS_SECRET_ACCESS_KEY": "minioadmin", 716 734 } 717 735 718 736 with pytest.raises(ValueError, match="Must specify hive path"): ··· 726 744 Should set hive_bucket to the first component of hive_path. 727 745 """ 728 746 creds = { 729 - 'AWS_ENDPOINT': 'http://localhost:9000', 730 - 'AWS_ACCESS_KEY_ID': 'minioadmin', 731 - 'AWS_SECRET_ACCESS_KEY': 'minioadmin' 747 + "AWS_ENDPOINT": "http://localhost:9000", 748 + "AWS_ACCESS_KEY_ID": "minioadmin", 749 + "AWS_SECRET_ACCESS_KEY": "minioadmin", 732 750 } 733 751 734 752 repo = atlocal.Repo(s3_credentials=creds, hive_path="my-bucket/path/to/datasets") ··· 752 770 ## 753 771 # Repo tests - Insert functionality 754 772 773 + 755 774 @pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") 756 775 def test_repo_insert_without_s3(): 757 776 """Test that inserting a dataset without S3 configured raises ValueError. ··· 775 794 a new Dataset pointing to the stored copy with correct URL format. 776 795 """ 777 796 repo = atlocal.Repo( 778 - s3_credentials=mock_s3['credentials'], 779 - hive_path=mock_s3['hive_path'], 780 - redis=clean_redis 797 + s3_credentials=mock_s3["credentials"], 798 + hive_path=mock_s3["hive_path"], 799 + redis=clean_redis, 781 800 ) 782 801 783 - entry, new_ds = repo.insert(sample_dataset, name="single-shard-dataset", maxcount=100) 802 + entry, new_ds = repo.insert( 803 + sample_dataset, name="single-shard-dataset", maxcount=100 804 + ) 784 805 785 806 assert entry.cid is not None 786 807 assert entry.cid.startswith("bafy") ··· 788 809 assert len(entry.data_urls) > 0 789 810 assert "SimpleTestSample" in entry.schema_ref 790 811 assert len(repo.index.all_entries) == 1 791 - assert '.tar' in new_ds.url 792 - assert new_ds.url.startswith(mock_s3['hive_path']) 812 + assert ".tar" in new_ds.url 813 + assert new_ds.url.startswith(mock_s3["hive_path"]) 793 814 794 815 795 816 @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") ··· 803 824 """ 804 825 ds = make_simple_dataset(tmp_path, num_samples=50, name="large") 805 826 repo = atlocal.Repo( 806 - s3_credentials=mock_s3['credentials'], 807 - hive_path=mock_s3['hive_path'], 808 - redis=clean_redis 827 + s3_credentials=mock_s3["credentials"], 828 + hive_path=mock_s3["hive_path"], 829 + redis=clean_redis, 809 830 ) 810 831 811 832 entry, new_ds = repo.insert(ds, name="multi-shard-dataset", maxcount=10) 812 833 813 834 assert entry.cid is not None 814 835 assert len(entry.data_urls) > 0 815 - assert '{' in new_ds.url and '}' in new_ds.url 836 + assert "{" in new_ds.url and "}" in new_ds.url 816 837 817 838 818 839 @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") ··· 827 848 ds._metadata = {"description": "test dataset", "version": "1.0"} 828 849 829 850 repo = atlocal.Repo( 830 - s3_credentials=mock_s3['credentials'], 831 - hive_path=mock_s3['hive_path'], 832 - redis=clean_redis 851 + s3_credentials=mock_s3["credentials"], 852 + hive_path=mock_s3["hive_path"], 853 + redis=clean_redis, 833 854 ) 834 855 835 856 entry, new_ds = repo.insert(ds, name="metadata-dataset", maxcount=100) ··· 849 870 """ 850 871 ds = make_simple_dataset(tmp_path, num_samples=5) 851 872 repo = atlocal.Repo( 852 - s3_credentials=mock_s3['credentials'], 853 - hive_path=mock_s3['hive_path'], 854 - redis=clean_redis 873 + s3_credentials=mock_s3["credentials"], 874 + hive_path=mock_s3["hive_path"], 875 + redis=clean_redis, 855 876 ) 856 877 857 878 entry, new_ds = repo.insert(ds, name="no-metadata-dataset", maxcount=100) ··· 869 890 Should write tar shards directly to S3 without local caching. 870 891 """ 871 892 repo = atlocal.Repo( 872 - s3_credentials=mock_s3['credentials'], 873 - hive_path=mock_s3['hive_path'], 874 - redis=clean_redis 893 + s3_credentials=mock_s3["credentials"], 894 + hive_path=mock_s3["hive_path"], 895 + redis=clean_redis, 875 896 ) 876 897 877 - entry, new_ds = repo.insert(sample_dataset, name="direct-write", cache_local=False, maxcount=100) 898 + entry, new_ds = repo.insert( 899 + sample_dataset, name="direct-write", cache_local=False, maxcount=100 900 + ) 878 901 879 902 assert entry.cid is not None 880 903 assert len(entry.data_urls) > 0 ··· 890 913 local cache files after copying. 891 914 """ 892 915 repo = atlocal.Repo( 893 - s3_credentials=mock_s3['credentials'], 894 - hive_path=mock_s3['hive_path'], 895 - redis=clean_redis 916 + s3_credentials=mock_s3["credentials"], 917 + hive_path=mock_s3["hive_path"], 918 + redis=clean_redis, 896 919 ) 897 920 898 - entry, new_ds = repo.insert(sample_dataset, name="cached-write", cache_local=True, maxcount=100) 921 + entry, new_ds = repo.insert( 922 + sample_dataset, name="cached-write", cache_local=True, maxcount=100 923 + ) 899 924 900 925 assert entry.cid is not None 901 926 assert len(entry.data_urls) > 0 ··· 911 936 and CID. 912 937 """ 913 938 repo = atlocal.Repo( 914 - s3_credentials=mock_s3['credentials'], 915 - hive_path=mock_s3['hive_path'], 916 - redis=clean_redis 939 + s3_credentials=mock_s3["credentials"], 940 + hive_path=mock_s3["hive_path"], 941 + redis=clean_redis, 917 942 ) 918 943 919 944 entry, new_ds = repo.insert(sample_dataset, name="indexed-dataset", maxcount=100) ··· 936 961 Should create different CIDs for datasets with different URLs. 937 962 """ 938 963 repo = atlocal.Repo( 939 - s3_credentials=mock_s3['credentials'], 940 - hive_path=mock_s3['hive_path'], 941 - redis=clean_redis 964 + s3_credentials=mock_s3["credentials"], 965 + hive_path=mock_s3["hive_path"], 966 + redis=clean_redis, 942 967 ) 943 968 944 969 entry1, new_ds1 = repo.insert(sample_dataset, name="dataset1", maxcount=100) ··· 965 990 966 991 ds = atdata.Dataset[SimpleTestSample](url=str(dataset_path)) 967 992 repo = atlocal.Repo( 968 - s3_credentials=mock_s3['credentials'], 969 - hive_path=mock_s3['hive_path'], 970 - redis=clean_redis 993 + s3_credentials=mock_s3["credentials"], 994 + hive_path=mock_s3["hive_path"], 995 + redis=clean_redis, 971 996 ) 972 997 973 998 # Empty datasets succeed because WebDataset creates a shard file regardless 974 999 entry, new_ds = repo.insert(ds, name="empty-dataset", maxcount=100) 975 1000 assert entry.cid is not None 976 - assert '.tar' in new_ds.url 1001 + assert ".tar" in new_ds.url 977 1002 978 1003 979 1004 @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") ··· 985 1010 Should return a Dataset[T] with the same sample type as the input dataset. 986 1011 """ 987 1012 repo = atlocal.Repo( 988 - s3_credentials=mock_s3['credentials'], 989 - hive_path=mock_s3['hive_path'], 990 - redis=clean_redis 1013 + s3_credentials=mock_s3["credentials"], 1014 + hive_path=mock_s3["hive_path"], 1015 + redis=clean_redis, 991 1016 ) 992 1017 993 1018 entry, new_ds = repo.insert(sample_dataset, name="typed-dataset", maxcount=100) ··· 1006 1031 """ 1007 1032 ds = make_simple_dataset(tmp_path, num_samples=30, name="large") 1008 1033 repo = atlocal.Repo( 1009 - s3_credentials=mock_s3['credentials'], 1010 - hive_path=mock_s3['hive_path'], 1011 - redis=clean_redis 1034 + s3_credentials=mock_s3["credentials"], 1035 + hive_path=mock_s3["hive_path"], 1036 + redis=clean_redis, 1012 1037 ) 1013 1038 1014 1039 entry, new_ds = repo.insert(ds, name="sharded-dataset", maxcount=5) 1015 1040 1016 - assert '{' in new_ds.url and '}' in new_ds.url 1041 + assert "{" in new_ds.url and "}" in new_ds.url 1017 1042 1018 1043 1019 1044 @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") ··· 1026 1051 """ 1027 1052 ds = make_array_dataset(tmp_path, num_samples=3, array_shape=(10, 10)) 1028 1053 repo = atlocal.Repo( 1029 - s3_credentials=mock_s3['credentials'], 1030 - hive_path=mock_s3['hive_path'], 1031 - redis=clean_redis 1054 + s3_credentials=mock_s3["credentials"], 1055 + hive_path=mock_s3["hive_path"], 1056 + redis=clean_redis, 1032 1057 ) 1033 1058 1034 1059 entry, new_ds = repo.insert(ds, name="array-dataset", maxcount=100) ··· 1039 1064 1040 1065 ## 1041 1066 # Integration tests 1067 + 1042 1068 1043 1069 @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") 1044 1070 @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") ··· 1050 1076 from the Index. 1051 1077 """ 1052 1078 repo = atlocal.Repo( 1053 - s3_credentials=mock_s3['credentials'], 1054 - hive_path=mock_s3['hive_path'], 1055 - redis=clean_redis 1079 + s3_credentials=mock_s3["credentials"], 1080 + hive_path=mock_s3["hive_path"], 1081 + redis=clean_redis, 1056 1082 ) 1057 1083 1058 1084 entry, new_ds = repo.insert(sample_dataset, name="integrated-dataset", maxcount=100) ··· 1073 1099 retrievable from the index. 1074 1100 """ 1075 1101 repo = atlocal.Repo( 1076 - s3_credentials=mock_s3['credentials'], 1077 - hive_path=mock_s3['hive_path'], 1078 - redis=clean_redis 1102 + s3_credentials=mock_s3["credentials"], 1103 + hive_path=mock_s3["hive_path"], 1104 + redis=clean_redis, 1079 1105 ) 1080 1106 1081 1107 entry1, _ = repo.insert(sample_dataset, name="dataset-a", maxcount=100) ··· 1105 1131 array_ds = make_array_dataset(tmp_path, num_samples=3, array_shape=(5, 5)) 1106 1132 1107 1133 repo = atlocal.Repo( 1108 - s3_credentials=mock_s3['credentials'], 1109 - hive_path=mock_s3['hive_path'], 1110 - redis=clean_redis 1134 + s3_credentials=mock_s3["credentials"], 1135 + hive_path=mock_s3["hive_path"], 1136 + redis=clean_redis, 1111 1137 ) 1112 1138 1113 1139 entry1, _ = repo.insert(simple_ds, name="simple-dataset", maxcount=100) ··· 1168 1194 ## 1169 1195 # S3DataStore tests 1170 1196 1197 + 1171 1198 def test_s3_datastore_init(): 1172 1199 """Test creating an S3DataStore.""" 1173 1200 creds = { 1174 - 'AWS_ENDPOINT': 'http://localhost:9000', 1175 - 'AWS_ACCESS_KEY_ID': 'minioadmin', 1176 - 'AWS_SECRET_ACCESS_KEY': 'minioadmin' 1201 + "AWS_ENDPOINT": "http://localhost:9000", 1202 + "AWS_ACCESS_KEY_ID": "minioadmin", 1203 + "AWS_SECRET_ACCESS_KEY": "minioadmin", 1177 1204 } 1178 1205 1179 1206 store = atlocal.S3DataStore(credentials=creds, bucket="test-bucket") ··· 1185 1212 1186 1213 def test_s3_datastore_supports_streaming(): 1187 1214 """Test that S3DataStore reports streaming support.""" 1188 - creds = { 1189 - 'AWS_ACCESS_KEY_ID': 'test', 1190 - 'AWS_SECRET_ACCESS_KEY': 'test' 1191 - } 1215 + creds = {"AWS_ACCESS_KEY_ID": "test", "AWS_SECRET_ACCESS_KEY": "test"} 1192 1216 1193 1217 store = atlocal.S3DataStore(credentials=creds, bucket="test") 1194 1218 ··· 1197 1221 1198 1222 def test_s3_datastore_read_url(): 1199 1223 """Test that read_url returns URL unchanged without custom endpoint.""" 1200 - creds = { 1201 - 'AWS_ACCESS_KEY_ID': 'test', 1202 - 'AWS_SECRET_ACCESS_KEY': 'test' 1203 - } 1224 + creds = {"AWS_ACCESS_KEY_ID": "test", "AWS_SECRET_ACCESS_KEY": "test"} 1204 1225 1205 1226 store = atlocal.S3DataStore(credentials=creds, bucket="test") 1206 1227 ··· 1211 1232 def test_s3_datastore_read_url_with_custom_endpoint(): 1212 1233 """Test that read_url transforms s3:// to https:// with custom endpoint.""" 1213 1234 creds = { 1214 - 'AWS_ACCESS_KEY_ID': 'test', 1215 - 'AWS_SECRET_ACCESS_KEY': 'test', 1216 - 'AWS_ENDPOINT': 'https://abc123.r2.cloudflarestorage.com' 1235 + "AWS_ACCESS_KEY_ID": "test", 1236 + "AWS_SECRET_ACCESS_KEY": "test", 1237 + "AWS_ENDPOINT": "https://abc123.r2.cloudflarestorage.com", 1217 1238 } 1218 1239 1219 1240 store = atlocal.S3DataStore(credentials=creds, bucket="test") ··· 1224 1245 assert store.read_url(url) == expected 1225 1246 1226 1247 # Trailing slash on endpoint should be handled 1227 - creds['AWS_ENDPOINT'] = 'https://endpoint.example.com/' 1248 + creds["AWS_ENDPOINT"] = "https://endpoint.example.com/" 1228 1249 store2 = atlocal.S3DataStore(credentials=creds, bucket="test") 1229 - assert store2.read_url(url) == "https://endpoint.example.com/my-bucket/path/to/data.tar" 1250 + assert ( 1251 + store2.read_url(url) 1252 + == "https://endpoint.example.com/my-bucket/path/to/data.tar" 1253 + ) 1230 1254 1231 1255 # Non-s3 URLs should be passed through unchanged 1232 1256 https_url = "https://example.com/data.tar" ··· 1240 1264 ds = make_simple_dataset(tmp_path, num_samples=5) 1241 1265 1242 1266 store = atlocal.S3DataStore( 1243 - credentials=mock_s3['credentials'], 1244 - bucket=mock_s3['bucket'] 1267 + credentials=mock_s3["credentials"], bucket=mock_s3["bucket"] 1245 1268 ) 1246 1269 1247 1270 urls = store.write_shards(ds, prefix="test/data", maxcount=100) 1248 1271 1249 1272 assert len(urls) >= 1 1250 1273 assert all(url.startswith("s3://") for url in urls) 1251 - assert all(mock_s3['bucket'] in url for url in urls) 1274 + assert all(mock_s3["bucket"] in url for url in urls) 1252 1275 1253 1276 1254 1277 @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") ··· 1258 1281 ds = make_simple_dataset(tmp_path, num_samples=5) 1259 1282 1260 1283 store = atlocal.S3DataStore( 1261 - credentials=mock_s3['credentials'], 1262 - bucket=mock_s3['bucket'] 1284 + credentials=mock_s3["credentials"], bucket=mock_s3["bucket"] 1263 1285 ) 1264 1286 1265 1287 urls = store.write_shards(ds, prefix="cached/data", cache_local=True, maxcount=100) ··· 1279 1301 ds = make_simple_dataset(tmp_path, num_samples=5) 1280 1302 1281 1303 store = atlocal.S3DataStore( 1282 - credentials=mock_s3['credentials'], 1283 - bucket=mock_s3['bucket'] 1304 + credentials=mock_s3["credentials"], bucket=mock_s3["bucket"] 1284 1305 ) 1285 1306 index = atlocal.Index(redis=clean_redis, data_store=store) 1286 1307 ··· 1303 1324 ds = make_simple_dataset(tmp_path, num_samples=3) 1304 1325 1305 1326 store = atlocal.S3DataStore( 1306 - credentials=mock_s3['credentials'], 1307 - bucket=mock_s3['bucket'] 1327 + credentials=mock_s3["credentials"], bucket=mock_s3["bucket"] 1308 1328 ) 1309 1329 index = atlocal.Index(redis=clean_redis, data_store=store) 1310 1330 1311 1331 entry = index.insert_dataset( 1312 - ds, 1313 - name="my-dataset", 1314 - prefix="custom/path/v1", 1315 - maxcount=100 1332 + ds, name="my-dataset", prefix="custom/path/v1", maxcount=100 1316 1333 ) 1317 1334 1318 1335 assert "custom/path/v1" in entry.data_urls[0] ··· 1334 1351 def test_index_data_store_property(mock_s3, clean_redis): 1335 1352 """Test that Index.data_store property returns the data store.""" 1336 1353 store = atlocal.S3DataStore( 1337 - credentials=mock_s3['credentials'], 1338 - bucket=mock_s3['bucket'] 1354 + credentials=mock_s3["credentials"], bucket=mock_s3["bucket"] 1339 1355 ) 1340 1356 index = atlocal.Index(redis=clean_redis, data_store=store) 1341 1357 ··· 1352 1368 ## 1353 1369 # Schema storage tests 1354 1370 1371 + 1355 1372 def test_publish_schema(clean_redis): 1356 1373 """Test publishing a schema to Redis.""" 1357 1374 index = atlocal.Index(redis=clean_redis) ··· 1368 1385 index = atlocal.Index(redis=clean_redis) 1369 1386 1370 1387 schema_ref = index.publish_schema( 1371 - SimpleTestSample, 1372 - version="2.0.0", 1373 - description="A simple test sample type" 1388 + SimpleTestSample, version="2.0.0", description="A simple test sample type" 1374 1389 ) 1375 1390 1376 1391 schema = index.get_schema(schema_ref) 1377 - assert schema.get('description') == "A simple test sample type" 1392 + assert schema.get("description") == "A simple test sample type" 1378 1393 1379 1394 1380 1395 def test_publish_schema_auto_increment(clean_redis): ··· 1411 1426 schema = index.get_schema(schema_ref) 1412 1427 1413 1428 # Should use the class docstring 1414 - assert schema.get('description') == SimpleTestSample.__doc__ 1429 + assert schema.get("description") == SimpleTestSample.__doc__ 1415 1430 1416 1431 1417 1432 def test_get_schema(clean_redis): ··· 1421 1436 schema_ref = index.publish_schema(SimpleTestSample, version="1.0.0") 1422 1437 schema = index.get_schema(schema_ref) 1423 1438 1424 - assert schema['name'] == 'SimpleTestSample' 1425 - assert schema['version'] == '1.0.0' 1426 - assert len(schema['fields']) == 2 # name and value fields 1427 - assert schema['$ref'] == schema_ref 1439 + assert schema["name"] == "SimpleTestSample" 1440 + assert schema["version"] == "1.0.0" 1441 + assert len(schema["fields"]) == 2 # name and value fields 1442 + assert schema["$ref"] == schema_ref 1428 1443 1429 1444 1430 1445 def test_get_schema_not_found(clean_redis): ··· 1461 1476 schemas = list(index.list_schemas()) 1462 1477 assert len(schemas) == 2 1463 1478 1464 - names = {s['name'] for s in schemas} 1465 - assert 'SimpleTestSample' in names 1466 - assert 'ArrayTestSample' in names 1479 + names = {s["name"] for s in schemas} 1480 + assert "SimpleTestSample" in names 1481 + assert "ArrayTestSample" in names 1467 1482 1468 1483 1469 1484 def test_schema_field_types(clean_redis): ··· 1474 1489 schema = index.get_schema(schema_ref) 1475 1490 1476 1491 # Find name field (should be str) 1477 - name_field = next(f for f in schema['fields'] if f['name'] == 'name') 1478 - assert name_field['fieldType']['primitive'] == 'str' 1479 - assert name_field['optional'] is False 1492 + name_field = next(f for f in schema["fields"] if f["name"] == "name") 1493 + assert name_field["fieldType"]["primitive"] == "str" 1494 + assert name_field["optional"] is False 1480 1495 1481 1496 # Find value field (should be int) 1482 - value_field = next(f for f in schema['fields'] if f['name'] == 'value') 1483 - assert value_field['fieldType']['primitive'] == 'int' 1497 + value_field = next(f for f in schema["fields"] if f["name"] == "value") 1498 + assert value_field["fieldType"]["primitive"] == "int" 1484 1499 1485 1500 1486 1501 def test_schema_ndarray_field(clean_redis): ··· 1491 1506 schema = index.get_schema(schema_ref) 1492 1507 1493 1508 # Find data field (should be ndarray) 1494 - data_field = next(f for f in schema['fields'] if f['name'] == 'data') 1495 - assert data_field['fieldType']['$type'] == 'local#ndarray' 1496 - assert data_field['fieldType']['dtype'] == 'float32' 1509 + data_field = next(f for f in schema["fields"] if f["name"] == "data") 1510 + assert data_field["fieldType"]["$type"] == "local#ndarray" 1511 + assert data_field["fieldType"]["dtype"] == "float32" 1497 1512 1498 1513 1499 1514 def test_decode_schema(clean_redis): ··· 1518 1533 1519 1534 # Check fields exist 1520 1535 import numpy as np 1536 + 1521 1537 instance = ReconstructedType(label="test", data=np.zeros((3, 3))) 1522 1538 assert instance.label == "test" 1523 1539 assert instance.data.shape == (3, 3) ··· 1556 1572 schema_v1 = index.get_schema(ref_v1) 1557 1573 schema_v2 = index.get_schema(ref_v2) 1558 1574 1559 - assert schema_v1['version'] == '1.0.0' 1560 - assert schema_v2['version'] == '2.0.0' 1575 + assert schema_v1["version"] == "1.0.0" 1576 + assert schema_v2["version"] == "2.0.0" 1561 1577 1562 1578 1563 1579 ## 1564 1580 # Schema codec tests 1581 + 1565 1582 1566 1583 def test_schema_codec_type_caching(): 1567 1584 """Test that schema_to_type caches generated types.""" ··· 1573 1590 schema = { 1574 1591 "name": "CacheTestSample", 1575 1592 "version": "1.0.0", 1576 - "fields": [{"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}], 1593 + "fields": [ 1594 + { 1595 + "name": "value", 1596 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 1597 + "optional": False, 1598 + } 1599 + ], 1577 1600 } 1578 1601 1579 1602 # First call creates and caches type ··· 1596 1619 clear_type_cache() 1597 1620 schema = { 1598 1621 "version": "1.0.0", 1599 - "fields": [{"name": "value", "fieldType": {"$type": "#primitive", "primitive": "int"}, "optional": False}], 1622 + "fields": [ 1623 + { 1624 + "name": "value", 1625 + "fieldType": {"$type": "#primitive", "primitive": "int"}, 1626 + "optional": False, 1627 + } 1628 + ], 1600 1629 } 1601 1630 1602 1631 with pytest.raises(ValueError, match="must have a 'name' field"): ··· 1626 1655 schema = { 1627 1656 "name": "BadFieldSample", 1628 1657 "version": "1.0.0", 1629 - "fields": [{"fieldType": {"$type": "#primitive", "primitive": "int"}, "optional": False}], 1658 + "fields": [ 1659 + { 1660 + "fieldType": {"$type": "#primitive", "primitive": "int"}, 1661 + "optional": False, 1662 + } 1663 + ], 1630 1664 } 1631 1665 1632 1666 # Raises KeyError from cache key generation (accesses f['name']) or ··· 1643 1677 schema = { 1644 1678 "name": "UnknownPrimitiveSample", 1645 1679 "version": "1.0.0", 1646 - "fields": [{"name": "value", "fieldType": {"$type": "#primitive", "primitive": "unknown_type"}, "optional": False}], 1680 + "fields": [ 1681 + { 1682 + "name": "value", 1683 + "fieldType": {"$type": "#primitive", "primitive": "unknown_type"}, 1684 + "optional": False, 1685 + } 1686 + ], 1647 1687 } 1648 1688 1649 1689 with pytest.raises(ValueError, match="Unknown primitive type"): ··· 1658 1698 schema = { 1659 1699 "name": "UnknownKindSample", 1660 1700 "version": "1.0.0", 1661 - "fields": [{"name": "value", "fieldType": {"$type": "#unknown_kind"}, "optional": False}], 1701 + "fields": [ 1702 + { 1703 + "name": "value", 1704 + "fieldType": {"$type": "#unknown_kind"}, 1705 + "optional": False, 1706 + } 1707 + ], 1662 1708 } 1663 1709 1664 1710 with pytest.raises(ValueError, match="Unknown field type kind"): ··· 1673 1719 schema = { 1674 1720 "name": "RefSample", 1675 1721 "version": "1.0.0", 1676 - "fields": [{"name": "other", "fieldType": {"$type": "#ref", "ref": "other.Schema"}, "optional": False}], 1722 + "fields": [ 1723 + { 1724 + "name": "other", 1725 + "fieldType": {"$type": "#ref", "ref": "other.Schema"}, 1726 + "optional": False, 1727 + } 1728 + ], 1677 1729 } 1678 1730 1679 1731 with pytest.raises(ValueError, match="Schema references.*not yet supported"): ··· 1689 1741 "name": "AllPrimitivesSample", 1690 1742 "version": "1.0.0", 1691 1743 "fields": [ 1692 - {"name": "s", "fieldType": {"$type": "#primitive", "primitive": "str"}, "optional": False}, 1693 - {"name": "i", "fieldType": {"$type": "#primitive", "primitive": "int"}, "optional": False}, 1694 - {"name": "f", "fieldType": {"$type": "#primitive", "primitive": "float"}, "optional": False}, 1695 - {"name": "b", "fieldType": {"$type": "#primitive", "primitive": "bool"}, "optional": False}, 1696 - {"name": "by", "fieldType": {"$type": "#primitive", "primitive": "bytes"}, "optional": False}, 1744 + { 1745 + "name": "s", 1746 + "fieldType": {"$type": "#primitive", "primitive": "str"}, 1747 + "optional": False, 1748 + }, 1749 + { 1750 + "name": "i", 1751 + "fieldType": {"$type": "#primitive", "primitive": "int"}, 1752 + "optional": False, 1753 + }, 1754 + { 1755 + "name": "f", 1756 + "fieldType": {"$type": "#primitive", "primitive": "float"}, 1757 + "optional": False, 1758 + }, 1759 + { 1760 + "name": "b", 1761 + "fieldType": {"$type": "#primitive", "primitive": "bool"}, 1762 + "optional": False, 1763 + }, 1764 + { 1765 + "name": "by", 1766 + "fieldType": {"$type": "#primitive", "primitive": "bytes"}, 1767 + "optional": False, 1768 + }, 1697 1769 ], 1698 1770 } 1699 1771 ··· 1716 1788 "name": "OptionalSample", 1717 1789 "version": "1.0.0", 1718 1790 "fields": [ 1719 - {"name": "required", "fieldType": {"$type": "#primitive", "primitive": "str"}, "optional": False}, 1720 - {"name": "optional_str", "fieldType": {"$type": "#primitive", "primitive": "str"}, "optional": True}, 1791 + { 1792 + "name": "required", 1793 + "fieldType": {"$type": "#primitive", "primitive": "str"}, 1794 + "optional": False, 1795 + }, 1796 + { 1797 + "name": "optional_str", 1798 + "fieldType": {"$type": "#primitive", "primitive": "str"}, 1799 + "optional": True, 1800 + }, 1721 1801 ], 1722 1802 } 1723 1803 ··· 1742 1822 "name": "ArraySample", 1743 1823 "version": "1.0.0", 1744 1824 "fields": [ 1745 - {"name": "data", "fieldType": {"$type": "#ndarray", "dtype": "float32"}, "optional": False}, 1825 + { 1826 + "name": "data", 1827 + "fieldType": {"$type": "#ndarray", "dtype": "float32"}, 1828 + "optional": False, 1829 + }, 1746 1830 ], 1747 1831 } 1748 1832 ··· 1762 1846 "name": "ListSample", 1763 1847 "version": "1.0.0", 1764 1848 "fields": [ 1765 - {"name": "tags", "fieldType": {"$type": "#array", "items": {"$type": "#primitive", "primitive": "str"}}, "optional": False}, 1849 + { 1850 + "name": "tags", 1851 + "fieldType": { 1852 + "$type": "#array", 1853 + "items": {"$type": "#primitive", "primitive": "str"}, 1854 + }, 1855 + "optional": False, 1856 + }, 1766 1857 ], 1767 1858 } 1768 1859 ··· 1780 1871 schema = { 1781 1872 "name": "NoCacheSample", 1782 1873 "version": "1.0.0", 1783 - "fields": [{"name": "value", "fieldType": {"$type": "#primitive", "primitive": "int"}, "optional": False}], 1874 + "fields": [ 1875 + { 1876 + "name": "value", 1877 + "fieldType": {"$type": "#primitive", "primitive": "int"}, 1878 + "optional": False, 1879 + } 1880 + ], 1784 1881 } 1785 1882 1786 1883 Type1 = schema_to_type(schema, use_cache=False) ··· 1818 1915 ref = index.publish_schema(SimpleTestSample, version="1.0.0") 1819 1916 1820 1917 # Get schema should trigger stub generation 1821 - schema = index.get_schema(ref) 1918 + _schema = index.get_schema(ref) 1822 1919 1823 1920 # Check stub was created (in local/ subdirectory for namespacing) 1824 1921 stub_path = stub_dir / "local" / "SimpleTestSample_1_0_0.py" ··· 1859 1956 1860 1957 # Small delay to ensure different mtime if regenerated 1861 1958 import time 1959 + 1862 1960 time.sleep(0.01) 1863 1961 1864 1962 # Second call should not regenerate
+36 -19
tests/test_promote.py
··· 15 15 @atdata.packable 16 16 class PromoteTestSample: 17 17 """Sample type for promotion tests.""" 18 + 18 19 name: str 19 20 value: int 20 21 ··· 35 36 "value": { 36 37 "name": "test_promote.PromoteTestSample", 37 38 "version": "1.0.0", 38 - } 39 + }, 39 40 } 40 41 ] 41 42 42 43 result = _find_existing_schema( 43 - mock_client, 44 - "test_promote.PromoteTestSample", 45 - "1.0.0" 44 + mock_client, "test_promote.PromoteTestSample", "1.0.0" 46 45 ) 47 46 48 47 assert result == "at://did:plc:test/ac.foundation.dataset.sampleSchema/abc" ··· 59 58 "value": { 60 59 "name": "other.OtherSample", 61 60 "version": "1.0.0", 62 - } 61 + }, 63 62 } 64 63 ] 65 64 66 65 result = _find_existing_schema( 67 - mock_client, 68 - "test_promote.PromoteTestSample", 69 - "1.0.0" 66 + mock_client, "test_promote.PromoteTestSample", "1.0.0" 70 67 ) 71 68 72 69 assert result is None ··· 83 80 "value": { 84 81 "name": "test_promote.PromoteTestSample", 85 82 "version": "2.0.0", # Different version 86 - } 83 + }, 87 84 } 88 85 ] 89 86 90 87 result = _find_existing_schema( 91 - mock_client, 92 - "test_promote.PromoteTestSample", 93 - "1.0.0" 88 + mock_client, "test_promote.PromoteTestSample", "1.0.0" 94 89 ) 95 90 96 91 assert result is None ··· 125 120 126 121 with patch("atdata.atmosphere.SchemaPublisher") as MockPublisher: 127 122 mock_publisher = MockPublisher.return_value 128 - mock_publisher.publish.return_value = Mock(__str__=lambda s: "at://new/schema/uri") 123 + mock_publisher.publish.return_value = Mock( 124 + __str__=lambda s: "at://new/schema/uri" 125 + ) 129 126 130 127 result = _find_or_publish_schema( 131 128 PromoteTestSample, ··· 167 164 "name": "test_promote.PromoteTestSample", 168 165 "version": "1.0.0", 169 166 "fields": [ 170 - {"name": "name", "fieldType": {"$type": "local#primitive", "primitive": "str"}, "optional": False}, 171 - {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 167 + { 168 + "name": "name", 169 + "fieldType": {"$type": "local#primitive", "primitive": "str"}, 170 + "optional": False, 171 + }, 172 + { 173 + "name": "value", 174 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 175 + "optional": False, 176 + }, 172 177 ], 173 178 } 174 179 ··· 207 212 "name": "TestSample", 208 213 "version": "1.0.0", 209 214 "fields": [ 210 - {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 215 + { 216 + "name": "value", 217 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 218 + "optional": False, 219 + }, 211 220 ], 212 221 } 213 222 ··· 218 227 219 228 with patch("atdata.atmosphere.DatasetPublisher") as MockPublisher: 220 229 mock_publisher = MockPublisher.return_value 221 - mock_publisher.publish_with_urls.return_value = Mock(__str__=lambda s: "at://result") 230 + mock_publisher.publish_with_urls.return_value = Mock( 231 + __str__=lambda s: "at://result" 232 + ) 222 233 223 234 promote_to_atmosphere( 224 235 entry, ··· 247 258 "name": "TestSample", 248 259 "version": "1.0.0", 249 260 "fields": [ 250 - {"name": "value", "fieldType": {"$type": "local#primitive", "primitive": "int"}, "optional": False}, 261 + { 262 + "name": "value", 263 + "fieldType": {"$type": "local#primitive", "primitive": "int"}, 264 + "optional": False, 265 + }, 251 266 ], 252 267 } 253 268 ··· 262 277 263 278 with patch("atdata.atmosphere.DatasetPublisher") as MockPublisher: 264 279 mock_publisher = MockPublisher.return_value 265 - mock_publisher.publish_with_urls.return_value = Mock(__str__=lambda s: "at://result") 280 + mock_publisher.publish_with_urls.return_value = Mock( 281 + __str__=lambda s: "at://result" 282 + ) 266 283 267 284 with patch("atdata.dataset.Dataset"): 268 285 promote_to_atmosphere(
+8 -2
tests/test_protocols.py
··· 309 309 { 310 310 "name": "dataset-b", 311 311 "schemaRef": "at://schema", 312 - "storage": {"$type": "ac.foundation.dataset.storageExternal", "urls": ["url2"]}, 312 + "storage": { 313 + "$type": "ac.foundation.dataset.storageExternal", 314 + "urls": ["url2"], 315 + }, 313 316 }, 314 317 ), 315 318 ] ··· 343 346 { 344 347 "name": "ds2", 345 348 "schemaRef": "at://s", 346 - "storage": {"$type": "ac.foundation.dataset.storageExternal", "urls": ["s3://b/1.tar"]}, 349 + "storage": { 350 + "$type": "ac.foundation.dataset.storageExternal", 351 + "urls": ["s3://b/1.tar"], 352 + }, 347 353 }, 348 354 ), 349 355 ]
+35 -22
tests/test_sources.py
··· 1 1 """Tests for data source implementations.""" 2 2 3 - import io 4 - import tarfile 5 3 from pathlib import Path 6 4 from unittest.mock import Mock, patch, MagicMock 7 5 ··· 17 15 @atdata.packable 18 16 class SourceTestSample: 19 17 """Simple sample for testing data sources.""" 18 + 20 19 name: str 21 20 value: int 22 21 ··· 98 97 def test_dataset_integration(self, tmp_path): 99 98 """URLSource works with Dataset.""" 100 99 tar_path = tmp_path / "test.tar" 101 - create_test_tar(tar_path, [ 102 - {"name": "sample1", "value": 1}, 103 - {"name": "sample2", "value": 2}, 104 - ]) 100 + create_test_tar( 101 + tar_path, 102 + [ 103 + {"name": "sample1", "value": 1}, 104 + {"name": "sample2", "value": 2}, 105 + ], 106 + ) 105 107 106 108 source = URLSource(str(tar_path)) 107 109 ds = atdata.Dataset[SourceTestSample](source) ··· 130 132 131 133 def test_from_urls(self): 132 134 """from_urls parses S3 URLs correctly.""" 133 - source = S3Source.from_urls([ 134 - "s3://bucket/path/a.tar", 135 - "s3://bucket/path/b.tar", 136 - ]) 135 + source = S3Source.from_urls( 136 + [ 137 + "s3://bucket/path/a.tar", 138 + "s3://bucket/path/b.tar", 139 + ] 140 + ) 137 141 138 142 assert source.bucket == "bucket" 139 143 assert source.keys == ["path/a.tar", "path/b.tar"] ··· 164 168 def test_from_urls_multiple_buckets(self): 165 169 """from_urls raises when URLs span buckets.""" 166 170 with pytest.raises(ValueError, match="same bucket"): 167 - S3Source.from_urls([ 168 - "s3://bucket-a/data.tar", 169 - "s3://bucket-b/data.tar", 170 - ]) 171 + S3Source.from_urls( 172 + [ 173 + "s3://bucket-a/data.tar", 174 + "s3://bucket-b/data.tar", 175 + ] 176 + ) 171 177 172 178 def test_from_credentials(self): 173 179 """from_credentials creates source from dict.""" ··· 299 305 300 306 def test_list_shards(self): 301 307 """list_shards returns AT URIs.""" 302 - source = BlobSource(blob_refs=[ 303 - {"did": "did:plc:abc", "cid": "bafyrei111"}, 304 - {"did": "did:plc:abc", "cid": "bafyrei222"}, 305 - ]) 308 + source = BlobSource( 309 + blob_refs=[ 310 + {"did": "did:plc:abc", "cid": "bafyrei111"}, 311 + {"did": "did:plc:abc", "cid": "bafyrei222"}, 312 + ] 313 + ) 306 314 assert source.list_shards() == [ 307 315 "at://did:plc:abc/blob/bafyrei111", 308 316 "at://did:plc:abc/blob/bafyrei222", ··· 310 318 311 319 def test_from_refs_simple_format(self): 312 320 """from_refs accepts simple {did, cid} format.""" 313 - source = BlobSource.from_refs([ 314 - {"did": "did:plc:abc", "cid": "bafyrei123"}, 315 - ]) 321 + source = BlobSource.from_refs( 322 + [ 323 + {"did": "did:plc:abc", "cid": "bafyrei123"}, 324 + ] 325 + ) 316 326 assert len(source.blob_refs) == 1 317 327 assert source.blob_refs[0]["did"] == "did:plc:abc" 318 328 assert source.blob_refs[0]["cid"] == "bafyrei123" ··· 368 378 ) 369 379 370 380 url = source._get_blob_url("did:plc:abc", "bafyrei123") 371 - assert url == "https://pds.example.com/xrpc/com.atproto.sync.getBlob?did=did:plc:abc&cid=bafyrei123" 381 + assert ( 382 + url 383 + == "https://pds.example.com/xrpc/com.atproto.sync.getBlob?did=did:plc:abc&cid=bafyrei123" 384 + ) 372 385 373 386 def test_shards_fetches_blobs(self): 374 387 """shards property fetches blobs via HTTP."""