A loose federation of distributed, typed datasets
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

updates to local indexing

+131 -37
+131 -37
src/atdata/local.py
··· 8 8 Dataset, 9 9 ) 10 10 11 + import os 11 12 from pathlib import Path 12 13 from uuid import uuid4 13 14 from tempfile import TemporaryDirectory 15 + import shutil 14 16 import subprocess 15 17 from dotenv import dotenv_values 16 18 import msgpack ··· 42 44 Optional, 43 45 Dict, 44 46 Type, 45 - TypeVar 47 + TypeVar, 48 + Generator, 46 49 ) 47 50 48 51 T = TypeVar( 'T', bound = PackableSample ) ··· 54 57 def _kind_str_for_sample_type( st: Type[PackableSample] ) -> str: 55 58 """TODO""" 56 59 return f'{st.__module__}.{st.__name__}' 60 + 61 + def _decode_bytes_dict( d: dict[bytes, bytes] ) -> dict[str, str]: 62 + """TODO""" 63 + return { 64 + k.decode('utf-8'): v.decode('utf-8') 65 + for k, v in d.items() 66 + } 57 67 58 68 59 69 ## ··· 159 169 ## 160 170 161 171 def insert( self, ds: Dataset[T], 172 + # 173 + cache_local: bool = False, 174 + # 162 175 **kwargs 163 - ) -> Dataset[T]: 176 + ) -> tuple[BasicIndexEntry, Dataset[T]]: 164 177 """TODO""" 165 178 179 + assert self.s3_credentials is not None 166 180 assert self.hive_bucket is not None 167 181 assert self.hive_path is not None 168 182 169 - with TemporaryDirectory() as tmpdir: 183 + new_uuid = str( uuid4() ) 170 184 171 - # Mount S3 filesystem 172 - mount_path = Path( tmpdir ) / 'atdata-local' / self.hive_bucket 173 - mount_cmd = [ 174 - 's3fs', 175 - self.hive_bucket, 176 - mount_path.as_posix() 177 - ] 178 - subprocess.run( mount_cmd, env = self.s3_credentials ) 185 + hive_fs = _s3_from_credentials( self.s3_credentials ) 179 186 180 - new_uuid = uuid4() 187 + # Write metadata 188 + metadata_path = ( 189 + self.hive_path 190 + / 'metadata' 191 + / f'atdata-metadata--{new_uuid}.msgpack' 192 + ) 193 + metadata_path.parent.mkdir( parents = True, exist_ok = True ) 181 194 182 - # Write metadata 183 - metadata_path = ( 184 - mount_path 185 - / 'metadata' 186 - / f'atdata-metadata--{new_uuid}.msgpack' 187 - ) 188 - with open( metadata_path, 'wb' ) as f: 189 - if ds.metadata is not None: 190 - # TODO Figure out how to make linting work better here 191 - f.write( msgpack.packb( ds.metadata ) ) 195 + if ds.metadata is not None: 196 + with hive_fs.open( metadata_path, 'wb' ) as f: 197 + # TODO Figure out how to make linting work better here 198 + f.write( msgpack.packb( ds.metadata ) ) 199 + 200 + 201 + # Write data 202 + shard_pattern = ( 203 + self.hive_path 204 + / f'atdata--{new_uuid}--%06d.tar' 205 + ).as_posix() 192 206 193 - # Write data 194 - shard_pattern = (mount_path / f'atdata--{new_uuid}--%06d.tar').as_posix() 207 + with TemporaryDirectory() as temp_dir: 208 + 209 + if cache_local: 210 + def _writer_opener( p: str ): 211 + local_cache_path = Path( temp_dir ) / p 212 + local_cache_path.parent.mkdir( parents = True, exist_ok = True ) 213 + return open( local_cache_path, 'wb' ) 214 + writer_opener = _writer_opener 215 + 216 + def _writer_post( p: str ): 217 + local_cache_path = Path( temp_dir ) / p 218 + 219 + # Copy to S3 220 + print( 'Copying file to s3 ...', end = '' ) 221 + with open( local_cache_path, 'rb' ) as f_in: 222 + with hive_fs.open( p, 'wb' ) as f_out: 223 + # TODO Linting issues 224 + f_out.write( f_in.read() ) 225 + print( ' done.' ) 226 + 227 + # Delete local cache file 228 + print( 'Deleting local cache file ...', end = '' ) 229 + os.remove( local_cache_path ) 230 + print( ' done.' ) 231 + writer_post = _writer_post 232 + 233 + else: 234 + writer_opener = lambda s: hive_fs.open( s, 'wb' ) 235 + writer_post = lambda s: written_shards.append( s ) 236 + 195 237 written_shards = [] 196 238 with wds.writer.ShardWriter( shard_pattern, 197 - post = lambda s: written_shards.append( s ), 239 + # opener = lambda s: hive_fs.open( s, 'wb' ), 240 + # post = lambda s: written_shards.append( s ), 241 + opener = writer_opener, 242 + post = writer_post, 198 243 **kwargs 199 244 ) as sink: 200 245 for sample in ds.ordered( batch_size = None ): 201 246 sink.write( sample.as_wds ) 202 - 203 - # Return created dataset 247 + 248 + # with TemporaryDirectory() as tmpdir: 249 + 250 + # # Mount S3 filesystem 251 + # mount_path = Path( tmpdir ) / 'atdata-s3' / self.hive_bucket 252 + # mount_path.mkdir( parents = True, exist_ok = True ) 253 + # s3fs_cmd = shutil.which( 's3fs' ) 254 + # mount_cmd = [ 255 + # s3fs_cmd, 256 + # self.hive_bucket, 257 + # mount_path.as_posix() 258 + # ] 259 + # result = subprocess.run( mount_cmd, env = self.s3_credentials ) 260 + # print( result ) 261 + 262 + # new_uuid = str( uuid4() ) 263 + 264 + # # Write metadata 265 + # metadata_path = ( 266 + # mount_path 267 + # / 'metadata' 268 + # / f'atdata-metadata--{new_uuid}.msgpack' 269 + # ) 270 + # metadata_path.parent.mkdir( parents = True, exist_ok = True ) 271 + # with open( metadata_path, 'wb' ) as f: 272 + # if ds.metadata is not None: 273 + # # TODO Figure out how to make linting work better here 274 + # f.write( msgpack.packb( ds.metadata ) ) 275 + 276 + # # Write data 277 + # shard_pattern = (Path( tmpdir ) / 'atdata-cache' / f'atdata--{new_uuid}--%06d.tar').as_posix() 278 + # written_shards = [] 279 + # with wds.writer.ShardWriter( shard_pattern, 280 + # opener = lambda s: 281 + # post = lambda s: written_shards.append( s ), 282 + # **kwargs 283 + # ) as sink: 284 + # for sample in ds.ordered( batch_size = None ): 285 + # sink.write( sample.as_wds ) 286 + 287 + # Make a new Dataset object for the written dataset copy 204 288 shard_s3_format = ( 205 289 ( 206 290 self.hive_path 207 291 / f'atdata--{new_uuid}' 208 292 ).as_posix() 209 293 ) + '--{shard_id}.tar' 210 - metadata_s3_path = ( 211 - self.hive_path 212 - / 'metadata' 213 - / f'atdata-metadata--{new_uuid}.msgpack' 214 - ) 215 294 shard_id_braced = '{' + f'{0:06d}..{len( written_shards ) - 1:06d}' + '}' 216 - return Dataset( 295 + 296 + new_dataset = Dataset[ds.sample_type]( 217 297 url = shard_s3_format.format( shard_id = shard_id_braced ), 218 298 metadata_url = metadata_path.as_posix(), 219 299 ) 220 300 301 + # Add to index 302 + new_entry = self.index.add_entry( new_dataset, uuid = new_uuid ) 303 + 304 + return new_entry, new_dataset 305 + 221 306 222 307 class Index: 223 308 """TODO""" ··· 240 325 # TODO this only works / is necessary for `redis_om`` 241 326 # Migrator().run() 242 327 243 - def list( self ): 328 + @property 329 + def all_entries( self ) -> list[BasicIndexEntry]: 330 + """TODO""" 331 + return list( self.entries ) 332 + 333 + @property 334 + def entries( self ) -> Generator[BasicIndexEntry, None, None]: 244 335 """TODO""" 245 336 ## 246 - ret = [] 247 337 for key in self._redis.scan_iter( match = 'BasicIndexEntry:*' ): 248 - ret.append( self._redis.hgetall( key ) ) 249 - return ret 338 + # TODO typing issue for `redis` 339 + cur_entry_data = _decode_bytes_dict( self._redis.hgetall( key ) ) 340 + cur_entry = BasicIndexEntry( **cur_entry_data ) 341 + yield cur_entry 342 + 343 + return 250 344 251 345 def add_entry( self, ds: Dataset, 252 346 uuid: str | None = None,