Mirror of https://github.com/roostorg/osprey github.com/roostorg/osprey
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

at b6706a7f83a8d0c9068f026d2a5f0c7a7be250f4 125 lines 5.0 kB view raw
1from contextlib import contextmanager 2from typing import Any, Generator 3 4from osprey.engine.language_types.entities import EntityT 5from osprey.worker.lib.osprey_shared.labels import EntityLabels 6from osprey.worker.lib.osprey_shared.logging import get_logger 7from osprey.worker.lib.storage.labels import LabelsServiceBase 8from osprey.worker.lib.storage.postgres import Model, init_from_config, scoped_session 9from sqlalchemy import Column, String, select 10from sqlalchemy.dialects.postgresql import JSONB, insert 11 12logger = get_logger(__name__) 13 14 15class EntityLabelsModel(Model): 16 """SQLAlchemy model for storing entity labels in PostgreSQL""" 17 18 __tablename__ = 'entity_labels' 19 20 entity_key = Column(String, primary_key=True) 21 labels = Column(JSONB, nullable=False) 22 23 def __str__(self) -> str: 24 return f'EntityLabelsModel(entity_key={self.entity_key}, labels={self.labels})' 25 26 27class PostgresLabelsService(LabelsServiceBase): 28 """ 29 PostgreSQL-backed implementation of LabelsServiceBase. 30 31 This service stores entity labels in a PostgreSQL database using SQLAlchemy. 32 It provides atomic read-modify-write operations through database transactions. 33 """ 34 35 def __init__(self, database: str = 'osprey_db') -> None: 36 """ 37 Initialize the PostgreSQL labels service. 38 Note: This will not init the postgres connection; To do that, 39 initialize() must be called (which is called by the LabelsProvider 40 by default) 41 42 Args: 43 database: The database name to use. Defaults to 'osprey_db'. 44 """ 45 super().__init__() 46 self._database_name: str = database 47 48 def initialize(self) -> None: 49 init_from_config(self._database_name) 50 logger.info(f'Initialized PostgresLabelsService with database: {self._database_name}') 51 52 def read_labels(self, entity: EntityT[Any]) -> EntityLabels: 53 """ 54 Read labels for an entity from PostgreSQL. 55 56 Returns an empty EntityLabels if the entity has no labels. 57 """ 58 entity_key = str(entity) 59 60 with scoped_session(database=self._database_name) as session: 61 stmt = select(EntityLabelsModel).where(EntityLabelsModel.entity_key == entity_key) 62 result = session.scalars(stmt).first() 63 64 if result is None: 65 logger.debug(f'No labels found for entity {entity_key}') 66 return EntityLabels() 67 68 labels = EntityLabels.deserialize(result.labels) 69 logger.debug(f'Read labels for entity {entity_key}', result) 70 return labels 71 72 @contextmanager 73 def read_modify_write_labels_atomically(self, entity: EntityT[Any]) -> Generator[EntityLabels, None, None]: 74 """ 75 Context manager for atomic read-modify-write operations. 76 77 This context manager: 78 1. Opens a database transaction 79 2. Acquires a row-level lock using SELECT FOR UPDATE 80 3. Reads and returns the current labels 81 4. Yields control to the caller (LabelsProvider) 82 5. The caller modifies the labels IN PLACE 83 6. On exit, writes the modified labels and commits the transaction 84 85 The key insight: The caller modifies the yielded labels object directly, 86 and this context manager persists those changes atomically. 87 88 For systems that don't need locking (e.g., in-memory stores), this can 89 be simplified to: 90 ```py 91 labels = self.read_labels(entity) 92 yield labels 93 # write the labels here 94 """ 95 entity_key = str(entity) 96 97 with scoped_session(commit=False, database=self._database_name) as session: 98 try: 99 # Use SELECT FOR UPDATE to acquire a row-level lock 100 stmt = select(EntityLabelsModel).where(EntityLabelsModel.entity_key == entity_key).with_for_update() 101 result = session.scalars(stmt).first() 102 103 if result is None: 104 labels = EntityLabels() 105 else: 106 labels = EntityLabels.deserialize(result.labels) 107 108 # Yield control - The default LabelsProvider will modify the labels IN PLACE 109 yield labels 110 111 # After yield, write the modified labels back 112 labels_dict = labels.serialize() 113 upsert_stmt = insert(EntityLabelsModel).values(entity_key=entity_key, labels=labels_dict) 114 upsert_stmt = upsert_stmt.on_conflict_do_update( 115 index_elements=['entity_key'], set_={EntityLabelsModel.labels: labels_dict} 116 ) 117 session.execute(upsert_stmt) 118 119 session.commit() 120 logger.debug(f'Committed atomic read-modify-write for entity {entity_key}', labels_dict) 121 122 except Exception: 123 session.rollback() 124 logger.error(f'Rolled back atomic read-modify-write for entity {entity_key}') 125 raise