A loose federation of distributed, typed datasets
1"""Tests for the HuggingFace Datasets-style API (_hf_api.py)."""
2
3##
4# Imports
5
6import pytest
7
8import numpy as np
9import webdataset as wds
10
11import atdata
12from atdata._hf_api import (
13 load_dataset,
14 DatasetDict,
15 _is_brace_pattern,
16 _is_glob_pattern,
17 _is_remote_url,
18 _detect_split_from_path,
19 _shards_to_wds_url,
20 _expand_local_glob,
21 _resolve_shards,
22 _resolve_data_files,
23 _group_shards_by_split,
24 _is_indexed_path,
25 _parse_indexed_path,
26)
27from unittest.mock import Mock
28
29from numpy.typing import NDArray
30
31
32##
33# Test sample types
34
35
36@atdata.packable
37class SimpleTestSample:
38 """Simple sample type for testing."""
39
40 text: str
41 label: int
42
43
44@atdata.packable
45class NumpyTestSample:
46 """Sample type with numpy arrays for testing."""
47
48 embedding: NDArray
49 label: int
50
51
52##
53# Helper function tests
54
55
56class TestIsBracePattern:
57 """Tests for _is_brace_pattern()."""
58
59 def test_range_pattern(self):
60 assert _is_brace_pattern("data-{000000..000099}.tar") is True
61
62 def test_list_pattern(self):
63 assert _is_brace_pattern("data-{train,test,val}.tar") is True
64
65 def test_no_pattern(self):
66 assert _is_brace_pattern("data-000000.tar") is False
67
68 def test_empty_braces(self):
69 # Empty braces are not valid WebDataset brace notation
70 assert _is_brace_pattern("data-{}.tar") is False
71
72 def test_nested_path_with_pattern(self):
73 assert _is_brace_pattern("path/to/data-{000..099}.tar") is True
74
75
76class TestIsGlobPattern:
77 """Tests for _is_glob_pattern()."""
78
79 def test_asterisk(self):
80 assert _is_glob_pattern("data-*.tar") is True
81
82 def test_question_mark(self):
83 assert _is_glob_pattern("data-00000?.tar") is True
84
85 def test_no_pattern(self):
86 assert _is_glob_pattern("data-000000.tar") is False
87
88 def test_path_with_glob(self):
89 assert _is_glob_pattern("path/to/*.tar") is True
90
91
92class TestIsRemoteUrl:
93 """Tests for _is_remote_url()."""
94
95 def test_s3_url(self):
96 assert _is_remote_url("s3://bucket/path/data.tar") is True
97
98 def test_https_url(self):
99 assert _is_remote_url("https://example.com/data.tar") is True
100
101 def test_http_url(self):
102 assert _is_remote_url("http://example.com/data.tar") is True
103
104 def test_gs_url(self):
105 assert _is_remote_url("gs://bucket/path/data.tar") is True
106
107 def test_az_url(self):
108 assert _is_remote_url("az://container/path/data.tar") is True
109
110 def test_local_absolute_path(self):
111 assert _is_remote_url("/local/path/data.tar") is False
112
113 def test_local_relative_path(self):
114 assert _is_remote_url("./data/data.tar") is False
115
116 def test_windows_path(self):
117 assert _is_remote_url("C:\\data\\data.tar") is False
118
119
120class TestDetectSplitFromPath:
121 """Tests for _detect_split_from_path()."""
122
123 def test_train_in_filename(self):
124 assert _detect_split_from_path("dataset-train-000000.tar") == "train"
125
126 def test_test_in_filename(self):
127 assert _detect_split_from_path("dataset-test-000000.tar") == "test"
128
129 def test_validation_in_filename(self):
130 assert _detect_split_from_path("dataset-validation-000000.tar") == "validation"
131
132 def test_val_in_filename(self):
133 assert _detect_split_from_path("dataset-val-000000.tar") == "validation"
134
135 def test_dev_in_filename(self):
136 assert _detect_split_from_path("dataset-dev-000000.tar") == "validation"
137
138 def test_train_directory(self):
139 assert _detect_split_from_path("train/shard-000000.tar") == "train"
140
141 def test_test_directory(self):
142 assert _detect_split_from_path("test/shard-000000.tar") == "test"
143
144 def test_no_split_detected(self):
145 assert _detect_split_from_path("dataset-000000.tar") is None
146
147 def test_case_insensitive(self):
148 assert _detect_split_from_path("dataset-TRAIN-000000.tar") == "train"
149 assert _detect_split_from_path("dataset-Train-000000.tar") == "train"
150
151 def test_training_variant(self):
152 assert _detect_split_from_path("dataset-training-000000.tar") == "train"
153
154 def test_testing_variant(self):
155 assert _detect_split_from_path("dataset-testing-000000.tar") == "test"
156
157
158class TestShardsToWdsUrl:
159 """Tests for _shards_to_wds_url()."""
160
161 def test_single_shard(self):
162 assert _shards_to_wds_url(["data.tar"]) == "data.tar"
163
164 def test_multiple_shards_common_pattern(self):
165 shards = ["data-000.tar", "data-001.tar", "data-002.tar"]
166 result = _shards_to_wds_url(shards)
167 # Algorithm finds longest common prefix/suffix, resulting in compact notation
168 # Both "data-{000,001,002}.tar" and "data-00{0,1,2}.tar" are valid
169 assert "{" in result and "}" in result
170 assert ".tar" in result
171 assert "data-" in result
172
173 def test_multiple_shards_different_lengths(self):
174 shards = ["data-0.tar", "data-1.tar", "data-10.tar"]
175 result = _shards_to_wds_url(shards)
176 # Should still produce brace notation
177 assert "{" in result and "}" in result
178
179 def test_empty_list_raises(self):
180 with pytest.raises(ValueError, match="empty shard list"):
181 _shards_to_wds_url([])
182
183 def test_no_common_pattern(self):
184 shards = ["train.tar", "test.tar", "val.tar"]
185 result = _shards_to_wds_url(shards)
186 # Falls back to space-separated or brace notation
187 assert "train" in result
188
189
190class TestExpandLocalGlob:
191 """Tests for _expand_local_glob()."""
192
193 def test_no_matches(self, tmp_path):
194 pattern = str(tmp_path / "*.tar")
195 assert _expand_local_glob(pattern) == []
196
197 def test_matches_files(self, tmp_path):
198 # Create test files
199 (tmp_path / "data-000.tar").touch()
200 (tmp_path / "data-001.tar").touch()
201 (tmp_path / "data-002.tar").touch()
202
203 pattern = str(tmp_path / "*.tar")
204 result = _expand_local_glob(pattern)
205
206 assert len(result) == 3
207 assert all(".tar" in p for p in result)
208
209 def test_ignores_directories(self, tmp_path):
210 # Create a file and a directory
211 (tmp_path / "data.tar").touch()
212 (tmp_path / "subdir.tar").mkdir()
213
214 pattern = str(tmp_path / "*.tar")
215 result = _expand_local_glob(pattern)
216
217 assert len(result) == 1
218
219 def test_nonexistent_directory(self):
220 result = _expand_local_glob("/nonexistent/path/*.tar")
221 assert result == []
222
223
224class TestGroupShardsBySplit:
225 """Tests for _group_shards_by_split()."""
226
227 def test_single_split(self):
228 shards = [
229 "train-000.tar",
230 "train-001.tar",
231 "train-002.tar",
232 ]
233 result = _group_shards_by_split(shards)
234 assert "train" in result
235 assert len(result["train"]) == 3
236
237 def test_multiple_splits(self):
238 shards = [
239 "data-train-000.tar",
240 "data-train-001.tar",
241 "data-test-000.tar",
242 "data-val-000.tar",
243 ]
244 result = _group_shards_by_split(shards)
245 assert "train" in result
246 assert "test" in result
247 assert "validation" in result
248 assert len(result["train"]) == 2
249 assert len(result["test"]) == 1
250 assert len(result["validation"]) == 1
251
252 def test_no_detected_split_defaults_to_train(self):
253 shards = ["shard-000.tar", "shard-001.tar"]
254 result = _group_shards_by_split(shards)
255 assert "train" in result
256 assert len(result["train"]) == 2
257
258
259class TestResolveDataFiles:
260 """Tests for _resolve_data_files()."""
261
262 def test_string_input(self, tmp_path):
263 result = _resolve_data_files(str(tmp_path), "data.tar")
264 assert "train" in result
265 assert len(result["train"]) == 1
266
267 def test_list_input(self, tmp_path):
268 result = _resolve_data_files(str(tmp_path), ["a.tar", "b.tar"])
269 assert "train" in result
270 assert len(result["train"]) == 2
271
272 def test_dict_input(self, tmp_path):
273 data_files = {
274 "train": ["train-000.tar", "train-001.tar"],
275 "test": "test-000.tar",
276 }
277 result = _resolve_data_files(str(tmp_path), data_files)
278 assert "train" in result
279 assert "test" in result
280 assert len(result["train"]) == 2
281 assert len(result["test"]) == 1
282
283 def test_resolves_relative_paths(self, tmp_path):
284 result = _resolve_data_files(str(tmp_path), "subdir/data.tar")
285 assert str(tmp_path) in result["train"][0]
286
287
288class TestResolveShards:
289 """Tests for _resolve_shards()."""
290
291 def test_brace_pattern_passthrough(self):
292 path = "data-{000000..000099}.tar"
293 result = _resolve_shards(path)
294 assert "train" in result
295 assert path in result["train"]
296
297 def test_brace_pattern_with_split_name(self):
298 path = "data-train-{000..099}.tar"
299 result = _resolve_shards(path)
300 assert "train" in result
301
302 def test_single_file(self):
303 path = "data.tar"
304 result = _resolve_shards(path)
305 assert "train" in result
306 assert result["train"] == [path]
307
308 def test_with_data_files_override(self, tmp_path):
309 data_files = {"train": "train.tar", "test": "test.tar"}
310 result = _resolve_shards(str(tmp_path), data_files)
311 assert "train" in result
312 assert "test" in result
313
314 def test_local_directory(self, tmp_path):
315 # Create test tar files
316 (tmp_path / "train-000.tar").touch()
317 (tmp_path / "train-001.tar").touch()
318 (tmp_path / "test-000.tar").touch()
319
320 result = _resolve_shards(str(tmp_path))
321 assert "train" in result
322 assert "test" in result
323
324 def test_glob_pattern(self, tmp_path):
325 # Create test files
326 (tmp_path / "data-000.tar").touch()
327 (tmp_path / "data-001.tar").touch()
328
329 pattern = str(tmp_path / "*.tar")
330 result = _resolve_shards(pattern)
331 assert "train" in result # defaults to train when no split detected
332
333
334##
335# DatasetDict tests
336
337
338class TestDatasetDict:
339 """Tests for DatasetDict class."""
340
341 def test_empty_init(self):
342 dd = DatasetDict()
343 assert len(dd) == 0
344
345 def test_init_with_splits(self, tmp_path):
346 # Create a minimal tar file for Dataset
347 tar_path = tmp_path / "data.tar"
348 with wds.writer.TarWriter(str(tar_path)) as sink:
349 sample = SimpleTestSample(text="hello", label=1)
350 sink.write(sample.as_wds)
351
352 train_ds = atdata.Dataset[SimpleTestSample](str(tar_path))
353 test_ds = atdata.Dataset[SimpleTestSample](str(tar_path))
354
355 dd = DatasetDict({"train": train_ds, "test": test_ds})
356
357 assert len(dd) == 2
358 assert "train" in dd
359 assert "test" in dd
360
361 def test_getitem(self, tmp_path):
362 tar_path = tmp_path / "data.tar"
363 with wds.writer.TarWriter(str(tar_path)) as sink:
364 sample = SimpleTestSample(text="hello", label=1)
365 sink.write(sample.as_wds)
366
367 train_ds = atdata.Dataset[SimpleTestSample](str(tar_path))
368 dd = DatasetDict({"train": train_ds})
369
370 assert dd["train"] is train_ds
371
372 def test_setitem(self, tmp_path):
373 tar_path = tmp_path / "data.tar"
374 with wds.writer.TarWriter(str(tar_path)) as sink:
375 sample = SimpleTestSample(text="hello", label=1)
376 sink.write(sample.as_wds)
377
378 dd = DatasetDict()
379 train_ds = atdata.Dataset[SimpleTestSample](str(tar_path))
380 dd["train"] = train_ds
381
382 assert "train" in dd
383 assert dd["train"] is train_ds
384
385 def test_keys_values_items(self, tmp_path):
386 tar_path = tmp_path / "data.tar"
387 with wds.writer.TarWriter(str(tar_path)) as sink:
388 sample = SimpleTestSample(text="hello", label=1)
389 sink.write(sample.as_wds)
390
391 train_ds = atdata.Dataset[SimpleTestSample](str(tar_path))
392 test_ds = atdata.Dataset[SimpleTestSample](str(tar_path))
393
394 dd = DatasetDict({"train": train_ds, "test": test_ds})
395
396 assert set(dd.keys()) == {"train", "test"}
397 assert len(list(dd.values())) == 2
398 assert len(list(dd.items())) == 2
399
400 def test_streaming_property(self):
401 dd = DatasetDict(streaming=True)
402 assert dd.streaming is True
403
404 dd2 = DatasetDict(streaming=False)
405 assert dd2.streaming is False
406
407 def test_sample_type_explicit(self):
408 dd = DatasetDict(sample_type=SimpleTestSample)
409 assert dd.sample_type is SimpleTestSample
410
411 def test_num_shards(self, tmp_path):
412 # Create two tar files for train split
413 train_path = tmp_path / "train.tar"
414 with wds.writer.TarWriter(str(train_path)) as sink:
415 sample = SimpleTestSample(text="hello", label=1)
416 sink.write(sample.as_wds)
417
418 train_ds = atdata.Dataset[SimpleTestSample](str(train_path))
419 dd = DatasetDict({"train": train_ds})
420
421 num_shards = dd.num_shards
422 assert "train" in num_shards
423 assert num_shards["train"] == 1
424
425
426##
427# load_dataset tests
428
429
430class TestLoadDataset:
431 """Tests for load_dataset() function."""
432
433 def test_load_single_file_with_split(self, tmp_path):
434 """Load a single tar file specifying a split."""
435 tar_path = tmp_path / "data.tar"
436 with wds.writer.TarWriter(str(tar_path)) as sink:
437 for i in range(10):
438 sample = SimpleTestSample(text=f"sample_{i}", label=i)
439 sink.write(sample.as_wds)
440
441 ds = load_dataset(str(tar_path), SimpleTestSample, split="train")
442
443 assert isinstance(ds, atdata.Dataset)
444 # Verify we can iterate
445 samples = list(ds.ordered(batch_size=None))
446 assert len(samples) == 10
447
448 def test_load_returns_dataset_dict_without_split(self, tmp_path):
449 """Without split parameter, returns DatasetDict."""
450 tar_path = tmp_path / "data.tar"
451 with wds.writer.TarWriter(str(tar_path)) as sink:
452 sample = SimpleTestSample(text="hello", label=1)
453 sink.write(sample.as_wds)
454
455 result = load_dataset(str(tar_path), SimpleTestSample)
456
457 assert isinstance(result, DatasetDict)
458 assert "train" in result
459
460 def test_load_with_data_files_dict(self, tmp_path):
461 """Load with explicit data_files mapping."""
462 # Create train and test files
463 train_path = tmp_path / "train.tar"
464 test_path = tmp_path / "test.tar"
465
466 with wds.writer.TarWriter(str(train_path)) as sink:
467 for i in range(5):
468 sample = SimpleTestSample(text=f"train_{i}", label=i)
469 sink.write(sample.as_wds)
470
471 with wds.writer.TarWriter(str(test_path)) as sink:
472 for i in range(3):
473 sample = SimpleTestSample(text=f"test_{i}", label=i)
474 sink.write(sample.as_wds)
475
476 result = load_dataset(
477 str(tmp_path),
478 SimpleTestSample,
479 data_files={"train": "train.tar", "test": "test.tar"},
480 )
481
482 assert isinstance(result, DatasetDict)
483 assert "train" in result
484 assert "test" in result
485
486 def test_load_nonexistent_split_raises(self, tmp_path):
487 """Requesting a split that doesn't exist raises ValueError."""
488 tar_path = tmp_path / "train.tar"
489 with wds.writer.TarWriter(str(tar_path)) as sink:
490 sample = SimpleTestSample(text="hello", label=1)
491 sink.write(sample.as_wds)
492
493 with pytest.raises(ValueError, match="Split 'test' not found"):
494 load_dataset(str(tar_path), SimpleTestSample, split="test")
495
496 def test_load_directory_with_split_detection(self, tmp_path):
497 """Load from directory auto-detecting splits from filenames."""
498 # Create files with split names
499 train_path = tmp_path / "data-train-000.tar"
500 test_path = tmp_path / "data-test-000.tar"
501
502 with wds.writer.TarWriter(str(train_path)) as sink:
503 for i in range(5):
504 sample = SimpleTestSample(text=f"train_{i}", label=i)
505 sink.write(sample.as_wds)
506
507 with wds.writer.TarWriter(str(test_path)) as sink:
508 for i in range(3):
509 sample = SimpleTestSample(text=f"test_{i}", label=i)
510 sink.write(sample.as_wds)
511
512 result = load_dataset(str(tmp_path), SimpleTestSample)
513
514 assert isinstance(result, DatasetDict)
515 assert "train" in result
516 assert "test" in result
517
518 def test_load_with_streaming_flag(self, tmp_path):
519 """streaming=True sets the streaming property."""
520 tar_path = tmp_path / "data.tar"
521 with wds.writer.TarWriter(str(tar_path)) as sink:
522 sample = SimpleTestSample(text="hello", label=1)
523 sink.write(sample.as_wds)
524
525 result = load_dataset(str(tar_path), SimpleTestSample, streaming=True)
526
527 assert isinstance(result, DatasetDict)
528 assert result.streaming is True
529
530 def test_load_with_numpy_sample_type(self, tmp_path):
531 """Load dataset with numpy arrays in samples."""
532 tar_path = tmp_path / "data.tar"
533 with wds.writer.TarWriter(str(tar_path)) as sink:
534 for i in range(5):
535 sample = NumpyTestSample(
536 embedding=np.random.randn(128).astype(np.float32), label=i
537 )
538 sink.write(sample.as_wds)
539
540 ds = load_dataset(str(tar_path), NumpyTestSample, split="train")
541 samples = list(ds.ordered(batch_size=None))
542
543 assert len(samples) == 5
544 assert isinstance(samples[0].embedding, np.ndarray)
545 assert samples[0].embedding.shape == (128,)
546
547 def test_load_glob_pattern(self, tmp_path):
548 """Load using glob pattern."""
549 # Create multiple shard files
550 for i in range(3):
551 shard_path = tmp_path / f"data-{i:03d}.tar"
552 with wds.writer.TarWriter(str(shard_path)) as sink:
553 sample = SimpleTestSample(text=f"shard_{i}", label=i)
554 sink.write(sample.as_wds)
555
556 pattern = str(tmp_path / "*.tar")
557 result = load_dataset(pattern, SimpleTestSample)
558
559 assert isinstance(result, DatasetDict)
560 assert "train" in result
561
562 def test_load_brace_notation(self, tmp_path):
563 """Load using WebDataset brace notation."""
564 # Create sharded files
565 for i in range(3):
566 shard_path = tmp_path / f"data-{i:06d}.tar"
567 with wds.writer.TarWriter(str(shard_path)) as sink:
568 for j in range(2):
569 sample = SimpleTestSample(text=f"shard_{i}_sample_{j}", label=j)
570 sink.write(sample.as_wds)
571
572 # Use brace notation
573 pattern = str(tmp_path / "data-{000000..000002}.tar")
574 ds = load_dataset(pattern, SimpleTestSample, split="train")
575
576 assert isinstance(ds, atdata.Dataset)
577 samples = list(ds.ordered(batch_size=None))
578 assert len(samples) == 6 # 3 shards * 2 samples each
579
580 def test_load_empty_directory_raises(self, tmp_path):
581 """Loading from empty directory raises FileNotFoundError."""
582 empty_dir = tmp_path / "empty"
583 empty_dir.mkdir()
584
585 with pytest.raises(FileNotFoundError):
586 load_dataset(str(empty_dir), SimpleTestSample)
587
588
589##
590# Integration tests
591
592
593class TestLoadDatasetIntegration:
594 """Integration tests combining multiple features."""
595
596 def test_full_workflow_train_test_split(self, tmp_path):
597 """Full workflow: create sharded dataset, load with splits, iterate."""
598 # Create train shards
599 for i in range(2):
600 shard_path = tmp_path / f"train-{i:03d}.tar"
601 with wds.writer.TarWriter(str(shard_path)) as sink:
602 for j in range(5):
603 sample = SimpleTestSample(text=f"train_{i}_{j}", label=j)
604 sink.write(sample.as_wds)
605
606 # Create test shard
607 test_path = tmp_path / "test-000.tar"
608 with wds.writer.TarWriter(str(test_path)) as sink:
609 for j in range(3):
610 sample = SimpleTestSample(text=f"test_{j}", label=j)
611 sink.write(sample.as_wds)
612
613 # Load dataset
614 ds = load_dataset(str(tmp_path), SimpleTestSample)
615
616 # Verify structure
617 assert "train" in ds
618 assert "test" in ds
619
620 # Iterate train
621 train_samples = list(ds["train"].ordered(batch_size=None))
622 assert len(train_samples) == 10 # 2 shards * 5 samples
623
624 # Iterate test
625 test_samples = list(ds["test"].ordered(batch_size=None))
626 assert len(test_samples) == 3
627
628 def test_batched_iteration(self, tmp_path):
629 """Test batched iteration through loaded dataset."""
630 tar_path = tmp_path / "data.tar"
631 with wds.writer.TarWriter(str(tar_path)) as sink:
632 for i in range(20):
633 sample = SimpleTestSample(text=f"sample_{i}", label=i % 5)
634 sink.write(sample.as_wds)
635
636 ds = load_dataset(str(tar_path), SimpleTestSample, split="train")
637
638 batches = list(ds.ordered(batch_size=4))
639 assert len(batches) == 5 # 20 samples / 4 per batch
640
641 # Check batch structure
642 first_batch = batches[0]
643 assert len(first_batch.samples) == 4
644 # Aggregated attributes
645 labels = first_batch.label
646 assert len(labels) == 4
647
648
649##
650# Indexed path tests
651
652
653class TestIsIndexedPath:
654 """Tests for _is_indexed_path function."""
655
656 def test_at_handle_path(self):
657 """@handle/dataset is indexed."""
658 assert _is_indexed_path("@maxine.science/mnist") is True
659
660 def test_at_did_path(self):
661 """@did:plc:abc/dataset is indexed."""
662 assert _is_indexed_path("@did:plc:abc123/my-dataset") is True
663
664 def test_local_path(self):
665 """Local paths are not indexed."""
666 assert _is_indexed_path("/path/to/data.tar") is False
667
668 def test_s3_path(self):
669 """S3 URLs are not indexed."""
670 assert _is_indexed_path("s3://bucket/data.tar") is False
671
672 def test_relative_path(self):
673 """Relative paths are not indexed."""
674 assert _is_indexed_path("./data/train.tar") is False
675
676
677class TestParseIndexedPath:
678 """Tests for _parse_indexed_path function."""
679
680 def test_parse_handle_dataset(self):
681 """Parse @handle/dataset format."""
682 handle, name = _parse_indexed_path("@maxine.science/mnist")
683 assert handle == "maxine.science"
684 assert name == "mnist"
685
686 def test_parse_did_dataset(self):
687 """Parse @did:plc:xxx/dataset format."""
688 handle, name = _parse_indexed_path("@did:plc:abc123/my-dataset")
689 assert handle == "did:plc:abc123"
690 assert name == "my-dataset"
691
692 def test_parse_invalid_no_slash(self):
693 """Invalid path without slash raises ValueError."""
694 with pytest.raises(ValueError, match="Invalid indexed path format"):
695 _parse_indexed_path("@handle-only")
696
697 def test_parse_invalid_no_at(self):
698 """Path without @ raises ValueError."""
699 with pytest.raises(ValueError, match="Not an indexed path"):
700 _parse_indexed_path("handle/dataset")
701
702 def test_parse_invalid_empty_parts(self):
703 """Empty handle or dataset raises ValueError."""
704 with pytest.raises(ValueError, match="Invalid indexed path"):
705 _parse_indexed_path("@/dataset")
706
707
708class TestLoadDatasetWithIndex:
709 """Tests for load_dataset with index parameter."""
710
711 def test_indexed_path_requires_index(self):
712 """@handle/dataset without index raises ValueError."""
713 with pytest.raises(ValueError, match="Index required"):
714 load_dataset("@handle/dataset", SimpleTestSample)
715
716 def test_none_sample_type_defaults_to_dictsample(self, tmp_path):
717 """sample_type=None returns Dataset[DictSample]."""
718 from atdata import DictSample
719
720 # Create a test tar file
721 tar_path = tmp_path / "data.tar"
722 sample = SimpleTestSample(text="hello", label=42)
723 with wds.writer.TarWriter(str(tar_path)) as writer:
724 writer.write(sample.as_wds)
725
726 # Load without specifying sample_type
727 ds = load_dataset(str(tar_path), split="train")
728
729 # Should return Dataset[DictSample]
730 assert ds.sample_type == DictSample
731
732 # Should be able to iterate and access fields
733 for sample in ds.ordered():
734 assert sample["text"] == "hello"
735 assert sample.label == 42
736 break
737
738 def test_indexed_path_with_mock_index(self):
739 """load_dataset with indexed path uses index lookup."""
740 mock_index = Mock()
741 mock_index.data_store = None # No data store, so no URL transformation
742 mock_entry = Mock()
743 mock_entry.data_urls = ["s3://bucket/data.tar"]
744 mock_entry.schema_ref = "local://schemas/test@1.0.0"
745 mock_index.get_dataset.return_value = mock_entry
746
747 # Need to mock decode_schema since sample_type is provided
748 ds = load_dataset(
749 "@local/my-dataset",
750 SimpleTestSample,
751 index=mock_index,
752 split="train",
753 )
754
755 mock_index.get_dataset.assert_called_once_with("my-dataset")
756 assert ds.url == "s3://bucket/data.tar"
757
758 def test_indexed_path_auto_type_resolution(self):
759 """load_dataset with sample_type=None uses decode_schema."""
760 mock_index = Mock()
761 mock_index.data_store = None # No data store, so no URL transformation
762 mock_entry = Mock()
763 mock_entry.data_urls = ["s3://bucket/data.tar"]
764 mock_entry.schema_ref = "local://schemas/test@1.0.0"
765 mock_index.get_dataset.return_value = mock_entry
766 mock_index.decode_schema.return_value = SimpleTestSample
767
768 ds = load_dataset(
769 "@local/my-dataset",
770 None,
771 index=mock_index,
772 split="train",
773 )
774
775 mock_index.decode_schema.assert_called_once_with("local://schemas/test@1.0.0")
776 assert ds.sample_type == SimpleTestSample
777
778 def test_indexed_path_returns_datasetdict_without_split(self):
779 """load_dataset with indexed path returns DatasetDict when split=None."""
780 mock_index = Mock()
781 mock_index.data_store = None # No data store, so no URL transformation
782 mock_entry = Mock()
783 mock_entry.data_urls = ["s3://bucket/data.tar"]
784 mock_entry.schema_ref = "local://schemas/test@1.0.0"
785 mock_index.get_dataset.return_value = mock_entry
786
787 result = load_dataset(
788 "@local/my-dataset",
789 SimpleTestSample,
790 index=mock_index,
791 )
792
793 assert isinstance(result, DatasetDict)
794 assert "train" in result
795
796 def test_indexed_path_transforms_urls_via_data_store(self):
797 """load_dataset transforms URLs through data_store.read_url() if available."""
798 mock_data_store = Mock()
799 mock_data_store.read_url.return_value = "https://r2.example.com/bucket/data.tar"
800
801 mock_index = Mock()
802 mock_index.data_store = mock_data_store
803 mock_entry = Mock()
804 mock_entry.data_urls = ["s3://bucket/data.tar"]
805 mock_entry.schema_ref = "local://schemas/test@1.0.0"
806 mock_index.get_dataset.return_value = mock_entry
807
808 ds = load_dataset(
809 "@local/my-dataset",
810 SimpleTestSample,
811 index=mock_index,
812 split="train",
813 )
814
815 # Verify read_url was called to transform the URL
816 mock_data_store.read_url.assert_called_once_with("s3://bucket/data.tar")
817 # Verify the transformed URL is used
818 assert ds.url == "https://r2.example.com/bucket/data.tar"
819
820 def test_indexed_path_no_transform_without_data_store(self):
821 """load_dataset uses URLs unchanged when index has no data_store."""
822 mock_index = Mock()
823 mock_index.data_store = None
824 mock_entry = Mock()
825 mock_entry.data_urls = ["s3://bucket/data.tar"]
826 mock_entry.schema_ref = "local://schemas/test@1.0.0"
827 mock_index.get_dataset.return_value = mock_entry
828
829 ds = load_dataset(
830 "@local/my-dataset",
831 SimpleTestSample,
832 index=mock_index,
833 split="train",
834 )
835
836 # URL should be unchanged
837 assert ds.url == "s3://bucket/data.tar"
838
839 def test_indexed_path_creates_s3source_with_credentials(self):
840 """load_dataset creates S3Source with credentials when S3DataStore is available."""
841 from atdata.local import S3DataStore
842 from atdata._sources import S3Source
843
844 # Create a real S3DataStore with mock credentials
845 mock_credentials = {
846 "AWS_ACCESS_KEY_ID": "test-access-key",
847 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
848 "AWS_ENDPOINT": "https://r2.example.com",
849 }
850
851 # Mock the S3DataStore
852 mock_store = Mock(spec=S3DataStore)
853 mock_store.credentials = mock_credentials
854
855 mock_index = Mock()
856 mock_index.data_store = mock_store
857 mock_entry = Mock()
858 mock_entry.data_urls = [
859 "s3://my-bucket/train-000.tar",
860 "s3://my-bucket/train-001.tar",
861 ]
862 mock_entry.schema_ref = "local://schemas/test@1.0.0"
863 mock_index.get_dataset.return_value = mock_entry
864
865 ds = load_dataset(
866 "@local/my-dataset",
867 SimpleTestSample,
868 index=mock_index,
869 split="train",
870 )
871
872 # Verify the dataset source is an S3Source with credentials
873 assert isinstance(ds.source, S3Source)
874 assert ds.source.bucket == "my-bucket"
875 assert ds.source.keys == ["train-000.tar", "train-001.tar"]
876 assert ds.source.endpoint == "https://r2.example.com"
877 assert ds.source.access_key == "test-access-key"
878 assert ds.source.secret_key == "test-secret-key"