A loose federation of distributed, typed datasets
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

implementation on local combined repo upload

+181 -15
+1
.vscode/settings.json
··· 1 1 { 2 2 "cSpell.words": [ 3 3 "atdata", 4 + "creds", 4 5 "getattr", 5 6 "hgetall", 6 7 "msgpack",
+2
pyproject.toml
··· 14 14 "ormsgpack>=1.11.0", 15 15 "pandas>=2.3.3", 16 16 "pydantic>=2.12.5", 17 + "python-dotenv>=1.2.1", 17 18 "redis-om>=0.3.5", 19 + "s3fs>=2025.12.0", 18 20 "schemamodels>=0.9.1", 19 21 "tqdm>=4.67.1", 20 22 "webdataset>=1.0.2",
+4 -3
src/atdata/dataset.py
··· 494 494 495 495 # 496 496 497 - def __init__( self, url: str ) -> None: 497 + def __init__( self, url: str, 498 + metadata_url: str | None = None, 499 + ) -> None: 498 500 """Create a dataset from a WebDataset URL. 499 501 500 502 Args: ··· 510 512 """ 511 513 512 514 self._metadata: dict[str, Any] | None = None 513 - 514 - self.metadata_url: str | None = None 515 + self.metadata_url: str | None = metadata_url 515 516 """TODO""" 516 517 517 518 # Allow addition of automatic transformation of raw underlying data
+174 -12
src/atdata/local.py
··· 8 8 Dataset, 9 9 ) 10 10 11 + from pathlib import Path 11 12 from uuid import uuid4 13 + from tempfile import TemporaryDirectory 14 + import subprocess 15 + from dotenv import dotenv_values 16 + import msgpack 12 17 13 18 # from redis_om import ( 14 19 # EmbeddedJsonModel, ··· 21 26 Redis, 22 27 ) 23 28 29 + from s3fs import ( 30 + S3FileSystem, 31 + ) 32 + 33 + import webdataset as wds 34 + 24 35 from dataclasses import ( 25 36 dataclass, 26 37 asdict, ··· 31 42 Optional, 32 43 Dict, 33 44 Type, 45 + TypeVar 34 46 ) 35 47 48 + T = TypeVar( 'T', bound = PackableSample ) 49 + 36 50 37 51 ## 38 - # Heplers 52 + # Helpers 39 53 40 54 def _kind_str_for_sample_type( st: Type[PackableSample] ) -> str: 41 55 """TODO""" ··· 61 75 uuid: str | None = field( default_factory = lambda: str( uuid4() ) ) 62 76 """TODO""" 63 77 64 - def save_to( self, redis: Redis ): 78 + def write_to( self, redis: Redis ): 65 79 """TODO""" 66 80 save_key = f'BasicIndexEntry:{self.uuid}' 67 81 # TODO figure out how to get linting to work correctly here 68 82 redis.hset( save_key, mapping = asdict( self ) ) 69 - 83 + 84 + def _s3_env( credentials_path: str | Path ) -> dict[str, Any]: 85 + """TODO""" 86 + ## 87 + credentials_path = Path( credentials_path ) 88 + env_values = dotenv_values( credentials_path ) 89 + assert 'AWS_ENDPOINT' in env_values 90 + assert 'AWS_ACCESS_KEY_ID' in env_values 91 + assert 'AWS_SECRET_ACCESS_KEY' in env_values 92 + 93 + return { 94 + k: env_values[k] 95 + for k in ( 96 + 'AWS_ENDPOINT', 97 + 'AWS_ACCESS_KEY_ID', 98 + 'AWS_SECRET_ACCESS_KEY', 99 + ) 100 + } 101 + 102 + def _s3_from_credentials( creds: str | Path | dict ) -> S3FileSystem: 103 + """TODO""" 104 + ## 105 + if not isinstance( creds, dict ): 106 + creds = _s3_env( creds ) 107 + 108 + return S3FileSystem( 109 + endpoint_url = creds['AWS_ENDPOINT'], 110 + key = creds['AWS_ACCESS_KEY_ID'], 111 + secret = creds['AWS_SECRET_ACCESS_KEY'] 112 + ) 113 + 70 114 71 115 ## 72 116 # Classes 73 117 118 + class Repo: 119 + """TODO""" 120 + 121 + ## 122 + 123 + def __init__( self, 124 + # 125 + s3_credentials: str | Path | dict[str, Any] | None = None, 126 + hive_path: str | Path | None = None, 127 + redis: Redis | None = None, 128 + # 129 + # 130 + **kwargs 131 + ) -> None: 132 + """TODO""" 133 + 134 + if s3_credentials is None: 135 + self.s3_credentials = None 136 + elif isinstance( s3_credentials, dict ): 137 + self.s3_credentials = s3_credentials 138 + else: 139 + self.s3_credentials = _s3_env( s3_credentials ) 140 + 141 + if self.s3_credentials is None: 142 + self.bucket_fs = None 143 + else: 144 + self.bucket_fs = _s3_from_credentials( self.s3_credentials ) 145 + 146 + if self.bucket_fs is not None: 147 + if hive_path is None: 148 + raise ValueError( 'Must specify hive path within bucket' ) 149 + self.hive_path = Path( hive_path ) 150 + self.hive_bucket = self.hive_path.parts[0] 151 + else: 152 + self.hive_path = None 153 + self.hive_bucket = None 154 + 155 + # 156 + 157 + self.index = Index( redis = redis ) 158 + 159 + ## 160 + 161 + def insert( self, ds: Dataset[T], 162 + **kwargs 163 + ) -> Dataset[T]: 164 + """TODO""" 165 + 166 + assert self.hive_bucket is not None 167 + assert self.hive_path is not None 168 + 169 + with TemporaryDirectory() as tmpdir: 170 + 171 + # Mount S3 filesystem 172 + mount_path = Path( tmpdir ) / 'atdata-local' / self.hive_bucket 173 + mount_cmd = [ 174 + 's3fs', 175 + self.hive_bucket, 176 + mount_path.as_posix() 177 + ] 178 + subprocess.run( mount_cmd, env = self.s3_credentials ) 179 + 180 + new_uuid = uuid4() 181 + 182 + # Write metadata 183 + metadata_path = ( 184 + mount_path 185 + / 'metadata' 186 + / f'atdata-metadata--{new_uuid}.msgpack' 187 + ) 188 + with open( metadata_path, 'wb' ) as f: 189 + if ds.metadata is not None: 190 + # TODO Figure out how to make linting work better here 191 + f.write( msgpack.packb( ds.metadata ) ) 192 + 193 + # Write data 194 + shard_pattern = (mount_path / f'atdata--{new_uuid}--%06d.tar').as_posix() 195 + written_shards = [] 196 + with wds.writer.ShardWriter( shard_pattern, 197 + post = lambda s: written_shards.append( s ), 198 + **kwargs 199 + ) as sink: 200 + for sample in ds.ordered( batch_size = None ): 201 + sink.write( sample.as_wds ) 202 + 203 + # Return created dataset 204 + shard_s3_format = ( 205 + ( 206 + self.hive_path 207 + / f'atdata--{new_uuid}' 208 + ).as_posix() 209 + ) + '--{shard_id}.tar' 210 + metadata_s3_path = ( 211 + self.hive_path 212 + / 'metadata' 213 + / f'atdata-metadata--{new_uuid}.msgpack' 214 + ) 215 + shard_id_braced = '{' + f'{0:06d}..{len( written_shards ) - 1:06d}' + '}' 216 + return Dataset( 217 + url = shard_s3_format.format( shard_id = shard_id_braced ), 218 + metadata_url = metadata_path.as_posix(), 219 + ) 220 + 221 + 74 222 class Index: 75 223 """TODO""" 76 224 ··· 86 234 if redis is not None: 87 235 self._redis = redis 88 236 else: 89 - self._redis = Redis() 237 + self._redis = Redis( **kwargs ) 90 238 91 239 # needed before we can do anything with `redis` 92 240 # TODO this only works / is necessary for `redis_om`` ··· 100 248 ret.append( self._redis.hgetall( key ) ) 101 249 return ret 102 250 103 - def add( self, ds: Dataset ) -> BasicIndexEntry: 251 + def add_entry( self, ds: Dataset, 252 + uuid: str | None = None, 253 + ) -> BasicIndexEntry: 104 254 """TODO""" 105 255 ## 106 256 temp_sample_kind = _kind_str_for_sample_type( ds.sample_type ) 107 257 108 - ret_data = BasicIndexEntry( 109 - wds_url = ds.url, 110 - sample_kind = temp_sample_kind, 111 - metadata_url = ds.metadata_url, 112 - ) 113 - ret_data.save_to( self._redis ) 258 + if uuid is None: 259 + ret_data = BasicIndexEntry( 260 + wds_url = ds.url, 261 + sample_kind = temp_sample_kind, 262 + metadata_url = ds.metadata_url, 263 + ) 264 + else: 265 + ret_data = BasicIndexEntry( 266 + wds_url = ds.url, 267 + sample_kind = temp_sample_kind, 268 + metadata_url = ds.metadata_url, 269 + uuid = uuid, 270 + ) 271 + 272 + ret_data.write_to( self._redis ) 273 + 274 + return ret_data 275 + 114 276 115 - return ret_data 277 + #