Mirror of https://github.com/roostorg/osprey
github.com/roostorg/osprey
1from contextlib import contextmanager
2from typing import Any, Generator
3
4from osprey.engine.language_types.entities import EntityT
5from osprey.worker.lib.osprey_shared.labels import EntityLabels
6from osprey.worker.lib.osprey_shared.logging import get_logger
7from osprey.worker.lib.storage.labels import LabelsServiceBase
8from osprey.worker.lib.storage.postgres import Model, init_from_config, scoped_session
9from sqlalchemy import Column, String, select
10from sqlalchemy.dialects.postgresql import JSONB, insert
11
12logger = get_logger(__name__)
13
14
15class EntityLabelsModel(Model):
16 """SQLAlchemy model for storing entity labels in PostgreSQL"""
17
18 __tablename__ = 'entity_labels'
19
20 entity_key = Column(String, primary_key=True)
21 labels = Column(JSONB, nullable=False)
22
23 def __str__(self) -> str:
24 return f'EntityLabelsModel(entity_key={self.entity_key}, labels={self.labels})'
25
26
27class PostgresLabelsService(LabelsServiceBase):
28 """
29 PostgreSQL-backed implementation of LabelsServiceBase.
30
31 This service stores entity labels in a PostgreSQL database using SQLAlchemy.
32 It provides atomic read-modify-write operations through database transactions.
33 """
34
35 def __init__(self, database: str = 'osprey_db') -> None:
36 """
37 Initialize the PostgreSQL labels service.
38 Note: This will not init the postgres connection; To do that,
39 initialize() must be called (which is called by the LabelsProvider
40 by default)
41
42 Args:
43 database: The database name to use. Defaults to 'osprey_db'.
44 """
45 super().__init__()
46 self._database_name: str = database
47
48 def initialize(self) -> None:
49 init_from_config(self._database_name)
50 logger.info(f'Initialized PostgresLabelsService with database: {self._database_name}')
51
52 def read_labels(self, entity: EntityT[Any]) -> EntityLabels:
53 """
54 Read labels for an entity from PostgreSQL.
55
56 Returns an empty EntityLabels if the entity has no labels.
57 """
58 entity_key = str(entity)
59
60 with scoped_session(database=self._database_name) as session:
61 stmt = select(EntityLabelsModel).where(EntityLabelsModel.entity_key == entity_key)
62 result = session.scalars(stmt).first()
63
64 if result is None:
65 logger.debug(f'No labels found for entity {entity_key}')
66 return EntityLabels()
67
68 labels = EntityLabels.deserialize(result.labels)
69 logger.debug(f'Read labels for entity {entity_key}', result)
70 return labels
71
72 @contextmanager
73 def read_modify_write_labels_atomically(self, entity: EntityT[Any]) -> Generator[EntityLabels, None, None]:
74 """
75 Context manager for atomic read-modify-write operations.
76
77 This context manager:
78 1. Opens a database transaction
79 2. Acquires a row-level lock using SELECT FOR UPDATE
80 3. Reads and returns the current labels
81 4. Yields control to the caller (LabelsProvider)
82 5. The caller modifies the labels IN PLACE
83 6. On exit, writes the modified labels and commits the transaction
84
85 The key insight: The caller modifies the yielded labels object directly,
86 and this context manager persists those changes atomically.
87
88 For systems that don't need locking (e.g., in-memory stores), this can
89 be simplified to:
90 ```py
91 labels = self.read_labels(entity)
92 yield labels
93 # write the labels here
94 """
95 entity_key = str(entity)
96
97 with scoped_session(commit=False, database=self._database_name) as session:
98 try:
99 # Use SELECT FOR UPDATE to acquire a row-level lock
100 stmt = select(EntityLabelsModel).where(EntityLabelsModel.entity_key == entity_key).with_for_update()
101 result = session.scalars(stmt).first()
102
103 if result is None:
104 labels = EntityLabels()
105 else:
106 labels = EntityLabels.deserialize(result.labels)
107
108 # Yield control - The default LabelsProvider will modify the labels IN PLACE
109 yield labels
110
111 # After yield, write the modified labels back
112 labels_dict = labels.serialize()
113 upsert_stmt = insert(EntityLabelsModel).values(entity_key=entity_key, labels=labels_dict)
114 upsert_stmt = upsert_stmt.on_conflict_do_update(
115 index_elements=['entity_key'], set_={EntityLabelsModel.labels: labels_dict}
116 )
117 session.execute(upsert_stmt)
118
119 session.commit()
120 logger.debug(f'Committed atomic read-modify-write for entity {entity_key}', labels_dict)
121
122 except Exception:
123 session.rollback()
124 logger.error(f'Rolled back atomic read-modify-write for entity {entity_key}')
125 raise