Mirror of https://github.com/roostorg/osprey github.com/roostorg/osprey
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

native support for labels (#26)

Co-authored-by: Ethan Breder <exbreder@gmail.com>

authored by

ayu
Ethan Breder
and committed by
GitHub
79a07d78 7c379c1d

+1304 -1254
+1 -10
osprey_worker/src/osprey/engine/executor/execution_context.py
··· 3 3 import traceback 4 4 from collections import defaultdict 5 5 from dataclasses import dataclass, field 6 - from datetime import datetime, timedelta 6 + from datetime import datetime 7 7 from typing import ( 8 8 TYPE_CHECKING, 9 9 Any, ··· 12 12 Iterable, 13 13 List, 14 14 Mapping, 15 - Optional, 16 15 Sequence, 17 16 Set, 18 17 Type, ··· 39 38 from osprey.engine.language_types.verdicts import VerdictEffect 40 39 from osprey.engine.utils.types import add_slots, cached_property 41 40 from osprey.rpc.common.v1.verdicts_pb2 import Verdicts 42 - from osprey.worker.lib.osprey_shared.labels import EntityMutation 43 41 from result import Result, UnwrapError 44 42 45 43 if TYPE_CHECKING: ··· 65 63 66 64 class ExternalServiceException(Exception): 67 65 """Indicates that an external service call failed or returned unexpected data.""" 68 - 69 - 70 - @add_slots 71 - @dataclass 72 - class ExtendedEntityMutation: 73 - mutation: EntityMutation 74 - delay_action_by: Optional[timedelta] 75 66 76 67 77 68 class ExecutionContext:
+6
osprey_worker/src/osprey/engine/executor/external_service_utils.py
··· 23 23 Returns a time to live for items in the cache. By default, KVs are cached indefinitely. 24 24 25 25 To have cache entries auto-expire, override this method in your external service definition. 26 + 27 + Note that timedeltas can accept negative values to represent the past, but only on the days field. 28 + You *can* use timedelta(seconds=0) to disable caching, but a negative time delta *ensures* that even 29 + if a time shift occurs (such as daylight savings), the cache_ttl will still be immediate. 30 + 31 + Therefore, to disable the read cache, it is recommended to set this to `timedelta(days=-1)` 26 32 """ 27 33 return None 28 34
+3 -22
osprey_worker/src/osprey/engine/language_types/labels.py
··· 1 1 from dataclasses import dataclass 2 2 from datetime import timedelta 3 - from enum import IntEnum 4 3 from typing import Any, List, Optional, Self, cast 5 4 6 5 from osprey.engine.executor.custom_extracted_features import CustomExtractedFeature ··· 9 8 ENTITY_LABEL_MUTATION_DIMENSION_NAME, 10 9 ENTITY_LABEL_MUTATION_DIMENSION_VALUE, 11 10 ) 11 + from osprey.worker.lib.osprey_shared.labels import LabelStatus 12 12 13 13 from .entities import EntityT 14 14 from .rules import RuleT, add_slots 15 15 16 16 17 - class LabelStatus(IntEnum): 18 - ADDED = 0 19 - REMOVED = 1 20 - MANUALLY_ADDED = 2 21 - MANUALLY_REMOVED = 3 22 - 23 - def effective_label_status(self) -> 'LabelStatus': 24 - """ 25 - Returns the effective status of the label, which is what the upstreams that are observing label 26 - status changes will see. Which is to say, the upstreams will currently not see if the label status was 27 - manually added or manually removed, just that it was added or removed. 28 - """ 29 - match self: 30 - case LabelStatus.ADDED | LabelStatus.MANUALLY_ADDED: 31 - return LabelStatus.ADDED 32 - case LabelStatus.REMOVED | LabelStatus.MANUALLY_REMOVED: 33 - return LabelStatus.REMOVED 34 - 35 - 36 17 @add_slots 37 18 @dataclass 38 19 class LabelEffect(EffectToCustomExtractedFeatureBase[List[str]]): ··· 51 32 expires_after: Optional[timedelta] = None 52 33 """If set, the label effect has a timed expiration, which means that the reason will expire after this time.""" 53 34 54 - delay_action_by: Optional[timedelta] = None 55 - """If set, the propagation of the effect to the upstream (if configured) will be delayed.S""" 35 + # delay_action_by: Optional[timedelta] = None 36 + # """If set, the propagation of the effect to the upstream (if configured via LabelsService.after_add or LabelsService.after_remove) will be delayed.""" 56 37 57 38 dependent_rule: Optional[RuleT] = None 58 39 """If set, the effect will only be applied if the dependent rule evaluates to true."""
+16 -8
osprey_worker/src/osprey/engine/stdlib/udfs/labels.py
··· 24 24 from osprey.engine.utils.get_closest_string_within_threshold import ( 25 25 get_closest_string_within_threshold, 26 26 ) 27 - from osprey.worker.lib.osprey_shared.labels import Labels 28 - from osprey.worker.lib.storage.labels import LabelProvider 27 + from osprey.worker.lib.osprey_shared.labels import EntityLabels 28 + from osprey.worker.lib.storage.labels import LabelsProvider 29 29 from result import Err, Ok, Result 30 30 31 31 32 - # TODO: move back to labels.py once we actually make it stdlib 33 32 class LabelArguments(ArgumentsBase): 34 33 entity: EntityT[Any] 35 34 """An entity to mutate a label on.""" 36 35 label: ConstExpr[str] 37 36 """The label to mutate.""" 38 - delay_action_by: Optional[TimeDeltaT] = None 39 - """Optional: Delays a label action by a specified `TimeDeltaT` time.""" 37 + # NOTE(ayubun): delayed actions are removed; they are legacy code from when discord used osprey 38 + # to trigger webhooks upon label adds/removes. 39 + # 40 + # we may eventually add something *similar* to this in the future? but i suspect 41 + # that a better abstraction would be to have any sort of "external impact" come 42 + # from verdicts, which were created to be an output (whereas labels were created 43 + # to simply store state, thus making label webhooks a leaky abstraction) 44 + # delay_action_by: Optional[TimeDeltaT] = None 45 + # """Optional: Delays a label action by a specified `TimeDeltaT` time.""" 40 46 apply_if: Optional[RuleT] = None 41 47 """Optional: Conditions that must be met for the label mutation to succeed.""" 42 48 expires_after: Optional[TimeDeltaT] = None ··· 49 55 status=status, 50 56 name=arguments.label.value, 51 57 expires_after=TimeDeltaT.inner_from_optional(arguments.expires_after), 52 - delay_action_by=TimeDeltaT.inner_from_optional(arguments.delay_action_by), 58 + # delay_action_by=TimeDeltaT.inner_from_optional(arguments.delay_action_by), 53 59 dependent_rule=arguments.apply_if, 54 60 # NOTE: This is fairly significant, if this call node has an `apply_if` ast, but 55 61 # the resolved apply_if is None, that means that the evaluation of the rule failed. ··· 123 129 desired_status: Optional[_SimpleStatus] 124 130 125 131 126 - class HasLabel(HasHelperInternal[LabelProvider], BatchableUDFBase[HasLabelArguments, bool, BatchableHasLabelArguments]): 132 + class HasLabel( 133 + HasHelperInternal[LabelsProvider], BatchableUDFBase[HasLabelArguments, bool, BatchableHasLabelArguments] 134 + ): 127 135 """Returns `True` if the specified label is currently present in a given non-expired state on a provided Entity.""" 128 136 129 137 category = UdfCategories.ENGINE ··· 157 165 validation_context.add_error(message='unknown label', span=arguments.label.argument_span, hint=hint) 158 166 159 167 def _execute( 160 - self, execution_context: ExecutionContext, arguments: BatchableHasLabelArguments, entity_labels: Labels 168 + self, execution_context: ExecutionContext, arguments: BatchableHasLabelArguments, entity_labels: EntityLabels 161 169 ) -> bool: 162 170 desired_manual = _ManualType.get(arguments.manual) 163 171 desired_delay = TimeDeltaT.inner_from_optional(arguments.min_label_age)
+2 -2
osprey_worker/src/osprey/engine/stdlib/udfs/tests/test_labels.py
··· 28 28 from osprey.engine.udf.registry import UDFRegistry 29 29 from osprey.engine.utils.proto_utils import datetime_to_timestamp 30 30 from osprey.rpc.labels.v1.service_pb2 import LabelReason, Labels, LabelState, LabelStatus 31 - from osprey.worker.lib.storage.labels import LabelProvider 31 + from osprey.worker.lib.storage.labels import LabelsProvider 32 32 33 33 if TYPE_CHECKING: 34 34 from osprey.rpc.labels.v1.service_pb2 import LabelStatusValue ··· 49 49 ] 50 50 51 51 52 - class StaticLabelProvider(LabelProvider): 52 + class StaticLabelProvider(LabelsProvider): 53 53 def __init__(self, entity_labels: Dict[EntityT[Any], Labels]) -> None: 54 54 self._entity_labels = entity_labels 55 55
+175 -253
osprey_worker/src/osprey/engine/stdlib/udfs/tests/test_rules.py
··· 13 13 RunValidationFunction, 14 14 ) 15 15 from osprey.engine.executor.execution_context import ( 16 + EntityLabelMutation, 16 17 ExecutionContext, 17 - ExtendedEntityMutation, 18 18 ) 19 19 from osprey.engine.language_types.entities import EntityT 20 20 from osprey.engine.stdlib.udfs.entity import Entity ··· 24 24 from osprey.engine.udf.arguments import ArgumentsBase 25 25 from osprey.engine.udf.base import UDFBase 26 26 from osprey.engine.udf.registry import UDFRegistry 27 - from osprey.engine.utils.proto_utils import datetime_to_timestamp 28 - from osprey.rpc.labels.v1.service_pb2 import EntityMutation, LabelStatus 27 + from osprey.rpc.labels.v1.service_pb2 import LabelStatus 29 28 from osprey.worker.sinks.sink.output_sink import _get_label_effects_from_result 30 29 31 30 # Moved here because WhenRules is not included in the MVP yet ··· 126 125 entity=E1, 127 126 label='foo', 128 127 expires_after=TimeDelta(minutes=3), 129 - delay_action_by=TimeDelta(minutes=1), 130 128 ), 131 129 # Split outputs for different entities from same WhenRules 132 - LabelRemove(entity=E2, label='garply', delay_action_by=TimeDelta(seconds=30)), 130 + LabelRemove(entity=E2, label='garply'), 133 131 LabelRemove(entity=E2, label='qux'), 134 132 ], 135 133 ) ··· 159 157 action_time=now, 160 158 ) 161 159 162 - expected: Mapping[EntityT[Any], Sequence[ExtendedEntityMutation]] = { 160 + expected: Mapping[EntityT[Any], Sequence[EntityLabelMutation]] = { 163 161 EntityT(type='MyEntity', id='entity 1'): [ 164 - ExtendedEntityMutation( 165 - mutation=EntityMutation( 166 - label_name='foo', 167 - reason_name='RSimple1', 168 - status=LabelStatus.ADDED, 169 - pending=False, 170 - description='simple rule 1', 171 - features={}, 172 - expires_at=datetime_to_timestamp(now + timedelta(seconds=45)), 173 - ), 174 - delay_action_by=None, 162 + EntityLabelMutation( 163 + label_name='foo', 164 + reason_name='RSimple1', 165 + status=LabelStatus.ADDED, 166 + pending=False, 167 + description='simple rule 1', 168 + features={}, 169 + expires_at=now + timedelta(seconds=45), 175 170 ), 176 - ExtendedEntityMutation( 177 - mutation=EntityMutation( 178 - label_name='foo', 179 - reason_name='RWithFString', 180 - status=LabelStatus.ADDED, 181 - pending=False, 182 - description='fstring rule description with name {UserName}', 183 - features={'UserName': 'Wumpus'}, 184 - expires_at=datetime_to_timestamp(now + timedelta(seconds=45)), 185 - ), 186 - delay_action_by=None, 171 + EntityLabelMutation( 172 + label_name='foo', 173 + reason_name='RWithFString', 174 + status=LabelStatus.ADDED, 175 + pending=False, 176 + description='fstring rule description with name {UserName}', 177 + features={'UserName': 'Wumpus'}, 178 + expires_at=now + timedelta(seconds=45), 187 179 ), 188 - ExtendedEntityMutation( 189 - mutation=EntityMutation( 190 - label_name='bar', 191 - reason_name='RSimple1', 192 - status=LabelStatus.REMOVED, 193 - pending=False, 194 - description='simple rule 1', 195 - features={}, 196 - expires_at=None, 197 - ), 198 - delay_action_by=None, 180 + EntityLabelMutation( 181 + label_name='bar', 182 + reason_name='RSimple1', 183 + status=LabelStatus.REMOVED, 184 + pending=False, 185 + description='simple rule 1', 186 + features={}, 187 + expires_at=None, 199 188 ), 200 - ExtendedEntityMutation( 201 - mutation=EntityMutation( 202 - label_name='bar', 203 - reason_name='RWithFString', 204 - status=LabelStatus.REMOVED, 205 - pending=False, 206 - description='fstring rule description with name {UserName}', 207 - features={'UserName': 'Wumpus'}, 208 - expires_at=None, 209 - ), 210 - delay_action_by=None, 189 + EntityLabelMutation( 190 + label_name='bar', 191 + reason_name='RWithFString', 192 + status=LabelStatus.REMOVED, 193 + pending=False, 194 + description='fstring rule description with name {UserName}', 195 + features={'UserName': 'Wumpus'}, 196 + expires_at=None, 211 197 ), 212 - ExtendedEntityMutation( 213 - mutation=EntityMutation( 214 - label_name='baz', 215 - reason_name='RSimple2', 216 - status=LabelStatus.ADDED, 217 - pending=False, 218 - description='simple rule 2', 219 - features={}, 220 - expires_at=datetime_to_timestamp(now + timedelta(minutes=5)), 221 - ), 222 - delay_action_by=None, 198 + EntityLabelMutation( 199 + label_name='baz', 200 + reason_name='RSimple2', 201 + status=LabelStatus.ADDED, 202 + pending=False, 203 + description='simple rule 2', 204 + features={}, 205 + expires_at=now + timedelta(minutes=5), 223 206 ), 224 - ExtendedEntityMutation( 225 - mutation=EntityMutation( 226 - label_name='foo', 227 - reason_name='RSimple2', 228 - status=LabelStatus.ADDED, 229 - pending=False, 230 - description='simple rule 2', 231 - features={}, 232 - expires_at=datetime_to_timestamp(now + timedelta(minutes=3)), 233 - ), 234 - delay_action_by=timedelta(minutes=1), 207 + EntityLabelMutation( 208 + label_name='foo', 209 + reason_name='RSimple2', 210 + status=LabelStatus.ADDED, 211 + pending=False, 212 + description='simple rule 2', 213 + features={}, 214 + expires_at=now + timedelta(minutes=3), 235 215 ), 236 - ExtendedEntityMutation( 237 - mutation=EntityMutation( 238 - label_name='garply', 239 - reason_name='RSimple3', 240 - status=LabelStatus.ADDED, 241 - pending=False, 242 - description='simple rule 3', 243 - features={}, 244 - expires_at=datetime_to_timestamp(now + timedelta(seconds=30)), 245 - ), 246 - delay_action_by=None, 216 + EntityLabelMutation( 217 + label_name='garply', 218 + reason_name='RSimple3', 219 + status=LabelStatus.ADDED, 220 + pending=False, 221 + description='simple rule 3', 222 + features={}, 223 + expires_at=now + timedelta(seconds=30), 247 224 ), 248 225 ], 249 226 EntityT(type='MyEntity', id='entity 2'): [ 250 - ExtendedEntityMutation( 251 - mutation=EntityMutation( 252 - label_name='garply', 253 - reason_name='RSimple1', 254 - status=LabelStatus.ADDED, 255 - pending=False, 256 - description='simple rule 1', 257 - features={}, 258 - expires_at=datetime_to_timestamp(now + timedelta(seconds=30)), 259 - ), 260 - delay_action_by=None, 227 + EntityLabelMutation( 228 + label_name='garply', 229 + reason_name='RSimple1', 230 + status=LabelStatus.ADDED, 231 + pending=False, 232 + description='simple rule 1', 233 + features={}, 234 + expires_at=now + timedelta(seconds=30), 261 235 ), 262 - ExtendedEntityMutation( 263 - mutation=EntityMutation( 264 - label_name='garply', 265 - reason_name='RSimple2', 266 - status=LabelStatus.REMOVED, 267 - pending=False, 268 - description='simple rule 2', 269 - features={}, 270 - expires_at=None, 271 - ), 272 - delay_action_by=timedelta(seconds=30), 236 + EntityLabelMutation( 237 + label_name='garply', 238 + reason_name='RSimple2', 239 + status=LabelStatus.REMOVED, 240 + pending=False, 241 + description='simple rule 2', 242 + features={}, 243 + expires_at=None, 273 244 ), 274 - ExtendedEntityMutation( 275 - mutation=EntityMutation( 276 - label_name='qux', 277 - reason_name='RSimple2', 278 - status=LabelStatus.REMOVED, 279 - pending=False, 280 - description='simple rule 2', 281 - features={}, 282 - expires_at=None, 283 - ), 284 - delay_action_by=None, 245 + EntityLabelMutation( 246 + label_name='qux', 247 + reason_name='RSimple2', 248 + status=LabelStatus.REMOVED, 249 + pending=False, 250 + description='simple rule 2', 251 + features={}, 252 + expires_at=None, 285 253 ), 286 - ExtendedEntityMutation( 287 - mutation=EntityMutation( 288 - label_name='garply', 289 - reason_name='RWithFString', 290 - status=LabelStatus.ADDED, 291 - pending=False, 292 - description='fstring rule description with name {UserName}', 293 - features={'UserName': 'Wumpus'}, 294 - expires_at=None, 295 - ), 296 - delay_action_by=None, 254 + EntityLabelMutation( 255 + label_name='garply', 256 + reason_name='RWithFString', 257 + status=LabelStatus.ADDED, 258 + pending=False, 259 + description='fstring rule description with name {UserName}', 260 + features={'UserName': 'Wumpus'}, 261 + expires_at=None, 297 262 ), 298 263 ], 299 264 } ··· 340 305 action_time=now, 341 306 ) 342 307 343 - expected: Mapping[EntityT[Any], Sequence[ExtendedEntityMutation]] = { 308 + expected: Mapping[EntityT[Any], Sequence[EntityLabelMutation]] = { 344 309 EntityT(type='MyEntity', id='entity 1'): [ 345 - ExtendedEntityMutation( 346 - mutation=EntityMutation( 347 - label_name='foo', 348 - reason_name='RT1', 349 - status=LabelStatus.ADDED, 350 - pending=False, 351 - description='rule 1', 352 - features={}, 353 - ), 354 - delay_action_by=None, 310 + EntityLabelMutation( 311 + label_name='foo', 312 + reason_name='RT1', 313 + status=LabelStatus.ADDED, 314 + pending=False, 315 + description='rule 1', 316 + features={}, 355 317 ), 356 - ExtendedEntityMutation( 357 - mutation=EntityMutation( 358 - label_name='foo', 359 - reason_name='RT2', 360 - status=LabelStatus.ADDED, 361 - pending=False, 362 - description='rule 2', 363 - features={}, 364 - ), 365 - delay_action_by=None, 318 + EntityLabelMutation( 319 + label_name='foo', 320 + reason_name='RT2', 321 + status=LabelStatus.ADDED, 322 + pending=False, 323 + description='rule 2', 324 + features={}, 366 325 ), 367 326 ], 368 327 EntityT(type='MyEntity', id='entity 3'): [ 369 - ExtendedEntityMutation( 370 - mutation=EntityMutation( 371 - label_name='baz', 372 - reason_name='RT1', 373 - status=LabelStatus.ADDED, 374 - pending=False, 375 - description='rule 1', 376 - features={}, 377 - ), 378 - delay_action_by=None, 328 + EntityLabelMutation( 329 + label_name='baz', 330 + reason_name='RT1', 331 + status=LabelStatus.ADDED, 332 + pending=False, 333 + description='rule 1', 334 + features={}, 379 335 ), 380 336 ], 381 337 EntityT(type='MyEntity', id='entity 5'): [ 382 - ExtendedEntityMutation( 383 - mutation=EntityMutation( 384 - label_name='uwu', 385 - reason_name='RT1', 386 - status=LabelStatus.ADDED, 387 - pending=False, 388 - description='rule 1', 389 - features={}, 390 - ), 391 - delay_action_by=None, 338 + EntityLabelMutation( 339 + label_name='uwu', 340 + reason_name='RT1', 341 + status=LabelStatus.ADDED, 342 + pending=False, 343 + description='rule 1', 344 + features={}, 392 345 ), 393 - ExtendedEntityMutation( 394 - mutation=EntityMutation( 395 - label_name='uwu', 396 - reason_name='RT2', 397 - status=LabelStatus.ADDED, 398 - pending=False, 399 - description='rule 2', 400 - features={}, 401 - ), 402 - delay_action_by=None, 346 + EntityLabelMutation( 347 + label_name='uwu', 348 + reason_name='RT2', 349 + status=LabelStatus.ADDED, 350 + pending=False, 351 + description='rule 2', 352 + features={}, 403 353 ), 404 354 ], 405 355 } ··· 437 387 action_time=now, 438 388 ) 439 389 440 - expected: Mapping[EntityT[Any], Sequence[ExtendedEntityMutation]] = { 390 + expected: Mapping[EntityT[Any], Sequence[EntityLabelMutation]] = { 441 391 EntityT(type='MyEntity', id='entity 1'): [ 442 - ExtendedEntityMutation( 443 - mutation=EntityMutation( 444 - label_name='bar', 445 - reason_name='R1', 446 - status=LabelStatus.ADDED, 447 - pending=False, 448 - description='rule 1', 449 - features={}, 450 - ), 451 - delay_action_by=None, 392 + EntityLabelMutation( 393 + label_name='bar', 394 + reason_name='R1', 395 + status=LabelStatus.ADDED, 396 + pending=False, 397 + description='rule 1', 398 + features={}, 452 399 ), 453 - ExtendedEntityMutation( 454 - mutation=EntityMutation( 455 - label_name='uwu', 456 - reason_name='R1', 457 - status=LabelStatus.ADDED, 458 - pending=False, 459 - description='rule 1', 460 - features={}, 461 - ), 462 - delay_action_by=None, 400 + EntityLabelMutation( 401 + label_name='uwu', 402 + reason_name='R1', 403 + status=LabelStatus.ADDED, 404 + pending=False, 405 + description='rule 1', 406 + features={}, 463 407 ), 464 - ExtendedEntityMutation( 465 - mutation=EntityMutation( 466 - label_name='uwu', 467 - reason_name='R2', 468 - status=LabelStatus.ADDED, 469 - pending=False, 470 - description='rule 2', 471 - features={}, 472 - ), 473 - delay_action_by=None, 408 + EntityLabelMutation( 409 + label_name='uwu', 410 + reason_name='R2', 411 + status=LabelStatus.ADDED, 412 + pending=False, 413 + description='rule 2', 414 + features={}, 474 415 ), 475 416 ] 476 417 } ··· 523 464 524 465 525 466 def _sort_entity_mutations( 526 - effects: Mapping[EntityT[Any], Sequence[ExtendedEntityMutation]], 527 - ) -> Mapping[EntityT[Any], Sequence[ExtendedEntityMutation]]: 467 + effects: Mapping[EntityT[Any], Sequence[EntityLabelMutation]], 468 + ) -> Mapping[EntityT[Any], Sequence[EntityLabelMutation]]: 528 469 """Sorts entity mutations so that two sets of effects can be compared easily.""" 529 470 530 - def sort_key(mutation: ExtendedEntityMutation) -> tuple: 531 - # Create a sorting key from the ExtendedEntityMutation fields 532 - entity_mutation = mutation.mutation 533 - # Extract comparable values from the pb2 EntityMutation 534 - expires_at_key = ( 535 - entity_mutation.expires_at.seconds if entity_mutation.HasField('expires_at') else 0, 536 - entity_mutation.expires_at.nanos if entity_mutation.HasField('expires_at') else 0, 537 - ) 538 - features_key = tuple(sorted(entity_mutation.features.items())) 539 - # Use days, seconds, and microseconds for precise timedelta comparison 540 - delay_key = ( 541 - mutation.delay_action_by.days if mutation.delay_action_by is not None else 0, 542 - mutation.delay_action_by.seconds if mutation.delay_action_by is not None else 0, 543 - mutation.delay_action_by.microseconds if mutation.delay_action_by is not None else 0, 544 - ) 471 + def sort_key(mutation: EntityLabelMutation) -> tuple: 472 + # Create a sorting key from the EntityLabelMutation fields 473 + # Extract comparable values 474 + expires_at_key = (mutation.expires_at.timestamp() if mutation.expires_at is not None else 0,) 475 + features_key = tuple(sorted(mutation.features.items())) 545 476 546 477 return ( 547 - entity_mutation.label_name, 548 - entity_mutation.reason_name, 549 - entity_mutation.status, 550 - entity_mutation.pending, 551 - entity_mutation.description, 478 + mutation.label_name, 479 + mutation.reason_name, 480 + mutation.status, 481 + mutation.pending, 482 + mutation.description, 552 483 features_key, 553 484 expires_at_key, 554 - delay_key, 555 485 ) 556 486 557 487 return {entity: sorted(mutations, key=sort_key) for entity, mutations in effects.items()} 558 488 559 489 560 - def _to_simple_dict(label_effects: Mapping[EntityT[Any], Sequence[ExtendedEntityMutation]]) -> Dict[object, object]: 490 + def _to_simple_dict(label_effects: Mapping[EntityT[Any], Sequence[EntityLabelMutation]]) -> Dict[object, object]: 561 491 """Converts effects to bare dicts, so py.test can display them better in failure output!""" 562 492 563 - def entity_mutation_to_dict(mutation: ExtendedEntityMutation) -> Dict[str, Any]: 564 - # Convert ExtendedEntityMutation to a comparable dict 565 - entity_mutation = mutation.mutation 566 - expires_at_dict = None 567 - if entity_mutation.HasField('expires_at'): 568 - expires_at_dict = {'seconds': entity_mutation.expires_at.seconds, 'nanos': entity_mutation.expires_at.nanos} 493 + def entity_mutation_to_dict(mutation: EntityLabelMutation) -> Dict[str, Any]: 494 + # Convert EntityLabelMutation to a comparable dict 495 + expires_at_timestamp = mutation.expires_at.timestamp() if mutation.expires_at is not None else None 569 496 570 497 return { 571 - 'mutation': { 572 - 'label_name': entity_mutation.label_name, 573 - 'reason_name': entity_mutation.reason_name, 574 - 'status': entity_mutation.status, 575 - 'pending': entity_mutation.pending, 576 - 'description': entity_mutation.description, 577 - 'features': dict(entity_mutation.features), 578 - 'expires_at': expires_at_dict, 579 - }, 580 - 'delay_action_by': mutation.delay_action_by.total_seconds() 581 - if mutation.delay_action_by is not None 582 - else None, 498 + 'label_name': mutation.label_name, 499 + 'reason_name': mutation.reason_name, 500 + 'status': mutation.status, 501 + 'pending': mutation.pending, 502 + 'description': mutation.description, 503 + 'features': dict(mutation.features), 504 + 'expires_at': expires_at_timestamp, 583 505 } 584 506 585 507 return { ··· 588 510 589 511 590 512 def _compare_effects( 591 - actual: Mapping[EntityT[Any], Sequence[ExtendedEntityMutation]], 592 - expected: Mapping[EntityT[Any], Sequence[ExtendedEntityMutation]], 513 + actual: Mapping[EntityT[Any], Sequence[EntityLabelMutation]], 514 + expected: Mapping[EntityT[Any], Sequence[EntityLabelMutation]], 593 515 ) -> bool: 594 516 """Given the actual effects from classification, and the expected effects, compare them to make sure they 595 517 are equal."""
-30
osprey_worker/src/osprey/engine/utils/proto_utils.py
··· 1 - from datetime import datetime, timedelta 2 - 3 - from google.protobuf.duration_pb2 import Duration 4 - from google.protobuf.timestamp_pb2 import Timestamp 5 - 6 - 7 - def datetime_to_timestamp(dt: datetime) -> Timestamp: 8 - timestamp = Timestamp() 9 - timestamp.FromDatetime(dt) 10 - return timestamp 11 - 12 - 13 - def optional_datetime_to_timestamp(dt: datetime | None) -> Timestamp | None: 14 - if dt is None: 15 - return None 16 - 17 - return datetime_to_timestamp(dt) 18 - 19 - 20 - def timedelta_to_duration(td: timedelta) -> Duration: 21 - duration = Duration() 22 - duration.FromTimedelta(td) 23 - return duration 24 - 25 - 26 - def optional_timedelta_to_duration(td: timedelta | None) -> Duration | None: 27 - if td is None: 28 - return None 29 - 30 - return timedelta_to_duration(td)
+1 -1
osprey_worker/src/osprey/worker/_stdlibplugin/sink_register.py
··· 1 1 from typing import List, Sequence 2 2 3 3 from kafka import KafkaProducer 4 - from osprey.worker.adaptor.plugin_manager import hookimpl_osprey, bootstrap_execution_result_store 4 + from osprey.worker.adaptor.plugin_manager import bootstrap_execution_result_store, hookimpl_osprey 5 5 from osprey.worker.lib.config import Config 6 6 from osprey.worker.sinks.sink.kafka_output_sink import KafkaOutputSink 7 7 from osprey.worker.sinks.sink.output_sink import BaseOutputSink, StdoutOutputSink
+6 -4
osprey_worker/src/osprey/worker/adaptor/hookspecs/osprey_hooks.py
··· 8 8 from osprey.engine.udf.base import UDFBase 9 9 from osprey.worker.adaptor.constants import OSPREY_ADAPTOR 10 10 from osprey.worker.lib.action_proto_deserializer import ActionProtoDeserializer 11 - from osprey.worker.lib.storage.labels import LabelProvider 11 + from osprey.worker.lib.storage.labels import LabelsProvider, LabelsServiceBase 12 12 from osprey.worker.sinks.sink.input_stream import BaseInputStream 13 13 from osprey.worker.sinks.utils.acking_contexts import BaseAckingContext 14 14 ··· 56 56 57 57 58 58 @hookspec(firstresult=True) 59 - def register_label_provider(config: Config) -> LabelProvider: 60 - """Register an execution result storage backend instance.""" 61 - raise NotImplementedError('register_label_provider must be implemented by the plugin') 59 + def register_labels_service(config: Config) -> LabelsServiceBase | LabelsProvider: 60 + """Register a labels service backend. This can be achieved by implementing a labels service 61 + and utilizing the provided labels provider, or by overriding the labels provider to fit your 62 + business needs""" 63 + raise NotImplementedError('register_labels_service must be implemented by the plugin')
+19 -12
osprey_worker/src/osprey/worker/adaptor/plugin_manager.py
··· 12 12 from osprey.worker.adaptor.constants import OSPREY_ADAPTOR 13 13 from osprey.worker.adaptor.hookspecs import osprey_hooks 14 14 from osprey.worker.lib.action_proto_deserializer import ActionProtoDeserializer 15 - from osprey.worker.lib.storage.labels import LabelProvider 15 + from osprey.worker.lib.storage.labels import LabelsProvider, LabelsServiceBase 16 16 from osprey.worker.sinks.sink.input_stream import BaseInputStream 17 17 from osprey.worker.sinks.sink.output_sink import BaseOutputSink, LabelOutputSink, MultiOutputSink 18 18 from osprey.worker.sinks.utils.acking_contexts import BaseAckingContext ··· 39 39 return sum(seq, []) 40 40 41 41 42 - def _has_labels_provider() -> bool: 43 - return hasattr(plugin_manager.hook, 'register_labels_provider') 42 + def has_labels_service() -> bool: 43 + return hasattr(plugin_manager.hook, 'register_labels_service') 44 44 45 45 46 46 def bootstrap_udfs() -> tuple[UDFRegistry, UDFHelpers]: ··· 55 55 udf_helpers.set_udf_helper(udf, udf.create_provider()) 56 56 57 57 # Label udfs should only be registered if the labels provider is available 58 - if _has_labels_provider(): 58 + if has_labels_service(): 59 59 # Imports kinda circular. Imports here are to avoid that. 60 60 from osprey.engine.stdlib.udfs.labels import HasLabel, LabelAdd, LabelRemove 61 61 62 62 udfs.extend([HasLabel, LabelAdd, LabelRemove]) 63 63 64 - labels_provider = plugin_manager.hook.register_labels_provider() 64 + provider_or_service: LabelsProvider | LabelsServiceBase = plugin_manager.hook.register_labels_service() 65 + if isinstance(provider_or_service, LabelsProvider): 66 + labels_provider = provider_or_service 67 + else: 68 + labels_provider = LabelsProvider(labels_service=provider_or_service) 69 + 65 70 udf_helpers.set_udf_helper(HasLabel, labels_provider) 66 71 67 72 return udf_registry, udf_helpers ··· 72 77 sinks = flatten(plugin_manager.hook.register_output_sinks(config=config)) 73 78 74 79 # Label udfs should only be registered if the labels provider is available 75 - if _has_labels_provider(): 76 - sinks.append(LabelOutputSink(bootstrap_label_provider())) 80 + if has_labels_service(): 81 + sinks.append(LabelOutputSink(bootstrap_labels_provider())) 77 82 78 83 return MultiOutputSink(sinks) 79 84 80 85 81 - def bootstrap_label_provider() -> LabelProvider: 86 + def bootstrap_labels_provider() -> LabelsProvider: 82 87 """ 83 88 Generates a bootstrapped label provider using the registered plugin. 84 - Calling this is not necessary if you already called bootstrap_output_sinks, but is available for convenience. 85 89 """ 86 90 load_all_osprey_plugins() 87 - if not _has_labels_provider(): 88 - raise NotImplementedError('Label provider assumes register_labels_provider is implemented.') 89 - return plugin_manager.hook.register_labels_provider() 91 + if not has_labels_service(): 92 + raise NotImplementedError('Labels provider assumes register_labels_service is implemented.') 93 + provider_or_service: LabelsProvider | LabelsServiceBase = plugin_manager.hook.register_labels_service() 94 + if isinstance(provider_or_service, LabelsProvider): 95 + return provider_or_service 96 + return LabelsProvider(provider_or_service) 90 97 91 98 92 99 def bootstrap_ast_validators() -> None:
+35 -227
osprey_worker/src/osprey/worker/lib/cli.py
··· 2 2 import sys 3 3 from pathlib import Path # noqa: E402 4 4 5 + from osprey.engine.language_types.entities import EntityT 6 + from osprey.worker.adaptor.plugin_manager import bootstrap_labels_provider 7 + from osprey.worker.lib.osprey_shared.labels import EntityLabelMutation, LabelStatus 5 8 from osprey.worker.lib.patcher import patch_all # noqa: E402 6 9 7 10 patch_all() # please ensure this occurs before *any* other imports ! ··· 12 15 import logging # noqa: E402 13 16 import os # noqa: E402 14 17 import subprocess # noqa: E402 15 - import time # noqa: E402 16 - import uuid # noqa: E402 17 - from typing import TYPE_CHECKING, Any, Optional, Set, Union # noqa: E402 18 + from typing import Any, Optional, Set, Union # noqa: E402 18 19 19 20 import click # noqa: E402 20 21 from click import Context, Parameter, ParamType # noqa: E402 ··· 24 25 25 26 # Import safety record and common protos 26 27 from osprey.engine.ast.sources import Sources # noqa: E402 27 - from osprey.engine.executor.execution_context import ExtendedEntityMutation # noqa: E402 28 - from osprey.rpc.labels.v1.service_pb2 import ( # noqa: E402 29 - Entity, 30 - EntityKey, 31 - EntityMutation, 32 - LabelStatus, 33 - ) 34 - from osprey.worker.lib.osprey_engine import bootstrap_engine # noqa: E402 35 - from osprey.worker.lib.publisher import PubSubPublisher # noqa: E402 36 - from osprey.worker.lib.singletons import CONFIG # noqa: E402 37 28 from osprey.worker.lib.sources_publisher import ( # noqa: E402 38 29 upload_dependencies_mapping, 39 30 validate_and_push, ··· 42 33 access_audit_log, # noqa: E402 43 34 entity_label_webhook, 44 35 labels, 45 - postgres, 46 36 stored_execution_result, 47 37 ) 48 38 from osprey.worker.lib.utils.click_utils import EnumChoicePb2 # noqa: E402 49 - from osprey.worker.sinks.sink.output_sink import LabelOutputSink # noqa: E402 50 - from osprey.worker.sinks.sink.output_sink_utils.constants import MutationEventType # noqa: E402 51 - 52 - if TYPE_CHECKING: 53 - from osprey.rpc.labels.v1.service_pb2 import LabelStatusValue 54 39 55 40 56 41 @click.group() ··· 140 125 'access_audit_log': access_audit_log, 141 126 'entity_label_webhook': entity_label_webhook, 142 127 'stored_execution_result': stored_execution_result, 143 - 'EntityKey': EntityKey, 144 - 'Entity': Entity, 145 - 'EntityMutation': EntityMutation, 128 + 'EntityT': EntityT, 129 + # 'Entity': Entity, 130 + 'EntityLabelMutation': EntityLabelMutation, 146 131 'LabelStatus': LabelStatus, 147 132 } 148 133 ··· 207 192 code.InteractiveConsole(namespace).interact() 208 193 209 194 210 - @cli.command() 211 - @click.argument('entity_type') 212 - @click.argument('entity_id') 213 - @click.argument('label_name') 214 - @click.argument('label_status', type=EnumChoicePb2(LabelStatus)) 215 - @click.option( 216 - '--reason', 217 - help=( 218 - 'If specified, the reason the label is being applied.' 219 - ' Should be camel case, without spaces. Defaults to "CliLabelMutation".' 220 - ), 221 - ) 222 - @click.option( 223 - '--description', 224 - help=( 225 - 'If specified, the description for why the label is being applied.' 226 - ' Should be an English sentence. Defaults to "Manually changed from the command line for debugging."' 227 - ), 228 - ) 229 - @click.option( 230 - '--expire-instantly', 231 - default=False, 232 - help=('Boolean option to make the label expire instantly. Supplying False means the label does not expire.'), 233 - ) 234 - def apply_label_without_effects( 235 - entity_type: str, 236 - entity_id: str, 237 - label_name: str, 238 - label_status: 'LabelStatusValue', 239 - reason: Optional[str], 240 - description: Optional[str], 241 - expire_instantly: bool, 242 - ) -> None: 243 - """Manually apply a label to an entity. 244 - 245 - Mainly intended to be used for debugging purposes or importing lists of labels from external sources. Does *NOT* 246 - do anything with label effects (eg, does *NOT* send webhooks for changed labels). 247 - """ 248 - if expire_instantly: 249 - mutation = EntityMutation( 250 - label_name=label_name, 251 - reason_name=reason or 'CliLabelMutationWithoutEffects', 252 - status=label_status, 253 - description=description or 'Manually changed from the command line for debugging.', 254 - expires_at=(datetime.datetime.now() + datetime.timedelta(seconds=5)), 255 - ) 256 - else: 257 - mutation = EntityMutation( 258 - label_name=label_name, 259 - reason_name=reason or 'CliLabelMutationWithoutEffects', 260 - status=label_status, 261 - description=description or 'Manually changed from the command line for debugging.', 262 - ) 263 - 264 - result = labels.apply_entity_mutation(entity_key=EntityKey(type=entity_type, id=entity_id), mutations=[mutation]) 265 - 266 - print(result) 267 - 268 - 269 - def get_event_effects_output_sink() -> LabelOutputSink: 270 - config = CONFIG.instance() 271 - config.configure_from_env() 272 - 273 - postgres.init_from_config('osprey_db') 274 - engine = bootstrap_engine() 275 - analytics_pubsub_project_id = config.get_str('PUBSUB_DATA_PROJECT_ID', 'osprey-dev') 276 - analytics_pubsub_topic_id = config.get_str('PUBSUB_ANALYTICS_EVENT_TOPIC_ID', 'osprey-analytics') 277 - analytics_publisher = PubSubPublisher(analytics_pubsub_project_id, analytics_pubsub_topic_id) 278 - 279 - osprey_webhook_pubsub_project = config.get_str('PUBSUB_OSPREY_WEBHOOKS_PROJECT_ID', 'osprey-dev') 280 - osprey_webhook_pubsub_topic = config.get_str('PUBSUB_OSPREY_WEBHOOKS_TOPIC_ID', 'osprey-webhooks') 281 - webhooks_publisher = PubSubPublisher(osprey_webhook_pubsub_project, osprey_webhook_pubsub_topic) 282 - return LabelOutputSink(engine, analytics_publisher, webhooks_publisher) 283 - 284 - 285 - @cli.command() 286 - @click.argument('entity_type') 287 - @click.argument('entity_id') 288 - @click.argument('label_name') 289 - @click.argument('label_status', type=EnumChoicePb2(LabelStatus)) 290 - @click.option( 291 - '--reason', 292 - help=( 293 - 'If specified, the reason the label is being applied.' 294 - ' Should be camel case, without spaces. Defaults to "CliLabelMutationWithEffects".' 295 - ), 296 - ) 297 - @click.option( 298 - '--description', 299 - help=( 300 - 'If specified, the description for why the label is being applied.' 301 - ' Should be an English sentence. Defaults to "Manually changed from the command line for debugging."' 302 - ), 303 - ) 304 - @click.option( 305 - '--expire-instantly', 306 - default=False, 307 - help=('Boolean option to make the label expire instantly. Supplying False means the label does not expire.'), 308 - ) 309 - @click.option( 310 - '--delay-by', 311 - default=0, 312 - help=('Number of seconds to delay the action by. Defaults to instant.'), 313 - ) 314 - def apply_label_with_effects( 315 - entity_type: str, 316 - entity_id: str, 317 - label_name: str, 318 - label_status: 'LabelStatusValue', 319 - reason: Optional[str], 320 - description: Optional[str], 321 - expire_instantly: bool, 322 - delay_by: int, 323 - ) -> None: 324 - """Manually apply a label to an entity. 325 - 326 - This method applies label effects (eg, sends webhooks for changed labels). 327 - """ 328 - if expire_instantly: 329 - mutation = EntityMutation( 330 - label_name=label_name, 331 - reason_name=reason or 'CliLabelMutationWithEffects', 332 - status=label_status, 333 - description=description or 'Manually changed from the command line for debugging.', 334 - expires_at=(datetime.datetime.now() + datetime.timedelta(seconds=5)), 335 - ) 336 - else: 337 - mutation = EntityMutation( 338 - label_name=label_name, 339 - reason_name=reason or 'CliLabelMutationWithEffects', 340 - status=label_status, 341 - description=description or 'Manually changed from the command line for debugging.', 342 - ) 343 - 344 - if not delay_by: 345 - correctly_typed_delay_by = None 346 - else: 347 - correctly_typed_delay_by = datetime.timedelta(seconds=float(delay_by)) 348 - 349 - result = get_event_effects_output_sink().apply_label_mutations( 350 - mutation_event_type=MutationEventType.MANUAL_UPDATE, 351 - mutation_event_id=str(uuid.uuid4()), 352 - entity_key=EntityKey(type=entity_type, id=entity_id), 353 - mutations=[ExtendedEntityMutation(mutation=mutation, delay_action_by=correctly_typed_delay_by)], 354 - ) 355 - 356 - time.sleep(2) 357 - 358 - print(result) 359 - 360 - 361 195 def get_lines_from_file_as_set(file_path: str) -> Set[str]: 362 196 """ 363 197 Collects all lines from a file in an unordered set. ··· 375 209 376 210 @cli.command() 377 211 @click.argument('entity_type') 378 - @click.argument('entity_ids_file_path') 212 + @click.argument('entity_id') 379 213 @click.argument('label_name') 380 214 @click.argument('label_status', type=EnumChoicePb2(LabelStatus)) 381 215 @click.option( 382 216 '--reason', 383 217 help=( 384 218 'If specified, the reason the label is being applied.' 385 - ' Should be camel case, without spaces. Defaults to "CliLabelMutationWithoutEffects".' 219 + ' Should be camel case, without spaces. Defaults to "CliLabelMutation".' 386 220 ), 387 221 ) 388 222 @click.option( ··· 397 231 default=False, 398 232 help=('Boolean option to make the label expire instantly. Supplying False means the label does not expire.'), 399 233 ) 400 - def bulk_apply_label_without_effects( 234 + def apply_label( 401 235 entity_type: str, 402 - entity_ids_file_path: str, 236 + entity_id: str, 403 237 label_name: str, 404 - label_status: 'LabelStatusValue', 238 + label_status: LabelStatus, 405 239 reason: Optional[str], 406 240 description: Optional[str], 407 241 expire_instantly: bool, 408 242 ) -> None: 409 - """Manually apply a label to all entity IDs in the provided file at the file path. 243 + """Manually apply a label to an entity. 410 244 411 - Mainly intended to be used for debugging purposes or importing lists of labels from external sources. Does *NOT* 412 - do anything with label effects (eg, does *NOT* send webhooks for changed labels). 245 + Mainly intended to be used for debugging purposes or importing lists of labels from external sources. 413 246 """ 414 - entity_ids = get_lines_from_file_as_set(file_path=entity_ids_file_path) 415 - # I found that it *generally* took ~10ms per request; Multiply by 10.05 for 5% latency headroom 416 - expire_timestamp = datetime.datetime.now() + datetime.timedelta(milliseconds=int(len(entity_ids) * 10.05)) 417 - print(f'Found {len(entity_ids)} entity IDs to label.\nETA: {int(len(entity_ids) * 10.05 / 100)} second(s)') 418 247 if expire_instantly: 419 - mutation = EntityMutation( 248 + mutation = EntityLabelMutation( 420 249 label_name=label_name, 421 250 reason_name=reason or 'CliLabelMutationWithoutEffects', 422 251 status=label_status, 423 252 description=description or 'Manually changed from the command line for debugging.', 424 - expires_at=expire_timestamp, 253 + expires_at=(datetime.datetime.now() + datetime.timedelta(seconds=5)), 425 254 ) 426 255 else: 427 - mutation = EntityMutation( 256 + mutation = EntityLabelMutation( 428 257 label_name=label_name, 429 258 reason_name=reason or 'CliLabelMutationWithoutEffects', 430 259 status=label_status, 431 260 description=description or 'Manually changed from the command line for debugging.', 432 261 ) 433 262 434 - progress_tracker: CliCommandProgressTracker = CliCommandProgressTracker(total_actions=len(entity_ids)) 435 - for entity_id in entity_ids: 436 - labels.apply_entity_mutation( 437 - entity_key=EntityKey(type=entity_type, id=entity_id), 438 - mutations=[mutation], 439 - ) 440 - progress_tracker.increment() 263 + result = bootstrap_labels_provider().apply_entity_label_mutations( 264 + entity=EntityT(type=entity_type, id=entity_id), mutations=[mutation] 265 + ) 441 266 442 - print(f'Bulk labelling complete! Total labels applied: {progress_tracker.total_actions}') 267 + print(result) 443 268 444 269 445 270 @cli.command() ··· 451 276 '--reason', 452 277 help=( 453 278 'If specified, the reason the label is being applied.' 454 - ' Should be camel case, without spaces. Defaults to "CliLabelMutationWithEffects".' 279 + ' Should be camel case, without spaces. Defaults to "CliLabelMutationWithoutEffects".' 455 280 ), 456 281 ) 457 282 @click.option( ··· 466 291 default=False, 467 292 help=('Boolean option to make the label expire instantly. Supplying False means the label does not expire.'), 468 293 ) 469 - @click.option( 470 - '--delay-by', 471 - default=0, 472 - help=('Number of seconds to delay the action by. Defaults to instant.'), 473 - ) 474 - def bulk_apply_label_with_effects( 294 + def bulk_apply_label( 475 295 entity_type: str, 476 296 entity_ids_file_path: str, 477 297 label_name: str, 478 - label_status: 'LabelStatusValue', 298 + label_status: LabelStatus, 479 299 reason: Optional[str], 480 300 description: Optional[str], 481 301 expire_instantly: bool, 482 - delay_by: int, 483 302 ) -> None: 484 303 """Manually apply a label to all entity IDs in the provided file at the file path. 485 304 486 - This method applies label effects (eg, sends webhooks for changed labels). 305 + Mainly intended to be used for debugging purposes or importing lists of labels from external sources. 487 306 """ 488 307 entity_ids = get_lines_from_file_as_set(file_path=entity_ids_file_path) 489 - event_id = str(uuid.uuid4()) 490 - event_effects_output_sink = get_event_effects_output_sink() 491 308 # I found that it *generally* took ~10ms per request; Multiply by 10.05 for 5% latency headroom 492 309 expire_timestamp = datetime.datetime.now() + datetime.timedelta(milliseconds=int(len(entity_ids) * 10.05)) 493 - print(f'Found {len(entity_ids)} entity IDs to label. Proceeding with analytics event ID {event_id}.') 494 - print(f'ETA: {int(len(entity_ids) * 1.05 / 100)} second(s)') 310 + print(f'Found {len(entity_ids)} entity IDs to label.\nETA: {int(len(entity_ids) * 10.05 / 100)} second(s)') 495 311 if expire_instantly: 496 - mutation = EntityMutation( 312 + mutation = EntityLabelMutation( 497 313 label_name=label_name, 498 - reason_name=reason or 'CliLabelMutationWithEffects', 314 + reason_name=reason or 'CliLabelMutationWithoutEffects', 499 315 status=label_status, 500 316 description=description or 'Manually changed from the command line for debugging.', 501 317 expires_at=expire_timestamp, 502 318 ) 503 319 else: 504 - mutation = EntityMutation( 320 + mutation = EntityLabelMutation( 505 321 label_name=label_name, 506 - reason_name=reason or 'CliLabelMutationWithEffects', 322 + reason_name=reason or 'CliLabelMutationWithoutEffects', 507 323 status=label_status, 508 324 description=description or 'Manually changed from the command line for debugging.', 509 325 ) 510 326 511 - if not delay_by: 512 - correctly_typed_delay_by = None 513 - else: 514 - correctly_typed_delay_by = datetime.timedelta(seconds=float(delay_by)) 515 - 516 327 progress_tracker: CliCommandProgressTracker = CliCommandProgressTracker(total_actions=len(entity_ids)) 328 + provider = bootstrap_labels_provider() 517 329 for entity_id in entity_ids: 518 - event_effects_output_sink.apply_label_mutations( 519 - mutation_event_type=MutationEventType.MANUAL_UPDATE, 520 - mutation_event_id=event_id, 521 - entity_key=EntityKey(type=entity_type, id=entity_id), 522 - mutations=[ExtendedEntityMutation(mutation=mutation, delay_action_by=correctly_typed_delay_by)], 330 + provider.apply_entity_label_mutations( 331 + entity=EntityT(type=entity_type, id=entity_id), 332 + mutations=[mutation], 523 333 ) 524 334 progress_tracker.increment() 525 335 526 - time.sleep(2) 527 - 528 - print(f'Bulk labelling complete! Total labels applied: {progress_tracker.total_actions}\nEvent ID: {event_id}') 336 + print(f'Bulk labelling complete! Total labels applied: {progress_tracker.total_actions}') 529 337 530 338 531 339 class IpAddress(ParamType):
+2 -2
osprey_worker/src/osprey/worker/lib/osprey_shared/abilities.py
··· 3 3 from urllib.parse import urlencode 4 4 5 5 import requests 6 + from osprey.engine.language_types.entities import EntityT 6 7 from osprey.worker.lib.utils.flask_signing import Signer 7 - from osprey.worker.ui_api.osprey.validators.entities import EntityKey 8 8 from pydantic.main import BaseModel 9 9 from requests import ConnectionError, HTTPError, Timeout 10 10 from requests.models import ChunkedEncodingError ··· 40 40 osprey_ui_endpoint: str, 41 41 osprey_ui_api_endpoint: str, 42 42 creation_origin: str, 43 - entities: List[EntityKey], 43 + entities: List[EntityT[str]], 44 44 raise_on_error: bool = False, 45 45 entity_url_options: EntityUrlOptions = EntityUrlOptions(), 46 46 ) -> Optional[Sequence[str]]:
+276 -57
osprey_worker/src/osprey/worker/lib/osprey_shared/labels.py
··· 1 - from dataclasses import dataclass, field 2 - from datetime import datetime 3 - from enum import Enum 4 - from typing import TYPE_CHECKING, Dict, List, Mapping, Optional 1 + import copy 2 + from collections import UserDict 3 + from dataclasses import dataclass, field, replace 4 + from datetime import datetime, timedelta 5 + from enum import Enum, IntEnum 6 + from typing import Dict, Optional 5 7 6 - from osprey.engine.language_types.labels import LabelStatus 7 8 from osprey.worker.lib.osprey_shared.logging import get_logger 8 9 from osprey.worker.lib.utils.request_utils import SessionWithRetries 9 - from pydantic import BaseModel 10 - 11 - if TYPE_CHECKING: 12 - from osprey.worker.lib.utils.flask_signing import Signer 13 - 14 10 15 11 # The requests session we will be using to contact osprey API. 16 12 _session = SessionWithRetries() ··· 21 17 logger = get_logger(__name__) 22 18 23 19 20 + class MutationDropReason(IntEnum): 21 + # If a label mutation was dropped due to another mutation that conflicted & was higher priority 22 + # (priority of conflicting mutations in a given entity update is determined by the int value of the 23 + # label status enum) 24 + CONFLICTING_MUTATION = 0 25 + # If the existing label status was manual and the attempted mutation was not 26 + CANNOT_OVERRIDE_MANUAL = 1 27 + 28 + 29 + class LabelStatus(IntEnum): 30 + """ 31 + indicates the status of label. 32 + 33 + regular (a.k.a. "automatic") statuses are applied via rules. they can be overwritten by manual 34 + statuses, which can only be applied via humans using the ui. 35 + 36 + statuses have weights, which control which ones get dropped when conflicting statuses occur during 37 + a single attempted mutation; i.e., if an execution of the rules results in a label add and a label remove 38 + of the same entity/label pair. 39 + """ 40 + 41 + REMOVED = 0 42 + ADDED = 1 43 + MANUALLY_REMOVED = 2 44 + MANUALLY_ADDED = 3 45 + 46 + def effective_label_status(self) -> 'LabelStatus': 47 + """ 48 + Returns the effective status of the label, which is what the upstreams that are observing label 49 + status changes will see. Which is to say, the upstreams will currently not see if the label status was 50 + manually added or manually removed, just that it was added or removed. 51 + """ 52 + match self: 53 + case LabelStatus.ADDED | LabelStatus.MANUALLY_ADDED: 54 + return LabelStatus.ADDED 55 + case LabelStatus.REMOVED | LabelStatus.MANUALLY_REMOVED: 56 + return LabelStatus.REMOVED 57 + case _: 58 + raise NotImplementedError() 59 + 60 + def is_manual(self) -> bool: 61 + match self: 62 + case LabelStatus.MANUALLY_ADDED | LabelStatus.MANUALLY_REMOVED: 63 + return True 64 + case _: 65 + return False 66 + 67 + def is_automatic(self) -> bool: 68 + return not self.is_manual() 69 + 70 + 24 71 # If you change this also change osprey/osprey_engine/packages/osprey_stdlib/configs/labels_config.py 25 72 class LabelConnotation(Enum): 26 73 POSITIVE = 'positive' ··· 28 75 NEUTRAL = 'neutral' 29 76 30 77 31 - # Pydantic-compatible versions of pb2 types 32 78 @dataclass 33 79 class LabelReason: 80 + """ 81 + a label reason tells us why a label mutation was made, when it happened, and when it expires (if at all) 82 + """ 83 + 34 84 pending: bool = False 35 85 description: str = '' 36 - features: Dict[str, str] = field(default_factory=dict) 86 + """why the label was mutated""" 87 + features: dict[str, str] = field(default_factory=dict) 88 + """features are injected into the description as k/v's, similar to how fstrings work. for example, 89 + the {you} in 'hello {you}' would be substituted as 'person' with a feature dict of {'you': 'person'}""" 37 90 created_at: datetime | None = None 91 + """ 92 + when this reason was made 93 + """ 38 94 expires_at: datetime | None = None 95 + """marks when this label reason 'expires' 96 + 97 + if a LabelState.MANUALLY_REMOVED is applied with a reason that has a 1 day expiration, then 98 + for 1 day, the label cannot be applied via LabelState.ADDED. all LabelState.ADDED attempts will be dropped. 99 + 100 + if a given label state has multiple label reasons, all reasons would need to expire before the status/state 101 + is considered expired, too. 102 + """ 103 + 104 + def is_expired(self) -> bool: 105 + return bool(self.expires_at is not None and self.expires_at + timedelta(seconds=5) < datetime.now()) 106 + 107 + 108 + @dataclass 109 + class LabelReasons(UserDict[str, LabelReason]): 110 + """ 111 + the label reasons userdict allows us to add a helper function to the dict directly, while otherwise 112 + operating as a normal dict would~ 113 + """ 114 + 115 + def __init__(self, initial_data: dict[str, LabelReason] | None = None) -> None: 116 + super().__init__(initial_data) 117 + 118 + def insert_or_update(self, reason_name: str, reason: LabelReason) -> bool: 119 + """ 120 + returns true if the reason was able to be inserted or updated an existing reason; 121 + false if it was dropped due to being older than the current reason 122 + """ 123 + if reason_name not in self: 124 + self[reason_name] = reason 125 + return True 126 + 127 + current_reason = self[reason_name] 128 + if current_reason.created_at is None or reason.created_at is None: 129 + raise AssertionError( 130 + f'invariant: missing created_at on one of the following LabelReasons: {current_reason} {reason}' 131 + ) 132 + 133 + if current_reason.created_at > reason.created_at + timedelta(seconds=5): 134 + # the reason we are trying to append is older than the one currently at the reason_name key, 135 + # so we will discard it (5sec added to adjust for potential code exec time). 136 + return False 137 + 138 + self[reason_name] = replace( 139 + reason, 140 + # since the current reason is older by this point in the code, we want to preserve the original created_at timestamp 141 + created_at=current_reason.created_at, 142 + ) 143 + return True 144 + 145 + @classmethod 146 + def __get_validators__(cls): 147 + """Pydantic v1 validator""" 148 + yield cls.validate 149 + 150 + @classmethod 151 + def validate(cls, v): 152 + """Validate and convert to LabelReasons""" 153 + if isinstance(v, cls): 154 + return v 155 + if isinstance(v, dict): 156 + return cls(v) 157 + raise TypeError(f'LabelReasons expected dict or LabelReasons, got {type(v)}') 158 + 159 + def __repr__(self): 160 + return f'LabelReasons({self.data})' 39 161 40 162 41 163 @dataclass 42 164 class LabelStateInner: 43 165 status: LabelStatus 44 - reasons: Dict[str, LabelReason] 166 + reasons: LabelReasons 45 167 46 168 47 169 @dataclass 48 170 class LabelState: 49 171 status: LabelStatus 50 - reasons: Dict[str, LabelReason] 51 - previous_states: List[LabelStateInner] = field(default_factory=list) 172 + """statuses dictate the way the current state behaves; certain statuses have priority over others 173 + (see LabelStatus for more info)""" 174 + 175 + reasons: LabelReasons 176 + """ 177 + reasons are why this label state was applied; it is a dict because there may be multiple, 178 + with each reason being distinct based on it's reason name. 179 + 180 + reasons applied under the same name are merged (assuming the status has not changed), 181 + with precedence given to newer creaeted_at timestamps. 182 + """ 183 + 184 + previous_states: list[LabelStateInner] = field(default_factory=list) 185 + """the top-level label state also contains previous label states; we use an inner type 186 + because we don't need these prior states to have the previous_states field""" 187 + 188 + @property 189 + def expires_at(self) -> datetime | None: 190 + """ 191 + when a given label state is effectively expired. expiration can only occur if all of the 192 + reasons are expired. 52 193 194 + this field is a convenience value to save users time on computing the effective expiration time from the reasons. 53 195 54 - @dataclass 55 - class Labels: 56 - labels: Dict[str, LabelState] = field(default_factory=dict) 57 - expires_at: Optional[datetime] = None 196 + expiration defines when future label states can be applied. if the current label state is not expired, 197 + then then upon a new label state change attempt, the current and new statuses have their weights' compared. 198 + whichever has the higher weight will take precedence, and the lower weight(s) will be dropped. 199 + if the weights are the *same*, then a merge of reasons is performed, which can also cause the expiration to be delayed. 200 + """ 201 + if not self.reasons: 202 + AssertionError(f'invariant: the label state {self} did not have any associated reasons') 203 + expires_at = datetime.min 204 + for reason in self.reasons.values(): 205 + if reason.expires_at is None: 206 + return None 207 + expires_at = max(reason.expires_at, expires_at) 208 + return expires_at 58 209 210 + @classmethod 211 + def from_inner(cls, inner: LabelStateInner) -> 'LabelState': 212 + return cls( 213 + status=inner.status, 214 + reasons=inner.reasons, 215 + ) 59 216 60 - class LabelsAndConnotationsResponse(BaseModel): 61 - labels: Labels 62 - label_connotations: Mapping[str, LabelConnotation] 217 + def is_expired(self) -> bool: 218 + return bool(self.expires_at is not None and self.expires_at + timedelta(seconds=5) < datetime.now()) 219 + 220 + def _shift_current_state_to_previous_state(self) -> None: 221 + if not self.reasons: 222 + # to make this function idempotent, we don't want to shift an empty state to the previous state. 223 + # we should always have reasons to shift 224 + return 225 + self.previous_states.insert(0, LabelStateInner(status=self.status, reasons=copy.deepcopy(self.reasons))) 226 + self.reasons = LabelReasons() 227 + 228 + def try_apply_desired_state(self, desired_state: LabelStateInner) -> MutationDropReason | None: 229 + """ 230 + attempts to apply the desired state to this state. 231 + if the state could not be applied (i.e. due to an unexpired manual status blocking 232 + a status change to an automatic status), this method will return the MutationDropReason that 233 + should be applied to the responsible mutations. otherwise, it will return None to indicate success 234 + """ 235 + if self.is_expired(): 236 + self._shift_current_state_to_previous_state() 237 + self.status = desired_state.status 238 + self.reasons = desired_state.reasons 239 + return None 63 240 241 + # if the current status is manual, we will drop automatic statuses (unless the current state is expired) 242 + if self.status.is_manual() and desired_state.status.is_automatic(): 243 + return MutationDropReason.CANNOT_OVERRIDE_MANUAL 64 244 65 - def get_labels_for_entity( 66 - endpoint: str, signer: 'Signer', entity_type: str, entity_id: str 67 - ) -> LabelsAndConnotationsResponse: 68 - url = f'{endpoint}entity/{entity_type}/{entity_id}/labels' 69 - headers = signer.sign_url(url) 70 - raw_resp = _session.get(url, headers=headers, timeout=_REQUEST_TIMEOUT_SECS) 71 - logger.info(f'[get_labels_for_entity] status code is {raw_resp.status_code}') 72 - raw_resp.raise_for_status() 73 - return LabelsAndConnotationsResponse.parse_obj(raw_resp.json()) 245 + # if the statuses are different and we've made it this far, the desired state is allowed to overwrite 246 + # the current state. so lets do that by shifting to previous state and updating 247 + if self.status != desired_state.status: 248 + self._shift_current_state_to_previous_state() 249 + self.status = desired_state.status 74 250 251 + for reason_name, reason in desired_state.reasons.items(): 252 + self.reasons.insert_or_update(reason_name, reason) 75 253 76 - class EntityLabelDisagreeRequest(BaseModel): 77 - label_name: str 78 - description: str 79 - admin_email: str 80 - expires_at: Optional[datetime] 254 + return None 81 255 82 256 83 257 @dataclass 84 - class EntityMutation: 258 + class EntityLabels: 259 + """this class represents a given entity's current labels & label states""" 260 + 261 + labels: Dict[str, LabelState] = field(default_factory=dict) 262 + """a mapping of label names to their current states'""" 263 + 264 + 265 + @dataclass 266 + class EntityLabelMutation: 267 + """ 268 + a class that allows callers of LabelsProvider.apply_entity_label_mutations() to request how an 269 + entity's labels should be mutated. 270 + 271 + mutations are not guaranteed to be written to the labels provider. see EntityLabelMutationsResult.dropped 272 + """ 273 + 85 274 label_name: str = '' 86 275 reason_name: str = '' 87 276 status: LabelStatus = LabelStatus.ADDED 88 277 pending: bool = False 89 278 description: str = '' 90 - features: Dict[str, str] = field(default_factory=dict) 91 - expires_at: Optional[datetime] = None 279 + features: dict[str, str] = field(default_factory=dict) 280 + expires_at: datetime | None = None 281 + 282 + def desired_state(self) -> LabelStateInner: 283 + return LabelStateInner( 284 + status=self.status, 285 + reasons=LabelReasons({self.reason_name: self.reason}), 286 + ) 287 + 288 + @property 289 + def reason(self) -> LabelReason: 290 + return LabelReason( 291 + pending=self.pending, 292 + description=self.description, 293 + features=self.features, 294 + created_at=datetime.now(), 295 + expires_at=self.expires_at, 296 + ) 92 297 93 298 94 299 @dataclass 95 - class ApplyEntityMutationReply: 96 - added: List[str] = field(default_factory=list) 97 - removed: List[str] = field(default_factory=list) 98 - unchanged: List[str] = field(default_factory=list) 99 - dropped: List[EntityMutation] = field(default_factory=list) 300 + class DroppedEntityLabelMutation: 301 + mutation: EntityLabelMutation 302 + reason: MutationDropReason 100 303 101 304 102 - class EntityLabelDisagreeResponse(BaseModel): 103 - mutation_result: ApplyEntityMutationReply 104 - labels: Dict[str, LabelState] 105 - expires_at: Optional[datetime] 305 + @dataclass 306 + class EntityLabelMutationsResult: 307 + new_entity_labels: EntityLabels 308 + """ 309 + all of the entity's labels post-mutation 310 + """ 311 + 312 + old_entity_labels: Optional[EntityLabels] = None 313 + """ 314 + all of the entity's labels pre-mutation 315 + """ 106 316 317 + labels_added: list[str] = field(default_factory=list) 318 + """ 319 + all (effective-status) label adds that occurred during this mutation 320 + """ 107 321 108 - def disagree_wth_label( 109 - endpoint: str, signer: 'Signer', entity_type: str, entity_id: str, label_disagreement: EntityLabelDisagreeRequest 110 - ) -> EntityLabelDisagreeResponse: 111 - url = f'{endpoint}entity/{entity_type}/{entity_id}/labels/disagree' 322 + labels_removed: list[str] = field(default_factory=list) 323 + """ 324 + all (effective-status) label removes that occurred during this mutation 325 + """ 112 326 113 - label_disagreement_bytes = label_disagreement.json().encode() 114 - headers = signer.sign(label_disagreement_bytes) 327 + labels_updated: list[str] = field(default_factory=list) 328 + """ 329 + labels that had their state updated. this can include simply updating or 330 + appending to the reason 331 + """ 115 332 116 - raw_resp = _session.post(url, headers=headers, data=label_disagreement_bytes, timeout=_REQUEST_TIMEOUT_SECS) 117 - raw_resp.raise_for_status() 118 - return EntityLabelDisagreeResponse.parse_obj(raw_resp.json()) 333 + dropped_mutations: list[DroppedEntityLabelMutation] = field(default_factory=list) 334 + """ 335 + mutations that were dropped for one reason or another. each dropped mutation is 336 + given a drop reason 337 + """
-159
osprey_worker/src/osprey/worker/lib/osprey_shared/webhooks.py
··· 1 - from collections import defaultdict 2 - from datetime import datetime 3 - from enum import Enum 4 - from typing import Callable, Dict, List, Optional, Sequence, Union 5 - 6 - import requests 7 - from osprey.worker.lib.instruments import metrics 8 - from osprey.worker.lib.osprey_shared.labels import LabelState, LabelStatus 9 - from osprey.worker.lib.osprey_shared.logging import get_logger 10 - from pydantic import BaseModel 11 - from typing_extensions import Protocol 12 - 13 - logger = get_logger(__name__) 14 - 15 - 16 - class EntityLabelUpdateNotification(BaseModel): 17 - entity_type: str 18 - entity_id: str 19 - label_name: str 20 - label_state: LabelState 21 - expires_at: Optional[datetime] 22 - time: datetime 23 - features: Dict[str, Union[int, str, float, bool, None]] = {} 24 - 25 - @property 26 - def is_label_remove(self) -> bool: 27 - return self.label_state.status.value in ( 28 - LabelStatus.REMOVED, 29 - LabelStatus.MANUALLY_REMOVED, 30 - ) 31 - 32 - @property 33 - def is_label_addition(self) -> bool: 34 - return self.label_state.status.value in (LabelStatus.ADDED, LabelStatus.MANUALLY_ADDED) 35 - 36 - @property 37 - def is_manual_mutation(self) -> bool: 38 - return self.label_state.status.value in ( 39 - LabelStatus.MANUALLY_ADDED, 40 - LabelStatus.MANUALLY_REMOVED, 41 - ) 42 - 43 - def formatted_reasons(self) -> str: 44 - return ', '.join(sorted(self.label_state.reasons.keys())) 45 - 46 - 47 - _HandlerCallback = Callable[[EntityLabelUpdateNotification], None] 48 - 49 - 50 - class _HandlerDispatch(Protocol): 51 - def __init__(self, callback: _HandlerCallback, labels: Sequence[str]): 52 - pass 53 - 54 - def dispatch(self, payload: EntityLabelUpdateNotification) -> None: 55 - pass 56 - 57 - 58 - STATS_ROOT: str = 'osprey_shared.webhooks' 59 - 60 - 61 - def tags_from_dict(tag_dict: Dict[str, str]) -> List[str]: 62 - return [f'{k}:{v}' for k, v in tag_dict.items()] 63 - 64 - 65 - class _InstrumentedDispatch(_HandlerDispatch): 66 - def __init__(self, callback: _HandlerCallback, labels: Sequence[str]): 67 - self._callback = callback 68 - self._callback_name = getattr(callback, '__name__', 'Unknown') 69 - self._labels = labels 70 - 71 - def dispatch(self, payload: EntityLabelUpdateNotification) -> None: 72 - tags = tags_from_dict( 73 - { 74 - 'label_name': payload.label_name, 75 - 'entity_type': payload.entity_type, 76 - 'callback_name': self._callback_name, 77 - } 78 - ) 79 - if not self._labels or payload.label_name in self._labels: 80 - metrics.increment(f'{STATS_ROOT}.attempt', tags=tags) 81 - try: 82 - self._callback(payload) 83 - except Exception as e: 84 - logger.error( 85 - f'Error processing webhook callback for {self._callback_name} for label operation ' 86 - f'{payload.label_state} {payload.label_name} on {payload.entity_type} {payload.entity_id}: {e}', 87 - exc_info=True, 88 - ) 89 - metrics.histogram(f'{STATS_ROOT}.availability', value=0, tags=tags) 90 - metrics.increment(f'{STATS_ROOT}.failure', tags=tags) 91 - raise e 92 - else: 93 - metrics.histogram(f'{STATS_ROOT}.availability', value=1, tags=tags) 94 - metrics.increment(f'{STATS_ROOT}.success', tags=tags) 95 - else: 96 - metrics.increment(f'{STATS_ROOT}.skip', tags=tags) 97 - 98 - 99 - class OspreyCallbackRunWhen(str, Enum): 100 - INTERVENTIONS_CLIENT_ENABLED = 'INTERVENTIONS_CLIENT_ENABLED' 101 - INTERVENTIONS_CLIENT_DISABLED = 'INTERVENTIONS_CLIENT_DISABLED' 102 - ALWAYS = 'ALWAYS' 103 - 104 - 105 - class OspreyWebhookRouter: 106 - def __init__(self, is_interventions_client_enabled: Callable[[], bool]) -> None: 107 - # List of legacy osprey callback handlers, pre interventions client 108 - self._handlers_by_entity_type: Dict[str, List[_HandlerDispatch]] = defaultdict(list) 109 - # Interventions client handlers, that will be used when we turn the interventions feature flag on 110 - # After migrating to the intervention client handlers, we go back to only a single list of handlers 111 - # TODO: After Feature flag dialup, cleanup extra handler list 112 - # https://app.asana.com/0/1202424124203663/1203913333388469/f 113 - self._interventions_client_handlers_by_entity_type: Dict[str, List[_HandlerDispatch]] = defaultdict(list) 114 - self._is_interventions_client_enabled = is_interventions_client_enabled 115 - 116 - def register( 117 - self, 118 - entity_type: str, 119 - labels: Sequence[str] = tuple(), 120 - run_when: OspreyCallbackRunWhen = OspreyCallbackRunWhen.ALWAYS, 121 - ) -> Callable[[_HandlerCallback], _HandlerCallback]: 122 - """Registers a handler to be called when a Osprey webhook is received. 123 - 124 - Can give a set of labels to limit to, or if no labels are given the handler is invoked for all webhooks for 125 - the given entity type. 126 - """ 127 - 128 - def decorator(handler: _HandlerCallback) -> _HandlerCallback: 129 - dispatcher = _InstrumentedDispatch(handler, labels) 130 - 131 - if run_when == OspreyCallbackRunWhen.INTERVENTIONS_CLIENT_ENABLED: 132 - self._interventions_client_handlers_by_entity_type[entity_type].append(dispatcher) 133 - elif run_when == OspreyCallbackRunWhen.ALWAYS: 134 - # Some handlers will not be affected by the cut between interventions client 135 - # These handlers will be added to both maps so they get run regardless of which mode we are in 136 - self._interventions_client_handlers_by_entity_type[entity_type].append(dispatcher) 137 - self._handlers_by_entity_type[entity_type].append(dispatcher) 138 - else: 139 - self._handlers_by_entity_type[entity_type].append(dispatcher) 140 - 141 - return handler 142 - 143 - return decorator 144 - 145 - def call_handlers(self, payload: EntityLabelUpdateNotification) -> None: 146 - """Calls the handlers for a given incoming webhook.""" 147 - 148 - if self._is_interventions_client_enabled(): 149 - for handler in self._interventions_client_handlers_by_entity_type.get(payload.entity_type, []): 150 - handler.dispatch(payload) 151 - else: 152 - for handler in self._handlers_by_entity_type.get(payload.entity_type, []): 153 - handler.dispatch(payload) 154 - 155 - 156 - def get_osprey_public_keys_by_id(endpoint: str) -> Dict[str, str]: 157 - response = requests.get(endpoint + 'keys') 158 - response.raise_for_status() 159 - return response.json()
-103
osprey_worker/src/osprey/worker/lib/storage/entity_label_webhook.py
··· 1 - from __future__ import absolute_import 2 - 3 - import logging 4 - from random import random 5 - from typing import Optional 6 - 7 - from osprey.worker.sinks.sink.output_sink_utils.models import LabelStatus 8 - from sqlalchemy import BigInteger, Column, DateTime, Integer, Text, and_, func, or_ 9 - from sqlalchemy.dialects.postgresql import INTERVAL, JSONB 10 - 11 - from ..webhooks import WebhookStatus 12 - from .postgres import Model, scoped_session 13 - from .types import Enum 14 - 15 - BASE_DELAY_SECONDS = 60 16 - MAX_ATTEMPTS = 3 # update table index in osprey/osprey_lib/schemas/osprey.sql if we change this value 17 - logger = logging.getLogger(__name__) 18 - 19 - 20 - class EntityLabelWebhook(Model): 21 - __tablename__ = 'entity_label_webhooks' 22 - 23 - id = Column(BigInteger, primary_key=True, autoincrement=True) 24 - 25 - entity_type = Column(Text, nullable=False) 26 - entity_id = Column(Text, nullable=False) 27 - label_name = Column(Text, nullable=False) 28 - label_status = Column(Enum(LabelStatus, name='label_status', create_type=False), nullable=False) 29 - webhook_name = Column(Text, nullable=False) 30 - arguments = Column(JSONB) 31 - features = Column(JSONB) 32 - status = Column(Enum(WebhookStatus, name='webhook_status', create_type=False)) 33 - claim_until = Column(DateTime(timezone=True)) 34 - result = Column(Text) 35 - attempts = Column(Integer, nullable=False, default=0) 36 - created_at = Column(DateTime(timezone=True), nullable=False) 37 - updated_at = Column(DateTime(timezone=True), nullable=False) 38 - 39 - @classmethod 40 - def claim(cls) -> Optional['EntityLabelWebhook']: 41 - """Claim one webhook to send. 42 - 43 - The claim duration is also used as the retry cooldown, since that's already longer than we expect sending a 44 - webhook to take. That way, if the process totally dies, the webhook can still be retried at the correct 45 - interval. 46 - """ 47 - table = cls.__table__ 48 - jitter_percent = 1 + random() 49 - lock_seconds = BASE_DELAY_SECONDS * func.power(2, table.c.attempts) * jitter_percent 50 - 51 - # Selects the oldest claimable row's id in a subquery, because UPDATE doesn't support ORDER BY. 52 - # - oldest is based on claim_until (which is initially set to created_at, or some other time if it's a 53 - # delayed action). 54 - # - claimable means: 55 - # - the claim has expired 56 - # - status is one of the non-final statuses 57 - # - it hasn't already been attempted too many times 58 - order_subq = ( 59 - table.select() 60 - .with_only_columns([table.c.id]) 61 - .where( 62 - and_( 63 - table.c.claim_until < func.now(), 64 - or_(*(table.c.status == status for status in WebhookStatus.non_final_statuses())), 65 - table.c.attempts < MAX_ATTEMPTS, 66 - ) 67 - ) 68 - .with_for_update(skip_locked=True) 69 - .order_by(table.c.claim_until) 70 - .limit(1) 71 - .alias('order_subq') 72 - ) 73 - 74 - query = ( 75 - table.update() 76 - .where(table.c.id.in_(order_subq)) 77 - .values( 78 - claim_until=func.now() + func.cast(func.concat(lock_seconds, ' SECONDS'), INTERVAL), 79 - attempts=table.c.attempts + 1, 80 - status=WebhookStatus.RUNNING, 81 - updated_at=func.now(), 82 - ) 83 - .returning(table) 84 - ) 85 - 86 - with scoped_session(commit=True) as session: 87 - cursor = session.execute(query) 88 - # We need to construct the ORM object in a way SQLAlchemy approves of so it can track state under the 89 - # hood correctly (namely know that this object represents an existing row in the database). 90 - rows = list(session.query(cls).instances(cursor)) 91 - if len(rows) == 0: 92 - return None 93 - # We should only ever match up to one row 94 - (row,) = rows 95 - assert isinstance(row, EntityLabelWebhook) 96 - return row 97 - 98 - def release(self, status: WebhookStatus, result: str) -> None: 99 - with scoped_session(commit=True) as session: 100 - session.add(self) 101 - self.status = status 102 - self.result = result 103 - self.updated_at = func.now()
+218 -12
osprey_worker/src/osprey/worker/lib/storage/labels.py
··· 1 + import copy 1 2 from abc import ABC, abstractmethod 3 + from collections import defaultdict 4 + from contextlib import contextmanager 2 5 from datetime import timedelta 3 - from typing import Any, List, Optional, Sequence 6 + from typing import Any, Generator, Optional, Sequence 4 7 5 8 from osprey.engine.executor.external_service_utils import ExternalService 6 9 from osprey.engine.language_types.entities import EntityT 7 - from osprey.worker.lib.osprey_shared.labels import ApplyEntityMutationReply, EntityMutation, Labels 8 - from result import Result 10 + from osprey.worker.lib.osprey_shared.labels import ( 11 + DroppedEntityLabelMutation, 12 + EntityLabelMutation, 13 + EntityLabelMutationsResult, 14 + EntityLabels, 15 + LabelState, 16 + LabelStateInner, 17 + LabelStatus, 18 + MutationDropReason, 19 + ) 20 + from osprey.worker.lib.osprey_shared.logging import get_logger 21 + from result import Err, Ok, Result 22 + from tenacity import retry, stop_after_attempt, wait_exponential 9 23 24 + logger = get_logger(__name__) 10 25 11 - class LabelProvider(ExternalService[EntityT[Any], Labels], ABC): 12 - def cache_ttl(self) -> Optional[timedelta]: 13 - return timedelta(minutes=5) 14 26 27 + class LabelsServiceBase(ABC): 15 28 @abstractmethod 16 - def get_from_service(self, key: EntityT[Any]) -> Labels: 29 + def write_labels(self, entity: EntityT[Any], labels: EntityLabels) -> None: 30 + """ 31 + A standard write to the labels service that attempts to write the value to the primary key. 32 + 33 + This method may be retried upon exceptions, so keep that in mind when adding potentially 34 + non-idempotent behaviour. 35 + """ 17 36 raise NotImplementedError() 18 37 19 38 @abstractmethod 20 - def batch_get_from_service(self, keys: Sequence[EntityT[Any]]) -> Sequence[Result[Labels, Exception]]: 39 + def read_labels(self, entity: EntityT[Any]) -> EntityLabels: 40 + """ 41 + A standard read from the labels service. Keep in mind that if there is a cache_ttl greater than 0 seconds, 42 + this method will not be called for every single label read. 43 + 44 + This method may be retried upon exceptions, so keep that in mind when adding potentially 45 + non-idempotent behaviour. 46 + """ 21 47 raise NotImplementedError() 22 48 49 + def batch_read_labels(self, entities: Sequence[EntityT[Any]]) -> Sequence[Result[EntityLabels, Exception]]: 50 + """ 51 + Batching can optimize the number of RPCs that are sent out during executions, 52 + which has been observed to provide noticeable performance benefits in a python/gevent world. 53 + 54 + The order that the entieties are supplied in the incoming sequence will match the order the results are returned. 55 + 56 + By default, this will just call read_labels in a for-loop, but it is encouraged to implemenent your own batch 57 + endpoints and logic for the aforementioned performance benefits. 58 + """ 59 + results: list[Result[EntityLabels, Exception]] = [] 60 + for entity in entities: 61 + result: Result[EntityLabels, Exception] = Err( 62 + Exception('invariant: label could not be retrieved but no error was caught') 63 + ) 64 + try: 65 + result = Ok(self.read_labels(entity)) 66 + except Exception as e: 67 + result = Err(e) 68 + finally: 69 + results.append(result) 70 + return results 71 + 23 72 @abstractmethod 24 - def apply_entity_mutation( 25 - self, entity_key: EntityT[Any], mutations: List[EntityMutation] 26 - ) -> ApplyEntityMutationReply: 27 - raise NotImplementedError() 73 + @contextmanager 74 + def get_labels_atomically(self, entity: EntityT[Any]) -> Generator[EntityLabels, None, None]: 75 + """ 76 + Context manager for atomic read-modify-write operations. 77 + Implementations should ensure the entity key is locked/in a transaction. 78 + """ 79 + pass 80 + 81 + 82 + class LabelsProvider(ExternalService[EntityT[Any], EntityLabels]): 83 + def __init__(self, labels_service: LabelsServiceBase): 84 + self._labels_service = labels_service 85 + 86 + def _get_mutations_by_label_name_and_drop_conflicts( 87 + self, mutations: Sequence[EntityLabelMutation] 88 + ) -> tuple[dict[str, list[EntityLabelMutation]], list[DroppedEntityLabelMutation]]: 89 + """ 90 + collect mutations based on the value of their status. this means if a higher status and a lower status label mutation 91 + occur in the same mutations request, the lower status one(s) will be dropped. 92 + 93 + by the end of this method, the returned mutations will all be of the same label status for a given label. 94 + """ 95 + # first, we collect all of the highest status mutations per label name. we collect a list because 96 + # same status mutations will need to be merged into a single label state later to represent all 97 + # applicable mutation reasons 98 + mutations_by_label_name: dict[str, list[EntityLabelMutation]] = defaultdict(list) 99 + dropped_mutations: list[DroppedEntityLabelMutation] = [] 100 + for mutation in mutations: 101 + label_name = mutation.label_name 102 + if label_name in mutations_by_label_name: 103 + other_mutation = mutations_by_label_name[label_name][0] 104 + if mutation.status.value > other_mutation.status.value: 105 + for mut in mutations_by_label_name[label_name]: 106 + # we may have a list of more than one mutation if the statuses are all the same 107 + dropped_mutations.append( 108 + DroppedEntityLabelMutation(mutation=mut, reason=MutationDropReason.CONFLICTING_MUTATION) 109 + ) 110 + mutations_by_label_name[label_name] = [mutation] 111 + continue 112 + elif mutation.status.value < other_mutation.status.value: 113 + dropped_mutations.append( 114 + DroppedEntityLabelMutation(mutation=mutation, reason=MutationDropReason.CONFLICTING_MUTATION) 115 + ) 116 + continue 117 + # if the status weights are equal or if there is no previous statuses, append 118 + mutations_by_label_name[label_name].append(mutation) 119 + 120 + return (mutations_by_label_name, dropped_mutations) 121 + 122 + def _get_desired_states_by_label_name( 123 + self, mutations_by_label_name: dict[str, list[EntityLabelMutation]] 124 + ) -> dict[str, LabelStateInner]: 125 + """ 126 + given a dict of label names to entity label mutations, return the desired states that the mutations are seeking. 127 + if there is more than one mutation for a given label, the resulting state should contain a merge of the mutation reasons. 128 + """ 129 + desired_states_by_label_name: dict[str, LabelStateInner] = dict() 130 + 131 + for label_name, mutations in mutations_by_label_name.items(): 132 + assert len(mutations) > 0, 'invariant: mutations by label name should not be empty' 133 + assert len({mutation.status for mutation in mutations}) == 1, ( 134 + 'invariant: more than one unique label status AFTER dropping conflicts' 135 + ) 136 + desired_state = mutations[0].desired_state() 137 + for i in range(1, len(mutations)): 138 + desired_state.reasons.insert_or_update(mutations[i].reason_name, mutations[i].reason) 139 + desired_states_by_label_name[label_name] = desired_state 140 + 141 + return desired_states_by_label_name 142 + 143 + def _compute_new_labels_from_mutations( 144 + self, old_labels: EntityLabels, mutations: Sequence[EntityLabelMutation] 145 + ) -> EntityLabelMutationsResult: 146 + (mutations_by_label_name, dropped_mutations) = self._get_mutations_by_label_name_and_drop_conflicts(mutations) 147 + desired_states_by_label_name: dict[str, LabelStateInner] = self._get_desired_states_by_label_name( 148 + mutations_by_label_name 149 + ) 150 + 151 + # lets take desired states and try to apply them to the entity labels. 152 + # for end-user convenience, we also track if labels are added, removed, updated, or if mutations are dropped entirely 153 + added: list[str] = [] 154 + removed: list[str] = [] 155 + updated: list[str] = [] 156 + new_labels = copy.deepcopy(old_labels) 157 + for label_name, desired_state in desired_states_by_label_name.items(): 158 + if label_name not in new_labels.labels: 159 + new_labels.labels[label_name] = LabelState.from_inner(desired_state) 160 + added.append(label_name) 161 + continue 162 + current_state = new_labels.labels[label_name] 163 + prev_status = current_state.status 164 + drop_reason = current_state.try_apply_desired_state(desired_state) 165 + if drop_reason: 166 + # if the current state rejected the desired state, we will drop the mutation(s) with the provided drop reason 167 + for mutation in mutations_by_label_name[label_name]: 168 + dropped_mutations.append(DroppedEntityLabelMutation(mutation=mutation, reason=drop_reason)) 169 + continue 170 + # otherwise, let's compare the new status so we can add data to the EntityLabelMutationsResult c: 171 + new_status = current_state.status 172 + if prev_status == new_status: 173 + updated.append(label_name) 174 + continue 175 + match new_status.effective_label_status(): 176 + case LabelStatus.ADDED: 177 + added.append(label_name) 178 + continue 179 + case LabelStatus.REMOVED: 180 + removed.append(label_name) 181 + continue 182 + 183 + # finally, return the result! duhh :D 184 + return EntityLabelMutationsResult( 185 + new_entity_labels=new_labels, 186 + old_entity_labels=old_labels, 187 + labels_added=added, 188 + labels_removed=removed, 189 + labels_updated=updated, 190 + dropped_mutations=dropped_mutations, 191 + ) 192 + 193 + @retry(wait=wait_exponential(min=0.5, max=5), stop=stop_after_attempt(3)) 194 + def apply_entity_label_mutations_with_retry( 195 + self, entity: EntityT[Any], mutations: Sequence[EntityLabelMutation] 196 + ) -> EntityLabelMutationsResult: 197 + return self.apply_entity_label_mutations(entity=entity, mutations=mutations) 198 + 199 + def apply_entity_label_mutations( 200 + self, entity: EntityT[Any], mutations: Sequence[EntityLabelMutation] 201 + ) -> EntityLabelMutationsResult: 202 + with self._labels_service.get_labels_atomically(entity) as old_labels: 203 + result = self._compute_new_labels_from_mutations(old_labels, mutations) 204 + 205 + self._labels_service.write_labels(entity, result.new_entity_labels) 206 + 207 + return result 208 + 209 + def cache_ttl(self) -> Optional[timedelta]: 210 + return timedelta(minutes=1) 211 + 212 + def get_from_service(self, key: EntityT[Any]) -> EntityLabels: 213 + return self._labels_service.read_labels(entity=key) 214 + 215 + def batch_get_from_service(self, keys: Sequence[EntityT[Any]]) -> Sequence[Result[EntityLabels, Exception]]: 216 + """ 217 + Note: By default, the labels service batch_read_labels calls read_labels in a for loop. 218 + This is because the HasLabel UDF is batchable and requires batch support on the 219 + provider. 220 + 221 + If you would like to reap the performance benefits of batching, please re-implement 222 + the batch_read_labels to call a proper batch endpoint. 223 + 224 + See LabelsServiceBase.batch_read_labels for more information 225 + """ 226 + return self._labels_service.batch_read_labels(entities=keys) 227 + 228 + def stop(self) -> None: 229 + """ 230 + this method is called when the output sink receives a shutdown signal. if you would like to 231 + add shutdown logic, override this~ 232 + """ 233 + pass
+8 -8
osprey_worker/src/osprey/worker/lib/storage/local_label_provider.py
··· 1 1 from typing import Any, Dict, List, Sequence 2 2 3 3 from osprey.engine.language_types.entities import EntityT 4 - from osprey.worker.lib.osprey_shared.labels import ApplyEntityMutationReply, EntityMutation, Labels 5 - from osprey.worker.lib.storage.labels import LabelProvider 4 + from osprey.worker.lib.osprey_shared.labels import EntityLabelMutation, EntityLabelMutationsResult, EntityLabels 5 + from osprey.worker.lib.storage.labels import LabelsProvider 6 6 from result import Result 7 7 8 8 9 - class LocalLabelProvider(LabelProvider): 9 + class LocalLabelProvider(LabelsProvider): 10 10 def __init__(self): 11 - self._labels: Dict[str, Labels] = {} 11 + self._labels: Dict[str, EntityLabels] = {} 12 12 13 - def batch_get_from_service(self, keys: Sequence[EntityT[Any]]) -> Sequence[Result[Labels, Exception]]: 13 + def batch_get_from_service(self, keys: Sequence[EntityT[Any]]) -> Sequence[Result[EntityLabels, Exception]]: 14 14 raise NotImplementedError() 15 15 16 16 def apply_entity_mutation( 17 - self, entity_key: EntityT[Any], mutations: List[EntityMutation] 18 - ) -> ApplyEntityMutationReply: 17 + self, entity_key: EntityT[Any], mutations: List[EntityLabelMutation] 18 + ) -> EntityLabelMutationsResult: 19 19 raise NotImplementedError() 20 20 21 - def get_from_service(self, key: EntityT[Any]) -> Labels: 21 + def get_from_service(self, key: EntityT[Any]) -> EntityLabels: 22 22 raise NotImplementedError()
+2 -2
osprey_worker/src/osprey/worker/lib/storage/stored_execution_result.py
··· 4 4 import json 5 5 from abc import ABC, abstractmethod 6 6 from datetime import datetime 7 + from io import BytesIO 7 8 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence 8 9 9 10 import gevent 10 11 import google.cloud.storage as storage 11 - from osprey.worker.lib.snowflake import Snowflake 12 12 import pytz 13 13 from google.api_core import retry 14 14 from google.cloud.bigtable import row_filters 15 15 from google.cloud.bigtable.row import Row 16 - from io import BytesIO 17 16 from minio import Minio 18 17 from minio.error import S3Error 19 18 from osprey.engine.executor.execution_context import ExecutionResult 20 19 from osprey.worker.lib.instruments import metrics 21 20 from osprey.worker.lib.osprey_shared.logging import get_logger 21 + from osprey.worker.lib.snowflake import Snowflake 22 22 from osprey.worker.lib.storage.bigtable import osprey_bigtable 23 23 from pydantic.main import BaseModel 24 24
+423
osprey_worker/src/osprey/worker/lib/storage/tests/test_labels.py
··· 1 + from datetime import datetime, timedelta 2 + from typing import Any 3 + 4 + import pytest 5 + from osprey.engine.language_types.entities import EntityT 6 + from osprey.worker.lib.osprey_shared.labels import ( 7 + EntityLabelMutation, 8 + EntityLabels, 9 + LabelReason, 10 + LabelReasons, 11 + LabelState, 12 + LabelStatus, 13 + MutationDropReason, 14 + ) 15 + from osprey.worker.lib.storage.labels import LabelsProvider, LabelsServiceBase 16 + 17 + 18 + class MockLabelsService(LabelsServiceBase): 19 + """Mock implementation of LabelsServiceBase for testing""" 20 + 21 + def __init__(self): 22 + self.storage: dict[tuple[str, str], EntityLabels] = {} 23 + 24 + def write_labels(self, entity: EntityT[Any], labels: EntityLabels) -> None: 25 + key = (entity.type, entity.id) 26 + self.storage[key] = labels 27 + 28 + def read_labels(self, entity: EntityT[Any]) -> EntityLabels: 29 + key = (entity.type, entity.id) 30 + return self.storage.get(key, EntityLabels()) 31 + 32 + def get_labels_atomically(self, entity: EntityT[Any]): 33 + """Context manager that yields labels for atomic operations""" 34 + from contextlib import contextmanager 35 + 36 + @contextmanager 37 + def _context(): 38 + yield self.read_labels(entity) 39 + 40 + return _context() 41 + 42 + 43 + @pytest.fixture 44 + def labels_provider() -> LabelsProvider: 45 + """Fixture that provides a LabelsProvider with a mock service""" 46 + return LabelsProvider(MockLabelsService()) 47 + 48 + 49 + @pytest.fixture 50 + def now() -> datetime: 51 + """Fixture that provides a consistent 'now' timestamp""" 52 + return datetime.now() 53 + 54 + 55 + def test_compute_new_labels_from_mutations_adds_new_label(labels_provider: LabelsProvider, now: datetime): 56 + """Test adding a new label to an empty EntityLabels""" 57 + old_labels = EntityLabels() 58 + mutations = [ 59 + EntityLabelMutation( 60 + label_name='test_label', 61 + reason_name='test_reason', 62 + status=LabelStatus.ADDED, 63 + pending=False, 64 + description='Test description', 65 + features={}, 66 + expires_at=now + timedelta(days=1), 67 + ) 68 + ] 69 + 70 + result = labels_provider._compute_new_labels_from_mutations(old_labels, mutations) 71 + 72 + assert 'test_label' in result.new_entity_labels.labels 73 + assert 'test_label' in result.labels_added 74 + assert len(result.labels_removed) == 0 75 + assert len(result.labels_updated) == 0 76 + assert len(result.dropped_mutations) == 0 77 + 78 + 79 + def test_compute_new_labels_from_mutations_removes_existing_label(labels_provider: LabelsProvider, now: datetime): 80 + """Test removing an existing label""" 81 + old_labels = EntityLabels( 82 + labels={ 83 + 'test_label': LabelState( 84 + status=LabelStatus.ADDED, 85 + reasons=LabelReasons({'reason1': LabelReason(description='original', created_at=now)}), 86 + ) 87 + } 88 + ) 89 + mutations = [ 90 + EntityLabelMutation( 91 + label_name='test_label', 92 + reason_name='removal_reason', 93 + status=LabelStatus.REMOVED, 94 + pending=False, 95 + description='Removing label', 96 + features={}, 97 + expires_at=None, 98 + ) 99 + ] 100 + 101 + result = labels_provider._compute_new_labels_from_mutations(old_labels, mutations) 102 + 103 + assert result.new_entity_labels.labels['test_label'].status == LabelStatus.REMOVED 104 + assert 'test_label' in result.labels_removed 105 + assert len(result.labels_added) == 0 106 + assert len(result.labels_updated) == 0 107 + assert len(result.dropped_mutations) == 0 108 + 109 + 110 + def test_compute_new_labels_from_mutations_updates_existing_label_same_status( 111 + labels_provider: LabelsProvider, now: datetime 112 + ): 113 + """Test updating an existing label with the same status (adds a new reason)""" 114 + old_labels = EntityLabels( 115 + labels={ 116 + 'test_label': LabelState( 117 + status=LabelStatus.ADDED, 118 + reasons=LabelReasons({'reason1': LabelReason(description='original', created_at=now)}), 119 + ) 120 + } 121 + ) 122 + mutations = [ 123 + EntityLabelMutation( 124 + label_name='test_label', 125 + reason_name='reason2', 126 + status=LabelStatus.ADDED, 127 + pending=False, 128 + description='Additional reason', 129 + features={}, 130 + expires_at=None, 131 + ) 132 + ] 133 + 134 + result = labels_provider._compute_new_labels_from_mutations(old_labels, mutations) 135 + 136 + assert 'test_label' in result.labels_updated 137 + assert len(result.labels_added) == 0 138 + assert len(result.labels_removed) == 0 139 + assert 'reason1' in result.new_entity_labels.labels['test_label'].reasons 140 + assert 'reason2' in result.new_entity_labels.labels['test_label'].reasons 141 + 142 + 143 + def test_compute_new_labels_from_mutations_drops_conflicting_mutations(labels_provider: LabelsProvider, now: datetime): 144 + """Test that conflicting mutations (same label, different status) result in drops""" 145 + old_labels = EntityLabels() 146 + mutations = [ 147 + EntityLabelMutation( 148 + label_name='test_label', 149 + reason_name='add_reason', 150 + status=LabelStatus.ADDED, 151 + pending=False, 152 + description='Add label', 153 + features={}, 154 + expires_at=None, 155 + ), 156 + EntityLabelMutation( 157 + label_name='test_label', 158 + reason_name='remove_reason', 159 + status=LabelStatus.REMOVED, 160 + pending=False, 161 + description='Remove label', 162 + features={}, 163 + expires_at=None, 164 + ), 165 + ] 166 + 167 + result = labels_provider._compute_new_labels_from_mutations(old_labels, mutations) 168 + 169 + # One mutation should be dropped due to conflict 170 + assert len(result.dropped_mutations) == 1 171 + assert result.dropped_mutations[0].reason == MutationDropReason.CONFLICTING_MUTATION 172 + # The higher priority status (ADDED > REMOVED) should win 173 + assert result.new_entity_labels.labels['test_label'].status == LabelStatus.REMOVED 174 + 175 + 176 + def test_compute_new_labels_from_mutations_manual_blocks_automatic(labels_provider: LabelsProvider, now: datetime): 177 + """Test that manual status blocks automatic status changes""" 178 + old_labels = EntityLabels( 179 + labels={ 180 + 'test_label': LabelState( 181 + status=LabelStatus.MANUALLY_REMOVED, 182 + reasons=LabelReasons({'manual_reason': LabelReason(description='manually removed', created_at=now)}), 183 + ) 184 + } 185 + ) 186 + mutations = [ 187 + EntityLabelMutation( 188 + label_name='test_label', 189 + reason_name='auto_reason', 190 + status=LabelStatus.ADDED, 191 + pending=False, 192 + description='Try to add automatically', 193 + features={}, 194 + expires_at=None, 195 + ) 196 + ] 197 + 198 + result = labels_provider._compute_new_labels_from_mutations(old_labels, mutations) 199 + 200 + # The automatic mutation should be dropped 201 + assert len(result.dropped_mutations) == 1 202 + assert result.dropped_mutations[0].reason == MutationDropReason.CANNOT_OVERRIDE_MANUAL 203 + # The label should still be manually removed 204 + assert result.new_entity_labels.labels['test_label'].status == LabelStatus.MANUALLY_REMOVED 205 + 206 + 207 + def test_compute_new_labels_from_mutations_manual_overrides_automatic(labels_provider: LabelsProvider, now: datetime): 208 + """Test that manual status can override automatic status""" 209 + old_labels = EntityLabels( 210 + labels={ 211 + 'test_label': LabelState( 212 + status=LabelStatus.ADDED, 213 + reasons=LabelReasons({'auto_reason': LabelReason(description='automatic add', created_at=now)}), 214 + ) 215 + } 216 + ) 217 + mutations = [ 218 + EntityLabelMutation( 219 + label_name='test_label', 220 + reason_name='manual_reason', 221 + status=LabelStatus.MANUALLY_REMOVED, 222 + pending=False, 223 + description='Manual removal', 224 + features={}, 225 + expires_at=None, 226 + ) 227 + ] 228 + 229 + result = labels_provider._compute_new_labels_from_mutations(old_labels, mutations) 230 + 231 + # Manual should override automatic 232 + assert result.new_entity_labels.labels['test_label'].status == LabelStatus.MANUALLY_REMOVED 233 + assert 'test_label' in result.labels_removed 234 + assert len(result.dropped_mutations) == 0 235 + 236 + 237 + def test_compute_new_labels_from_mutations_multiple_labels(labels_provider: LabelsProvider, now: datetime): 238 + """Test mutations affecting multiple different labels""" 239 + old_labels = EntityLabels( 240 + labels={ 241 + 'existing_label': LabelState( 242 + status=LabelStatus.ADDED, 243 + reasons=LabelReasons({'reason1': LabelReason(description='exists', created_at=now)}), 244 + ) 245 + } 246 + ) 247 + mutations = [ 248 + EntityLabelMutation( 249 + label_name='new_label', 250 + reason_name='reason_new', 251 + status=LabelStatus.ADDED, 252 + pending=False, 253 + description='New label', 254 + features={}, 255 + expires_at=None, 256 + ), 257 + EntityLabelMutation( 258 + label_name='existing_label', 259 + reason_name='reason_update', 260 + status=LabelStatus.ADDED, 261 + pending=False, 262 + description='Update existing', 263 + features={}, 264 + expires_at=None, 265 + ), 266 + ] 267 + 268 + result = labels_provider._compute_new_labels_from_mutations(old_labels, mutations) 269 + 270 + assert 'new_label' in result.labels_added 271 + assert 'existing_label' in result.labels_updated 272 + assert len(result.new_entity_labels.labels) == 2 273 + 274 + 275 + def test_compute_new_labels_from_mutations_expired_label_can_be_changed(labels_provider: LabelsProvider, now: datetime): 276 + """Test that an expired label state can be changed""" 277 + # Create an expired label (expires_at in the past) 278 + old_labels = EntityLabels( 279 + labels={ 280 + 'test_label': LabelState( 281 + status=LabelStatus.MANUALLY_REMOVED, 282 + reasons=LabelReasons( 283 + { 284 + 'expired_reason': LabelReason( 285 + description='expired', 286 + created_at=now - timedelta(days=2), 287 + expires_at=now - timedelta(days=1), 288 + ) 289 + } 290 + ), 291 + ) 292 + } 293 + ) 294 + mutations = [ 295 + EntityLabelMutation( 296 + label_name='test_label', 297 + reason_name='new_reason', 298 + status=LabelStatus.ADDED, 299 + pending=False, 300 + description='Add after expiration', 301 + features={}, 302 + expires_at=None, 303 + ) 304 + ] 305 + 306 + result = labels_provider._compute_new_labels_from_mutations(old_labels, mutations) 307 + 308 + # Should be able to change expired manual status 309 + assert result.new_entity_labels.labels['test_label'].status == LabelStatus.ADDED 310 + assert 'test_label' in result.labels_added 311 + assert len(result.dropped_mutations) == 0 312 + 313 + 314 + def test_compute_new_labels_from_mutations_merge_multiple_mutations_same_label( 315 + labels_provider: LabelsProvider, now: datetime 316 + ): 317 + """Test that multiple mutations for the same label with same status are merged""" 318 + old_labels = EntityLabels() 319 + mutations = [ 320 + EntityLabelMutation( 321 + label_name='test_label', 322 + reason_name='reason1', 323 + status=LabelStatus.ADDED, 324 + pending=False, 325 + description='First reason', 326 + features={'key1': 'value1'}, 327 + expires_at=None, 328 + ), 329 + EntityLabelMutation( 330 + label_name='test_label', 331 + reason_name='reason2', 332 + status=LabelStatus.ADDED, 333 + pending=False, 334 + description='Second reason', 335 + features={'key2': 'value2'}, 336 + expires_at=None, 337 + ), 338 + ] 339 + 340 + result = labels_provider._compute_new_labels_from_mutations(old_labels, mutations) 341 + 342 + # Both reasons should be present 343 + assert 'reason1' in result.new_entity_labels.labels['test_label'].reasons 344 + assert 'reason2' in result.new_entity_labels.labels['test_label'].reasons 345 + assert len(result.dropped_mutations) == 0 346 + assert 'test_label' in result.labels_added 347 + 348 + 349 + def test_compute_new_labels_from_mutations_preserves_old_labels(labels_provider: LabelsProvider, now: datetime): 350 + """Test that old_labels reference is preserved in result""" 351 + old_labels = EntityLabels( 352 + labels={ 353 + 'existing_label': LabelState( 354 + status=LabelStatus.ADDED, 355 + reasons=LabelReasons({'reason1': LabelReason(description='exists', created_at=now)}), 356 + ) 357 + } 358 + ) 359 + mutations = [ 360 + EntityLabelMutation( 361 + label_name='new_label', 362 + reason_name='reason_new', 363 + status=LabelStatus.ADDED, 364 + pending=False, 365 + description='New label', 366 + features={}, 367 + expires_at=None, 368 + ) 369 + ] 370 + 371 + result = labels_provider._compute_new_labels_from_mutations(old_labels, mutations) 372 + 373 + # Old labels should be preserved 374 + assert result.old_entity_labels == old_labels 375 + # Old labels should not be modified 376 + assert 'new_label' not in old_labels.labels 377 + assert 'new_label' in result.new_entity_labels.labels 378 + 379 + 380 + def test_compute_new_labels_from_mutations_empty_mutations(labels_provider: LabelsProvider, now: datetime): 381 + """Test behavior with no mutations""" 382 + old_labels = EntityLabels( 383 + labels={ 384 + 'existing_label': LabelState( 385 + status=LabelStatus.ADDED, 386 + reasons=LabelReasons({'reason1': LabelReason(description='exists', created_at=now)}), 387 + ) 388 + } 389 + ) 390 + mutations = [] 391 + 392 + result = labels_provider._compute_new_labels_from_mutations(old_labels, mutations) 393 + 394 + # Nothing should change 395 + assert len(result.labels_added) == 0 396 + assert len(result.labels_removed) == 0 397 + assert len(result.labels_updated) == 0 398 + assert len(result.dropped_mutations) == 0 399 + assert result.new_entity_labels.labels == old_labels.labels 400 + 401 + 402 + def test_compute_new_labels_from_mutations_with_pending_status(labels_provider: LabelsProvider, now: datetime): 403 + """Test mutations with pending=True""" 404 + old_labels = EntityLabels() 405 + mutations = [ 406 + EntityLabelMutation( 407 + label_name='pending_label', 408 + reason_name='pending_reason', 409 + status=LabelStatus.ADDED, 410 + pending=True, 411 + description='Pending label', 412 + features={}, 413 + expires_at=None, 414 + ) 415 + ] 416 + 417 + result = labels_provider._compute_new_labels_from_mutations(old_labels, mutations) 418 + 419 + assert 'pending_label' in result.new_entity_labels.labels 420 + assert 'pending_label' in result.labels_added 421 + # Verify the reason is marked as pending 422 + reasons = result.new_entity_labels.labels['pending_label'].reasons 423 + assert reasons['pending_reason'].pending is True
+44 -59
osprey_worker/src/osprey/worker/sinks/sink/bulk_label_sink.py
··· 3 3 from typing import Any, Iterable, List, Optional, Set 4 4 5 5 import sentry_sdk 6 - from osprey.engine.executor.execution_context import ExtendedEntityMutation 6 + from osprey.engine.language_types.entities import EntityT 7 7 from osprey.engine.language_types.labels import LabelStatus 8 + from osprey.worker.adaptor.plugin_manager import bootstrap_labels_provider 8 9 from osprey.worker.lib.bulk_label import TaskStatus 9 10 from osprey.worker.lib.discovery.exceptions import ServiceUnavailable 10 11 from osprey.worker.lib.instruments import metrics 11 12 from osprey.worker.lib.osprey_engine import OspreyEngine 13 + from osprey.worker.lib.osprey_shared.labels import EntityLabelMutation 12 14 from osprey.worker.lib.osprey_shared.logging import get_logger 13 15 from osprey.worker.lib.pigeon.exceptions import RPCException 14 16 from osprey.worker.lib.publisher import BasePublisher 15 17 from osprey.worker.lib.storage.bulk_label_task import BASE_DELAY_SECONDS, MAX_ATTEMPTS, BulkLabelTask 18 + from osprey.worker.lib.storage.labels import LabelsProvider 16 19 from osprey.worker.sinks.sink.input_stream import BaseInputStream 17 - from osprey.worker.sinks.sink.output_sink import LabelOutputSink 18 - from osprey.worker.sinks.sink.output_sink_utils.constants import MutationEventType 19 20 from osprey.worker.sinks.sink.output_sink_utils.models import OspreyBulkJobAnalyticsEvent 20 21 from osprey.worker.ui_api.osprey.lib.druid import PeriodData, TopNDruidQuery, TopNPoPResponse 21 22 from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt, wait_exponential 22 23 23 - from ...adaptor.plugin_manager import bootstrap_label_provider 24 - from ...lib.osprey_shared.labels import EntityMutation 25 - from ...lib.storage.labels import LabelProvider 26 - from ...ui_api.osprey.validators.entities import EntityKey 27 24 from .base_sink import BaseSink 28 25 29 26 logger = get_logger() ··· 69 66 def __init__( 70 67 self, 71 68 input_stream: BaseInputStream[BulkLabelTask], 72 - label_provider: LabelProvider, 69 + labels_provider: LabelsProvider, 73 70 engine: OspreyEngine, 74 71 analytics_publisher: BasePublisher, 75 72 send_status_webhook: bool = True, 76 73 ): 77 74 self._input_stream = input_stream 78 - self._label_provider = label_provider 79 - self._label_output_sink = LabelOutputSink(self._label_provider) 75 + self._labels_provider = labels_provider 80 76 self._engine = engine 81 77 self._metric_tags = [f'sink:{self.__class__.__name__}'] 82 78 self._analytics_publisher = analytics_publisher ··· 319 315 # This range resumes work at the place last left off 320 316 for current_entity_index in task.iterate_entity_indices(): 321 317 entity_id = entity_ids_list[current_entity_index] 322 - entity_key = EntityKey(type=entity_type, id=entity_id) 323 - self._apply_label_mutations(entity_key, task) 318 + entity = EntityT(type=entity_type, id=entity_id) 319 + self._apply_label_mutations(entity, task) 324 320 325 321 self._send_bulk_job_analytics(task) 326 322 ··· 340 336 ) 341 337 self._analytics_publisher.publish(analytics_properties) 342 338 343 - def _apply_label_mutations(self, entity_key: EntityKey, task: BulkLabelTask) -> None: 339 + def _apply_label_mutations(self, entity: EntityT[Any], task: BulkLabelTask) -> None: 344 340 def _log_before_sleep(retry_state: RetryCallState) -> None: 345 341 assert isinstance(task.entities_labeled, int) 346 342 ··· 367 363 assert isinstance(task.dimension, str) 368 364 assert isinstance(task.excluded_entities, Iterable) 369 365 370 - self._label_output_sink.apply_label_mutations( 371 - mutation_event_type=MutationEventType.BULK_ACTION, 372 - mutation_event_id=str(task.id), 373 - entity_key=entity_key, 366 + _ = self._labels_provider.apply_entity_label_mutations( 367 + entity=entity, 374 368 mutations=[ 375 - ExtendedEntityMutation( 376 - mutation=EntityMutation( 377 - label_name=task.label_name, 378 - reason_name=BULK_LABEL_REASON, 379 - expires_at=task.label_expiry, 380 - status=task.label_status, # type: ignore 381 - description='Bulk label (id={BulkLabelTaskId}) by {AdminEmail}: {Reason}', 382 - features={ 383 - 'AdminEmail': task.initiated_by, 384 - 'Reason': task.label_reason, 385 - 'BulkLabelTaskId': str(task.id), 386 - }, 387 - ), 388 - delay_action_by=None, 389 - ) 369 + EntityLabelMutation( 370 + label_name=task.label_name, 371 + reason_name=BULK_LABEL_REASON, 372 + expires_at=task.label_expiry, 373 + status=task.label_status, # type: ignore 374 + description='Bulk label (id={BulkLabelTaskId}) by {AdminEmail}: {Reason}', 375 + features={ 376 + 'AdminEmail': task.initiated_by, 377 + 'Reason': task.label_reason, 378 + 'BulkLabelTaskId': str(task.id), 379 + }, 380 + ), 390 381 ], 391 382 ) 392 383 ··· 433 424 rows_rolled_back = 0 434 425 435 426 feature_name_to_entity_type_mapping = engine.get_feature_name_to_entity_type_mapping() 436 - label_provider = bootstrap_label_provider() 437 - label_output_sink = LabelOutputSink(label_provider) 427 + labels_provider = bootstrap_labels_provider() 438 428 feature_name = task.dimension 439 429 entity_type = feature_name_to_entity_type_mapping[feature_name] 440 430 ··· 451 441 rows_skipped += 1 452 442 continue 453 443 454 - entity_key = EntityKey(type=entity_type, id=value) 455 - if entity_key.id in excluded_ids: 444 + entity = EntityT(type=entity_type, id=value) 445 + if entity.id in excluded_ids: 456 446 rows_excluded += 1 457 447 continue 458 448 459 - if include_ids is not None and str(entity_key.id) not in include_ids: 449 + if include_ids is not None and str(entity.id) not in include_ids: 460 450 rows_excluded += 1 461 451 continue 462 452 463 - labels = label_provider.get_from_service(entity_key) 453 + labels = labels_provider.get_from_service(entity) 464 454 465 455 # No label anymore, nothing to do. 466 456 label_state = labels.labels.get(task.label_name) ··· 480 470 continue 481 471 482 472 # Apply the inverse effect of the rollback. 483 - label_output_sink.apply_label_mutations( 484 - mutation_event_type=MutationEventType.BULK_ACTION, 485 - mutation_event_id=str(task.id), 486 - entity_key=entity_key, 473 + _ = labels_provider.apply_entity_label_mutations( 474 + entity=entity, 487 475 mutations=[ 488 - ExtendedEntityMutation( 489 - mutation=EntityMutation( 490 - label_name=task.label_name, 491 - reason_name='_BulkLabelRollback', 492 - status=LabelStatus.MANUALLY_REMOVED, 493 - expires_at=datetime.now() + timedelta(hours=2), 494 - description=( 495 - 'Bulk label rollback of (id={BulkLabelTaskId}) ' 496 - '(initial reason: {Reason}, initially initiated by: {AdminEmail})' 497 - ), 498 - features={ 499 - 'AdminEmail': task.initiated_by, 500 - 'Reason': task.label_reason, 501 - 'BulkLabelTaskId': str(task.id), 502 - }, 476 + EntityLabelMutation( 477 + label_name=task.label_name, 478 + reason_name='_BulkLabelRollback', 479 + status=LabelStatus.MANUALLY_REMOVED, 480 + expires_at=datetime.now() + timedelta(hours=2), 481 + description=( 482 + 'Bulk label rollback of (id={BulkLabelTaskId}) ' 483 + '(initial reason: {Reason}, initially initiated by: {AdminEmail})' 503 484 ), 504 - delay_action_by=None, 505 - ) 485 + features={ 486 + 'AdminEmail': task.initiated_by, 487 + 'Reason': task.label_reason, 488 + 'BulkLabelTaskId': str(task.id), 489 + }, 490 + ), 506 491 ], 507 492 ) 508 493 rows_rolled_back += 1
-57
osprey_worker/src/osprey/worker/sinks/sink/input_stream.py
··· 23 23 from osprey.worker.lib.osprey_shared.logging import get_logger 24 24 from osprey.worker.lib.storage.postgres import Model, scoped_session 25 25 from osprey.worker.lib.utils.dates import parse_go_timestamp 26 - from osprey.worker.sinks.sink.output_sink_utils.models import OspreyEntityLabelWebhook 27 26 from osprey.worker.sinks.utils.acking_contexts import ( 28 27 BaseAckingContext, 29 28 NoopAckingContext, ··· 299 298 publish_time=message.publish_time.ToDatetime(), 300 299 attributes={k: v for k, v in message.attributes.items()}, 301 300 ) 302 - 303 - 304 - class PubSubEntityLabelWebhookInputStream(BasePubSubInputStream[BaseAckingContext[OspreyEntityLabelWebhook]]): 305 - def __init__( 306 - self, 307 - subscriber: SubscriberClient, 308 - subscription_path: str, 309 - max_messages: int = 250, 310 - gevent_queue_size: int = 1000, 311 - ): 312 - super().__init__(subscriber, subscription_path, max_messages) 313 - self.queue = GeventQueue(maxsize=gevent_queue_size) 314 - 315 - def _worker(self) -> None: 316 - logger.info('Webhook Pubsub worker spawned') 317 - 318 - def stream_callback(message: Message) -> None: 319 - self.queue.put(message) 320 - 321 - with self.subscriber as subscriber: 322 - flow_control = types.FlowControl( 323 - max_messages=self.max_messages, 324 - # assume 4KB per message 325 - max_bytes=4_000 * self.max_messages, 326 - max_lease_duration=60 * 60, 327 - max_duration_per_lease_extension=600, 328 - ) 329 - streaming_pull_future = subscriber.subscribe( 330 - self.subscription_path, callback=stream_callback, flow_control=flow_control 331 - ) 332 - 333 - while True: 334 - try: 335 - streaming_pull_future.result() 336 - except Exception as e: 337 - logger.error(e) 338 - sentry_sdk.capture_exception(error=e) 339 - continue 340 - 341 - def _gen(self) -> Iterator[PubSubMessageAckingContext[OspreyEntityLabelWebhook]]: 342 - gevent.spawn(self._worker) 343 - while True: 344 - try: 345 - received_message = self.queue.get() 346 - assert isinstance(received_message, Message) 347 - metrics.increment('webhook_pubsub_reads', tags=['status:success']) 348 - 349 - webhook_object_data = received_message.data 350 - webhook_object = OspreyEntityLabelWebhook.parse_raw(webhook_object_data) 351 - 352 - yield PubSubMessageAckingContext(webhook_object, received_message) 353 - except Exception: 354 - metrics.increment('webhook_pubsub_reads', tags=['status:failure']) 355 - logger.exception('Error while generating input message') 356 - sentry_sdk.capture_exception() 357 - continue 358 301 359 302 360 303 # Use for utility scripts, not for production
+18 -69
osprey_worker/src/osprey/worker/sinks/sink/output_sink.py
··· 1 1 import abc 2 2 from collections import defaultdict 3 3 from datetime import datetime 4 - from typing import Any, DefaultDict, Dict, List, Mapping, Optional, Sequence 4 + from typing import Any, DefaultDict, Dict, Mapping, Optional, Sequence 5 5 6 6 import gevent 7 7 import sentry_sdk 8 8 from osprey.engine.executor.execution_context import ( 9 9 ExecutionResult, 10 - ExtendedEntityMutation, 11 10 ) 12 11 from osprey.engine.language_types.entities import EntityT 13 12 from osprey.engine.language_types.labels import LabelEffect 14 13 from osprey.engine.stdlib.udfs.rules import RuleT 15 14 from osprey.worker.lib.ddtrace_utils import trace 16 15 from osprey.worker.lib.instruments import metrics 17 - from osprey.worker.lib.osprey_shared.labels import ApplyEntityMutationReply, EntityMutation 16 + from osprey.worker.lib.osprey_shared.labels import EntityLabelMutation 18 17 from osprey.worker.lib.osprey_shared.logging import DynamicLogSampler, get_logger 19 - from osprey.worker.lib.storage.labels import LabelProvider 20 - from osprey.worker.sinks.sink.output_sink_utils.constants import MutationEventType 21 - from osprey.worker.ui_api.osprey.validators.entities import EntityKey 22 - from tenacity import retry, stop_after_attempt, wait_exponential 18 + from osprey.worker.lib.storage.labels import LabelsProvider 23 19 24 20 logger = get_logger() 25 21 ··· 104 100 105 101 def _create_entity_mutation( 106 102 label_effect: LabelEffect, rule: RuleT, expires_at: Optional[datetime] 107 - ) -> ExtendedEntityMutation: 108 - return ExtendedEntityMutation( 109 - mutation=EntityMutation( 110 - label_name=label_effect.name, 111 - reason_name=rule.name, 112 - status=label_effect.status, 113 - description=rule.description, 114 - features=rule.features, 115 - expires_at=expires_at, 116 - ), 117 - delay_action_by=label_effect.delay_action_by, 103 + ) -> EntityLabelMutation: 104 + return EntityLabelMutation( 105 + label_name=label_effect.name, 106 + reason_name=rule.name, 107 + status=label_effect.status, 108 + description=rule.description, 109 + features=rule.features, 110 + expires_at=expires_at, 118 111 ) 119 112 120 113 121 - def _get_label_effects_from_result(result: ExecutionResult) -> Mapping[EntityT[Any], List[ExtendedEntityMutation]]: 122 - effects: DefaultDict[EntityT[Any], List[ExtendedEntityMutation]] = defaultdict(list) 114 + def _get_label_effects_from_result(result: ExecutionResult) -> Mapping[EntityT[Any], list[EntityLabelMutation]]: 115 + effects: DefaultDict[EntityT[Any], list[EntityLabelMutation]] = defaultdict(list) 123 116 124 117 for label_effect in result.effects.get(LabelEffect, []): 125 118 # assert for typing ··· 157 150 class LabelOutputSink(BaseOutputSink): 158 151 """An output sink that will send event effects to the label service.""" 159 152 160 - def __init__(self, label_provider: LabelProvider) -> None: 161 - self._label_provider = label_provider 153 + def __init__(self, labels_provider: LabelsProvider) -> None: 154 + self._labels_provider = labels_provider 162 155 163 156 def will_do_work(self, result: ExecutionResult) -> bool: 164 157 return len(_get_label_effects_from_result(result)) > 0 165 158 166 159 def push(self, result: ExecutionResult) -> None: 167 160 for entity, mutations in _get_label_effects_from_result(result).items(): 168 - entity_key = EntityKey(type=entity.type, id=str(entity.id)) 169 - self.apply_label_mutations( 170 - MutationEventType.OSPREY_ACTION, 171 - str(result.action.action_id), 172 - entity_key, 161 + _ = self._labels_provider.apply_entity_label_mutations( 162 + entity, 173 163 mutations, 174 - result.extracted_features, 175 - mutation_event_action_name=result.action.action_name, 176 164 ) 177 165 178 - @retry(wait=wait_exponential(min=0.5, max=5), stop=stop_after_attempt(3)) 179 - def apply_entity_mutation_with_retry( 180 - self, entity_key: EntityKey, mutations: Sequence[ExtendedEntityMutation] 181 - ) -> ApplyEntityMutationReply: 182 - return self._label_provider.apply_entity_mutation( 183 - entity_key=entity_key, mutations=[extended_mutation.mutation for extended_mutation in mutations] 184 - ) 185 - 186 - def apply_label_mutations( 187 - self, 188 - mutation_event_type: MutationEventType, 189 - mutation_event_id: str, 190 - entity_key: EntityKey, 191 - mutations: Sequence[ExtendedEntityMutation], 192 - features: Optional[Dict[str, Any]] = None, 193 - mutation_event_action_name: str = '', 194 - ) -> ApplyEntityMutationReply: 195 - if not entity_key.id: 196 - metrics.increment( 197 - 'output_sink.apply_entity_mutation', 198 - tags=['status:skipped', 'reason:no_entity_id', f'entity_type:{entity_key.type}'], 199 - ) 200 - return ApplyEntityMutationReply( 201 - unchanged=[mutation.mutation.label_name for mutation in mutations], 202 - ) 203 - 204 - try: 205 - result: ApplyEntityMutationReply = self.apply_entity_mutation_with_retry(entity_key, mutations) 206 - metrics.increment('output_sink.apply_entity_mutation', tags=['status:success']) 207 - except Exception as e: 208 - logger.error( 209 - f'Failed to apply entity mutation on entity of type: {entity_key.type} with id: {entity_key.id} - {e}', 210 - exc_info=True, 211 - ) 212 - metrics.increment('output_sink.apply_entity_mutation', tags=['status:failure']) 213 - raise e 214 - 215 - return result 216 - 217 166 def stop(self) -> None: 218 - pass 167 + self._labels_provider.stop()
-7
osprey_worker/src/osprey/worker/sinks/sink/output_sink_utils/constants.py
··· 10 10 RULES_VISUALIZER_GEN_GRAPH = 'network_action_osprey_rules_visualizer_generate_graph' 11 11 12 12 13 - class MutationEventType(str, Enum): 14 - OSPREY_ACTION = 'osprey_action' 15 - BULK_ACTION = 'bulk_action' 16 - LABEL_DISAGREEMENT = 'label_disagreement' 17 - MANUAL_UPDATE = 'manual_update' 18 - 19 - 20 13 # There are more types, currently listing the ones we need to use in code 21 14 class EntityType(str, Enum): 22 15 USER = 'User'
+2 -77
osprey_worker/src/osprey/worker/sinks/sink/output_sink_utils/models.py
··· 1 - from datetime import datetime 2 - from typing import Any, Dict, List, Optional 1 + from typing import List, Optional 3 2 4 - from osprey.worker.lib.osprey_shared.labels import LabelStatus 5 3 from pydantic import BaseModel 6 4 7 - from .constants import MutationEventType, OspreyAnalyticsEvents 8 - 9 - 10 - class OspreyLabelMutationAnalyticsEvent(BaseModel): 11 - name: str = 'event' 12 - event_type: str = OspreyAnalyticsEvents.LABEL_MUTATIONS 13 - mutation_event_type: MutationEventType 14 - mutation_event_id: str 15 - mutation_event_action_name: Optional[str] = None 16 - user_id: Optional[int] 17 - entity_id_v2: str 18 - entity_type: str 19 - labels: List[str] 20 - label_statuses: List[str] 21 - label_reasons: List[str] 22 - 23 - 24 - class OspreyEntityLabelWebhook(BaseModel): 25 - entity_type: str 26 - entity_id: str 27 - label_name: str 28 - label_status: LabelStatus 29 - webhook_name: str 30 - features: Dict[str, Any] 31 - created_at: datetime 32 - 33 - 34 - class OspreyActionClassificationAnalyticsEvent(BaseModel): 35 - name: str = 'event' 36 - event_type: str = OspreyAnalyticsEvents.ACTION_CLASSIFICATION 37 - action_name: str 38 - action_id: str 39 - action_timestamp: str 40 - 41 - 42 - class OspreyExtractedFeaturesAnalyticsEvent(BaseModel): 43 - name: str = 'event' 44 - event_type: str = OspreyAnalyticsEvents.EXTRACTED_FEATURES 45 - action_name: str 46 - action_id: str 47 - action_timestamp: str 48 - user_id: Optional[int] 49 - error_count: int 50 - extracted_features_json: Dict[str, Any] 51 - 52 - 53 - class OspreyExperimentExposureEvent(BaseModel): 54 - name: str = 'event' 55 - event_type: str = OspreyAnalyticsEvents.EXPERIMENT_EXPOSURE_EVENT 56 - experiment: str 57 - user_id: Optional[int] 58 - entity_id_v2: str 59 - entity_type: str 60 - bucket_name: str 61 - bucket_index: int 62 - action_id: str 63 - event: str = 'experiment_osprey_triggered' # follows experimentation platform naming convention 64 - experiment_version: int 65 - experiment_revision: int 66 - guild_id: Optional[int] 67 - action_timestamp: str 5 + from .constants import OspreyAnalyticsEvents 68 6 69 7 70 8 class OspreyBulkJobAnalyticsEvent(BaseModel): ··· 85 23 path: str 86 24 request_method: str 87 25 timestamp: str 88 - 89 - 90 - class OspreyExecutionResultBigQueryPubsubEvent(BaseModel): 91 - action_id: str 92 - action_name: str 93 - timestamp: datetime 94 - error_count: int 95 - sample_rate: int 96 - classifications: List[str] 97 - signals: List[str] 98 - entity_label_mutations: List[str] 99 - error_results: Optional[str] 100 - execution_results: str
+2 -8
osprey_worker/src/osprey/worker/ui_api/osprey/validators/entities.py
··· 1 - from dataclasses import dataclass 2 1 from datetime import datetime 3 2 from typing import List, Optional, Type 4 3 ··· 10 9 from pydantic import BaseModel 11 10 12 11 13 - @dataclass(frozen=True) 14 - class EntityKey(EntityT[str]): 15 - pass 16 - 17 - 18 12 class EntityMarshaller(FlaskRequestMarshaller): 19 13 @classmethod 20 14 def marshal(cls: Type[T], flask_request: Request) -> T: ··· 26 20 27 21 28 22 class GetLabelsForEntityRequest(BaseModel, EntityMarshaller): 29 - entity: EntityKey 23 + entity: EntityT[str] 30 24 31 25 32 26 class EventCountsByFeatureForEntityQuery(TimeseriesDruidQuery, EntityMarshaller): ··· 41 35 42 36 43 37 class ManualEntityLabelMutationRequest(BaseModel, EntityMarshaller): 44 - entity: EntityKey 38 + entity: EntityT[str] 45 39 mutations: List[EntityLabelMutation]
+45 -65
osprey_worker/src/osprey/worker/ui_api/osprey/views/entities.py
··· 1 - from datetime import datetime 2 - from typing import Any, Dict, Optional 1 + from typing import Any 3 2 4 3 from flask import Blueprint, abort, jsonify 5 - from osprey.worker.lib.osprey_shared.labels import ApplyEntityMutationReply, LabelState 4 + from osprey.worker.adaptor.plugin_manager import bootstrap_labels_provider, has_labels_service 5 + from osprey.worker.lib.osprey_shared.labels import EntityLabelMutation 6 6 from osprey.worker.ui_api.osprey.lib.abilities import ( 7 7 CanMutateEntities, 8 8 CanMutateLabels, ··· 12 12 require_ability, 13 13 require_ability_with_request, 14 14 ) 15 - from pydantic.main import BaseModel 15 + from osprey.worker.ui_api.osprey.lib.auth import get_current_user, get_current_user_email 16 16 17 17 from ..lib.marshal import marshal_with 18 18 from ..validators.entities import ( ··· 30 30 def get_labels_for_entity(request_model: GetLabelsForEntityRequest) -> Any: 31 31 require_ability_with_request(request_model, CanViewLabelsForEntity) 32 32 33 - # TODO(ayubun): Support plug-and-play label service 34 - return { 35 - 'labels': {}, 36 - 'expires_at': None, 37 - } 33 + if not has_labels_service(): 34 + return { 35 + 'labels': {}, 36 + # this field is deprecated 37 + 'expires_at': None, 38 + } 38 39 39 - # entity_labels = labels.get_for_entity(entity_key=request_model.entity.to_proto()) 40 - # # Filter out all but the allowed labels 41 - # ability = get_current_user().get_ability(CanViewLabels) 40 + labels_provider = bootstrap_labels_provider() 41 + 42 + entity_labels = labels_provider.get_from_service(key=request_model.entity) 43 + # Filter out all but the allowed labels 44 + ability = get_current_user().get_ability(CanViewLabels) 42 45 43 - # response_labels = {} 44 - # if hasattr(entity_labels, 'labels'): 45 - # for label_name, label_state in entity_labels.labels.items(): 46 - # if ability and ability.item_is_allowed(label_name): 47 - # response_labels[label_name] = MessageToDict( 48 - # label_state, 49 - # use_integers_for_enums=True, 50 - # preserving_proto_field_name=True, 51 - # ) 46 + response_labels = {} 47 + if hasattr(entity_labels, 'labels'): 48 + for label_name, label_state in entity_labels.labels.items(): 49 + if ability and ability.item_is_allowed(label_name): 50 + response_labels[label_name] = label_state 52 51 53 - # return { 54 - # 'labels': response_labels, 55 - # 'expires_at': entity_labels.expires_at.ToDatetime() if hasattr(entity_labels, 'expires_at') else None, 56 - # } 52 + return { 53 + 'labels': response_labels, 54 + } 57 55 58 56 59 57 @blueprint.route('/entities/event-count-by-feature', methods=['POST']) ··· 65 63 return jsonify(timeseries_result[0]['result']) 66 64 67 65 68 - class EntityLabelMutationResult(BaseModel): 69 - mutation_result: ApplyEntityMutationReply 70 - labels: Dict[str, LabelState] 71 - expires_at: Optional[datetime] 72 - 73 - 74 66 @blueprint.route('/entities/labels', methods=['POST']) 75 67 @marshal_with(ManualEntityLabelMutationRequest) 76 68 def manual_entity_mutation(request_model: ManualEntityLabelMutationRequest) -> Any: 77 69 require_ability_with_request(request_model, CanMutateEntities) 78 70 require_ability_with_request(request_model, CanMutateLabels) 79 71 80 - # TODO(ayubun): Support plug-and-play label service 81 - return abort(501, 'Not Implemented') 72 + if not has_labels_service(): 73 + return abort(501, 'Labels Provider Not Found') 82 74 83 - # can_mutate_labels_ability = get_current_user().get_ability(CanMutateLabels) 84 - # # We can make this assertion because of the above line that requires CanMutateLabel for the request 85 - # assert can_mutate_labels_ability is not None 75 + labels_provider = bootstrap_labels_provider() 86 76 87 - # mutations: List[ExtendedEntityMutation] = [] 88 - # for request_mutation in request_model.mutations: 89 - # if not can_mutate_labels_ability.item_is_allowed(request_mutation.label_name): 90 - # continue 91 - # entity_mutation = ExtendedEntityMutation( 92 - # mutation=EntityMutation( 93 - # label_name=request_mutation.label_name, 94 - # status=request_mutation.status.value, 95 - # expires_at=optional_datetime_to_timestamp(request_mutation.expires_at), 96 - # reason_name='_ManuallyUpdated', 97 - # description='Manual update by {AdminEmail}: {Reason}', 98 - # features={'AdminEmail': get_current_user_email(), 'Reason': request_mutation.reason}, 99 - # ), 100 - # delay_action_by=None, 101 - # ) 102 - # mutations.append(entity_mutation) 77 + can_mutate_labels_ability = get_current_user().get_ability(CanMutateLabels) 78 + # We can make this assertion because of the above line that requires CanMutateLabel for the request 79 + assert can_mutate_labels_ability is not None 80 + 81 + mutations: list[EntityLabelMutation] = [] 82 + for request_mutation in request_model.mutations: 83 + if not can_mutate_labels_ability.item_is_allowed(request_mutation.label_name): 84 + continue 85 + entity_mutation = EntityLabelMutation( 86 + label_name=request_mutation.label_name, 87 + status=request_mutation.status, 88 + expires_at=request_mutation.expires_at, 89 + reason_name='_ManuallyUpdated', 90 + description='Manual update by {AdminEmail}: {Reason}', 91 + features={'AdminEmail': get_current_user_email(), 'Reason': request_mutation.reason}, 92 + ) 93 + mutations.append(entity_mutation) 103 94 104 - # # TODO Give unique ids to manual update requests. 105 - # mutation_result_external = EVENT_EFFECT_SINK.instance().apply_label_mutations_pb2( 106 - # mutation_event_type=MutationEventType.MANUAL_UPDATE, 107 - # mutation_event_id=get_current_user_email(), 108 - # entity_key=request_model.entity.to_proto(), 109 - # mutations=mutations, 110 - # ) 111 - # mutation_result = ApplyEntityMutationReply.from_pb2(mutation_result_external) 95 + result = labels_provider.apply_entity_label_mutations(request_model.entity, mutations) 112 96 113 - # entity_labels_internal = labels.get_for_entity(request_model.entity.to_proto()) 114 - # entity_labels = LabelsModel.from_pb2(entity_labels_internal) 115 - # return EntityLabelMutationResult( 116 - # labels=entity_labels.labels, expires_at=entity_labels.expires_at, mutation_result=mutation_result 117 - # ) 97 + return result