A loose federation of distributed, typed datasets
1"""Tests for atdata._helpers module."""
2
3import numpy as np
4import pytest
5
6from atdata._helpers import array_to_bytes, bytes_to_array
7
8
9class TestArraySerialization:
10 """Test array_to_bytes and bytes_to_array round-trip serialization."""
11
12 @pytest.mark.parametrize(
13 "dtype",
14 [
15 np.float32,
16 np.float64,
17 np.int32,
18 np.int64,
19 np.uint8,
20 np.bool_,
21 np.complex64,
22 ],
23 )
24 def test_dtype_preservation(self, dtype):
25 """Verify dtype is preserved through serialization."""
26 original = np.array([1, 2, 3], dtype=dtype)
27 serialized = array_to_bytes(original)
28 restored = bytes_to_array(serialized)
29
30 assert restored.dtype == original.dtype
31 np.testing.assert_array_equal(restored, original)
32
33 @pytest.mark.parametrize(
34 "shape",
35 [
36 (10,),
37 (3, 4),
38 (2, 3, 4),
39 (1, 1, 1, 1),
40 ],
41 )
42 def test_shape_preservation(self, shape):
43 """Verify shape is preserved through serialization."""
44 original = np.random.rand(*shape).astype(np.float32)
45 serialized = array_to_bytes(original)
46 restored = bytes_to_array(serialized)
47
48 assert restored.shape == original.shape
49 np.testing.assert_array_almost_equal(restored, original)
50
51 def test_empty_array(self):
52 """Verify empty arrays serialize correctly."""
53 original = np.array([], dtype=np.float32)
54 serialized = array_to_bytes(original)
55 restored = bytes_to_array(serialized)
56
57 assert restored.shape == (0,)
58 assert restored.dtype == np.float32
59
60 def test_scalar_array(self):
61 """Verify 0-dimensional arrays serialize correctly."""
62 original = np.array(42.0)
63 serialized = array_to_bytes(original)
64 restored = bytes_to_array(serialized)
65
66 assert restored.shape == ()
67 assert restored == 42.0
68
69 def test_large_array(self):
70 """Verify large arrays serialize correctly."""
71 original = np.random.rand(100, 100).astype(np.float32)
72 serialized = array_to_bytes(original)
73 restored = bytes_to_array(serialized)
74
75 np.testing.assert_array_almost_equal(restored, original)
76
77 def test_contiguous_and_noncontiguous(self):
78 """Verify non-contiguous arrays serialize correctly."""
79 original = np.random.rand(10, 10).astype(np.float32)
80 non_contiguous = original[::2, ::2] # Strided view
81
82 assert not non_contiguous.flags["C_CONTIGUOUS"]
83
84 serialized = array_to_bytes(non_contiguous)
85 restored = bytes_to_array(serialized)
86
87 np.testing.assert_array_almost_equal(restored, non_contiguous)
88
89 def test_bytes_output_type(self):
90 """Verify array_to_bytes returns bytes."""
91 arr = np.array([1, 2, 3])
92 result = array_to_bytes(arr)
93 assert isinstance(result, bytes)
94
95 def test_ndarray_output_type(self):
96 """Verify bytes_to_array returns ndarray."""
97 arr = np.array([1, 2, 3])
98 serialized = array_to_bytes(arr)
99 result = bytes_to_array(serialized)
100 assert isinstance(result, np.ndarray)