Mirror of https://github.com/roostorg/osprey github.com/roostorg/osprey
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

add postgres-backed labels service example & unify docker compose names (#31)

authored by

ayu and committed by
GitHub
75065595 dca6bada

+573 -277
+40 -23
docker-compose.yaml
··· 45 45 minio: 46 46 image: minio/minio:latest 47 47 container_name: minio 48 + hostname: minio 48 49 ports: 49 50 - "9000:9000" # minio API 50 51 - "9001:9001" # minio Console ··· 82 83 kafka-topics --bootstrap-server kafka:29092 --list 83 84 " 84 85 85 - osprey_worker: 86 - container_name: osprey_worker 86 + osprey-worker: 87 + container_name: osprey-worker 88 + hostname: osprey-worker 87 89 build: 88 90 context: . 89 91 dockerfile: osprey_worker/Dockerfile ··· 94 96 condition: service_completed_successfully 95 97 bigtable: 96 98 condition: service_healthy 97 - bigtable_initializer: 99 + bigtable-initializer: 98 100 condition: service_completed_successfully 99 101 minio: 100 102 condition: service_healthy ··· 106 108 environment: 107 109 - PYTHONPATH=/osprey 108 110 - PORT=5000 111 + - POSTGRES_HOSTS={"osprey_db":"postgresql://osprey:FoolishPassword@postgres:5432/osprey"} 109 112 - OSPREY_INPUT_STREAM_SOURCE=kafka 110 113 - OSPREY_STDOUT_OUTPUT_SINK=True 111 114 - OSPREY_KAFKA_BOOTSTRAP_SERVERS=["kafka:29092"] ··· 125 128 - OSPREY_MINIO_SECRET_KEY=minioadmin123 126 129 - OSPREY_MINIO_SECURE=false 127 130 - OSPREY_MINIO_EXECUTION_RESULTS_BUCKET=execution-output 128 - - SNOWFLAKE_API_ENDPOINT=http://snowflake:8080 131 + - SNOWFLAKE_API_ENDPOINT=http://snowflake-id-worker:8088 129 132 - OSPREY_RULES_PATH=./example_rules 130 133 volumes: 131 134 - ./osprey_worker:/osprey/osprey_worker 132 135 - ./osprey_rpc:/osprey/osprey_rpc 133 136 - ./example_rules:/osprey/example_rules 134 137 - ./entrypoint.sh:/osprey/entrypoint.sh 135 - osprey_ui_api: 136 - container_name: osprey_ui_api 138 + 139 + osprey-ui-api: 140 + container_name: osprey-ui-api 137 141 build: 138 142 context: . 139 143 dockerfile: osprey_worker/Dockerfile 140 144 depends_on: 141 - - osprey_worker 145 + - osprey-worker 142 146 - druid-broker 143 147 - postgres 144 - - snowflake 148 + - snowflake-id-worker 145 149 - bigtable 146 - - bigtable_initializer 150 + - bigtable-initializer 147 151 ports: 148 152 - "5004:5004" 149 153 command: ["osprey-ui-api"] ··· 160 164 - OSPREY_RULES_PATH=/osprey/example_rules 161 165 - OSPREY_DISABLE_VALIDATION_EXPORTER=true 162 166 - BIGTABLE_EMULATOR_HOST=bigtable:8361 163 - - SNOWFLAKE_API_ENDPOINT=http://snowflake:8080 167 + - SNOWFLAKE_API_ENDPOINT=http://snowflake-id-worker:8088 164 168 - SNOWFLAKE_EPOCH=1420070400000 165 169 volumes: 166 170 - ./osprey_worker:/osprey/osprey_worker 167 171 - ./osprey_rpc:/osprey/osprey_rpc 168 172 - ./example_rules:/osprey/example_rules 169 173 170 - osprey_ui: 171 - container_name: osprey_ui 174 + osprey-ui: 175 + container_name: osprey-ui 176 + hostname: osprey-ui 172 177 build: 173 178 context: . 174 179 dockerfile: osprey_ui/Dockerfile 175 180 depends_on: 176 - - osprey_ui_api 181 + - osprey-ui-api 177 182 ports: 178 183 - "5002:5002" 179 184 environment: 180 185 - NODE_ENV=development 181 - - REACT_APP_API_BASE_URL=http://osprey_ui_api:5004 186 + - REACT_APP_API_BASE_URL=http://localhost:5004 182 187 volumes: 183 188 - ./osprey_ui:/app 184 189 - /app/node_modules 185 190 186 - snowflake: 187 - container_name: snowflake_id_worker 191 + snowflake-id-worker: 192 + hostname: snowflake-id-worker 193 + container_name: snowflake-id-worker 188 194 image: ghcr.io/ayubun/snowflake-id-worker:0 189 195 ports: 190 - - "8080:8080" 196 + - "8088:8088" 191 197 environment: 192 198 - WORKER_ID=0 193 199 - DATA_CENTER_ID=0 194 200 - EPOCH=1420070400000 201 + - PORT=8088 195 202 restart: unless-stopped 196 203 197 204 bigtable: 198 - container_name: bigtable_emulator 205 + hostname: bigtable 206 + container_name: bigtable 199 207 image: gcr.io/google.com/cloudsdktool/cloud-sdk:latest 200 208 ports: 201 209 - "8361:8361" ··· 210 218 retries: 5 211 219 restart: unless-stopped 212 220 213 - bigtable_initializer: 214 - container_name: bigtable_initializer 221 + bigtable-initializer: 222 + container_name: bigtable-initializer 215 223 image: gcr.io/google.com/cloudsdktool/cloud-sdk:latest 216 224 depends_on: 217 225 bigtable: ··· 221 229 command: ["/bin/bash", "/init-bigtable.sh"] 222 230 223 231 # Optional test data generator - run with: 224 - # docker compose --profile test_data up kafka_test_data_producer -d 225 - kafka_test_data_producer: 232 + # docker compose --profile test_data up kafka-test-data-producer -d 233 + kafka-test-data-producer: 226 234 image: confluentinc/cp-kafka:7.4.0 227 - container_name: kafka_test_data 235 + hostname: kafka-test-data-producer 236 + container_name: kafka-test-data-producer 228 237 depends_on: 229 238 kafka: 230 239 condition: service_healthy ··· 232 241 condition: service_completed_successfully 233 242 profiles: 234 243 - test_data 244 + - test-data 235 245 environment: 236 246 KAFKA_TOPIC: "osprey.actions_input" 237 247 KAFKA_BROKER: "kafka:29092" ··· 242 252 command: ["/osprey/example_data/generate_test_data.sh"] 243 253 244 254 postgres: 255 + hostname: postgres 245 256 container_name: postgres 246 257 image: postgres:latest 247 258 ports: ··· 256 267 # DRUID, HERE BE DRAGONS 257 268 # Need 3.5 or later for container nodes 258 269 druid-zookeeper: 270 + hostname: druid-zookeeper 259 271 container_name: druid-zookeeper 260 272 image: zookeeper:3.5.10 261 273 ports: ··· 265 277 266 278 druid-coordinator: 267 279 image: apache/druid:34.0.0 280 + hostname: druid-coordinator 268 281 container_name: druid-coordinator 269 282 volumes: 270 283 - druid_shared:/opt/shared ··· 282 295 druid-broker: 283 296 image: apache/druid:34.0.0 284 297 container_name: druid-broker 298 + hostname: druid-broker 285 299 volumes: 286 300 - broker_var:/opt/druid/var 287 301 depends_on: ··· 298 312 druid-historical: 299 313 image: apache/druid:34.0.0 300 314 container_name: druid-historical 315 + hostname: druid-historical 301 316 volumes: 302 317 - druid_shared:/opt/shared 303 318 - historical_var:/opt/druid/var ··· 315 330 druid-middlemanager: 316 331 image: apache/druid:34.0.0 317 332 container_name: druid-middlemanager 333 + hostname: druid-middlemanager 318 334 volumes: 319 335 - druid_shared:/opt/shared 320 336 - middle_var:/opt/druid/var ··· 333 349 druid-router: 334 350 image: apache/druid:34.0.0 335 351 container_name: druid-router 352 + hostname: druid-router 336 353 volumes: 337 354 - router_var:/opt/druid/var 338 355 depends_on:
+1 -1
druid/environment
··· 29 29 30 30 druid_extensions_loadList=["druid-histogram", "druid-datasketches", "druid-lookups-cached-global", "postgresql-metadata-storage", "druid-multi-stage-query", "druid-kafka-indexing-service"] 31 31 32 - druid_zk_service_host=zookeeper 32 + druid_zk_service_host=druid-zookeeper 33 33 34 34 druid_metadata_storage_host= 35 35 druid_metadata_storage_type=postgresql
+8
example_plugins/src/register_plugins.py
··· 3 3 from osprey.engine.udf.base import UDFBase 4 4 from osprey.worker.adaptor.plugin_manager import hookimpl_osprey 5 5 from osprey.worker.lib.config import Config 6 + from osprey.worker.lib.storage.labels import LabelsServiceBase 6 7 from osprey.worker.sinks.sink.output_sink import BaseOutputSink, StdoutOutputSink 8 + from services.labels_service import PostgresLabelsService 7 9 from udfs.ban_user import BanUser 8 10 from udfs.text_contains import TextContains 9 11 ··· 16 18 @hookimpl_osprey 17 19 def register_output_sinks(config: Config) -> Sequence[BaseOutputSink]: 18 20 return [StdoutOutputSink(log_sampler=None)] 21 + 22 + 23 + @hookimpl_osprey 24 + def register_labels_service_or_provider(config: Config) -> LabelsServiceBase: 25 + """Register a PostgreSQL-backed labels service.""" 26 + return PostgresLabelsService()
+125
example_plugins/src/services/labels_service.py
··· 1 + from contextlib import contextmanager 2 + from typing import Any, Generator 3 + 4 + from osprey.engine.language_types.entities import EntityT 5 + from osprey.worker.lib.osprey_shared.labels import EntityLabels 6 + from osprey.worker.lib.osprey_shared.logging import get_logger 7 + from osprey.worker.lib.storage.labels import LabelsServiceBase 8 + from osprey.worker.lib.storage.postgres import Model, init_from_config, scoped_session 9 + from sqlalchemy import Column, String, select 10 + from sqlalchemy.dialects.postgresql import JSONB, insert 11 + 12 + logger = get_logger(__name__) 13 + 14 + 15 + class EntityLabelsModel(Model): 16 + """SQLAlchemy model for storing entity labels in PostgreSQL""" 17 + 18 + __tablename__ = 'entity_labels' 19 + 20 + entity_key = Column(String, primary_key=True) 21 + labels = Column(JSONB, nullable=False) 22 + 23 + def __str__(self) -> str: 24 + return f'EntityLabelsModel(entity_key={self.entity_key}, labels={self.labels})' 25 + 26 + 27 + class PostgresLabelsService(LabelsServiceBase): 28 + """ 29 + PostgreSQL-backed implementation of LabelsServiceBase. 30 + 31 + This service stores entity labels in a PostgreSQL database using SQLAlchemy. 32 + It provides atomic read-modify-write operations through database transactions. 33 + """ 34 + 35 + def __init__(self, database: str = 'osprey_db') -> None: 36 + """ 37 + Initialize the PostgreSQL labels service. 38 + Note: This will not init the postgres connection; To do that, 39 + initialize() must be called (which is called by the LabelsProvider 40 + by default) 41 + 42 + Args: 43 + database: The database name to use. Defaults to 'osprey_db'. 44 + """ 45 + super().__init__() 46 + self._database_name: str = database 47 + 48 + def initialize(self) -> None: 49 + init_from_config(self._database_name) 50 + logger.info(f'Initialized PostgresLabelsService with database: {self._database_name}') 51 + 52 + def read_labels(self, entity: EntityT[Any]) -> EntityLabels: 53 + """ 54 + Read labels for an entity from PostgreSQL. 55 + 56 + Returns an empty EntityLabels if the entity has no labels. 57 + """ 58 + entity_key = str(entity) 59 + 60 + with scoped_session(database=self._database_name) as session: 61 + stmt = select(EntityLabelsModel).where(EntityLabelsModel.entity_key == entity_key) 62 + result = session.scalars(stmt).first() 63 + 64 + if result is None: 65 + logger.debug(f'No labels found for entity {entity_key}') 66 + return EntityLabels() 67 + 68 + labels = EntityLabels.deserialize(result.labels) 69 + logger.debug(f'Read labels for entity {entity_key}', result) 70 + return labels 71 + 72 + @contextmanager 73 + def read_modify_write_labels_atomically(self, entity: EntityT[Any]) -> Generator[EntityLabels, None, None]: 74 + """ 75 + Context manager for atomic read-modify-write operations. 76 + 77 + This context manager: 78 + 1. Opens a database transaction 79 + 2. Acquires a row-level lock using SELECT FOR UPDATE 80 + 3. Reads and returns the current labels 81 + 4. Yields control to the caller (LabelsProvider) 82 + 5. The caller modifies the labels IN PLACE 83 + 6. On exit, writes the modified labels and commits the transaction 84 + 85 + The key insight: The caller modifies the yielded labels object directly, 86 + and this context manager persists those changes atomically. 87 + 88 + For systems that don't need locking (e.g., in-memory stores), this can 89 + be simplified to: 90 + ```py 91 + labels = self.read_labels(entity) 92 + yield labels 93 + # write the labels here 94 + """ 95 + entity_key = str(entity) 96 + 97 + with scoped_session(commit=False, database=self._database_name) as session: 98 + try: 99 + # Use SELECT FOR UPDATE to acquire a row-level lock 100 + stmt = select(EntityLabelsModel).where(EntityLabelsModel.entity_key == entity_key).with_for_update() 101 + result = session.scalars(stmt).first() 102 + 103 + if result is None: 104 + labels = EntityLabels() 105 + else: 106 + labels = EntityLabels.deserialize(result.labels) 107 + 108 + # Yield control - The default LabelsProvider will modify the labels IN PLACE 109 + yield labels 110 + 111 + # After yield, write the modified labels back 112 + labels_dict = labels.serialize() 113 + upsert_stmt = insert(EntityLabelsModel).values(entity_key=entity_key, labels=labels_dict) 114 + upsert_stmt = upsert_stmt.on_conflict_do_update( 115 + index_elements=['entity_key'], set_={EntityLabelsModel.labels: labels_dict} 116 + ) 117 + session.execute(upsert_stmt) 118 + 119 + session.commit() 120 + logger.debug(f'Committed atomic read-modify-write for entity {entity_key}', labels_dict) 121 + 122 + except Exception: 123 + session.rollback() 124 + logger.error(f'Rolled back atomic read-modify-write for entity {entity_key}') 125 + raise
+3 -3
example_plugins/src/udfs/ban_user.py
··· 3 3 4 4 from osprey.engine.executor.custom_extracted_features import CustomExtractedFeature 5 5 from osprey.engine.executor.execution_context import ExecutionContext 6 - from osprey.engine.language_types.effects import EffectToCustomExtractedFeatureBase 6 + from osprey.engine.language_types.effects import EffectBase, EffectToCustomExtractedFeatureBase 7 7 from osprey.engine.stdlib.udfs.categories import UdfCategories 8 8 from osprey.engine.udf.arguments import ArgumentsBase 9 9 from osprey.engine.udf.base import UDFBase ··· 53 53 ) 54 54 55 55 56 - class BanUser(UDFBase[BanUserArguments, BanUserEffect]): 56 + class BanUser(UDFBase[BanUserArguments, EffectBase]): 57 57 category = UdfCategories.ENGINE 58 58 59 - def execute(self, execution_context: ExecutionContext, arguments: BanUserArguments) -> BanUserEffect: 59 + def execute(self, execution_context: ExecutionContext, arguments: BanUserArguments) -> EffectBase: 60 60 return synthesize_effect(arguments)
+5
example_rules/config/labels.yaml
··· 1 + labels: 2 + meow: 3 + valid_for: [User] 4 + connotation: positive 5 + description: testing label
+1 -1
example_rules/main.sml
··· 1 1 Import(rules=['models/base.sml']) 2 2 3 - Require(rule='rules/post_contains_hello.sml') 3 + Require(rule='rules/post_contains_hello.sml')
+4 -1
example_rules/rules/post_contains_hello.sml
··· 15 15 16 16 WhenRules( 17 17 rules_any=[ContainsHello], 18 - then=[BanUser(entity=UserId, comment='User said "hello"')], 18 + then=[ 19 + BanUser(entity=UserId, comment='User said "hello"'), 20 + LabelAdd(entity=UserId, label='meow'), 21 + ], 19 22 )
+3 -3
osprey_ui/src/types/LabelTypes.tsx
··· 1 1 export enum LabelStatus { 2 - ADDED, 3 2 REMOVED, 4 - MANUALLY_ADDED, 3 + ADDED, 5 4 MANUALLY_REMOVED, 5 + MANUALLY_ADDED, 6 6 } 7 7 8 8 export const LabelStatusAPIMapping: Record<LabelStatus, string> = { ··· 44 44 45 45 export interface LabelMutationDetails { 46 46 added: string[]; 47 - dropped: string[]; 47 + updated: string[]; 48 48 removed: string[]; 49 49 unchanged: string[]; 50 50 }
+6
osprey_worker/src/osprey/engine/language_types/entities.py
··· 36 36 def to_post_execution_value(self) -> _T: 37 37 return self.id 38 38 39 + def __str__(self) -> str: 40 + return f'{self.type}/{self.id}' 41 + 42 + def __repr__(self) -> str: 43 + return f"EntityT[{type(self.id)}](type='{self.type}', id={self.id})" 44 + 39 45 @staticmethod 40 46 def _internal_post_execution_type(cls: Type['PostExecutionConvertible[_U]']) -> Type[_U]: 41 47 # Since we leave PostExecutionConvertible with a generic variable, override how we determine our type to give
+2 -2
osprey_worker/src/osprey/engine/stdlib/udfs/labels.py
··· 1 1 from dataclasses import dataclass 2 - from datetime import datetime 2 + from datetime import datetime, timezone 3 3 from enum import Enum, auto 4 4 from typing import Any, Optional, Sequence 5 5 ··· 170 170 desired_manual = _ManualType.get(arguments.manual) 171 171 desired_delay = TimeDeltaT.inner_from_optional(arguments.min_label_age) 172 172 label_state = entity_labels.labels.get(arguments.label) 173 - now = datetime.now() 173 + now = datetime.now(timezone.utc) 174 174 175 175 if label_state is not None: 176 176 # Check to see if all reasons have expired, if so, the label should be considered as expired.
+6 -6
osprey_worker/src/osprey/worker/adaptor/hookspecs/osprey_hooks.py
··· 9 9 from osprey.worker.adaptor.constants import OSPREY_ADAPTOR 10 10 from osprey.worker.lib.action_proto_deserializer import ActionProtoDeserializer 11 11 from osprey.worker.lib.storage.labels import LabelsProvider, LabelsServiceBase 12 - from osprey.worker.sinks.sink.input_stream import BaseInputStream 13 12 from osprey.worker.sinks.utils.acking_contexts import BaseAckingContext 14 13 15 14 if TYPE_CHECKING: 16 15 from osprey.worker.lib.config import Config 17 16 from osprey.worker.lib.storage.stored_execution_result import ExecutionResultStore 17 + from osprey.worker.sinks.sink.input_stream import BaseInputStream 18 18 from osprey.worker.sinks.sink.output_sink import BaseOutputSink 19 19 20 20 hookspec: pluggy.HookspecMarker = pluggy.HookspecMarker(OSPREY_ADAPTOR) ··· 56 56 57 57 58 58 @hookspec(firstresult=True) 59 - def register_labels_service(config: Config) -> LabelsServiceBase | LabelsProvider: 60 - """Register a labels service backend. This can be achieved by implementing a labels service 61 - and utilizing the provided labels provider, or by overriding the labels provider to fit your 62 - business needs""" 63 - raise NotImplementedError('register_labels_service must be implemented by the plugin') 59 + def register_labels_service_or_provider(config: Config) -> LabelsServiceBase | LabelsProvider: 60 + """Register a labels service or labels provider. This can be achieved by implementing a labels 61 + service base and utilizing the provided labels provider, or by overriding the labels provider to 62 + fit your needs""" 63 + raise NotImplementedError('register_labels_service_or_provider must be implemented by the plugin')
+44 -20
osprey_worker/src/osprey/worker/adaptor/plugin_manager.py
··· 12 12 from osprey.worker.adaptor.constants import OSPREY_ADAPTOR 13 13 from osprey.worker.adaptor.hookspecs import osprey_hooks 14 14 from osprey.worker.lib.action_proto_deserializer import ActionProtoDeserializer 15 + from osprey.worker.lib.singletons import LABELS_PROVIDER 15 16 from osprey.worker.lib.storage.labels import LabelsProvider, LabelsServiceBase 16 17 from osprey.worker.sinks.sink.input_stream import BaseInputStream 17 18 from osprey.worker.sinks.sink.output_sink import BaseOutputSink, LabelOutputSink, MultiOutputSink ··· 39 40 return sum(seq, []) 40 41 41 42 42 - def has_labels_service() -> bool: 43 - return hasattr(plugin_manager.hook, 'register_labels_service') 43 + def _labels_service_or_provider_is_registered() -> bool: 44 + return hasattr(plugin_manager.hook, 'register_labels_service_or_provider') 44 45 45 46 46 47 def bootstrap_udfs() -> tuple[UDFRegistry, UDFHelpers]: ··· 48 49 udf_helpers = UDFHelpers() 49 50 50 51 udfs: List[Type[UDFBase[Any, Any]]] = flatten(plugin_manager.hook.register_udfs()) 51 - udf_registry = UDFRegistry.with_udfs(*udfs) 52 52 53 53 for udf in udfs: 54 54 if issubclass(udf, HasHelper): 55 55 udf_helpers.set_udf_helper(udf, udf.create_provider()) 56 56 57 57 # Label udfs should only be registered if the labels provider is available 58 - if has_labels_service(): 58 + labels_provider = LABELS_PROVIDER.instance() 59 + if labels_provider: 59 60 # Imports kinda circular. Imports here are to avoid that. 60 61 from osprey.engine.stdlib.udfs.labels import HasLabel, LabelAdd, LabelRemove 61 62 62 63 udfs.extend([HasLabel, LabelAdd, LabelRemove]) 63 64 64 - provider_or_service: LabelsProvider | LabelsServiceBase = plugin_manager.hook.register_labels_service() 65 - if isinstance(provider_or_service, LabelsProvider): 66 - labels_provider = provider_or_service 67 - else: 68 - labels_provider = LabelsProvider(labels_service=provider_or_service) 69 - 70 65 udf_helpers.set_udf_helper(HasLabel, labels_provider) 71 66 67 + udf_registry = UDFRegistry.with_udfs(*udfs) 68 + 72 69 return udf_registry, udf_helpers 73 70 74 71 ··· 77 74 sinks = flatten(plugin_manager.hook.register_output_sinks(config=config)) 78 75 79 76 # Label udfs should only be registered if the labels provider is available 80 - if has_labels_service(): 81 - sinks.append(LabelOutputSink(bootstrap_labels_provider())) 77 + labels_provider = LABELS_PROVIDER.instance() 78 + if labels_provider: 79 + sinks.append(LabelOutputSink(labels_provider)) 82 80 83 81 return MultiOutputSink(sinks) 84 82 85 83 86 - def bootstrap_labels_provider() -> LabelsProvider: 84 + def bootstrap_labels_provider(config: Config) -> LabelsProvider: 87 85 """ 88 - Generates a bootstrapped label provider using the registered plugin. 86 + NOTE: If you are looking to get a labels provider to use within Osprey, 87 + it is best practice to reference the LABELS_PROVIDER singleton in 88 + `osprey_worker/src/osprey/worker/lib/singletons.py` by calling 89 + `LABELS_PROVIDER.instance()` 90 + 91 + This way, any statefulness that implementers want for `LabelsServiceBase` 92 + or `LabelsProvider` will be respected across a given Osprey worker. 93 + 94 + - 95 + 96 + Generates a bootstrapped labels provider using the registered plugin. 97 + This will also call `initialize()` on the LabelsProvider, which will call 98 + `initialize()` on the LabelsServiceBase by default. 99 + 100 + This will throw an assertion error if `register_labels_service_or_provider` 101 + does not exist, i.e. in the event that a labels service / provider was not 102 + configured. 89 103 """ 90 104 load_all_osprey_plugins() 91 - if not has_labels_service(): 92 - raise NotImplementedError('Labels provider assumes register_labels_service is implemented.') 93 - provider_or_service: LabelsProvider | LabelsServiceBase = plugin_manager.hook.register_labels_service() 94 - if isinstance(provider_or_service, LabelsProvider): 95 - return provider_or_service 96 - return LabelsProvider(provider_or_service) 105 + if not _labels_service_or_provider_is_registered(): 106 + raise NotImplementedError( 107 + 'bootstrap_labels_provider() assumes `register_labels_service_or_provider` is implemented.' 108 + ) 109 + provider_or_service = plugin_manager.hook.register_labels_service_or_provider(config=config) 110 + assert isinstance(provider_or_service, LabelsProvider) or isinstance(provider_or_service, LabelsServiceBase), ( 111 + f"invariant: `register_labels_service_or_provider` has an invalid return type: '{type(provider_or_service)}';", 112 + "expected 'LabelsServiceBase' or 'LabelsProvider'", 113 + ) 114 + if isinstance(provider_or_service, LabelsServiceBase): 115 + provider = LabelsProvider(provider_or_service) 116 + provider.initialize() 117 + return provider 118 + # if we reach here, then a provider was supplied by the hook ! 119 + provider_or_service.initialize() 120 + return provider_or_service 97 121 98 122 99 123 def bootstrap_ast_validators() -> None:
+7 -18
osprey_worker/src/osprey/worker/cli/sinks.py
··· 6 6 # do not move this below other imports 7 7 patch_all(ddtrace_args={'cassandra': True, 'psycopg': True}) 8 8 9 - from osprey.worker.sinks.input_stream_chooser import get_rules_sink_input_stream 10 - from osprey.worker.sinks.sink.output_sink import LabelOutputSink 11 - 12 9 import signal 13 10 from uuid import uuid1 14 11 15 12 # this is required to avoid memory leaks with gRPC 16 13 from gevent import config as gevent_config 17 14 from osprey.worker.adaptor.plugin_manager import bootstrap_output_sinks 15 + from osprey.worker.sinks.input_stream_chooser import get_rules_sink_input_stream 18 16 19 17 gevent_config.track_greenlet_tree = False 20 18 ··· 40 38 from osprey.worker.lib.osprey_engine import bootstrap_engine, bootstrap_engine_with_helpers, get_sources_provider 41 39 from osprey.worker.lib.osprey_shared.logging import get_logger 42 40 from osprey.worker.lib.publisher import PubSubPublisher 43 - from osprey.worker.lib.singletons import CONFIG 41 + from osprey.worker.lib.singletons import CONFIG, LABELS_PROVIDER 44 42 from osprey.worker.lib.storage import postgres 45 43 from osprey.worker.lib.storage.bigtable import osprey_bigtable 46 44 from osprey.worker.lib.storage.bulk_label_task import BulkLabelTask ··· 248 246 249 247 @cli.command() 250 248 @click.option('--pooled/--no-pooled', default=True, help='Whether to run multiple bulk label sinks in a pool') 251 - @click.option( 252 - '--send-status-webhook/--no-send-status-webhook', 253 - default=True, 254 - help='Whether to send status webhook to channel specified in rules repo', 255 - ) 256 - def run_bulk_label_sink(pooled: bool, send_status_webhook: bool) -> None: 249 + def run_bulk_label_sink(pooled: bool) -> None: 257 250 config = init_config() 258 251 259 252 sentry_dsn = config.get_str(CONFIG_SENTRY_OTHER_SINKS_DSN, '') ··· 268 261 analytics_pubsub_topic_id = config.get_str('PUBSUB_ANALYTICS_EVENT_TOPIC_ID', 'osprey-analytics') 269 262 analytics_publisher = PubSubPublisher(analytics_pubsub_project_id, analytics_pubsub_topic_id) 270 263 271 - osprey_webhook_pubsub_project = config.get_str('PUBSUB_OSPREY_WEBHOOKS_PROJECT_ID', 'osprey-dev') 272 - osprey_webhook_pubsub_topic = config.get_str('PUBSUB_OSPREY_WEBHOOKS_TOPIC_ID', 'osprey-webhooks') 273 - webhooks_publisher = PubSubPublisher(osprey_webhook_pubsub_project, osprey_webhook_pubsub_topic) 274 - 275 - event_effects_output_sink = LabelOutputSink(engine, analytics_publisher, webhooks_publisher) 276 - 277 264 def factory() -> BulkLabelSink: 278 265 # NOTE: It's very important the input stream is created per-webhook sink 279 266 postgres_source = PostgresInputStream(BulkLabelTask, tags=['sink:bulklabelsink']) 267 + labels_provider = LABELS_PROVIDER.instance() 268 + if not labels_provider: 269 + raise AssertionError('BulkLabelSink cannot be instantiated without a labels provider') 280 270 return BulkLabelSink( 281 271 input_stream=postgres_source, 282 - event_effects_output_sink=event_effects_output_sink, 272 + labels_provider=labels_provider, 283 273 engine=engine, 284 274 analytics_publisher=analytics_publisher, 285 - send_status_webhook=send_status_webhook, 286 275 ) 287 276 288 277 if pooled:
+18 -74
osprey_worker/src/osprey/worker/lib/cli.py
··· 3 3 from pathlib import Path # noqa: E402 4 4 5 5 from osprey.engine.language_types.entities import EntityT 6 - from osprey.worker.adaptor.plugin_manager import bootstrap_labels_provider 7 6 from osprey.worker.lib.osprey_shared.labels import EntityLabelMutation, LabelStatus 8 - from osprey.worker.lib.patcher import patch_all # noqa: E402 7 + from osprey.worker.lib.patcher import patch_all 8 + from osprey.worker.lib.singletons import LABELS_PROVIDER # noqa: E402 9 9 10 10 patch_all() # please ensure this occurs before *any* other imports ! 11 11 12 12 13 13 import datetime # noqa: E402 14 - import ipaddress # noqa: E402 15 - import logging # noqa: E402 16 14 import os # noqa: E402 17 - import subprocess # noqa: E402 18 - from typing import Any, Optional, Set, Union # noqa: E402 15 + from typing import Any, Optional, Set # noqa: E402 19 16 20 17 import click # noqa: E402 21 - from click import Context, Parameter, ParamType # noqa: E402 22 18 from osprey.worker.lib.osprey_logging import configure_logging # noqa: E402 23 19 24 20 configure_logging() ··· 31 27 ) 32 28 from osprey.worker.lib.storage import ( # noqa: E402 33 29 access_audit_log, # noqa: E402 34 - entity_label_webhook, 35 30 labels, 36 31 stored_execution_result, 37 32 ) ··· 39 34 40 35 41 36 @click.group() 42 - def cli() -> None: 37 + def cli(): 43 38 pass 44 39 45 40 ··· 86 81 @cli.command() 87 82 @click.option('--auto-import/--no-auto-import', '-i', default=True) 88 83 def shell(auto_import: str) -> None: 89 - import os 90 84 import sys 91 85 92 86 osprey_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ··· 123 117 namespace_overrides = { 124 118 'labels': labels, 125 119 'access_audit_log': access_audit_log, 126 - 'entity_label_webhook': entity_label_webhook, 127 120 'stored_execution_result': stored_execution_result, 128 121 'EntityT': EntityT, 129 122 # 'Entity': Entity, ··· 250 243 reason_name=reason or 'CliLabelMutationWithoutEffects', 251 244 status=label_status, 252 245 description=description or 'Manually changed from the command line for debugging.', 253 - expires_at=(datetime.datetime.now() + datetime.timedelta(seconds=5)), 246 + expires_at=(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=5)), 254 247 ) 255 248 else: 256 249 mutation = EntityLabelMutation( ··· 260 253 description=description or 'Manually changed from the command line for debugging.', 261 254 ) 262 255 263 - result = bootstrap_labels_provider().apply_entity_label_mutations( 264 - entity=EntityT(type=entity_type, id=entity_id), mutations=[mutation] 256 + provider = LABELS_PROVIDER.instance() 257 + assert provider is not None, ( 258 + 'this CLI cannot be used because no labels service / provider is supplied for this osprey instance' 265 259 ) 266 260 261 + result = provider.apply_entity_label_mutations(entity=EntityT(type=entity_type, id=entity_id), mutations=[mutation]) 262 + 267 263 print(result) 268 264 269 265 ··· 306 302 """ 307 303 entity_ids = get_lines_from_file_as_set(file_path=entity_ids_file_path) 308 304 # I found that it *generally* took ~10ms per request; Multiply by 10.05 for 5% latency headroom 309 - expire_timestamp = datetime.datetime.now() + datetime.timedelta(milliseconds=int(len(entity_ids) * 10.05)) 305 + expire_timestamp = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta( 306 + milliseconds=int(len(entity_ids) * 10.05) 307 + ) 310 308 print(f'Found {len(entity_ids)} entity IDs to label.\nETA: {int(len(entity_ids) * 10.05 / 100)} second(s)') 311 309 if expire_instantly: 312 310 mutation = EntityLabelMutation( ··· 325 323 ) 326 324 327 325 progress_tracker: CliCommandProgressTracker = CliCommandProgressTracker(total_actions=len(entity_ids)) 328 - provider = bootstrap_labels_provider() 326 + provider = LABELS_PROVIDER.instance() 327 + assert provider is not None, ( 328 + 'this code cannot be used because no labels service / provider is supplied for this osprey instance' 329 + ) 329 330 for entity_id in entity_ids: 330 - provider.apply_entity_label_mutations( 331 + _ = provider.apply_entity_label_mutations( 331 332 entity=EntityT(type=entity_type, id=entity_id), 332 333 mutations=[mutation], 333 334 ) 334 335 progress_tracker.increment() 335 336 336 337 print(f'Bulk labelling complete! Total labels applied: {progress_tracker.total_actions}') 337 - 338 - 339 - class IpAddress(ParamType): 340 - def convert(self, value: Union[str], param: Optional[Parameter], ctx: Optional[Context]) -> Optional[str]: 341 - """Check that the value parses as an ip V4 or V6 address""" 342 - ipaddress.IPv4Address(value) 343 - return value 344 - 345 - 346 - @cli.command() 347 - @click.option('--from-sub-gcp-project', required=True, help='Source GCP project for the subscription') 348 - @click.option('--from-subscription', required=True, help='Subscription ID to read from') 349 - @click.option('--target-destination-gcp-topic', required=True, help='Target GCP project for the topic') 350 - @click.option('--destination-topic', required=True, help='Destination topic to publish to') 351 - def restream_subscription( 352 - from_sub_gcp_project: str, from_subscription: str, target_destination_gcp_topic: str, destination_topic: str 353 - ) -> None: 354 - """ 355 - Run the restreamer to restream from a subscription to a topic. 356 - This is most useful for running DLQs. 357 - This command runs ./pubsub_restream with the specified options. 358 - """ 359 - # Unset PUBSUB_EMULATOR_HOST if it is set 360 - if 'PUBSUB_EMULATOR_HOST' in os.environ: 361 - logging.info('Unsetting PUBSUB_EMULATOR_HOST') 362 - del os.environ['PUBSUB_EMULATOR_HOST'] 363 - 364 - cargo_bin_path = os.path.expanduser('~/.cargo/bin') 365 - if cargo_bin_path not in os.environ['PATH']: 366 - logging.info('Adding cargo bin path to PATH: %s', cargo_bin_path) 367 - os.environ['PATH'] += os.pathsep + cargo_bin_path 368 - 369 - command = [ 370 - 'pubsub_restream', 371 - '--use-gcloud-auth', 372 - f'--gcp-project={from_sub_gcp_project}', 373 - f'--subscription={from_subscription}', 374 - f'--dst-gcp-project={target_destination_gcp_topic}', 375 - f'--dst-topic={destination_topic}', 376 - ] 377 - 378 - logging.info('Running the pubsub restreaming command: %s', ' '.join(command)) 379 - 380 - try: 381 - result = subprocess.run(command, capture_output=True, text=True, check=True) 382 - logging.info('Command output:\n%s', result.stdout) 383 - except subprocess.CalledProcessError as e: 384 - logging.error('Command failed with return code %d', e.returncode) 385 - logging.error('Command stderr:\n%s', e.stderr) 386 - sys.exit(e.returncode) 387 - except Exception as e: 388 - logging.error('An unexpected error occurred: %s', str(e)) 389 - sys.exit(1) 390 - 391 - 392 - if __name__ == '__main__': 393 - cli()
+1
osprey_worker/src/osprey/worker/lib/data_exporters/validation_result_exporter.py
··· 47 47 48 48 def get_validation_result_exporter() -> BaseValidationResultExporter: 49 49 """setup and returns the validation result exporter that will run during source updates""" 50 + 50 51 config = CONFIG.instance() 51 52 52 53 # Use null exporter if disabled (for development)
+180 -15
osprey_worker/src/osprey/worker/lib/osprey_shared/labels.py
··· 1 1 import copy 2 2 from collections import UserDict 3 3 from dataclasses import dataclass, field, replace 4 - from datetime import datetime, timedelta 4 + from datetime import datetime, timedelta, timezone 5 5 from enum import Enum, IntEnum 6 - from typing import Dict, Optional 6 + from typing import Any, Dict, Self 7 7 8 8 from osprey.worker.lib.osprey_shared.logging import get_logger 9 9 from osprey.worker.lib.utils.request_utils import SessionWithRetries ··· 15 15 16 16 17 17 logger = get_logger(__name__) 18 + 19 + 20 + def _guarantee_utc_timezone_awareness(dt: datetime | None) -> datetime | None: 21 + if dt is None: 22 + return None 23 + if dt.tzinfo is None: 24 + return dt.replace(tzinfo=timezone.utc) 25 + return dt 18 26 19 27 20 28 class MutationDropReason(IntEnum): ··· 85 93 description: str = '' 86 94 """why the label was mutated""" 87 95 features: dict[str, str] = field(default_factory=dict) 88 - """features are injected into the description as k/v's, similar to how fstrings work. for example, 96 + """features are injected into the description as k/v's, similar to how fstrings work. for example, 89 97 the {you} in 'hello {you}' would be substituted as 'person' with a feature dict of {'you': 'person'}""" 90 98 created_at: datetime | None = None 91 99 """ ··· 94 102 expires_at: datetime | None = None 95 103 """marks when this label reason 'expires' 96 104 97 - if a LabelState.MANUALLY_REMOVED is applied with a reason that has a 1 day expiration, then 105 + if a LabelState.MANUALLY_REMOVED is applied with a reason that has a 1 day expiration, then 98 106 for 1 day, the label cannot be applied via LabelState.ADDED. all LabelState.ADDED attempts will be dropped. 99 107 100 108 if a given label state has multiple label reasons, all reasons would need to expire before the status/state 101 - is considered expired, too. 109 + is considered expired, too. 102 110 """ 103 111 104 112 def is_expired(self) -> bool: 105 - return bool(self.expires_at is not None and self.expires_at + timedelta(seconds=5) < datetime.now()) 113 + return bool(self.expires_at is not None and self.expires_at + timedelta(seconds=5) < datetime.now(timezone.utc)) 114 + 115 + def serialize(self) -> dict[str, Any]: 116 + """ 117 + serialize LabelReason to a JSON-compatible dict. 118 + converts datetime objects to ISO format strings. 119 + """ 120 + created_at = _guarantee_utc_timezone_awareness(self.created_at) 121 + expires_at = _guarantee_utc_timezone_awareness(self.expires_at) 122 + return { 123 + 'pending': self.pending, 124 + 'description': self.description, 125 + 'features': self.features, 126 + 'created_at': created_at.isoformat() if created_at else None, 127 + 'expires_at': expires_at.isoformat() if expires_at else None, 128 + } 129 + 130 + @classmethod 131 + def deserialize(cls, d: dict[str, Any]) -> Self: 132 + """ 133 + deserialize a dict into a LabelReason object. 134 + converts ISO format strings back to datetime objects. 135 + """ 136 + created_at = _guarantee_utc_timezone_awareness( 137 + datetime.fromisoformat(d['created_at']) if d.get('created_at') else None 138 + ) 139 + expires_at = _guarantee_utc_timezone_awareness( 140 + datetime.fromisoformat(d['expires_at']) if d.get('expires_at') else None 141 + ) 142 + return cls( 143 + pending=d.get('pending', False), 144 + description=d.get('description', ''), 145 + features=d.get('features', {}), 146 + created_at=created_at, 147 + expires_at=expires_at, 148 + ) 106 149 107 150 108 151 @dataclass ··· 159 202 def __repr__(self): 160 203 return f'LabelReasons({self.data})' 161 204 205 + def serialize(self) -> dict[str, dict[str, Any]]: 206 + """ 207 + serialize LabelReasons to a JSON-compatible dict. 208 + returns a dict mapping reason names to serialized LabelReason dicts. 209 + """ 210 + return {reason_name: reason.serialize() for reason_name, reason in self.items()} 211 + 212 + @classmethod 213 + def deserialize(cls, d: dict[str, dict[str, Any]]) -> Self: 214 + """ 215 + deserialize a dict into a LabelReasons object. 216 + expects a dict mapping reason names to LabelReason dicts. 217 + """ 218 + 219 + deserialized_reasons: dict[str, LabelReason] = {} 220 + for reason_name, reason_data in d.items(): 221 + try: 222 + deserialized_reasons[reason_name] = LabelReason.deserialize(reason_data) 223 + except Exception as e: 224 + raise TypeError(f'could not create LabelReasons from dict: failed to deserialize {reason_name}', e) 225 + 226 + return cls(deserialized_reasons) 227 + 162 228 163 229 @dataclass 164 230 class LabelStateInner: 165 231 status: LabelStatus 166 232 reasons: LabelReasons 167 233 234 + def serialize(self) -> dict[str, Any]: 235 + """ 236 + serialize LabelStateInner to a JSON-compatible dict. 237 + """ 238 + return { 239 + 'status': self.status.value, 240 + 'reasons': self.reasons.serialize(), 241 + } 242 + 243 + @classmethod 244 + def deserialize(cls, d: dict[str, Any]) -> Self: 245 + """ 246 + deserialize a dict into a LabelStateInner object. 247 + """ 248 + try: 249 + status = LabelStatus(d['status']) 250 + reasons = LabelReasons.deserialize(d['reasons']) 251 + return cls(status=status, reasons=reasons) 252 + except Exception as e: 253 + raise TypeError(f'could not create LabelStateInner from dict: {d}', e) 254 + 168 255 169 256 @dataclass 170 257 class LabelState: 171 258 status: LabelStatus 172 - """statuses dictate the way the current state behaves; certain statuses have priority over others 259 + """statuses dictate the way the current state behaves; certain statuses have priority over others 173 260 (see LabelStatus for more info)""" 174 261 175 262 reasons: LabelReasons ··· 177 264 reasons are why this label state was applied; it is a dict because there may be multiple, 178 265 with each reason being distinct based on it's reason name. 179 266 180 - reasons applied under the same name are merged (assuming the status has not changed), 267 + reasons applied under the same name are merged (assuming the status has not changed), 181 268 with precedence given to newer creaeted_at timestamps. 182 269 """ 183 270 ··· 199 286 if the weights are the *same*, then a merge of reasons is performed, which can also cause the expiration to be delayed. 200 287 """ 201 288 if not self.reasons: 202 - AssertionError(f'invariant: the label state {self} did not have any associated reasons') 203 - expires_at = datetime.min 289 + raise AssertionError(f'invariant: the label state {self} did not have any associated reasons') 290 + expires_at = datetime.min.replace(tzinfo=timezone.utc) 204 291 for reason in self.reasons.values(): 205 292 if reason.expires_at is None: 206 293 return None ··· 215 302 ) 216 303 217 304 def is_expired(self) -> bool: 218 - return bool(self.expires_at is not None and self.expires_at + timedelta(seconds=5) < datetime.now()) 305 + return bool(self.expires_at is not None and self.expires_at + timedelta(seconds=5) < datetime.now(timezone.utc)) 219 306 220 307 def _shift_current_state_to_previous_state(self) -> None: 221 308 if not self.reasons: ··· 253 340 254 341 return None 255 342 343 + def serialize(self) -> dict[str, Any]: 344 + """ 345 + serialize LabelState to a JSON-compatible dict. 346 + """ 347 + return { 348 + 'status': self.status.value, 349 + 'reasons': self.reasons.serialize(), 350 + 'previous_states': [prev_state.serialize() for prev_state in self.previous_states], 351 + } 352 + 353 + @classmethod 354 + def deserialize(cls, d: dict[str, Any]) -> Self: 355 + """ 356 + deserialize a dict into a LabelState object. 357 + """ 358 + 359 + try: 360 + status = LabelStatus(d['status']) 361 + reasons = LabelReasons.deserialize(d['reasons']) 362 + previous_states = [ 363 + LabelStateInner.deserialize(prev_state_data) for prev_state_data in d.get('previous_states', []) 364 + ] 365 + return cls(status=status, reasons=reasons, previous_states=previous_states) 366 + except Exception as e: 367 + raise TypeError(f'could not create LabelState from dict: {d}', e) 368 + 256 369 257 370 @dataclass 258 371 class EntityLabels: ··· 261 374 labels: Dict[str, LabelState] = field(default_factory=dict) 262 375 """a mapping of label names to their current states'""" 263 376 377 + def serialize(self) -> dict[str, Any]: 378 + """ 379 + given the current EntityLabels object, returns a dict that is 380 + json-serializable via json.dumps() 381 + """ 382 + return {'labels': {k: v.serialize() for k, v in self.labels.items()}} 383 + 384 + @classmethod 385 + def deserialize(cls, d: dict[str, dict[str, Any]]) -> Self: 386 + """ 387 + given a dict, deserializes it into an EntityLabels object 388 + """ 389 + if 'labels' in d: 390 + d = d['labels'] 391 + 392 + try: 393 + return cls(labels={k: LabelState.deserialize(v) for k, v in d.items()}) 394 + except Exception as e: 395 + raise TypeError(f'could not create EntityLabels from dict: {d};', e) 396 + 264 397 265 398 @dataclass 266 399 class EntityLabelMutation: ··· 291 424 pending=self.pending, 292 425 description=self.description, 293 426 features=self.features, 294 - created_at=datetime.now(), 295 - expires_at=self.expires_at, 427 + created_at=datetime.now(timezone.utc), 428 + expires_at=_guarantee_utc_timezone_awareness(self.expires_at), 296 429 ) 430 + 431 + def serialize(self) -> dict[str, Any]: 432 + expires_at = _guarantee_utc_timezone_awareness(self.expires_at) 433 + return { 434 + 'label_name': self.label_name, 435 + 'reason_name': self.reason_name, 436 + 'status': self.status, 437 + 'pending': self.pending, 438 + 'description': self.description, 439 + 'features': self.features, 440 + 'expires_at': expires_at.isoformat() if expires_at else None, 441 + } 297 442 298 443 299 444 @dataclass ··· 301 446 mutation: EntityLabelMutation 302 447 reason: MutationDropReason 303 448 449 + def serialize(self) -> dict[str, Any]: 450 + return { 451 + 'mutation': self.mutation.serialize(), 452 + 'reason': self.reason, 453 + } 454 + 304 455 305 456 @dataclass 306 457 class EntityLabelMutationsResult: ··· 309 460 all of the entity's labels post-mutation 310 461 """ 311 462 312 - old_entity_labels: Optional[EntityLabels] = None 463 + old_entity_labels: EntityLabels 313 464 """ 314 465 all of the entity's labels pre-mutation 315 466 """ ··· 326 477 327 478 labels_updated: list[str] = field(default_factory=list) 328 479 """ 329 - labels that had their state updated. this can include simply updating or 480 + labels that had their state updated. this can include simply updating or 330 481 appending to the reason 331 482 """ 332 483 ··· 335 486 mutations that were dropped for one reason or another. each dropped mutation is 336 487 given a drop reason 337 488 """ 489 + 490 + def serialize(self) -> dict[str, Any]: 491 + """ 492 + the only place this is currently needed is for the ui, which expects a specific json blob 493 + """ 494 + return { 495 + 'mutation_result': { 496 + 'added': self.labels_added, 497 + 'removed': self.labels_removed, 498 + 'updated': self.labels_updated, 499 + 'unchanged': list(set(mut.mutation.label_name for mut in self.dropped_mutations)), 500 + }, 501 + **self.new_entity_labels.serialize(), 502 + }
+39
osprey_worker/src/osprey/worker/lib/singletons.py
··· 4 4 from osprey.engine.stdlib import get_config_registry 5 5 from osprey.worker.lib.config import Config 6 6 from osprey.worker.lib.singleton import Singleton 7 + from osprey.worker.lib.storage.labels import LabelsProvider 7 8 8 9 if TYPE_CHECKING: 9 10 from osprey.worker.lib.osprey_engine import OspreyEngine ··· 20 21 21 22 22 23 ENGINE: Singleton['OspreyEngine'] = Singleton(_init_engine) 24 + 25 + 26 + def _init_labels_provider() -> LabelsProvider | None: 27 + """ 28 + a helper method to initialize the labels provider for the LABELS_PROVIDER singleton 29 + """ 30 + # the plugin manager imports this file to reference the labels provider singleton; 31 + # therefore, we need these to not cause circular imports 32 + from osprey.worker.adaptor.plugin_manager import ( 33 + _labels_service_or_provider_is_registered, 34 + bootstrap_labels_provider, 35 + ) 36 + # from osprey.worker.lib.singletons import CONFIG 37 + 38 + if not _labels_service_or_provider_is_registered(): 39 + return None 40 + config = CONFIG.instance() 41 + return bootstrap_labels_provider(config) 42 + 43 + 44 + LABELS_PROVIDER: Singleton['LabelsProvider | None'] = Singleton(_init_labels_provider) 45 + """ 46 + A Singleton that holds a `LabelsProvider`, if one is registered by the plugin manager. 47 + 48 + If not, this Singleton will hold `None`. This makes it always safe to call `LABELS_PROVIDER.instance()`, 49 + and enforces type checking rules when dealing with labels provider code, which may or may not be 50 + supplied by users of Osprey. 51 + 52 + An example use pattern might be: 53 + ```py 54 + labels_provider: LabelsProvider | None = LABELS_PROVIDER.instance() 55 + if labels_provider: 56 + # do labels provider things 57 + ``` 58 + 59 + Because this is a Singleton, implementers of `LabelsServiceBase` / `LabelsProvider` can implement statefulness 60 + and expect that the statefulness will be present across all references within a given Osprey worker. 61 + """
-59
osprey_worker/src/osprey/worker/lib/storage/bigtable.py
··· 89 89 return t 90 90 91 91 92 - class DataServicesBigTable(BigTableClient): 93 - """ 94 - A BigTable client wrapper for the data services BigTable instance 95 - """ 96 - 97 - def __init__(self) -> None: 98 - CONFIG.instance().register_configuration_callback(self.init_from_config) 99 - 100 - def init_from_config(self, config: Config) -> None: 101 - """Initialize this bigtable client once configuration is available.""" 102 - config = CONFIG.instance() 103 - self._gcp_project = config.get_str('DATA_SERVICES_GCP_PROJECT_ID', 'osprey-dev') 104 - self._instance_id = config.get_str('DATA_SERVICES_BIGTABLE_INSTANCE_ID', 'derived-sinks-ml-instance-dev') 105 - self._admin_enabled = config.get_bool('DATA_SERVICES_BIGTABLE_ADMIN_ENABLED', True) 106 - 107 - def table(self, table_name: str) -> Table: 108 - """ 109 - Get a Table instance for the requested table 110 - """ 111 - t = self._instance.table(table_name) 112 - pin_override( 113 - t, 114 - service='osprey-bigtable-client', 115 - tags={'bigtable_instance': self._instance.instance_id, 'table_id': t.table_id}, 116 - ) 117 - return t 118 - 119 - 120 - class DataStreamBigTable(BigTableClient): 121 - """ 122 - A BigTable client wrapper for the data services BigTable instance 123 - Instance: `stream` 124 - """ 125 - 126 - def __init__(self) -> None: 127 - CONFIG.instance().register_configuration_callback(self.init_from_config) 128 - 129 - def init_from_config(self, config: Config) -> None: 130 - """Initialize this bigtable client once configuration is available.""" 131 - config = CONFIG.instance() 132 - self._gcp_project = config.get_str('DATA_GCP_PROJECT_ID', 'osprey-dev') 133 - self._instance_id = config.get_str('DATA_STREAM_BIGTABLE_INSTANCE_ID', 'stream') 134 - self._admin_enabled = config.get_bool('DATA_STREAM_BIGTABLE_ADMIN_ENABLED', True) 135 - 136 - def table(self, table_name: str) -> Table: 137 - """ 138 - Get a Table instance for the requested table 139 - """ 140 - t = self._instance.table(table_name) 141 - pin_override( 142 - t, 143 - service='osprey-bigtable-client', 144 - tags={'bigtable_instance': self._instance.instance_id, 'table_id': t.table_id}, 145 - ) 146 - return t 147 - 148 - 149 92 osprey_bigtable = OspreyBigTable() 150 - data_services_bigtable = DataServicesBigTable() 151 - data_stream_bigtable = DataStreamBigTable()
+47 -21
osprey_worker/src/osprey/worker/lib/storage/labels.py
··· 25 25 26 26 27 27 class LabelsServiceBase(ABC): 28 - @abstractmethod 29 - def write_labels(self, entity: EntityT[Any], labels: EntityLabels) -> None: 30 - """ 31 - A standard write to the labels service that attempts to write the value to the primary key. 28 + """ 29 + An abstract class to represent an implementable labels service backend. 32 30 33 - This method may be retried upon exceptions, so keep that in mind when adding potentially 34 - non-idempotent behaviour. 31 + With the default LabelsProvider, read_labels and batch_read_labels are called 32 + *during* rule executions (or by the ui api). 33 + 34 + read_modify_write_labels_atomically is called post-rule execution (or by the ui api). 35 + """ 36 + 37 + def initialize(self) -> None: 35 38 """ 36 - raise NotImplementedError() 39 + This method will be called after the initialization of this labels service base. Any side effects 40 + that implementers may want, i.e. connecting to an external service, should be placed here. 41 + """ 42 + pass 37 43 38 44 @abstractmethod 39 45 def read_labels(self, entity: EntityT[Any]) -> EntityLabels: ··· 71 77 72 78 @abstractmethod 73 79 @contextmanager 74 - def get_labels_atomically(self, entity: EntityT[Any]) -> Generator[EntityLabels, None, None]: 80 + def read_modify_write_labels_atomically(self, entity: EntityT[Any]) -> Generator[EntityLabels, None, None]: 75 81 """ 76 - Context manager for atomic read-modify-write operations. 77 - Implementations should ensure the entity key is locked/in a transaction. 82 + Context manager for atomic read-modify-write operations. This generator should yield EntityLabels upon reading 83 + and should write the EntityLabels post-yield. 84 + 85 + IMPORTANT: Implementations should ensure the entity key is locked/in a transaction so that other read-modify-write 86 + calls (even across multiple workers) must wait. 87 + 88 + This code may be retried upon exceptions, so keep that in mind when adding potentially 89 + non-idempotent behaviour. 78 90 """ 79 91 pass 80 92 ··· 83 95 def __init__(self, labels_service: LabelsServiceBase): 84 96 self._labels_service = labels_service 85 97 98 + def initialize(self) -> None: 99 + """ 100 + This method will be called after the initialization of this labels provider. Any side effects 101 + that implementers may want, i.e. connecting to an external service, should be placed here. 102 + """ 103 + self._labels_service.initialize() 104 + 86 105 def _get_mutations_by_label_name_and_drop_conflicts( 87 106 self, mutations: Sequence[EntityLabelMutation] 88 107 ) -> tuple[dict[str, list[EntityLabelMutation]], list[DroppedEntityLabelMutation]]: ··· 141 160 return desired_states_by_label_name 142 161 143 162 def _compute_new_labels_from_mutations( 144 - self, old_labels: EntityLabels, mutations: Sequence[EntityLabelMutation] 163 + self, labels: EntityLabels, mutations: Sequence[EntityLabelMutation] 145 164 ) -> EntityLabelMutationsResult: 165 + """ 166 + given an entity's labels and a set of mutations, modify the labels based on the mutations' desired states. 167 + 168 + **this method WILL modify the labels object that is passed into it**. 169 + it will also return the pre-modification labels in EntityLabelMutationsResult.old_labels 170 + """ 146 171 (mutations_by_label_name, dropped_mutations) = self._get_mutations_by_label_name_and_drop_conflicts(mutations) 147 172 desired_states_by_label_name: dict[str, LabelStateInner] = self._get_desired_states_by_label_name( 148 173 mutations_by_label_name ··· 153 178 added: list[str] = [] 154 179 removed: list[str] = [] 155 180 updated: list[str] = [] 156 - new_labels = copy.deepcopy(old_labels) 181 + old_labels = copy.deepcopy(labels) 157 182 for label_name, desired_state in desired_states_by_label_name.items(): 158 - if label_name not in new_labels.labels: 159 - new_labels.labels[label_name] = LabelState.from_inner(desired_state) 183 + if label_name not in labels.labels: 184 + labels.labels[label_name] = LabelState.from_inner(desired_state) 160 185 added.append(label_name) 161 186 continue 162 - current_state = new_labels.labels[label_name] 187 + current_state = labels.labels[label_name] 163 188 prev_status = current_state.status 164 189 drop_reason = current_state.try_apply_desired_state(desired_state) 165 190 if drop_reason: ··· 182 207 183 208 # finally, return the result! duhh :D 184 209 return EntityLabelMutationsResult( 185 - new_entity_labels=new_labels, 210 + new_entity_labels=labels, 186 211 old_entity_labels=old_labels, 187 212 labels_added=added, 188 213 labels_removed=removed, ··· 199 224 def apply_entity_label_mutations( 200 225 self, entity: EntityT[Any], mutations: Sequence[EntityLabelMutation] 201 226 ) -> EntityLabelMutationsResult: 202 - with self._labels_service.get_labels_atomically(entity) as old_labels: 203 - result = self._compute_new_labels_from_mutations(old_labels, mutations) 204 - 205 - self._labels_service.write_labels(entity, result.new_entity_labels) 206 - 227 + try: 228 + with self._labels_service.read_modify_write_labels_atomically(entity) as entity_labels: 229 + result = self._compute_new_labels_from_mutations(entity_labels, mutations) 207 230 return result 231 + except Exception as e: 232 + logger.error(f'Could not read-modify-write labels for entity {entity.__repr__()}:', e) 233 + raise e 208 234 209 235 def cache_ttl(self) -> Optional[timedelta]: 210 236 return timedelta(minutes=1)
+1 -2
osprey_worker/src/osprey/worker/lib/storage/postgres.py
··· 8 8 9 9 from flask import Flask, has_request_context # noqa: E402 10 10 from osprey.worker.lib.config import Config # noqa: E402 11 + from osprey.worker.lib.singletons import CONFIG # noqa: E402 11 12 from sqlalchemy import MetaData # noqa: E402 12 13 from sqlalchemy.engine.url import make_url # noqa: E402 13 14 from sqlalchemy.ext.declarative import declarative_base # noqa: E402 14 15 from sqlalchemy.orm import Session, sessionmaker # noqa: E402 15 16 from sqlalchemy.orm.scoping import ThreadLocalRegistry # type: ignore # missing stub # noqa: E402 16 - 17 - from ..singletons import CONFIG # noqa: E402 18 17 19 18 metadata = MetaData() 20 19 Model = declarative_base(name='Model', metadata=metadata)
+17 -13
osprey_worker/src/osprey/worker/lib/storage/tests/test_labels.py
··· 21 21 def __init__(self): 22 22 self.storage: dict[tuple[str, str], EntityLabels] = {} 23 23 24 - def write_labels(self, entity: EntityT[Any], labels: EntityLabels) -> None: 25 - key = (entity.type, entity.id) 26 - self.storage[key] = labels 27 - 28 24 def read_labels(self, entity: EntityT[Any]) -> EntityLabels: 29 25 key = (entity.type, entity.id) 30 26 return self.storage.get(key, EntityLabels()) 31 27 32 - def get_labels_atomically(self, entity: EntityT[Any]): 33 - """Context manager that yields labels for atomic operations""" 28 + def read_modify_write_labels_atomically(self, entity: EntityT[Any]): 29 + """Context manager that yields labels for atomic read-modify-write operations""" 34 30 from contextlib import contextmanager 35 31 36 32 @contextmanager 37 33 def _context(): 38 - yield self.read_labels(entity) 34 + # Read current labels 35 + labels = self.read_labels(entity) 36 + # Yield for modification 37 + yield labels 38 + # Write modified labels back 39 + key = (entity.type, entity.id) 40 + self.storage[key] = labels 39 41 40 42 return _context() 41 43 ··· 347 349 348 350 349 351 def test_compute_new_labels_from_mutations_preserves_old_labels(labels_provider: LabelsProvider, now: datetime): 350 - """Test that old_labels reference is preserved in result""" 352 + """Test that old_labels snapshot is captured in result before modifications""" 351 353 old_labels = EntityLabels( 352 354 labels={ 353 355 'existing_label': LabelState( ··· 370 372 371 373 result = labels_provider._compute_new_labels_from_mutations(old_labels, mutations) 372 374 373 - # Old labels should be preserved 374 - assert result.old_entity_labels == old_labels 375 - # Old labels should not be modified 376 - assert 'new_label' not in old_labels.labels 377 - assert 'new_label' in result.new_entity_labels.labels 375 + # result.old_entity_labels should be a snapshot before modifications 376 + assert 'new_label' not in result.old_entity_labels.labels 377 + assert 'existing_label' in result.old_entity_labels.labels 378 + # The input labels parameter IS modified in place 379 + assert 'new_label' in old_labels.labels 380 + # result.new_entity_labels should reference the same modified object 381 + assert result.new_entity_labels is old_labels 378 382 379 383 380 384 def test_compute_new_labels_from_mutations_empty_mutations(labels_provider: LabelsProvider, now: datetime):
+6 -4
osprey_worker/src/osprey/worker/sinks/sink/bulk_label_sink.py
··· 5 5 import sentry_sdk 6 6 from osprey.engine.language_types.entities import EntityT 7 7 from osprey.engine.language_types.labels import LabelStatus 8 - from osprey.worker.adaptor.plugin_manager import bootstrap_labels_provider 9 8 from osprey.worker.lib.bulk_label import TaskStatus 10 9 from osprey.worker.lib.discovery.exceptions import ServiceUnavailable 11 10 from osprey.worker.lib.instruments import metrics ··· 14 13 from osprey.worker.lib.osprey_shared.logging import get_logger 15 14 from osprey.worker.lib.pigeon.exceptions import RPCException 16 15 from osprey.worker.lib.publisher import BasePublisher 16 + from osprey.worker.lib.singletons import LABELS_PROVIDER 17 17 from osprey.worker.lib.storage.bulk_label_task import BASE_DELAY_SECONDS, MAX_ATTEMPTS, BulkLabelTask 18 18 from osprey.worker.lib.storage.labels import LabelsProvider 19 19 from osprey.worker.sinks.sink.input_stream import BaseInputStream ··· 69 69 labels_provider: LabelsProvider, 70 70 engine: OspreyEngine, 71 71 analytics_publisher: BasePublisher, 72 - send_status_webhook: bool = True, 73 72 ): 74 73 self._input_stream = input_stream 75 74 self._labels_provider = labels_provider 76 75 self._engine = engine 77 76 self._metric_tags = [f'sink:{self.__class__.__name__}'] 78 77 self._analytics_publisher = analytics_publisher 79 - self._send_status_webhook = send_status_webhook 80 78 81 79 def run(self) -> None: 82 80 for task in self._input_stream: ··· 424 422 rows_rolled_back = 0 425 423 426 424 feature_name_to_entity_type_mapping = engine.get_feature_name_to_entity_type_mapping() 427 - labels_provider = bootstrap_labels_provider() 425 + labels_provider = LABELS_PROVIDER.instance() 426 + if labels_provider is None: 427 + raise NotImplementedError( 428 + 'this code cannot be used because no labels service / provider is supplied for this osprey instance' 429 + ) 428 430 feature_name = task.dimension 429 431 entity_type = feature_name_to_entity_type_mapping[feature_name] 430 432
+9 -11
osprey_worker/src/osprey/worker/ui_api/osprey/views/entities.py
··· 1 1 from typing import Any 2 2 3 3 from flask import Blueprint, abort, jsonify 4 - from osprey.worker.adaptor.plugin_manager import bootstrap_labels_provider, has_labels_service 5 - from osprey.worker.lib.osprey_shared.labels import EntityLabelMutation 4 + from osprey.worker.lib.osprey_shared.labels import EntityLabelMutation, EntityLabelMutationsResult 5 + from osprey.worker.lib.singletons import LABELS_PROVIDER 6 6 from osprey.worker.ui_api.osprey.lib.abilities import ( 7 7 CanMutateEntities, 8 8 CanMutateLabels, ··· 30 30 def get_labels_for_entity(request_model: GetLabelsForEntityRequest) -> Any: 31 31 require_ability_with_request(request_model, CanViewLabelsForEntity) 32 32 33 - if not has_labels_service(): 33 + labels_provider = LABELS_PROVIDER.instance() 34 + if not labels_provider: 34 35 return { 35 36 'labels': {}, 36 37 # this field is deprecated 37 38 'expires_at': None, 38 39 } 39 - 40 - labels_provider = bootstrap_labels_provider() 41 40 42 41 entity_labels = labels_provider.get_from_service(key=request_model.entity) 43 42 # Filter out all but the allowed labels ··· 47 46 if hasattr(entity_labels, 'labels'): 48 47 for label_name, label_state in entity_labels.labels.items(): 49 48 if ability and ability.item_is_allowed(label_name): 50 - response_labels[label_name] = label_state 49 + response_labels[label_name] = label_state.serialize() 51 50 52 51 return { 53 52 'labels': response_labels, ··· 69 68 require_ability_with_request(request_model, CanMutateEntities) 70 69 require_ability_with_request(request_model, CanMutateLabels) 71 70 72 - if not has_labels_service(): 71 + labels_provider = LABELS_PROVIDER.instance() 72 + if not labels_provider: 73 73 return abort(501, 'Labels Provider Not Found') 74 74 75 - labels_provider = bootstrap_labels_provider() 76 - 77 75 can_mutate_labels_ability = get_current_user().get_ability(CanMutateLabels) 78 76 # We can make this assertion because of the above line that requires CanMutateLabel for the request 79 77 assert can_mutate_labels_ability is not None ··· 92 90 ) 93 91 mutations.append(entity_mutation) 94 92 95 - result = labels_provider.apply_entity_label_mutations(request_model.entity, mutations) 93 + result: EntityLabelMutationsResult = labels_provider.apply_entity_label_mutations(request_model.entity, mutations) 96 94 97 - return result 95 + return result.serialize()