···11+from contextlib import contextmanager
22+from typing import Any, Generator
33+44+from osprey.engine.language_types.entities import EntityT
55+from osprey.worker.lib.osprey_shared.labels import EntityLabels
66+from osprey.worker.lib.osprey_shared.logging import get_logger
77+from osprey.worker.lib.storage.labels import LabelsServiceBase
88+from osprey.worker.lib.storage.postgres import Model, init_from_config, scoped_session
99+from sqlalchemy import Column, String, select
1010+from sqlalchemy.dialects.postgresql import JSONB, insert
1111+1212+logger = get_logger(__name__)
1313+1414+1515+class EntityLabelsModel(Model):
1616+ """SQLAlchemy model for storing entity labels in PostgreSQL"""
1717+1818+ __tablename__ = 'entity_labels'
1919+2020+ entity_key = Column(String, primary_key=True)
2121+ labels = Column(JSONB, nullable=False)
2222+2323+ def __str__(self) -> str:
2424+ return f'EntityLabelsModel(entity_key={self.entity_key}, labels={self.labels})'
2525+2626+2727+class PostgresLabelsService(LabelsServiceBase):
2828+ """
2929+ PostgreSQL-backed implementation of LabelsServiceBase.
3030+3131+ This service stores entity labels in a PostgreSQL database using SQLAlchemy.
3232+ It provides atomic read-modify-write operations through database transactions.
3333+ """
3434+3535+ def __init__(self, database: str = 'osprey_db') -> None:
3636+ """
3737+ Initialize the PostgreSQL labels service.
3838+ Note: This will not init the postgres connection; To do that,
3939+ initialize() must be called (which is called by the LabelsProvider
4040+ by default)
4141+4242+ Args:
4343+ database: The database name to use. Defaults to 'osprey_db'.
4444+ """
4545+ super().__init__()
4646+ self._database_name: str = database
4747+4848+ def initialize(self) -> None:
4949+ init_from_config(self._database_name)
5050+ logger.info(f'Initialized PostgresLabelsService with database: {self._database_name}')
5151+5252+ def read_labels(self, entity: EntityT[Any]) -> EntityLabels:
5353+ """
5454+ Read labels for an entity from PostgreSQL.
5555+5656+ Returns an empty EntityLabels if the entity has no labels.
5757+ """
5858+ entity_key = str(entity)
5959+6060+ with scoped_session(database=self._database_name) as session:
6161+ stmt = select(EntityLabelsModel).where(EntityLabelsModel.entity_key == entity_key)
6262+ result = session.scalars(stmt).first()
6363+6464+ if result is None:
6565+ logger.debug(f'No labels found for entity {entity_key}')
6666+ return EntityLabels()
6767+6868+ labels = EntityLabels.deserialize(result.labels)
6969+ logger.debug(f'Read labels for entity {entity_key}', result)
7070+ return labels
7171+7272+ @contextmanager
7373+ def read_modify_write_labels_atomically(self, entity: EntityT[Any]) -> Generator[EntityLabels, None, None]:
7474+ """
7575+ Context manager for atomic read-modify-write operations.
7676+7777+ This context manager:
7878+ 1. Opens a database transaction
7979+ 2. Acquires a row-level lock using SELECT FOR UPDATE
8080+ 3. Reads and returns the current labels
8181+ 4. Yields control to the caller (LabelsProvider)
8282+ 5. The caller modifies the labels IN PLACE
8383+ 6. On exit, writes the modified labels and commits the transaction
8484+8585+ The key insight: The caller modifies the yielded labels object directly,
8686+ and this context manager persists those changes atomically.
8787+8888+ For systems that don't need locking (e.g., in-memory stores), this can
8989+ be simplified to:
9090+ ```py
9191+ labels = self.read_labels(entity)
9292+ yield labels
9393+ # write the labels here
9494+ """
9595+ entity_key = str(entity)
9696+9797+ with scoped_session(commit=False, database=self._database_name) as session:
9898+ try:
9999+ # Use SELECT FOR UPDATE to acquire a row-level lock
100100+ stmt = select(EntityLabelsModel).where(EntityLabelsModel.entity_key == entity_key).with_for_update()
101101+ result = session.scalars(stmt).first()
102102+103103+ if result is None:
104104+ labels = EntityLabels()
105105+ else:
106106+ labels = EntityLabels.deserialize(result.labels)
107107+108108+ # Yield control - The default LabelsProvider will modify the labels IN PLACE
109109+ yield labels
110110+111111+ # After yield, write the modified labels back
112112+ labels_dict = labels.serialize()
113113+ upsert_stmt = insert(EntityLabelsModel).values(entity_key=entity_key, labels=labels_dict)
114114+ upsert_stmt = upsert_stmt.on_conflict_do_update(
115115+ index_elements=['entity_key'], set_={EntityLabelsModel.labels: labels_dict}
116116+ )
117117+ session.execute(upsert_stmt)
118118+119119+ session.commit()
120120+ logger.debug(f'Committed atomic read-modify-write for entity {entity_key}', labels_dict)
121121+122122+ except Exception:
123123+ session.rollback()
124124+ logger.error(f'Rolled back atomic read-modify-write for entity {entity_key}')
125125+ raise
···11from dataclasses import dataclass
22-from datetime import datetime
22+from datetime import datetime, timezone
33from enum import Enum, auto
44from typing import Any, Optional, Sequence
55···170170 desired_manual = _ManualType.get(arguments.manual)
171171 desired_delay = TimeDeltaT.inner_from_optional(arguments.min_label_age)
172172 label_state = entity_labels.labels.get(arguments.label)
173173- now = datetime.now()
173173+ now = datetime.now(timezone.utc)
174174175175 if label_state is not None:
176176 # Check to see if all reasons have expired, if so, the label should be considered as expired.
···99from osprey.worker.adaptor.constants import OSPREY_ADAPTOR
1010from osprey.worker.lib.action_proto_deserializer import ActionProtoDeserializer
1111from osprey.worker.lib.storage.labels import LabelsProvider, LabelsServiceBase
1212-from osprey.worker.sinks.sink.input_stream import BaseInputStream
1312from osprey.worker.sinks.utils.acking_contexts import BaseAckingContext
14131514if TYPE_CHECKING:
1615 from osprey.worker.lib.config import Config
1716 from osprey.worker.lib.storage.stored_execution_result import ExecutionResultStore
1717+ from osprey.worker.sinks.sink.input_stream import BaseInputStream
1818 from osprey.worker.sinks.sink.output_sink import BaseOutputSink
19192020hookspec: pluggy.HookspecMarker = pluggy.HookspecMarker(OSPREY_ADAPTOR)
···565657575858@hookspec(firstresult=True)
5959-def register_labels_service(config: Config) -> LabelsServiceBase | LabelsProvider:
6060- """Register a labels service backend. This can be achieved by implementing a labels service
6161- and utilizing the provided labels provider, or by overriding the labels provider to fit your
6262- business needs"""
6363- raise NotImplementedError('register_labels_service must be implemented by the plugin')
5959+def register_labels_service_or_provider(config: Config) -> LabelsServiceBase | LabelsProvider:
6060+ """Register a labels service or labels provider. This can be achieved by implementing a labels
6161+ service base and utilizing the provided labels provider, or by overriding the labels provider to
6262+ fit your needs"""
6363+ raise NotImplementedError('register_labels_service_or_provider must be implemented by the plugin')
···1212from osprey.worker.adaptor.constants import OSPREY_ADAPTOR
1313from osprey.worker.adaptor.hookspecs import osprey_hooks
1414from osprey.worker.lib.action_proto_deserializer import ActionProtoDeserializer
1515+from osprey.worker.lib.singletons import LABELS_PROVIDER
1516from osprey.worker.lib.storage.labels import LabelsProvider, LabelsServiceBase
1617from osprey.worker.sinks.sink.input_stream import BaseInputStream
1718from osprey.worker.sinks.sink.output_sink import BaseOutputSink, LabelOutputSink, MultiOutputSink
···3940 return sum(seq, [])
404141424242-def has_labels_service() -> bool:
4343- return hasattr(plugin_manager.hook, 'register_labels_service')
4343+def _labels_service_or_provider_is_registered() -> bool:
4444+ return hasattr(plugin_manager.hook, 'register_labels_service_or_provider')
444545464647def bootstrap_udfs() -> tuple[UDFRegistry, UDFHelpers]:
···4849 udf_helpers = UDFHelpers()
49505051 udfs: List[Type[UDFBase[Any, Any]]] = flatten(plugin_manager.hook.register_udfs())
5151- udf_registry = UDFRegistry.with_udfs(*udfs)
52525353 for udf in udfs:
5454 if issubclass(udf, HasHelper):
5555 udf_helpers.set_udf_helper(udf, udf.create_provider())
56565757 # Label udfs should only be registered if the labels provider is available
5858- if has_labels_service():
5858+ labels_provider = LABELS_PROVIDER.instance()
5959+ if labels_provider:
5960 # Imports kinda circular. Imports here are to avoid that.
6061 from osprey.engine.stdlib.udfs.labels import HasLabel, LabelAdd, LabelRemove
61626263 udfs.extend([HasLabel, LabelAdd, LabelRemove])
63646464- provider_or_service: LabelsProvider | LabelsServiceBase = plugin_manager.hook.register_labels_service()
6565- if isinstance(provider_or_service, LabelsProvider):
6666- labels_provider = provider_or_service
6767- else:
6868- labels_provider = LabelsProvider(labels_service=provider_or_service)
6969-7065 udf_helpers.set_udf_helper(HasLabel, labels_provider)
71666767+ udf_registry = UDFRegistry.with_udfs(*udfs)
6868+7269 return udf_registry, udf_helpers
73707471···7774 sinks = flatten(plugin_manager.hook.register_output_sinks(config=config))
78757976 # Label udfs should only be registered if the labels provider is available
8080- if has_labels_service():
8181- sinks.append(LabelOutputSink(bootstrap_labels_provider()))
7777+ labels_provider = LABELS_PROVIDER.instance()
7878+ if labels_provider:
7979+ sinks.append(LabelOutputSink(labels_provider))
82808381 return MultiOutputSink(sinks)
848285838686-def bootstrap_labels_provider() -> LabelsProvider:
8484+def bootstrap_labels_provider(config: Config) -> LabelsProvider:
8785 """
8888- Generates a bootstrapped label provider using the registered plugin.
8686+ NOTE: If you are looking to get a labels provider to use within Osprey,
8787+ it is best practice to reference the LABELS_PROVIDER singleton in
8888+ `osprey_worker/src/osprey/worker/lib/singletons.py` by calling
8989+ `LABELS_PROVIDER.instance()`
9090+9191+ This way, any statefulness that implementers want for `LabelsServiceBase`
9292+ or `LabelsProvider` will be respected across a given Osprey worker.
9393+9494+ -
9595+9696+ Generates a bootstrapped labels provider using the registered plugin.
9797+ This will also call `initialize()` on the LabelsProvider, which will call
9898+ `initialize()` on the LabelsServiceBase by default.
9999+100100+ This will throw an assertion error if `register_labels_service_or_provider`
101101+ does not exist, i.e. in the event that a labels service / provider was not
102102+ configured.
89103 """
90104 load_all_osprey_plugins()
9191- if not has_labels_service():
9292- raise NotImplementedError('Labels provider assumes register_labels_service is implemented.')
9393- provider_or_service: LabelsProvider | LabelsServiceBase = plugin_manager.hook.register_labels_service()
9494- if isinstance(provider_or_service, LabelsProvider):
9595- return provider_or_service
9696- return LabelsProvider(provider_or_service)
105105+ if not _labels_service_or_provider_is_registered():
106106+ raise NotImplementedError(
107107+ 'bootstrap_labels_provider() assumes `register_labels_service_or_provider` is implemented.'
108108+ )
109109+ provider_or_service = plugin_manager.hook.register_labels_service_or_provider(config=config)
110110+ assert isinstance(provider_or_service, LabelsProvider) or isinstance(provider_or_service, LabelsServiceBase), (
111111+ f"invariant: `register_labels_service_or_provider` has an invalid return type: '{type(provider_or_service)}';",
112112+ "expected 'LabelsServiceBase' or 'LabelsProvider'",
113113+ )
114114+ if isinstance(provider_or_service, LabelsServiceBase):
115115+ provider = LabelsProvider(provider_or_service)
116116+ provider.initialize()
117117+ return provider
118118+ # if we reach here, then a provider was supplied by the hook !
119119+ provider_or_service.initialize()
120120+ return provider_or_service
971219812299123def bootstrap_ast_validators() -> None:
+7-18
osprey_worker/src/osprey/worker/cli/sinks.py
···66# do not move this below other imports
77patch_all(ddtrace_args={'cassandra': True, 'psycopg': True})
8899-from osprey.worker.sinks.input_stream_chooser import get_rules_sink_input_stream
1010-from osprey.worker.sinks.sink.output_sink import LabelOutputSink
1111-129import signal
1310from uuid import uuid1
14111512# this is required to avoid memory leaks with gRPC
1613from gevent import config as gevent_config
1714from osprey.worker.adaptor.plugin_manager import bootstrap_output_sinks
1515+from osprey.worker.sinks.input_stream_chooser import get_rules_sink_input_stream
18161917gevent_config.track_greenlet_tree = False
2018···4038from osprey.worker.lib.osprey_engine import bootstrap_engine, bootstrap_engine_with_helpers, get_sources_provider
4139from osprey.worker.lib.osprey_shared.logging import get_logger
4240from osprey.worker.lib.publisher import PubSubPublisher
4343-from osprey.worker.lib.singletons import CONFIG
4141+from osprey.worker.lib.singletons import CONFIG, LABELS_PROVIDER
4442from osprey.worker.lib.storage import postgres
4543from osprey.worker.lib.storage.bigtable import osprey_bigtable
4644from osprey.worker.lib.storage.bulk_label_task import BulkLabelTask
···248246249247@cli.command()
250248@click.option('--pooled/--no-pooled', default=True, help='Whether to run multiple bulk label sinks in a pool')
251251-@click.option(
252252- '--send-status-webhook/--no-send-status-webhook',
253253- default=True,
254254- help='Whether to send status webhook to channel specified in rules repo',
255255-)
256256-def run_bulk_label_sink(pooled: bool, send_status_webhook: bool) -> None:
249249+def run_bulk_label_sink(pooled: bool) -> None:
257250 config = init_config()
258251259252 sentry_dsn = config.get_str(CONFIG_SENTRY_OTHER_SINKS_DSN, '')
···268261 analytics_pubsub_topic_id = config.get_str('PUBSUB_ANALYTICS_EVENT_TOPIC_ID', 'osprey-analytics')
269262 analytics_publisher = PubSubPublisher(analytics_pubsub_project_id, analytics_pubsub_topic_id)
270263271271- osprey_webhook_pubsub_project = config.get_str('PUBSUB_OSPREY_WEBHOOKS_PROJECT_ID', 'osprey-dev')
272272- osprey_webhook_pubsub_topic = config.get_str('PUBSUB_OSPREY_WEBHOOKS_TOPIC_ID', 'osprey-webhooks')
273273- webhooks_publisher = PubSubPublisher(osprey_webhook_pubsub_project, osprey_webhook_pubsub_topic)
274274-275275- event_effects_output_sink = LabelOutputSink(engine, analytics_publisher, webhooks_publisher)
276276-277264 def factory() -> BulkLabelSink:
278265 # NOTE: It's very important the input stream is created per-webhook sink
279266 postgres_source = PostgresInputStream(BulkLabelTask, tags=['sink:bulklabelsink'])
267267+ labels_provider = LABELS_PROVIDER.instance()
268268+ if not labels_provider:
269269+ raise AssertionError('BulkLabelSink cannot be instantiated without a labels provider')
280270 return BulkLabelSink(
281271 input_stream=postgres_source,
282282- event_effects_output_sink=event_effects_output_sink,
272272+ labels_provider=labels_provider,
283273 engine=engine,
284274 analytics_publisher=analytics_publisher,
285285- send_status_webhook=send_status_webhook,
286275 )
287276288277 if pooled:
+18-74
osprey_worker/src/osprey/worker/lib/cli.py
···33from pathlib import Path # noqa: E402
4455from osprey.engine.language_types.entities import EntityT
66-from osprey.worker.adaptor.plugin_manager import bootstrap_labels_provider
76from osprey.worker.lib.osprey_shared.labels import EntityLabelMutation, LabelStatus
88-from osprey.worker.lib.patcher import patch_all # noqa: E402
77+from osprey.worker.lib.patcher import patch_all
88+from osprey.worker.lib.singletons import LABELS_PROVIDER # noqa: E402
991010patch_all() # please ensure this occurs before *any* other imports !
111112121313import datetime # noqa: E402
1414-import ipaddress # noqa: E402
1515-import logging # noqa: E402
1614import os # noqa: E402
1717-import subprocess # noqa: E402
1818-from typing import Any, Optional, Set, Union # noqa: E402
1515+from typing import Any, Optional, Set # noqa: E402
19162017import click # noqa: E402
2121-from click import Context, Parameter, ParamType # noqa: E402
2218from osprey.worker.lib.osprey_logging import configure_logging # noqa: E402
23192420configure_logging()
···3127)
3228from osprey.worker.lib.storage import ( # noqa: E402
3329 access_audit_log, # noqa: E402
3434- entity_label_webhook,
3530 labels,
3631 stored_execution_result,
3732)
···393440354136@click.group()
4242-def cli() -> None:
3737+def cli():
4338 pass
44394540···8681@cli.command()
8782@click.option('--auto-import/--no-auto-import', '-i', default=True)
8883def shell(auto_import: str) -> None:
8989- import os
9084 import sys
91859286 osprey_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
···123117 namespace_overrides = {
124118 'labels': labels,
125119 'access_audit_log': access_audit_log,
126126- 'entity_label_webhook': entity_label_webhook,
127120 'stored_execution_result': stored_execution_result,
128121 'EntityT': EntityT,
129122 # 'Entity': Entity,
···250243 reason_name=reason or 'CliLabelMutationWithoutEffects',
251244 status=label_status,
252245 description=description or 'Manually changed from the command line for debugging.',
253253- expires_at=(datetime.datetime.now() + datetime.timedelta(seconds=5)),
246246+ expires_at=(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=5)),
254247 )
255248 else:
256249 mutation = EntityLabelMutation(
···260253 description=description or 'Manually changed from the command line for debugging.',
261254 )
262255263263- result = bootstrap_labels_provider().apply_entity_label_mutations(
264264- entity=EntityT(type=entity_type, id=entity_id), mutations=[mutation]
256256+ provider = LABELS_PROVIDER.instance()
257257+ assert provider is not None, (
258258+ 'this CLI cannot be used because no labels service / provider is supplied for this osprey instance'
265259 )
266260261261+ result = provider.apply_entity_label_mutations(entity=EntityT(type=entity_type, id=entity_id), mutations=[mutation])
262262+267263 print(result)
268264269265···306302 """
307303 entity_ids = get_lines_from_file_as_set(file_path=entity_ids_file_path)
308304 # I found that it *generally* took ~10ms per request; Multiply by 10.05 for 5% latency headroom
309309- expire_timestamp = datetime.datetime.now() + datetime.timedelta(milliseconds=int(len(entity_ids) * 10.05))
305305+ expire_timestamp = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
306306+ milliseconds=int(len(entity_ids) * 10.05)
307307+ )
310308 print(f'Found {len(entity_ids)} entity IDs to label.\nETA: {int(len(entity_ids) * 10.05 / 100)} second(s)')
311309 if expire_instantly:
312310 mutation = EntityLabelMutation(
···325323 )
326324327325 progress_tracker: CliCommandProgressTracker = CliCommandProgressTracker(total_actions=len(entity_ids))
328328- provider = bootstrap_labels_provider()
326326+ provider = LABELS_PROVIDER.instance()
327327+ assert provider is not None, (
328328+ 'this code cannot be used because no labels service / provider is supplied for this osprey instance'
329329+ )
329330 for entity_id in entity_ids:
330330- provider.apply_entity_label_mutations(
331331+ _ = provider.apply_entity_label_mutations(
331332 entity=EntityT(type=entity_type, id=entity_id),
332333 mutations=[mutation],
333334 )
334335 progress_tracker.increment()
335336336337 print(f'Bulk labelling complete! Total labels applied: {progress_tracker.total_actions}')
337337-338338-339339-class IpAddress(ParamType):
340340- def convert(self, value: Union[str], param: Optional[Parameter], ctx: Optional[Context]) -> Optional[str]:
341341- """Check that the value parses as an ip V4 or V6 address"""
342342- ipaddress.IPv4Address(value)
343343- return value
344344-345345-346346-@cli.command()
347347-@click.option('--from-sub-gcp-project', required=True, help='Source GCP project for the subscription')
348348-@click.option('--from-subscription', required=True, help='Subscription ID to read from')
349349-@click.option('--target-destination-gcp-topic', required=True, help='Target GCP project for the topic')
350350-@click.option('--destination-topic', required=True, help='Destination topic to publish to')
351351-def restream_subscription(
352352- from_sub_gcp_project: str, from_subscription: str, target_destination_gcp_topic: str, destination_topic: str
353353-) -> None:
354354- """
355355- Run the restreamer to restream from a subscription to a topic.
356356- This is most useful for running DLQs.
357357- This command runs ./pubsub_restream with the specified options.
358358- """
359359- # Unset PUBSUB_EMULATOR_HOST if it is set
360360- if 'PUBSUB_EMULATOR_HOST' in os.environ:
361361- logging.info('Unsetting PUBSUB_EMULATOR_HOST')
362362- del os.environ['PUBSUB_EMULATOR_HOST']
363363-364364- cargo_bin_path = os.path.expanduser('~/.cargo/bin')
365365- if cargo_bin_path not in os.environ['PATH']:
366366- logging.info('Adding cargo bin path to PATH: %s', cargo_bin_path)
367367- os.environ['PATH'] += os.pathsep + cargo_bin_path
368368-369369- command = [
370370- 'pubsub_restream',
371371- '--use-gcloud-auth',
372372- f'--gcp-project={from_sub_gcp_project}',
373373- f'--subscription={from_subscription}',
374374- f'--dst-gcp-project={target_destination_gcp_topic}',
375375- f'--dst-topic={destination_topic}',
376376- ]
377377-378378- logging.info('Running the pubsub restreaming command: %s', ' '.join(command))
379379-380380- try:
381381- result = subprocess.run(command, capture_output=True, text=True, check=True)
382382- logging.info('Command output:\n%s', result.stdout)
383383- except subprocess.CalledProcessError as e:
384384- logging.error('Command failed with return code %d', e.returncode)
385385- logging.error('Command stderr:\n%s', e.stderr)
386386- sys.exit(e.returncode)
387387- except Exception as e:
388388- logging.error('An unexpected error occurred: %s', str(e))
389389- sys.exit(1)
390390-391391-392392-if __name__ == '__main__':
393393- cli()
···47474848def get_validation_result_exporter() -> BaseValidationResultExporter:
4949 """setup and returns the validation result exporter that will run during source updates"""
5050+5051 config = CONFIG.instance()
51525253 # Use null exporter if disabled (for development)
···11import copy
22from collections import UserDict
33from dataclasses import dataclass, field, replace
44-from datetime import datetime, timedelta
44+from datetime import datetime, timedelta, timezone
55from enum import Enum, IntEnum
66-from typing import Dict, Optional
66+from typing import Any, Dict, Self
7788from osprey.worker.lib.osprey_shared.logging import get_logger
99from osprey.worker.lib.utils.request_utils import SessionWithRetries
···151516161717logger = get_logger(__name__)
1818+1919+2020+def _guarantee_utc_timezone_awareness(dt: datetime | None) -> datetime | None:
2121+ if dt is None:
2222+ return None
2323+ if dt.tzinfo is None:
2424+ return dt.replace(tzinfo=timezone.utc)
2525+ return dt
182619272028class MutationDropReason(IntEnum):
···8593 description: str = ''
8694 """why the label was mutated"""
8795 features: dict[str, str] = field(default_factory=dict)
8888- """features are injected into the description as k/v's, similar to how fstrings work. for example,
9696+ """features are injected into the description as k/v's, similar to how fstrings work. for example,
8997 the {you} in 'hello {you}' would be substituted as 'person' with a feature dict of {'you': 'person'}"""
9098 created_at: datetime | None = None
9199 """
···94102 expires_at: datetime | None = None
95103 """marks when this label reason 'expires'
961049797- if a LabelState.MANUALLY_REMOVED is applied with a reason that has a 1 day expiration, then
105105+ if a LabelState.MANUALLY_REMOVED is applied with a reason that has a 1 day expiration, then
98106 for 1 day, the label cannot be applied via LabelState.ADDED. all LabelState.ADDED attempts will be dropped.
99107100108 if a given label state has multiple label reasons, all reasons would need to expire before the status/state
101101- is considered expired, too.
109109+ is considered expired, too.
102110 """
103111104112 def is_expired(self) -> bool:
105105- return bool(self.expires_at is not None and self.expires_at + timedelta(seconds=5) < datetime.now())
113113+ return bool(self.expires_at is not None and self.expires_at + timedelta(seconds=5) < datetime.now(timezone.utc))
114114+115115+ def serialize(self) -> dict[str, Any]:
116116+ """
117117+ serialize LabelReason to a JSON-compatible dict.
118118+ converts datetime objects to ISO format strings.
119119+ """
120120+ created_at = _guarantee_utc_timezone_awareness(self.created_at)
121121+ expires_at = _guarantee_utc_timezone_awareness(self.expires_at)
122122+ return {
123123+ 'pending': self.pending,
124124+ 'description': self.description,
125125+ 'features': self.features,
126126+ 'created_at': created_at.isoformat() if created_at else None,
127127+ 'expires_at': expires_at.isoformat() if expires_at else None,
128128+ }
129129+130130+ @classmethod
131131+ def deserialize(cls, d: dict[str, Any]) -> Self:
132132+ """
133133+ deserialize a dict into a LabelReason object.
134134+ converts ISO format strings back to datetime objects.
135135+ """
136136+ created_at = _guarantee_utc_timezone_awareness(
137137+ datetime.fromisoformat(d['created_at']) if d.get('created_at') else None
138138+ )
139139+ expires_at = _guarantee_utc_timezone_awareness(
140140+ datetime.fromisoformat(d['expires_at']) if d.get('expires_at') else None
141141+ )
142142+ return cls(
143143+ pending=d.get('pending', False),
144144+ description=d.get('description', ''),
145145+ features=d.get('features', {}),
146146+ created_at=created_at,
147147+ expires_at=expires_at,
148148+ )
106149107150108151@dataclass
···159202 def __repr__(self):
160203 return f'LabelReasons({self.data})'
161204205205+ def serialize(self) -> dict[str, dict[str, Any]]:
206206+ """
207207+ serialize LabelReasons to a JSON-compatible dict.
208208+ returns a dict mapping reason names to serialized LabelReason dicts.
209209+ """
210210+ return {reason_name: reason.serialize() for reason_name, reason in self.items()}
211211+212212+ @classmethod
213213+ def deserialize(cls, d: dict[str, dict[str, Any]]) -> Self:
214214+ """
215215+ deserialize a dict into a LabelReasons object.
216216+ expects a dict mapping reason names to LabelReason dicts.
217217+ """
218218+219219+ deserialized_reasons: dict[str, LabelReason] = {}
220220+ for reason_name, reason_data in d.items():
221221+ try:
222222+ deserialized_reasons[reason_name] = LabelReason.deserialize(reason_data)
223223+ except Exception as e:
224224+ raise TypeError(f'could not create LabelReasons from dict: failed to deserialize {reason_name}', e)
225225+226226+ return cls(deserialized_reasons)
227227+162228163229@dataclass
164230class LabelStateInner:
165231 status: LabelStatus
166232 reasons: LabelReasons
167233234234+ def serialize(self) -> dict[str, Any]:
235235+ """
236236+ serialize LabelStateInner to a JSON-compatible dict.
237237+ """
238238+ return {
239239+ 'status': self.status.value,
240240+ 'reasons': self.reasons.serialize(),
241241+ }
242242+243243+ @classmethod
244244+ def deserialize(cls, d: dict[str, Any]) -> Self:
245245+ """
246246+ deserialize a dict into a LabelStateInner object.
247247+ """
248248+ try:
249249+ status = LabelStatus(d['status'])
250250+ reasons = LabelReasons.deserialize(d['reasons'])
251251+ return cls(status=status, reasons=reasons)
252252+ except Exception as e:
253253+ raise TypeError(f'could not create LabelStateInner from dict: {d}', e)
254254+168255169256@dataclass
170257class LabelState:
171258 status: LabelStatus
172172- """statuses dictate the way the current state behaves; certain statuses have priority over others
259259+ """statuses dictate the way the current state behaves; certain statuses have priority over others
173260 (see LabelStatus for more info)"""
174261175262 reasons: LabelReasons
···177264 reasons are why this label state was applied; it is a dict because there may be multiple,
178265 with each reason being distinct based on it's reason name.
179266180180- reasons applied under the same name are merged (assuming the status has not changed),
267267+ reasons applied under the same name are merged (assuming the status has not changed),
181268 with precedence given to newer creaeted_at timestamps.
182269 """
183270···199286 if the weights are the *same*, then a merge of reasons is performed, which can also cause the expiration to be delayed.
200287 """
201288 if not self.reasons:
202202- AssertionError(f'invariant: the label state {self} did not have any associated reasons')
203203- expires_at = datetime.min
289289+ raise AssertionError(f'invariant: the label state {self} did not have any associated reasons')
290290+ expires_at = datetime.min.replace(tzinfo=timezone.utc)
204291 for reason in self.reasons.values():
205292 if reason.expires_at is None:
206293 return None
···215302 )
216303217304 def is_expired(self) -> bool:
218218- return bool(self.expires_at is not None and self.expires_at + timedelta(seconds=5) < datetime.now())
305305+ return bool(self.expires_at is not None and self.expires_at + timedelta(seconds=5) < datetime.now(timezone.utc))
219306220307 def _shift_current_state_to_previous_state(self) -> None:
221308 if not self.reasons:
···253340254341 return None
255342343343+ def serialize(self) -> dict[str, Any]:
344344+ """
345345+ serialize LabelState to a JSON-compatible dict.
346346+ """
347347+ return {
348348+ 'status': self.status.value,
349349+ 'reasons': self.reasons.serialize(),
350350+ 'previous_states': [prev_state.serialize() for prev_state in self.previous_states],
351351+ }
352352+353353+ @classmethod
354354+ def deserialize(cls, d: dict[str, Any]) -> Self:
355355+ """
356356+ deserialize a dict into a LabelState object.
357357+ """
358358+359359+ try:
360360+ status = LabelStatus(d['status'])
361361+ reasons = LabelReasons.deserialize(d['reasons'])
362362+ previous_states = [
363363+ LabelStateInner.deserialize(prev_state_data) for prev_state_data in d.get('previous_states', [])
364364+ ]
365365+ return cls(status=status, reasons=reasons, previous_states=previous_states)
366366+ except Exception as e:
367367+ raise TypeError(f'could not create LabelState from dict: {d}', e)
368368+256369257370@dataclass
258371class EntityLabels:
···261374 labels: Dict[str, LabelState] = field(default_factory=dict)
262375 """a mapping of label names to their current states'"""
263376377377+ def serialize(self) -> dict[str, Any]:
378378+ """
379379+ given the current EntityLabels object, returns a dict that is
380380+ json-serializable via json.dumps()
381381+ """
382382+ return {'labels': {k: v.serialize() for k, v in self.labels.items()}}
383383+384384+ @classmethod
385385+ def deserialize(cls, d: dict[str, dict[str, Any]]) -> Self:
386386+ """
387387+ given a dict, deserializes it into an EntityLabels object
388388+ """
389389+ if 'labels' in d:
390390+ d = d['labels']
391391+392392+ try:
393393+ return cls(labels={k: LabelState.deserialize(v) for k, v in d.items()})
394394+ except Exception as e:
395395+ raise TypeError(f'could not create EntityLabels from dict: {d};', e)
396396+264397265398@dataclass
266399class EntityLabelMutation:
···291424 pending=self.pending,
292425 description=self.description,
293426 features=self.features,
294294- created_at=datetime.now(),
295295- expires_at=self.expires_at,
427427+ created_at=datetime.now(timezone.utc),
428428+ expires_at=_guarantee_utc_timezone_awareness(self.expires_at),
296429 )
430430+431431+ def serialize(self) -> dict[str, Any]:
432432+ expires_at = _guarantee_utc_timezone_awareness(self.expires_at)
433433+ return {
434434+ 'label_name': self.label_name,
435435+ 'reason_name': self.reason_name,
436436+ 'status': self.status,
437437+ 'pending': self.pending,
438438+ 'description': self.description,
439439+ 'features': self.features,
440440+ 'expires_at': expires_at.isoformat() if expires_at else None,
441441+ }
297442298443299444@dataclass
···301446 mutation: EntityLabelMutation
302447 reason: MutationDropReason
303448449449+ def serialize(self) -> dict[str, Any]:
450450+ return {
451451+ 'mutation': self.mutation.serialize(),
452452+ 'reason': self.reason,
453453+ }
454454+304455305456@dataclass
306457class EntityLabelMutationsResult:
···309460 all of the entity's labels post-mutation
310461 """
311462312312- old_entity_labels: Optional[EntityLabels] = None
463463+ old_entity_labels: EntityLabels
313464 """
314465 all of the entity's labels pre-mutation
315466 """
···326477327478 labels_updated: list[str] = field(default_factory=list)
328479 """
329329- labels that had their state updated. this can include simply updating or
480480+ labels that had their state updated. this can include simply updating or
330481 appending to the reason
331482 """
332483···335486 mutations that were dropped for one reason or another. each dropped mutation is
336487 given a drop reason
337488 """
489489+490490+ def serialize(self) -> dict[str, Any]:
491491+ """
492492+ the only place this is currently needed is for the ui, which expects a specific json blob
493493+ """
494494+ return {
495495+ 'mutation_result': {
496496+ 'added': self.labels_added,
497497+ 'removed': self.labels_removed,
498498+ 'updated': self.labels_updated,
499499+ 'unchanged': list(set(mut.mutation.label_name for mut in self.dropped_mutations)),
500500+ },
501501+ **self.new_entity_labels.serialize(),
502502+ }
+39
osprey_worker/src/osprey/worker/lib/singletons.py
···44from osprey.engine.stdlib import get_config_registry
55from osprey.worker.lib.config import Config
66from osprey.worker.lib.singleton import Singleton
77+from osprey.worker.lib.storage.labels import LabelsProvider
7889if TYPE_CHECKING:
910 from osprey.worker.lib.osprey_engine import OspreyEngine
···202121222223ENGINE: Singleton['OspreyEngine'] = Singleton(_init_engine)
2424+2525+2626+def _init_labels_provider() -> LabelsProvider | None:
2727+ """
2828+ a helper method to initialize the labels provider for the LABELS_PROVIDER singleton
2929+ """
3030+ # the plugin manager imports this file to reference the labels provider singleton;
3131+ # therefore, we need these to not cause circular imports
3232+ from osprey.worker.adaptor.plugin_manager import (
3333+ _labels_service_or_provider_is_registered,
3434+ bootstrap_labels_provider,
3535+ )
3636+ # from osprey.worker.lib.singletons import CONFIG
3737+3838+ if not _labels_service_or_provider_is_registered():
3939+ return None
4040+ config = CONFIG.instance()
4141+ return bootstrap_labels_provider(config)
4242+4343+4444+LABELS_PROVIDER: Singleton['LabelsProvider | None'] = Singleton(_init_labels_provider)
4545+"""
4646+A Singleton that holds a `LabelsProvider`, if one is registered by the plugin manager.
4747+4848+If not, this Singleton will hold `None`. This makes it always safe to call `LABELS_PROVIDER.instance()`,
4949+and enforces type checking rules when dealing with labels provider code, which may or may not be
5050+supplied by users of Osprey.
5151+5252+An example use pattern might be:
5353+```py
5454+labels_provider: LabelsProvider | None = LABELS_PROVIDER.instance()
5555+if labels_provider:
5656+ # do labels provider things
5757+```
5858+5959+Because this is a Singleton, implementers of `LabelsServiceBase` / `LabelsProvider` can implement statefulness
6060+and expect that the statefulness will be present across all references within a given Osprey worker.
6161+"""
···252526262727class LabelsServiceBase(ABC):
2828- @abstractmethod
2929- def write_labels(self, entity: EntityT[Any], labels: EntityLabels) -> None:
3030- """
3131- A standard write to the labels service that attempts to write the value to the primary key.
2828+ """
2929+ An abstract class to represent an implementable labels service backend.
32303333- This method may be retried upon exceptions, so keep that in mind when adding potentially
3434- non-idempotent behaviour.
3131+ With the default LabelsProvider, read_labels and batch_read_labels are called
3232+ *during* rule executions (or by the ui api).
3333+3434+ read_modify_write_labels_atomically is called post-rule execution (or by the ui api).
3535+ """
3636+3737+ def initialize(self) -> None:
3538 """
3636- raise NotImplementedError()
3939+ This method will be called after the initialization of this labels service base. Any side effects
4040+ that implementers may want, i.e. connecting to an external service, should be placed here.
4141+ """
4242+ pass
37433844 @abstractmethod
3945 def read_labels(self, entity: EntityT[Any]) -> EntityLabels:
···71777278 @abstractmethod
7379 @contextmanager
7474- def get_labels_atomically(self, entity: EntityT[Any]) -> Generator[EntityLabels, None, None]:
8080+ def read_modify_write_labels_atomically(self, entity: EntityT[Any]) -> Generator[EntityLabels, None, None]:
7581 """
7676- Context manager for atomic read-modify-write operations.
7777- Implementations should ensure the entity key is locked/in a transaction.
8282+ Context manager for atomic read-modify-write operations. This generator should yield EntityLabels upon reading
8383+ and should write the EntityLabels post-yield.
8484+8585+ IMPORTANT: Implementations should ensure the entity key is locked/in a transaction so that other read-modify-write
8686+ calls (even across multiple workers) must wait.
8787+8888+ This code may be retried upon exceptions, so keep that in mind when adding potentially
8989+ non-idempotent behaviour.
7890 """
7991 pass
8092···8395 def __init__(self, labels_service: LabelsServiceBase):
8496 self._labels_service = labels_service
85979898+ def initialize(self) -> None:
9999+ """
100100+ This method will be called after the initialization of this labels provider. Any side effects
101101+ that implementers may want, i.e. connecting to an external service, should be placed here.
102102+ """
103103+ self._labels_service.initialize()
104104+86105 def _get_mutations_by_label_name_and_drop_conflicts(
87106 self, mutations: Sequence[EntityLabelMutation]
88107 ) -> tuple[dict[str, list[EntityLabelMutation]], list[DroppedEntityLabelMutation]]:
···141160 return desired_states_by_label_name
142161143162 def _compute_new_labels_from_mutations(
144144- self, old_labels: EntityLabels, mutations: Sequence[EntityLabelMutation]
163163+ self, labels: EntityLabels, mutations: Sequence[EntityLabelMutation]
145164 ) -> EntityLabelMutationsResult:
165165+ """
166166+ given an entity's labels and a set of mutations, modify the labels based on the mutations' desired states.
167167+168168+ **this method WILL modify the labels object that is passed into it**.
169169+ it will also return the pre-modification labels in EntityLabelMutationsResult.old_labels
170170+ """
146171 (mutations_by_label_name, dropped_mutations) = self._get_mutations_by_label_name_and_drop_conflicts(mutations)
147172 desired_states_by_label_name: dict[str, LabelStateInner] = self._get_desired_states_by_label_name(
148173 mutations_by_label_name
···153178 added: list[str] = []
154179 removed: list[str] = []
155180 updated: list[str] = []
156156- new_labels = copy.deepcopy(old_labels)
181181+ old_labels = copy.deepcopy(labels)
157182 for label_name, desired_state in desired_states_by_label_name.items():
158158- if label_name not in new_labels.labels:
159159- new_labels.labels[label_name] = LabelState.from_inner(desired_state)
183183+ if label_name not in labels.labels:
184184+ labels.labels[label_name] = LabelState.from_inner(desired_state)
160185 added.append(label_name)
161186 continue
162162- current_state = new_labels.labels[label_name]
187187+ current_state = labels.labels[label_name]
163188 prev_status = current_state.status
164189 drop_reason = current_state.try_apply_desired_state(desired_state)
165190 if drop_reason:
···182207183208 # finally, return the result! duhh :D
184209 return EntityLabelMutationsResult(
185185- new_entity_labels=new_labels,
210210+ new_entity_labels=labels,
186211 old_entity_labels=old_labels,
187212 labels_added=added,
188213 labels_removed=removed,
···199224 def apply_entity_label_mutations(
200225 self, entity: EntityT[Any], mutations: Sequence[EntityLabelMutation]
201226 ) -> EntityLabelMutationsResult:
202202- with self._labels_service.get_labels_atomically(entity) as old_labels:
203203- result = self._compute_new_labels_from_mutations(old_labels, mutations)
204204-205205- self._labels_service.write_labels(entity, result.new_entity_labels)
206206-227227+ try:
228228+ with self._labels_service.read_modify_write_labels_atomically(entity) as entity_labels:
229229+ result = self._compute_new_labels_from_mutations(entity_labels, mutations)
207230 return result
231231+ except Exception as e:
232232+ logger.error(f'Could not read-modify-write labels for entity {entity.__repr__()}:', e)
233233+ raise e
208234209235 def cache_ttl(self) -> Optional[timedelta]:
210236 return timedelta(minutes=1)