A loose federation of distributed, typed datasets
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

feat(api): streamline Dataset API with DictSample default and auto-lens registration

- Add DictSample class for schema-less dataset exploration with dict/attr access
- Update load_dataset to default to DictSample when no type specified
- Auto-register lens from DictSample in @packable decorator for seamless as_type()
- Add comprehensive tests for DictSample creation, serialization, and iteration
- Remove outdated human-review notebook

Closes #338, #339, #340, #341, #342

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

+443 -561
.chainlink/issues.db

This is a binary file and will not be displayed.

+5
CHANGELOG.md
··· 25 25 - **Comprehensive integration test suite**: 593 tests covering E2E flows, error handling, edge cases 26 26 27 27 ### Changed 28 + - Streamline Dataset API with DictSample default type (#338) 29 + - Add tests for DictSample and new API (#342) 30 + - Update load_dataset default type to DictSample (#341) 31 + - Update @packable to auto-register DictSample lens (#340) 32 + - Implement DictSample class with __getattr__ and __getitem__ (#339) 28 33 - Fix failing tests in test_integration_error_handling.py (#337) 29 34 - v0.2.2 beta release improvements (#326) 30 35 - Document to_parquet() memory usage (#336)
-525
prototyping/human-review-01.ipynb
··· 1 - { 2 - "cells": [ 3 - { 4 - "cell_type": "code", 5 - "execution_count": 1, 6 - "id": "df3f0691", 7 - "metadata": {}, 8 - "outputs": [], 9 - "source": [ 10 - "import numpy as np\n", 11 - "from numpy.typing import NDArray\n", 12 - "import atdata\n", 13 - "from atdata.local import LocalDatasetEntry, S3DataStore, Index\n", 14 - "import webdataset as wds" 15 - ] 16 - }, 17 - { 18 - "cell_type": "code", 19 - "execution_count": 2, 20 - "id": "1f7ea651", 21 - "metadata": {}, 22 - "outputs": [], 23 - "source": [ 24 - "@atdata.packable\n", 25 - "class TrainingSample:\n", 26 - " \"\"\"A sample containing features and label for training.\"\"\"\n", 27 - " features: NDArray\n", 28 - " label: int\n", 29 - "\n", 30 - "@atdata.packable\n", 31 - "class TextSample:\n", 32 - " \"\"\"A sample containing text data.\"\"\"\n", 33 - " text: str\n", 34 - " category: str" 35 - ] 36 - }, 37 - { 38 - "cell_type": "markdown", 39 - "id": "55549f64", 40 - "metadata": {}, 41 - "source": [ 42 - "x = TextSample(\n", 43 - " text = 'Hello',\n", 44 - " category = 'test',\n", 45 - ")" 46 - ] 47 - }, 48 - { 49 - "cell_type": "markdown", 50 - "id": "462a780b", 51 - "metadata": {}, 52 - "source": [ 53 - "---" 54 - ] 55 - }, 56 - { 57 - "cell_type": "code", 58 - "execution_count": 3, 59 - "id": "ed0821b9", 60 - "metadata": {}, 61 - "outputs": [ 62 - { 63 - "name": "stdout", 64 - "output_type": "stream", 65 - "text": [ 66 - "Bucket: analysis-hive\n", 67 - "Supports streaming: True\n", 68 - "LocalIndex connected\n" 69 - ] 70 - } 71 - ], 72 - "source": [ 73 - "from redis import Redis\n", 74 - "\n", 75 - "# Connect to S3\n", 76 - "store = S3DataStore( '.credentials/r2-analysis-hive.env',\n", 77 - " bucket = \"analysis-hive\"\n", 78 - ")\n", 79 - "\n", 80 - "print(f\"Bucket: {store.bucket}\")\n", 81 - "print(f\"Supports streaming: {store.supports_streaming()}\")\n", 82 - "\n", 83 - "# Connect to Redis\n", 84 - "index = Index(\n", 85 - " data_store = store,\n", 86 - " auto_stubs = True,\n", 87 - ")\n", 88 - "\n", 89 - "print(\"LocalIndex connected\")" 90 - ] 91 - }, 92 - { 93 - "cell_type": "markdown", 94 - "id": "1b5b7e2c", 95 - "metadata": {}, 96 - "source": [ 97 - "TextSample = index.decode_schema( 'atdata://local/sampleSchema/TextSample@1.0.1' )" 98 - ] 99 - }, 100 - { 101 - "cell_type": "code", 102 - "execution_count": 4, 103 - "id": "301ded22", 104 - "metadata": {}, 105 - "outputs": [], 106 - "source": [ 107 - "x = TextSample(\n", 108 - " text = 'hello',\n", 109 - " category = 'test',\n", 110 - ")" 111 - ] 112 - }, 113 - { 114 - "cell_type": "code", 115 - "execution_count": 5, 116 - "id": "51829873", 117 - "metadata": {}, 118 - "outputs": [ 119 - { 120 - "name": "stdout", 121 - "output_type": "stream", 122 - "text": [ 123 - "Published schema: atdata://local/sampleSchema/TrainingSample@1.0.0\n", 124 - " - TrainingSample v1.0.0\n", 125 - "Schema fields: ['features', 'label']\n", 126 - "Decoded type: TrainingSample\n" 127 - ] 128 - } 129 - ], 130 - "source": [ 131 - "# Publish a schema\n", 132 - "schema_ref = index.publish_schema(TrainingSample, version=\"1.0.0\")\n", 133 - "print(f\"Published schema: {schema_ref}\")\n", 134 - "\n", 135 - "# List all schemas\n", 136 - "for schema in index.list_schemas():\n", 137 - " print(f\" - {schema.get('name', 'Unknown')} v{schema.get('version', '?')}\")\n", 138 - "\n", 139 - "# Get schema record\n", 140 - "schema_record = index.get_schema(schema_ref)\n", 141 - "print(f\"Schema fields: {[f['name'] for f in schema_record.get('fields', [])]}\")\n", 142 - "\n", 143 - "# Decode schema back to a PackableSample class\n", 144 - "decoded_type = index.decode_schema(schema_ref)\n", 145 - "print(f\"Decoded type: {decoded_type.__name__}\")" 146 - ] 147 - }, 148 - { 149 - "cell_type": "code", 150 - "execution_count": 6, 151 - "id": "fadbddaa", 152 - "metadata": {}, 153 - "outputs": [ 154 - { 155 - "name": "stdout", 156 - "output_type": "stream", 157 - "text": [ 158 - "Published schema: atdata://local/sampleSchema/TextSample@1.0.1\n", 159 - " - TrainingSample v1.0.0\n", 160 - " - TextSample v1.0.1\n", 161 - "Schema fields: ['text', 'category']\n", 162 - "Decoded type: TextSample\n" 163 - ] 164 - } 165 - ], 166 - "source": [ 167 - "# Publish a schema\n", 168 - "schema_ref_2 = index.publish_schema(TextSample, version=\"1.0.1\")\n", 169 - "print(f\"Published schema: {schema_ref_2}\")\n", 170 - "\n", 171 - "# List all schemas\n", 172 - "for schema in index.list_schemas():\n", 173 - " print(f\" - {schema.get('name', 'Unknown')} v{schema.get('version', '?')}\")\n", 174 - "\n", 175 - "# Get schema record\n", 176 - "schema_record = index.get_schema(schema_ref_2)\n", 177 - "print(f\"Schema fields: {[f['name'] for f in schema_record.get('fields', [])]}\")\n", 178 - "\n", 179 - "# Decode schema back to a PackableSample class\n", 180 - "decoded_type = index.decode_schema(schema_ref_2)\n", 181 - "print(f\"Decoded type: {decoded_type.__name__}\")" 182 - ] 183 - }, 184 - { 185 - "cell_type": "code", 186 - "execution_count": 7, 187 - "id": "afdc07f5", 188 - "metadata": {}, 189 - "outputs": [], 190 - "source": [ 191 - "del TextSample" 192 - ] 193 - }, 194 - { 195 - "cell_type": "code", 196 - "execution_count": 8, 197 - "id": "5a3122bd", 198 - "metadata": {}, 199 - "outputs": [ 200 - { 201 - "data": { 202 - "text/plain": [ 203 - "_atdata_generated_TextSample_1_0_1.TextSample" 204 - ] 205 - }, 206 - "execution_count": 8, 207 - "metadata": {}, 208 - "output_type": "execute_result" 209 - } 210 - ], 211 - "source": [ 212 - "index.load_schema( 'atdata://local/sampleSchema/TextSample@1.0.1' )" 213 - ] 214 - }, 215 - { 216 - "cell_type": "code", 217 - "execution_count": 9, 218 - "id": "492f1d37", 219 - "metadata": {}, 220 - "outputs": [], 221 - "source": [ 222 - "TextSample = index.types.TextSample" 223 - ] 224 - }, 225 - { 226 - "cell_type": "code", 227 - "execution_count": 12, 228 - "id": "04cf7394", 229 - "metadata": {}, 230 - "outputs": [], 231 - "source": [ 232 - "x = TextSample(\n", 233 - " text = 'hello',\n", 234 - " category = 'test',\n", 235 - ")" 236 - ] 237 - }, 238 - { 239 - "cell_type": "code", 240 - "execution_count": null, 241 - "id": "a5022c2f", 242 - "metadata": {}, 243 - "outputs": [], 244 - "source": [ 245 - "@atdata.packable\n", 246 - "class LocalTextSample:\n", 247 - " content: str\n", 248 - " \"Test\"\n", 249 - " category: str\n", 250 - " \"stuff\"\n", 251 - "\n", 252 - "@atdata.lens\n", 253 - "def _convert_text_sample( s: TextSample ) -> LocalTextSample:\n", 254 - " return LocalTextSample(\n", 255 - " content = s.text,\n", 256 - " category = s.category,\n", 257 - " )" 258 - ] 259 - }, 260 - { 261 - "cell_type": "markdown", 262 - "id": "74d785df", 263 - "metadata": {}, 264 - "source": [ 265 - "Notes:\n", 266 - "\n", 267 - "* We get linting errors here on `@atdata.lens` because `LocalTextSample` doesn't show up as a subclass of `PackableSample`; is there a way to resolve this?" 268 - ] 269 - }, 270 - { 271 - "cell_type": "code", 272 - "execution_count": 22, 273 - "id": "e70b084a", 274 - "metadata": {}, 275 - "outputs": [], 276 - "source": [ 277 - "y = _convert_text_sample( x )" 278 - ] 279 - }, 280 - { 281 - "cell_type": "markdown", 282 - "id": "08b8d647", 283 - "metadata": {}, 284 - "source": [ 285 - "---" 286 - ] 287 - }, 288 - { 289 - "cell_type": "code", 290 - "execution_count": 24, 291 - "id": "55d944d0", 292 - "metadata": {}, 293 - "outputs": [ 294 - { 295 - "name": "stdout", 296 - "output_type": "stream", 297 - "text": [ 298 - "# writing data/TextSample_test-000000.tar 0 0.0 GB 0\n", 299 - "# writing data/TextSample_test-000001.tar 1000 0.0 GB 1000\n", 300 - "# writing data/TextSample_test-000002.tar 1000 0.0 GB 2000\n", 301 - "# writing data/TextSample_test-000003.tar 1000 0.0 GB 3000\n", 302 - "# writing data/TextSample_test-000004.tar 1000 0.0 GB 4000\n", 303 - "# writing data/TextSample_test-000005.tar 1000 0.0 GB 5000\n", 304 - "# writing data/TextSample_test-000006.tar 1000 0.0 GB 6000\n", 305 - "# writing data/TextSample_test-000007.tar 1000 0.0 GB 7000\n", 306 - "# writing data/TextSample_test-000008.tar 1000 0.0 GB 8000\n", 307 - "# writing data/TextSample_test-000009.tar 1000 0.0 GB 9000\n" 308 - ] 309 - } 310 - ], 311 - "source": [ 312 - "import webdataset as wds\n", 313 - "from uuid import uuid4\n", 314 - "\n", 315 - "data_pattern = 'data/TextSample_test-%06d.tar'\n", 316 - "\n", 317 - "with wds.writer.ShardWriter( data_pattern, maxcount = 1_000 ) as sink:\n", 318 - " for i in range( 10_000 ):\n", 319 - " new_sample = TextSample(\n", 320 - " text = str( uuid4() ),\n", 321 - " category = 'test',\n", 322 - " )\n", 323 - " sink.write( new_sample.as_wds )" 324 - ] 325 - }, 326 - { 327 - "cell_type": "code", 328 - "execution_count": 30, 329 - "id": "ab656145", 330 - "metadata": {}, 331 - "outputs": [], 332 - "source": [ 333 - "from atdata import Dataset\n", 334 - "ds = Dataset[TextSample]( 'data/TextSample_test-{000000..000009}.tar' )\n", 335 - "x = next( iter( ds.ordered( batch_size = None ) ) )" 336 - ] 337 - }, 338 - { 339 - "cell_type": "markdown", 340 - "id": "2bb32688", 341 - "metadata": {}, 342 - "source": [ 343 - "Notes:\n", 344 - "\n", 345 - "* We should make the default for `Dataset.ordered` and `Dataset.shuffled` be to have `batch_size` be `None`, rather than 1." 346 - ] 347 - }, 348 - { 349 - "cell_type": "code", 350 - "execution_count": 32, 351 - "id": "4ebfcc63", 352 - "metadata": {}, 353 - "outputs": [ 354 - { 355 - "name": "stdout", 356 - "output_type": "stream", 357 - "text": [ 358 - "# writing analysis-hive/prototyping/data--4a5ff662-803b-4700-81f4-45f288f6e565--000000.tar 0 0.0 GB 0\n" 359 - ] 360 - } 361 - ], 362 - "source": [ 363 - "entry = index.insert_dataset( ds, \n", 364 - " name = 'proto-text-samples-2',\n", 365 - " prefix = 'prototyping',\n", 366 - " schema_ref = 'atdata://local/sampleSchema/TextSample@1.0.1',\n", 367 - ")" 368 - ] 369 - }, 370 - { 371 - "cell_type": "code", 372 - "execution_count": 33, 373 - "id": "e74d68f6", 374 - "metadata": {}, 375 - "outputs": [ 376 - { 377 - "data": { 378 - "text/plain": [ 379 - "LocalDatasetEntry(_name='proto-text-samples-2', _schema_ref='atdata://local/sampleSchema/TextSample@1.0.1', _data_urls=['s3://analysis-hive/prototyping/data--4a5ff662-803b-4700-81f4-45f288f6e565--000000.tar'], _metadata=None)" 380 - ] 381 - }, 382 - "execution_count": 33, 383 - "metadata": {}, 384 - "output_type": "execute_result" 385 - } 386 - ], 387 - "source": [ 388 - "entry" 389 - ] 390 - }, 391 - { 392 - "cell_type": "markdown", 393 - "id": "a51090c3", 394 - "metadata": {}, 395 - "source": [ 396 - "Notes:\n", 397 - "\n", 398 - "* We should make sure that the `s3` URI-scheme here is properly used\n", 399 - " * Should we be using the `https` URI since actually this is doing data streaming with `wds`? Or does this indicate that we should think more deeply about the `Dataset` API design and generalizing how we're setting up the `wds` data streaming ...\n", 400 - " * No matter what, we're definitely going to want to make sure that we incorporate the actual host details of the `LocalIndex`'s `S3DataStore` for this, since the S3 host is definitely not local.\n", 401 - " * Should there be underscores here? These feel like public properties ..." 402 - ] 403 - }, 404 - { 405 - "cell_type": "markdown", 406 - "id": "90872fe7", 407 - "metadata": {}, 408 - "source": [ 409 - "---" 410 - ] 411 - }, 412 - { 413 - "cell_type": "code", 414 - "execution_count": 35, 415 - "id": "4a2736f0", 416 - "metadata": {}, 417 - "outputs": [ 418 - { 419 - "ename": "ValueError", 420 - "evalue": "('s3://analysis-hive/prototyping/data--4a5ff662-803b-4700-81f4-45f288f6e565--000000.tar: no gopen handler defined', 's3://analysis-hive/prototyping/data--4a5ff662-803b-4700-81f4-45f288f6e565--000000.tar')", 421 - "output_type": "error", 422 - "traceback": [ 423 - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", 424 - "\u001b[31mValueError\u001b[39m Traceback (most recent call last)", 425 - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[35]\u001b[39m\u001b[32m, line 10\u001b[39m\n\u001b[32m 4\u001b[39m ds = load_dataset( \u001b[33m\"\u001b[39m\u001b[33m@local/proto-text-samples-2\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m 5\u001b[39m index = index,\n\u001b[32m 6\u001b[39m split = \u001b[33m'\u001b[39m\u001b[33mtrain\u001b[39m\u001b[33m'\u001b[39m,\n\u001b[32m 7\u001b[39m )\n\u001b[32m 9\u001b[39m \u001b[38;5;66;03m# The index resolves the dataset name to URLs and schema\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m10\u001b[39m \u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mds\u001b[49m\u001b[43m.\u001b[49m\u001b[43mshuffled\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m32\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 11\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mbreak\u001b[39;49;00m\n", 426 - "\u001b[36mFile \u001b[39m\u001b[32m~/git-forecast/atdata/.venv/lib/python3.12/site-packages/webdataset/pipeline.py:105\u001b[39m, in \u001b[36mDataPipeline.iterator\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 103\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m.repetitions):\n\u001b[32m 104\u001b[39m count = \u001b[32m0\u001b[39m\n\u001b[32m--> \u001b[39m\u001b[32m105\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43msample\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43miterator1\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 106\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01myield\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43msample\u001b[49m\n\u001b[32m 107\u001b[39m \u001b[43m \u001b[49m\u001b[43mcount\u001b[49m\u001b[43m \u001b[49m\u001b[43m+\u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[32;43m1\u001b[39;49m\n", 427 - "\u001b[36mFile \u001b[39m\u001b[32m~/git-forecast/atdata/.venv/lib/python3.12/site-packages/webdataset/filters.py:520\u001b[39m, in \u001b[36m_map\u001b[39m\u001b[34m(data, f, handler)\u001b[39m\n\u001b[32m 505\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_map\u001b[39m(data, f, handler=reraise_exception):\n\u001b[32m 506\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 507\u001b[39m \u001b[33;03m Map samples through a function.\u001b[39;00m\n\u001b[32m 508\u001b[39m \n\u001b[32m (...)\u001b[39m\u001b[32m 518\u001b[39m \u001b[33;03m Exception: If the handler doesn't handle an exception.\u001b[39;00m\n\u001b[32m 519\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m520\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43msample\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 521\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mtry\u001b[39;49;00m\u001b[43m:\u001b[49m\n\u001b[32m 522\u001b[39m \u001b[43m \u001b[49m\u001b[43mresult\u001b[49m\u001b[43m \u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[43msample\u001b[49m\u001b[43m)\u001b[49m\n", 428 - "\u001b[36mFile \u001b[39m\u001b[32m~/git-forecast/atdata/.venv/lib/python3.12/site-packages/webdataset/filters.py:783\u001b[39m, in \u001b[36m_batched\u001b[39m\u001b[34m(data, batchsize, collation_fn, partial)\u001b[39m\n\u001b[32m 770\u001b[39m \u001b[38;5;250m\u001b[39m\u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 771\u001b[39m \u001b[33;03mCreate batches of the given size.\u001b[39;00m\n\u001b[32m 772\u001b[39m \n\u001b[32m (...)\u001b[39m\u001b[32m 780\u001b[39m \u001b[33;03m Batches of samples.\u001b[39;00m\n\u001b[32m 781\u001b[39m \u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 782\u001b[39m batch = []\n\u001b[32m--> \u001b[39m\u001b[32m783\u001b[39m \u001b[43m\u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43msample\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 784\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[43m>\u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatchsize\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 785\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mcollation_fn\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m:\u001b[49m\n", 429 - "\u001b[36mFile \u001b[39m\u001b[32m~/git-forecast/atdata/.venv/lib/python3.12/site-packages/webdataset/filters.py:358\u001b[39m, in \u001b[36m_shuffle\u001b[39m\u001b[34m(data, bufsize, initial, rng, seed, handler)\u001b[39m\n\u001b[32m 356\u001b[39m initial = \u001b[38;5;28mmin\u001b[39m(initial, bufsize)\n\u001b[32m 357\u001b[39m buf = []\n\u001b[32m--> \u001b[39m\u001b[32m358\u001b[39m \u001b[43m\u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43msample\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 359\u001b[39m \u001b[43m \u001b[49m\u001b[43mbuf\u001b[49m\u001b[43m.\u001b[49m\u001b[43mappend\u001b[49m\u001b[43m(\u001b[49m\u001b[43msample\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 360\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mbuf\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[43m<\u001b[49m\u001b[43m \u001b[49m\u001b[43mbufsize\u001b[49m\u001b[43m:\u001b[49m\n", 430 - "\u001b[36mFile \u001b[39m\u001b[32m~/git-forecast/atdata/.venv/lib/python3.12/site-packages/webdataset/tariterators.py:230\u001b[39m, in \u001b[36mgroup_by_keys\u001b[39m\u001b[34m(data, keys, lcase, suffixes, handler)\u001b[39m\n\u001b[32m 214\u001b[39m \u001b[38;5;250m\u001b[39m\u001b[33;03m\"\"\"Group tarfile contents by keys and yield samples.\u001b[39;00m\n\u001b[32m 215\u001b[39m \n\u001b[32m 216\u001b[39m \u001b[33;03mArgs:\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 227\u001b[39m \u001b[33;03m Iterator over samples.\u001b[39;00m\n\u001b[32m 228\u001b[39m \u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 229\u001b[39m current_sample = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m230\u001b[39m \u001b[43m\u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mfilesample\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 231\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mtry\u001b[39;49;00m\u001b[43m:\u001b[49m\n\u001b[32m 232\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01massert\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43misinstance\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mfilesample\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mdict\u001b[39;49m\u001b[43m)\u001b[49m\n", 431 - "\u001b[36mFile \u001b[39m\u001b[32m~/git-forecast/atdata/.venv/lib/python3.12/site-packages/webdataset/tariterators.py:178\u001b[39m, in \u001b[36mtar_file_expander\u001b[39m\u001b[34m(data, handler, select_files, rename_files, eof_value)\u001b[39m\n\u001b[32m 159\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mtar_file_expander\u001b[39m(\n\u001b[32m 160\u001b[39m data: Iterable[Dict[\u001b[38;5;28mstr\u001b[39m, Any]],\n\u001b[32m 161\u001b[39m handler: Callable[[\u001b[38;5;167;01mException\u001b[39;00m], \u001b[38;5;28mbool\u001b[39m] = reraise_exception,\n\u001b[32m (...)\u001b[39m\u001b[32m 164\u001b[39m eof_value: Optional[Any] = {},\n\u001b[32m 165\u001b[39m ) -> Iterator[Dict[\u001b[38;5;28mstr\u001b[39m, Any]]:\n\u001b[32m 166\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Expand tar files.\u001b[39;00m\n\u001b[32m 167\u001b[39m \n\u001b[32m 168\u001b[39m \u001b[33;03m Args:\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 176\u001b[39m \u001b[33;03m A stream of samples.\u001b[39;00m\n\u001b[32m 177\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m178\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43msource\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 179\u001b[39m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m \u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[43msource\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43murl\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[32m 180\u001b[39m \u001b[43m \u001b[49m\u001b[43mlocal_path\u001b[49m\u001b[43m \u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[43msource\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mlocal_path\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", 432 - "\u001b[36mFile \u001b[39m\u001b[32m~/git-forecast/atdata/.venv/lib/python3.12/site-packages/webdataset/tariterators.py:103\u001b[39m, in \u001b[36murl_opener\u001b[39m\u001b[34m(data, handler, **kw)\u001b[39m\n\u001b[32m 101\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m exn:\n\u001b[32m 102\u001b[39m exn.args = exn.args + (url,)\n\u001b[32m--> \u001b[39m\u001b[32m103\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mhandler\u001b[49m\u001b[43m(\u001b[49m\u001b[43mexn\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[32m 104\u001b[39m \u001b[38;5;28;01mcontinue\u001b[39;00m\n\u001b[32m 105\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n", 433 - "\u001b[36mFile \u001b[39m\u001b[32m~/git-forecast/atdata/.venv/lib/python3.12/site-packages/webdataset/handlers.py:31\u001b[39m, in \u001b[36mreraise_exception\u001b[39m\u001b[34m(exn)\u001b[39m\n\u001b[32m 22\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mreraise_exception\u001b[39m(exn):\n\u001b[32m 23\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Re-raise the given exception.\u001b[39;00m\n\u001b[32m 24\u001b[39m \n\u001b[32m 25\u001b[39m \u001b[33;03m Args:\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 29\u001b[39m \u001b[33;03m The input exception.\u001b[39;00m\n\u001b[32m 30\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m31\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m exn\n", 434 - "\u001b[36mFile \u001b[39m\u001b[32m~/git-forecast/atdata/.venv/lib/python3.12/site-packages/webdataset/tariterators.py:98\u001b[39m, in \u001b[36murl_opener\u001b[39m\u001b[34m(data, handler, **kw)\u001b[39m\n\u001b[32m 96\u001b[39m url = sample[\u001b[33m\"\u001b[39m\u001b[33murl\u001b[39m\u001b[33m\"\u001b[39m]\n\u001b[32m 97\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m98\u001b[39m stream = \u001b[43mgopen\u001b[49m\u001b[43m.\u001b[49m\u001b[43mgopen\u001b[49m\u001b[43m(\u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkw\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 99\u001b[39m sample.update(stream=stream)\n\u001b[32m 100\u001b[39m \u001b[38;5;28;01myield\u001b[39;00m sample\n", 435 - "\u001b[36mFile \u001b[39m\u001b[32m~/git-forecast/atdata/.venv/lib/python3.12/site-packages/webdataset/gopen.py:591\u001b[39m, in \u001b[36mgopen\u001b[39m\u001b[34m(url, mode, bufsize, **kw)\u001b[39m\n\u001b[32m 589\u001b[39m handler = gopen_schemes[\u001b[33m\"\u001b[39m\u001b[33m__default__\u001b[39m\u001b[33m\"\u001b[39m]\n\u001b[32m 590\u001b[39m handler = gopen_schemes.get(pr.scheme, handler)\n\u001b[32m--> \u001b[39m\u001b[32m591\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mhandler\u001b[49m\u001b[43m(\u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbufsize\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkw\u001b[49m\u001b[43m)\u001b[49m\n", 436 - "\u001b[36mFile \u001b[39m\u001b[32m~/git-forecast/atdata/.venv/lib/python3.12/site-packages/webdataset/gopen.py:446\u001b[39m, in \u001b[36mgopen_error\u001b[39m\u001b[34m(url, *args, **kw)\u001b[39m\n\u001b[32m 435\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mgopen_error\u001b[39m(url, *args, **kw):\n\u001b[32m 436\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Raise a value error.\u001b[39;00m\n\u001b[32m 437\u001b[39m \n\u001b[32m 438\u001b[39m \u001b[33;03m Args:\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 444\u001b[39m \u001b[33;03m ValueError: Always raised with the URL and a message\u001b[39;00m\n\u001b[32m 445\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m446\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00murl\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m: no gopen handler defined\u001b[39m\u001b[33m\"\u001b[39m)\n", 437 - "\u001b[31mValueError\u001b[39m: ('s3://analysis-hive/prototyping/data--4a5ff662-803b-4700-81f4-45f288f6e565--000000.tar: no gopen handler defined', 's3://analysis-hive/prototyping/data--4a5ff662-803b-4700-81f4-45f288f6e565--000000.tar')" 438 - ] 439 - } 440 - ], 441 - "source": [ 442 - "from atdata import load_dataset\n", 443 - "\n", 444 - "# Load from local index\n", 445 - "ds = load_dataset( \"@local/proto-text-samples-2\",\n", 446 - " index = index,\n", 447 - " split = 'train',\n", 448 - ")\n", 449 - "\n", 450 - "# The index resolves the dataset name to URLs and schema\n", 451 - "for batch in ds.shuffled(batch_size=32):\n", 452 - " break" 453 - ] 454 - }, 455 - { 456 - "cell_type": "markdown", 457 - "id": "3b30bb49", 458 - "metadata": {}, 459 - "source": [ 460 - "Notes:\n", 461 - "\n", 462 - "* This is also getting linting errors on `load_dataset` that there are no matching overloads." 463 - ] 464 - }, 465 - { 466 - "cell_type": "code", 467 - "execution_count": 36, 468 - "id": "5c2afcd2", 469 - "metadata": {}, 470 - "outputs": [ 471 - { 472 - "data": { 473 - "text/plain": [ 474 - "'s3://analysis-hive/prototyping/data--4a5ff662-803b-4700-81f4-45f288f6e565--000000.tar'" 475 - ] 476 - }, 477 - "execution_count": 36, 478 - "metadata": {}, 479 - "output_type": "execute_result" 480 - } 481 - ], 482 - "source": [ 483 - "ds.url" 484 - ] 485 - }, 486 - { 487 - "cell_type": "markdown", 488 - "id": "2e4238ba", 489 - "metadata": {}, 490 - "source": [ 491 - "Notes:\n", 492 - "\n", 493 - "* We're getting linting errors because of the protocol use for `AbstractIndex`; better to subclass, or is there a way for this to get the protocol adherence?\n", 494 - "* The S3 URI error is showing up here now because of how dataset loading works! The data is uploaded correctly on my end, but it can't be accessed because of this URI not being the correct way to access the data for `wds` streaming over `https`; we should think of how best to encode this!" 495 - ] 496 - }, 497 - { 498 - "cell_type": "markdown", 499 - "id": "2bbedcd2", 500 - "metadata": {}, 501 - "source": [] 502 - } 503 - ], 504 - "metadata": { 505 - "kernelspec": { 506 - "display_name": "atdata", 507 - "language": "python", 508 - "name": "python3" 509 - }, 510 - "language_info": { 511 - "codemirror_mode": { 512 - "name": "ipython", 513 - "version": 3 514 - }, 515 - "file_extension": ".py", 516 - "mimetype": "text/x-python", 517 - "name": "python", 518 - "nbconvert_exporter": "python", 519 - "pygments_lexer": "ipython3", 520 - "version": "3.12.11" 521 - } 522 - }, 523 - "nbformat": 4, 524 - "nbformat_minor": 5 525 - }
+1
src/atdata/__init__.py
··· 39 39 # Expose components 40 40 41 41 from .dataset import ( 42 + DictSample as DictSample, 42 43 PackableSample as PackableSample, 43 44 SampleBatch as SampleBatch, 44 45 Dataset as Dataset,
+41 -32
src/atdata/_hf_api.py
··· 40 40 overload, 41 41 ) 42 42 43 - from .dataset import Dataset, PackableSample 43 + from .dataset import Dataset, PackableSample, DictSample 44 44 45 45 if TYPE_CHECKING: 46 46 from ._protocols import AbstractIndex ··· 472 472 # Main load_dataset function 473 473 474 474 475 + # Overload: explicit type with split -> Dataset[ST] 475 476 @overload 476 477 def load_dataset( 477 478 path: str, ··· 484 485 ) -> Dataset[ST]: ... 485 486 486 487 488 + # Overload: explicit type without split -> DatasetDict[ST] 487 489 @overload 488 490 def load_dataset( 489 491 path: str, ··· 496 498 ) -> DatasetDict[ST]: ... 497 499 498 500 501 + # Overload: no type with split -> Dataset[DictSample] 499 502 @overload 500 503 def load_dataset( 501 504 path: str, ··· 504 507 split: str, 505 508 data_files: str | list[str] | dict[str, str | list[str]] | None = None, 506 509 streaming: bool = False, 507 - index: "AbstractIndex", 508 - ) -> Dataset[PackableSample]: ... 510 + index: Optional["AbstractIndex"] = None, 511 + ) -> Dataset[DictSample]: ... 509 512 510 513 514 + # Overload: no type without split -> DatasetDict[DictSample] 511 515 @overload 512 516 def load_dataset( 513 517 path: str, ··· 516 520 split: None = None, 517 521 data_files: str | list[str] | dict[str, str | list[str]] | None = None, 518 522 streaming: bool = False, 519 - index: "AbstractIndex", 520 - ) -> DatasetDict[PackableSample]: ... 523 + index: Optional["AbstractIndex"] = None, 524 + ) -> DatasetDict[DictSample]: ... 521 525 522 526 523 527 def load_dataset( ··· 536 540 returns either a single Dataset or a DatasetDict depending on the split 537 541 parameter. 538 542 543 + When no ``sample_type`` is provided, returns a ``Dataset[DictSample]`` that 544 + provides dynamic dict-like access to fields. Use ``.as_type(MyType)`` to 545 + convert to a typed schema. 546 + 539 547 Args: 540 548 path: Path to dataset. Can be: 541 549 - Index lookup: "@handle/dataset-name" or "@local/dataset-name" ··· 545 553 - Remote URL: "s3://bucket/path/data-*.tar" 546 554 - Single file: "path/to/data.tar" 547 555 548 - sample_type: The PackableSample subclass defining the schema. Can be 549 - None if index is provided - the type will be resolved from the 550 - schema stored in the index. 556 + sample_type: The PackableSample subclass defining the schema. If None, 557 + returns ``Dataset[DictSample]`` with dynamic field access. Can also 558 + be resolved from an index when using @handle/dataset syntax. 551 559 552 560 split: Which split to load. If None, returns a DatasetDict with all 553 561 detected splits. If specified (e.g., "train", "test"), returns ··· 563 571 pipelines, so this parameter primarily signals intent. 564 572 565 573 index: Optional AbstractIndex for dataset lookup. Required when using 566 - @handle/dataset syntax or when sample_type is None. Can be a 567 - LocalIndex or AtmosphereIndex. 574 + @handle/dataset syntax. When provided with an indexed path, the 575 + schema can be auto-resolved from the index. 568 576 569 577 Returns: 570 - If split is None: DatasetDict[ST] with all detected splits. 571 - If split is specified: Dataset[ST] for that split. 578 + If split is None: DatasetDict with all detected splits. 579 + If split is specified: Dataset for that split. 580 + Type is ``ST`` if sample_type provided, otherwise ``DictSample``. 572 581 573 582 Raises: 574 - ValueError: If the specified split is not found, or if sample_type 575 - is None without an index. 583 + ValueError: If the specified split is not found. 576 584 FileNotFoundError: If no data files are found at the path. 577 585 KeyError: If dataset not found in index. 578 586 579 587 Example: 580 - >>> # Load from local path with explicit type 588 + >>> # Load without type - get DictSample for exploration 589 + >>> ds = load_dataset("./data/train.tar", split="train") 590 + >>> for sample in ds.ordered(): 591 + ... print(sample.keys()) # Explore fields 592 + ... print(sample["text"]) # Dict-style access 593 + ... print(sample.label) # Attribute access 594 + >>> 595 + >>> # Convert to typed schema 596 + >>> typed_ds = ds.as_type(TextData) 597 + >>> 598 + >>> # Or load with explicit type directly 581 599 >>> train_ds = load_dataset("./data/train-*.tar", TextData, split="train") 582 600 >>> 583 601 >>> # Load from index with auto-type resolution 584 602 >>> index = LocalIndex() 585 603 >>> ds = load_dataset("@local/my-dataset", index=index, split="train") 586 - >>> 587 - >>> # Load all splits 588 - >>> ds_dict = load_dataset("./data/", TextData) 589 - >>> train_ds = ds_dict["train"] 590 604 """ 591 605 # Handle @handle/dataset indexed path resolution 592 606 if _is_indexed_path(path): ··· 599 613 data_urls, schema_ref = _resolve_indexed_path(path, index) 600 614 601 615 # Resolve sample_type from schema if not provided 602 - if sample_type is None: 603 - sample_type = index.decode_schema(schema_ref) 616 + resolved_type: Type = sample_type if sample_type is not None else index.decode_schema(schema_ref) 604 617 605 618 # For indexed datasets, we treat all URLs as a single "train" split 606 619 url = _shards_to_wds_url(data_urls) 607 - ds = Dataset[sample_type](url) 620 + ds = Dataset[resolved_type](url) 608 621 609 622 if split is not None: 610 623 # Indexed datasets are single-split by default 611 624 return ds 612 625 613 - return DatasetDict({"train": ds}, sample_type=sample_type, streaming=streaming) 626 + return DatasetDict({"train": ds}, sample_type=resolved_type, streaming=streaming) 614 627 615 - # Validate sample_type for non-indexed paths 616 - if sample_type is None: 617 - raise ValueError( 618 - "sample_type is required for non-indexed paths. " 619 - "Use @handle/dataset with an index for auto-type resolution." 620 - ) 628 + # Use DictSample as default when no type specified 629 + resolved_type = sample_type if sample_type is not None else DictSample 621 630 622 631 # Resolve path to split -> shard URL mapping 623 632 splits_shards = _resolve_shards(path, data_files) ··· 626 635 raise FileNotFoundError(f"No data files found at path: {path}") 627 636 628 637 # Build Dataset for each split 629 - datasets: dict[str, Dataset[ST]] = {} 638 + datasets: dict[str, Dataset] = {} 630 639 for split_name, shards in splits_shards.items(): 631 640 url = _shards_to_wds_url(shards) 632 - ds = Dataset[sample_type](url) 641 + ds = Dataset[resolved_type](url) 633 642 datasets[split_name] = ds 634 643 635 644 # Return single Dataset or DatasetDict ··· 641 650 ) 642 651 return datasets[split] 643 652 644 - return DatasetDict(datasets, sample_type=sample_type, streaming=streaming) 653 + return DatasetDict(datasets, sample_type=resolved_type, streaming=streaming) 645 654 646 655 647 656 ##
+173
src/atdata/dataset.py
··· 106 106 return any( x == NDArray for x in t.__args__ ) 107 107 return False 108 108 109 + class DictSample: 110 + """Dynamic sample type providing dict-like access to raw msgpack data. 111 + 112 + This class is the default sample type for datasets when no explicit type is 113 + specified. It stores the raw unpacked msgpack data and provides both 114 + attribute-style (``sample.field``) and dict-style (``sample["field"]``) 115 + access to fields. 116 + 117 + ``DictSample`` is useful for: 118 + - Exploring datasets without defining a schema first 119 + - Working with datasets that have variable schemas 120 + - Prototyping before committing to a typed schema 121 + 122 + To convert to a typed schema, use ``Dataset.as_type()`` with a 123 + ``@packable``-decorated class. Every ``@packable`` class automatically 124 + registers a lens from ``DictSample``, making this conversion seamless. 125 + 126 + Example: 127 + >>> ds = load_dataset("path/to/data.tar") # Returns Dataset[DictSample] 128 + >>> for sample in ds.ordered(): 129 + ... print(sample.some_field) # Attribute access 130 + ... print(sample["other_field"]) # Dict access 131 + ... print(sample.keys()) # Inspect available fields 132 + ... 133 + >>> # Convert to typed schema 134 + >>> typed_ds = ds.as_type(MyTypedSample) 135 + 136 + Note: 137 + NDArray fields are stored as raw bytes in DictSample. They are only 138 + converted to numpy arrays when accessed through a typed sample class. 139 + """ 140 + 141 + __slots__ = ('_data',) 142 + 143 + def __init__(self, _data: dict[str, Any] | None = None, **kwargs: Any) -> None: 144 + """Create a DictSample from a dictionary or keyword arguments. 145 + 146 + Args: 147 + _data: Raw data dictionary. If provided, kwargs are ignored. 148 + **kwargs: Field values if _data is not provided. 149 + """ 150 + if _data is not None: 151 + object.__setattr__(self, '_data', _data) 152 + else: 153 + object.__setattr__(self, '_data', kwargs) 154 + 155 + @classmethod 156 + def from_data(cls, data: dict[str, Any]) -> 'DictSample': 157 + """Create a DictSample from unpacked msgpack data. 158 + 159 + Args: 160 + data: Dictionary with field names as keys. 161 + 162 + Returns: 163 + New DictSample instance wrapping the data. 164 + """ 165 + return cls(_data=data) 166 + 167 + @classmethod 168 + def from_bytes(cls, bs: bytes) -> 'DictSample': 169 + """Create a DictSample from raw msgpack bytes. 170 + 171 + Args: 172 + bs: Raw bytes from a msgpack-serialized sample. 173 + 174 + Returns: 175 + New DictSample instance with the unpacked data. 176 + """ 177 + return cls.from_data(ormsgpack.unpackb(bs)) 178 + 179 + def __getattr__(self, name: str) -> Any: 180 + """Access a field by attribute name. 181 + 182 + Args: 183 + name: Field name to access. 184 + 185 + Returns: 186 + The field value. 187 + 188 + Raises: 189 + AttributeError: If the field doesn't exist. 190 + """ 191 + # Avoid infinite recursion for _data lookup 192 + if name == '_data': 193 + raise AttributeError(name) 194 + try: 195 + return self._data[name] 196 + except KeyError: 197 + raise AttributeError( 198 + f"'{type(self).__name__}' has no field '{name}'. " 199 + f"Available fields: {list(self._data.keys())}" 200 + ) from None 201 + 202 + def __getitem__(self, key: str) -> Any: 203 + """Access a field by dict key. 204 + 205 + Args: 206 + key: Field name to access. 207 + 208 + Returns: 209 + The field value. 210 + 211 + Raises: 212 + KeyError: If the field doesn't exist. 213 + """ 214 + return self._data[key] 215 + 216 + def __contains__(self, key: str) -> bool: 217 + """Check if a field exists.""" 218 + return key in self._data 219 + 220 + def keys(self) -> list[str]: 221 + """Return list of field names.""" 222 + return list(self._data.keys()) 223 + 224 + def values(self) -> list[Any]: 225 + """Return list of field values.""" 226 + return list(self._data.values()) 227 + 228 + def items(self) -> list[tuple[str, Any]]: 229 + """Return list of (field_name, value) tuples.""" 230 + return list(self._data.items()) 231 + 232 + def get(self, key: str, default: Any = None) -> Any: 233 + """Get a field value with optional default. 234 + 235 + Args: 236 + key: Field name to access. 237 + default: Value to return if field doesn't exist. 238 + 239 + Returns: 240 + The field value or default. 241 + """ 242 + return self._data.get(key, default) 243 + 244 + def to_dict(self) -> dict[str, Any]: 245 + """Return a copy of the underlying data dictionary.""" 246 + return dict(self._data) 247 + 248 + @property 249 + def packed(self) -> bytes: 250 + """Pack this sample's data into msgpack bytes. 251 + 252 + Returns: 253 + Raw msgpack bytes representing this sample's data. 254 + """ 255 + return msgpack.packb(self._data) 256 + 257 + @property 258 + def as_wds(self) -> 'WDSRawSample': 259 + """Pack this sample's data for writing to WebDataset. 260 + 261 + Returns: 262 + A dictionary with ``__key__`` and ``msgpack`` fields. 263 + """ 264 + return { 265 + '__key__': str(uuid.uuid1(0, 0)), 266 + 'msgpack': self.packed, 267 + } 268 + 269 + def __repr__(self) -> str: 270 + fields = ', '.join(f'{k}=...' for k in self._data.keys()) 271 + return f'DictSample({fields})' 272 + 273 + 109 274 @dataclass 110 275 class PackableSample( ABC ): 111 276 """Base class for samples that can be serialized with msgpack. ··· 804 969 attr.__qualname__ = attr.__qualname__.replace( 805 970 old_qualname_prefix, class_name, 1 806 971 ) 972 + 973 + # Auto-register lens from DictSample to this type 974 + # This enables ds.as_type(MyType) when ds is Dataset[DictSample] 975 + def _dict_to_typed(ds: DictSample) -> as_packable: 976 + return as_packable.from_data(ds._data) 977 + 978 + _dict_lens = Lens(_dict_to_typed) 979 + LensNetwork().register(_dict_lens) 807 980 808 981 ## 809 982
+202
tests/test_dataset.py
··· 697 697 list(dataset.ordered(batch_size=0)) 698 698 699 699 700 + ## 701 + # DictSample tests 702 + 703 + 704 + def test_dictsample_creation(): 705 + """Test DictSample can be created with keyword args or dict.""" 706 + # From keyword args 707 + ds1 = atdata.DictSample(name="test", value=42) 708 + assert ds1.name == "test" 709 + assert ds1.value == 42 710 + 711 + # From _data dict 712 + ds2 = atdata.DictSample(_data={"name": "test2", "value": 100}) 713 + assert ds2.name == "test2" 714 + assert ds2.value == 100 715 + 716 + 717 + def test_dictsample_getattr(): 718 + """Test DictSample attribute access.""" 719 + sample = atdata.DictSample(text="hello", label=1) 720 + 721 + assert sample.text == "hello" 722 + assert sample.label == 1 723 + 724 + # Non-existent attribute raises AttributeError 725 + with pytest.raises(AttributeError, match="has no field"): 726 + _ = sample.nonexistent 727 + 728 + 729 + def test_dictsample_getitem(): 730 + """Test DictSample dict-style access.""" 731 + sample = atdata.DictSample(text="hello", label=1) 732 + 733 + assert sample["text"] == "hello" 734 + assert sample["label"] == 1 735 + 736 + # Non-existent key raises KeyError 737 + with pytest.raises(KeyError): 738 + _ = sample["nonexistent"] 739 + 740 + 741 + def test_dictsample_dict_methods(): 742 + """Test DictSample dict-like methods.""" 743 + sample = atdata.DictSample(a=1, b=2, c=3) 744 + 745 + assert set(sample.keys()) == {"a", "b", "c"} 746 + assert set(sample.values()) == {1, 2, 3} 747 + assert set(sample.items()) == {("a", 1), ("b", 2), ("c", 3)} 748 + assert "a" in sample 749 + assert "x" not in sample 750 + assert sample.get("a") == 1 751 + assert sample.get("x", "default") == "default" 752 + 753 + 754 + def test_dictsample_to_dict(): 755 + """Test DictSample.to_dict returns a copy.""" 756 + sample = atdata.DictSample(name="test", value=42) 757 + d = sample.to_dict() 758 + 759 + assert d == {"name": "test", "value": 42} 760 + # Should be a copy 761 + d["name"] = "modified" 762 + assert sample.name == "test" 763 + 764 + 765 + def test_dictsample_serialization(): 766 + """Test DictSample can be serialized and deserialized.""" 767 + original = atdata.DictSample(text="hello", count=42) 768 + 769 + # Serialize 770 + packed = original.packed 771 + 772 + # Deserialize 773 + restored = atdata.DictSample.from_bytes(packed) 774 + 775 + assert restored.text == "hello" 776 + assert restored.count == 42 777 + 778 + 779 + def test_dictsample_as_wds(): 780 + """Test DictSample.as_wds produces valid WebDataset format.""" 781 + sample = atdata.DictSample(name="test", value=123) 782 + wds_dict = sample.as_wds 783 + 784 + assert "__key__" in wds_dict 785 + assert "msgpack" in wds_dict 786 + assert isinstance(wds_dict["msgpack"], bytes) 787 + 788 + 789 + def test_dictsample_repr(): 790 + """Test DictSample has a useful repr.""" 791 + sample = atdata.DictSample(name="test", value=42) 792 + repr_str = repr(sample) 793 + 794 + assert "DictSample" in repr_str 795 + assert "name" in repr_str 796 + assert "value" in repr_str 797 + 798 + 799 + def test_dictsample_dataset_iteration(tmp_path): 800 + """Test Dataset[DictSample] can iterate over data.""" 801 + # Create typed sample data 802 + @atdata.packable 803 + class SourceSample: 804 + text: str 805 + label: int 806 + 807 + wds_filename = (tmp_path / "dictsample_test.tar").as_posix() 808 + with wds.writer.TarWriter(wds_filename) as sink: 809 + for i in range(5): 810 + sample = SourceSample(text=f"item_{i}", label=i) 811 + sink.write(sample.as_wds) 812 + 813 + # Read as DictSample 814 + dataset = atdata.Dataset[atdata.DictSample](wds_filename) 815 + 816 + samples = list(dataset.ordered()) 817 + assert len(samples) == 5 818 + 819 + for i, sample in enumerate(samples): 820 + assert isinstance(sample, atdata.DictSample) 821 + assert sample.text == f"item_{i}" 822 + assert sample["label"] == i 823 + 824 + 825 + def test_dictsample_to_typed_via_as_type(tmp_path): 826 + """Test converting DictSample dataset to typed via as_type.""" 827 + @atdata.packable 828 + class TypedSample: 829 + text: str 830 + label: int 831 + 832 + # Create data using typed sample 833 + wds_filename = (tmp_path / "astype_test.tar").as_posix() 834 + with wds.writer.TarWriter(wds_filename) as sink: 835 + for i in range(5): 836 + sample = TypedSample(text=f"item_{i}", label=i) 837 + sink.write(sample.as_wds) 838 + 839 + # Load as DictSample first 840 + ds_dict = atdata.Dataset[atdata.DictSample](wds_filename) 841 + 842 + # Convert to typed 843 + ds_typed = ds_dict.as_type(TypedSample) 844 + 845 + # Verify typed iteration works 846 + samples = list(ds_typed.ordered()) 847 + assert len(samples) == 5 848 + 849 + for i, sample in enumerate(samples): 850 + assert isinstance(sample, TypedSample) 851 + assert sample.text == f"item_{i}" 852 + assert sample.label == i 853 + 854 + 855 + def test_packable_auto_registers_dictsample_lens(): 856 + """Test @packable decorator auto-registers lens from DictSample.""" 857 + @atdata.packable 858 + class AutoLensSample: 859 + name: str 860 + value: int 861 + 862 + # The lens should be registered automatically 863 + network = atdata.LensNetwork() 864 + lens = network.transform(atdata.DictSample, AutoLensSample) 865 + 866 + # Test the lens works 867 + dict_sample = atdata.DictSample(name="test", value=42) 868 + typed_sample = lens(dict_sample) 869 + 870 + assert isinstance(typed_sample, AutoLensSample) 871 + assert typed_sample.name == "test" 872 + assert typed_sample.value == 42 873 + 874 + 875 + def test_dictsample_batched_iteration(tmp_path): 876 + """Test Dataset[DictSample] works with batched iteration.""" 877 + @atdata.packable 878 + class BatchSource: 879 + text: str 880 + value: int 881 + 882 + wds_filename = (tmp_path / "batch_dictsample_test.tar").as_posix() 883 + with wds.writer.TarWriter(wds_filename) as sink: 884 + for i in range(10): 885 + sample = BatchSource(text=f"item_{i}", value=i) 886 + sink.write(sample.as_wds) 887 + 888 + # Read as DictSample with batching 889 + dataset = atdata.Dataset[atdata.DictSample](wds_filename) 890 + 891 + batch_count = 0 892 + for batch in dataset.ordered(batch_size=4): 893 + assert isinstance(batch, atdata.SampleBatch) 894 + assert len(batch.samples) <= 4 895 + for sample in batch.samples: 896 + assert isinstance(sample, atdata.DictSample) 897 + batch_count += 1 898 + 899 + assert batch_count == 3 # 10 samples / 4 per batch = 2 full + 1 partial 900 + 901 + 700 902 ##
+21 -4
tests/test_hf_api.py
··· 713 713 with pytest.raises(ValueError, match="Index required"): 714 714 load_dataset("@handle/dataset", SimpleTestSample) 715 715 716 - def test_none_sample_type_requires_index(self): 717 - """sample_type=None without index raises ValueError.""" 718 - with pytest.raises(ValueError, match="sample_type is required"): 719 - load_dataset("/path/to/data.tar", None) 716 + def test_none_sample_type_defaults_to_dictsample(self, tmp_path): 717 + """sample_type=None returns Dataset[DictSample].""" 718 + from atdata import DictSample 719 + 720 + # Create a test tar file 721 + tar_path = tmp_path / "data.tar" 722 + sample = SimpleTestSample(text="hello", label=42) 723 + with wds.writer.TarWriter(str(tar_path)) as writer: 724 + writer.write(sample.as_wds) 725 + 726 + # Load without specifying sample_type 727 + ds = load_dataset(str(tar_path), split="train") 728 + 729 + # Should return Dataset[DictSample] 730 + assert ds.sample_type == DictSample 731 + 732 + # Should be able to iterate and access fields 733 + for sample in ds.ordered(): 734 + assert sample["text"] == "hello" 735 + assert sample.label == 42 736 + break 720 737 721 738 def test_indexed_path_with_mock_index(self): 722 739 """load_dataset with indexed path uses index lookup."""