···2323 Returns a time to live for items in the cache. By default, KVs are cached indefinitely.
24242525 To have cache entries auto-expire, override this method in your external service definition.
2626+2727+ Note that timedeltas can accept negative values to represent the past, but only on the days field.
2828+ You *can* use timedelta(seconds=0) to disable caching, but a negative time delta *ensures* that even
2929+ if a time shift occurs (such as daylight savings), the cache_ttl will still be immediate.
3030+3131+ Therefore, to disable the read cache, it is recommended to set this to `timedelta(days=-1)`
2632 """
2733 return None
2834
···11from dataclasses import dataclass
22from datetime import timedelta
33-from enum import IntEnum
43from typing import Any, List, Optional, Self, cast
5465from osprey.engine.executor.custom_extracted_features import CustomExtractedFeature
···98 ENTITY_LABEL_MUTATION_DIMENSION_NAME,
109 ENTITY_LABEL_MUTATION_DIMENSION_VALUE,
1110)
1111+from osprey.worker.lib.osprey_shared.labels import LabelStatus
12121313from .entities import EntityT
1414from .rules import RuleT, add_slots
151516161717-class LabelStatus(IntEnum):
1818- ADDED = 0
1919- REMOVED = 1
2020- MANUALLY_ADDED = 2
2121- MANUALLY_REMOVED = 3
2222-2323- def effective_label_status(self) -> 'LabelStatus':
2424- """
2525- Returns the effective status of the label, which is what the upstreams that are observing label
2626- status changes will see. Which is to say, the upstreams will currently not see if the label status was
2727- manually added or manually removed, just that it was added or removed.
2828- """
2929- match self:
3030- case LabelStatus.ADDED | LabelStatus.MANUALLY_ADDED:
3131- return LabelStatus.ADDED
3232- case LabelStatus.REMOVED | LabelStatus.MANUALLY_REMOVED:
3333- return LabelStatus.REMOVED
3434-3535-3617@add_slots
3718@dataclass
3819class LabelEffect(EffectToCustomExtractedFeatureBase[List[str]]):
···5132 expires_after: Optional[timedelta] = None
5233 """If set, the label effect has a timed expiration, which means that the reason will expire after this time."""
53345454- delay_action_by: Optional[timedelta] = None
5555- """If set, the propagation of the effect to the upstream (if configured) will be delayed.S"""
3535+ # delay_action_by: Optional[timedelta] = None
3636+ # """If set, the propagation of the effect to the upstream (if configured via LabelsService.after_add or LabelsService.after_remove) will be delayed."""
56375738 dependent_rule: Optional[RuleT] = None
5839 """If set, the effect will only be applied if the dependent rule evaluates to true."""
···2424from osprey.engine.utils.get_closest_string_within_threshold import (
2525 get_closest_string_within_threshold,
2626)
2727-from osprey.worker.lib.osprey_shared.labels import Labels
2828-from osprey.worker.lib.storage.labels import LabelProvider
2727+from osprey.worker.lib.osprey_shared.labels import EntityLabels
2828+from osprey.worker.lib.storage.labels import LabelsProvider
2929from result import Err, Ok, Result
303031313232-# TODO: move back to labels.py once we actually make it stdlib
3332class LabelArguments(ArgumentsBase):
3433 entity: EntityT[Any]
3534 """An entity to mutate a label on."""
3635 label: ConstExpr[str]
3736 """The label to mutate."""
3838- delay_action_by: Optional[TimeDeltaT] = None
3939- """Optional: Delays a label action by a specified `TimeDeltaT` time."""
3737+ # NOTE(ayubun): delayed actions are removed; they are legacy code from when discord used osprey
3838+ # to trigger webhooks upon label adds/removes.
3939+ #
4040+ # we may eventually add something *similar* to this in the future? but i suspect
4141+ # that a better abstraction would be to have any sort of "external impact" come
4242+ # from verdicts, which were created to be an output (whereas labels were created
4343+ # to simply store state, thus making label webhooks a leaky abstraction)
4444+ # delay_action_by: Optional[TimeDeltaT] = None
4545+ # """Optional: Delays a label action by a specified `TimeDeltaT` time."""
4046 apply_if: Optional[RuleT] = None
4147 """Optional: Conditions that must be met for the label mutation to succeed."""
4248 expires_after: Optional[TimeDeltaT] = None
···4955 status=status,
5056 name=arguments.label.value,
5157 expires_after=TimeDeltaT.inner_from_optional(arguments.expires_after),
5252- delay_action_by=TimeDeltaT.inner_from_optional(arguments.delay_action_by),
5858+ # delay_action_by=TimeDeltaT.inner_from_optional(arguments.delay_action_by),
5359 dependent_rule=arguments.apply_if,
5460 # NOTE: This is fairly significant, if this call node has an `apply_if` ast, but
5561 # the resolved apply_if is None, that means that the evaluation of the rule failed.
···123129 desired_status: Optional[_SimpleStatus]
124130125131126126-class HasLabel(HasHelperInternal[LabelProvider], BatchableUDFBase[HasLabelArguments, bool, BatchableHasLabelArguments]):
132132+class HasLabel(
133133+ HasHelperInternal[LabelsProvider], BatchableUDFBase[HasLabelArguments, bool, BatchableHasLabelArguments]
134134+):
127135 """Returns `True` if the specified label is currently present in a given non-expired state on a provided Entity."""
128136129137 category = UdfCategories.ENGINE
···157165 validation_context.add_error(message='unknown label', span=arguments.label.argument_span, hint=hint)
158166159167 def _execute(
160160- self, execution_context: ExecutionContext, arguments: BatchableHasLabelArguments, entity_labels: Labels
168168+ self, execution_context: ExecutionContext, arguments: BatchableHasLabelArguments, entity_labels: EntityLabels
161169 ) -> bool:
162170 desired_manual = _ManualType.get(arguments.manual)
163171 desired_delay = TimeDeltaT.inner_from_optional(arguments.min_label_age)
···88from osprey.engine.udf.base import UDFBase
99from osprey.worker.adaptor.constants import OSPREY_ADAPTOR
1010from osprey.worker.lib.action_proto_deserializer import ActionProtoDeserializer
1111-from osprey.worker.lib.storage.labels import LabelProvider
1111+from osprey.worker.lib.storage.labels import LabelsProvider, LabelsServiceBase
1212from osprey.worker.sinks.sink.input_stream import BaseInputStream
1313from osprey.worker.sinks.utils.acking_contexts import BaseAckingContext
1414···565657575858@hookspec(firstresult=True)
5959-def register_label_provider(config: Config) -> LabelProvider:
6060- """Register an execution result storage backend instance."""
6161- raise NotImplementedError('register_label_provider must be implemented by the plugin')
5959+def register_labels_service(config: Config) -> LabelsServiceBase | LabelsProvider:
6060+ """Register a labels service backend. This can be achieved by implementing a labels service
6161+ and utilizing the provided labels provider, or by overriding the labels provider to fit your
6262+ business needs"""
6363+ raise NotImplementedError('register_labels_service must be implemented by the plugin')
···11-from dataclasses import dataclass, field
22-from datetime import datetime
33-from enum import Enum
44-from typing import TYPE_CHECKING, Dict, List, Mapping, Optional
11+import copy
22+from collections import UserDict
33+from dataclasses import dataclass, field, replace
44+from datetime import datetime, timedelta
55+from enum import Enum, IntEnum
66+from typing import Dict, Optional
5766-from osprey.engine.language_types.labels import LabelStatus
78from osprey.worker.lib.osprey_shared.logging import get_logger
89from osprey.worker.lib.utils.request_utils import SessionWithRetries
99-from pydantic import BaseModel
1010-1111-if TYPE_CHECKING:
1212- from osprey.worker.lib.utils.flask_signing import Signer
1313-14101511# The requests session we will be using to contact osprey API.
1612_session = SessionWithRetries()
···2117logger = get_logger(__name__)
221823192020+class MutationDropReason(IntEnum):
2121+ # If a label mutation was dropped due to another mutation that conflicted & was higher priority
2222+ # (priority of conflicting mutations in a given entity update is determined by the int value of the
2323+ # label status enum)
2424+ CONFLICTING_MUTATION = 0
2525+ # If the existing label status was manual and the attempted mutation was not
2626+ CANNOT_OVERRIDE_MANUAL = 1
2727+2828+2929+class LabelStatus(IntEnum):
3030+ """
3131+ indicates the status of label.
3232+3333+ regular (a.k.a. "automatic") statuses are applied via rules. they can be overwritten by manual
3434+ statuses, which can only be applied via humans using the ui.
3535+3636+ statuses have weights, which control which ones get dropped when conflicting statuses occur during
3737+ a single attempted mutation; i.e., if an execution of the rules results in a label add and a label remove
3838+ of the same entity/label pair.
3939+ """
4040+4141+ REMOVED = 0
4242+ ADDED = 1
4343+ MANUALLY_REMOVED = 2
4444+ MANUALLY_ADDED = 3
4545+4646+ def effective_label_status(self) -> 'LabelStatus':
4747+ """
4848+ Returns the effective status of the label, which is what the upstreams that are observing label
4949+ status changes will see. Which is to say, the upstreams will currently not see if the label status was
5050+ manually added or manually removed, just that it was added or removed.
5151+ """
5252+ match self:
5353+ case LabelStatus.ADDED | LabelStatus.MANUALLY_ADDED:
5454+ return LabelStatus.ADDED
5555+ case LabelStatus.REMOVED | LabelStatus.MANUALLY_REMOVED:
5656+ return LabelStatus.REMOVED
5757+ case _:
5858+ raise NotImplementedError()
5959+6060+ def is_manual(self) -> bool:
6161+ match self:
6262+ case LabelStatus.MANUALLY_ADDED | LabelStatus.MANUALLY_REMOVED:
6363+ return True
6464+ case _:
6565+ return False
6666+6767+ def is_automatic(self) -> bool:
6868+ return not self.is_manual()
6969+7070+2471# If you change this also change osprey/osprey_engine/packages/osprey_stdlib/configs/labels_config.py
2572class LabelConnotation(Enum):
2673 POSITIVE = 'positive'
···2875 NEUTRAL = 'neutral'
297630773131-# Pydantic-compatible versions of pb2 types
3278@dataclass
3379class LabelReason:
8080+ """
8181+ a label reason tells us why a label mutation was made, when it happened, and when it expires (if at all)
8282+ """
8383+3484 pending: bool = False
3585 description: str = ''
3636- features: Dict[str, str] = field(default_factory=dict)
8686+ """why the label was mutated"""
8787+ features: dict[str, str] = field(default_factory=dict)
8888+ """features are injected into the description as k/v's, similar to how fstrings work. for example,
8989+ the {you} in 'hello {you}' would be substituted as 'person' with a feature dict of {'you': 'person'}"""
3790 created_at: datetime | None = None
9191+ """
9292+ when this reason was made
9393+ """
3894 expires_at: datetime | None = None
9595+ """marks when this label reason 'expires'
9696+9797+ if a LabelState.MANUALLY_REMOVED is applied with a reason that has a 1 day expiration, then
9898+ for 1 day, the label cannot be applied via LabelState.ADDED. all LabelState.ADDED attempts will be dropped.
9999+100100+ if a given label state has multiple label reasons, all reasons would need to expire before the status/state
101101+ is considered expired, too.
102102+ """
103103+104104+ def is_expired(self) -> bool:
105105+ return bool(self.expires_at is not None and self.expires_at + timedelta(seconds=5) < datetime.now())
106106+107107+108108+@dataclass
109109+class LabelReasons(UserDict[str, LabelReason]):
110110+ """
111111+ the label reasons userdict allows us to add a helper function to the dict directly, while otherwise
112112+ operating as a normal dict would~
113113+ """
114114+115115+ def __init__(self, initial_data: dict[str, LabelReason] | None = None) -> None:
116116+ super().__init__(initial_data)
117117+118118+ def insert_or_update(self, reason_name: str, reason: LabelReason) -> bool:
119119+ """
120120+ returns true if the reason was able to be inserted or updated an existing reason;
121121+ false if it was dropped due to being older than the current reason
122122+ """
123123+ if reason_name not in self:
124124+ self[reason_name] = reason
125125+ return True
126126+127127+ current_reason = self[reason_name]
128128+ if current_reason.created_at is None or reason.created_at is None:
129129+ raise AssertionError(
130130+ f'invariant: missing created_at on one of the following LabelReasons: {current_reason} {reason}'
131131+ )
132132+133133+ if current_reason.created_at > reason.created_at + timedelta(seconds=5):
134134+ # the reason we are trying to append is older than the one currently at the reason_name key,
135135+ # so we will discard it (5sec added to adjust for potential code exec time).
136136+ return False
137137+138138+ self[reason_name] = replace(
139139+ reason,
140140+ # since the current reason is older by this point in the code, we want to preserve the original created_at timestamp
141141+ created_at=current_reason.created_at,
142142+ )
143143+ return True
144144+145145+ @classmethod
146146+ def __get_validators__(cls):
147147+ """Pydantic v1 validator"""
148148+ yield cls.validate
149149+150150+ @classmethod
151151+ def validate(cls, v):
152152+ """Validate and convert to LabelReasons"""
153153+ if isinstance(v, cls):
154154+ return v
155155+ if isinstance(v, dict):
156156+ return cls(v)
157157+ raise TypeError(f'LabelReasons expected dict or LabelReasons, got {type(v)}')
158158+159159+ def __repr__(self):
160160+ return f'LabelReasons({self.data})'
391614016241163@dataclass
42164class LabelStateInner:
43165 status: LabelStatus
4444- reasons: Dict[str, LabelReason]
166166+ reasons: LabelReasons
451674616847169@dataclass
48170class LabelState:
49171 status: LabelStatus
5050- reasons: Dict[str, LabelReason]
5151- previous_states: List[LabelStateInner] = field(default_factory=list)
172172+ """statuses dictate the way the current state behaves; certain statuses have priority over others
173173+ (see LabelStatus for more info)"""
174174+175175+ reasons: LabelReasons
176176+ """
177177+ reasons are why this label state was applied; it is a dict because there may be multiple,
178178+ with each reason being distinct based on it's reason name.
179179+180180+ reasons applied under the same name are merged (assuming the status has not changed),
181181+ with precedence given to newer creaeted_at timestamps.
182182+ """
183183+184184+ previous_states: list[LabelStateInner] = field(default_factory=list)
185185+ """the top-level label state also contains previous label states; we use an inner type
186186+ because we don't need these prior states to have the previous_states field"""
187187+188188+ @property
189189+ def expires_at(self) -> datetime | None:
190190+ """
191191+ when a given label state is effectively expired. expiration can only occur if all of the
192192+ reasons are expired.
52193194194+ this field is a convenience value to save users time on computing the effective expiration time from the reasons.
531955454-@dataclass
5555-class Labels:
5656- labels: Dict[str, LabelState] = field(default_factory=dict)
5757- expires_at: Optional[datetime] = None
196196+ expiration defines when future label states can be applied. if the current label state is not expired,
197197+ then then upon a new label state change attempt, the current and new statuses have their weights' compared.
198198+ whichever has the higher weight will take precedence, and the lower weight(s) will be dropped.
199199+ if the weights are the *same*, then a merge of reasons is performed, which can also cause the expiration to be delayed.
200200+ """
201201+ if not self.reasons:
202202+ AssertionError(f'invariant: the label state {self} did not have any associated reasons')
203203+ expires_at = datetime.min
204204+ for reason in self.reasons.values():
205205+ if reason.expires_at is None:
206206+ return None
207207+ expires_at = max(reason.expires_at, expires_at)
208208+ return expires_at
58209210210+ @classmethod
211211+ def from_inner(cls, inner: LabelStateInner) -> 'LabelState':
212212+ return cls(
213213+ status=inner.status,
214214+ reasons=inner.reasons,
215215+ )
592166060-class LabelsAndConnotationsResponse(BaseModel):
6161- labels: Labels
6262- label_connotations: Mapping[str, LabelConnotation]
217217+ def is_expired(self) -> bool:
218218+ return bool(self.expires_at is not None and self.expires_at + timedelta(seconds=5) < datetime.now())
219219+220220+ def _shift_current_state_to_previous_state(self) -> None:
221221+ if not self.reasons:
222222+ # to make this function idempotent, we don't want to shift an empty state to the previous state.
223223+ # we should always have reasons to shift
224224+ return
225225+ self.previous_states.insert(0, LabelStateInner(status=self.status, reasons=copy.deepcopy(self.reasons)))
226226+ self.reasons = LabelReasons()
227227+228228+ def try_apply_desired_state(self, desired_state: LabelStateInner) -> MutationDropReason | None:
229229+ """
230230+ attempts to apply the desired state to this state.
231231+ if the state could not be applied (i.e. due to an unexpired manual status blocking
232232+ a status change to an automatic status), this method will return the MutationDropReason that
233233+ should be applied to the responsible mutations. otherwise, it will return None to indicate success
234234+ """
235235+ if self.is_expired():
236236+ self._shift_current_state_to_previous_state()
237237+ self.status = desired_state.status
238238+ self.reasons = desired_state.reasons
239239+ return None
63240241241+ # if the current status is manual, we will drop automatic statuses (unless the current state is expired)
242242+ if self.status.is_manual() and desired_state.status.is_automatic():
243243+ return MutationDropReason.CANNOT_OVERRIDE_MANUAL
642446565-def get_labels_for_entity(
6666- endpoint: str, signer: 'Signer', entity_type: str, entity_id: str
6767-) -> LabelsAndConnotationsResponse:
6868- url = f'{endpoint}entity/{entity_type}/{entity_id}/labels'
6969- headers = signer.sign_url(url)
7070- raw_resp = _session.get(url, headers=headers, timeout=_REQUEST_TIMEOUT_SECS)
7171- logger.info(f'[get_labels_for_entity] status code is {raw_resp.status_code}')
7272- raw_resp.raise_for_status()
7373- return LabelsAndConnotationsResponse.parse_obj(raw_resp.json())
245245+ # if the statuses are different and we've made it this far, the desired state is allowed to overwrite
246246+ # the current state. so lets do that by shifting to previous state and updating
247247+ if self.status != desired_state.status:
248248+ self._shift_current_state_to_previous_state()
249249+ self.status = desired_state.status
74250251251+ for reason_name, reason in desired_state.reasons.items():
252252+ self.reasons.insert_or_update(reason_name, reason)
752537676-class EntityLabelDisagreeRequest(BaseModel):
7777- label_name: str
7878- description: str
7979- admin_email: str
8080- expires_at: Optional[datetime]
254254+ return None
812558225683257@dataclass
8484-class EntityMutation:
258258+class EntityLabels:
259259+ """this class represents a given entity's current labels & label states"""
260260+261261+ labels: Dict[str, LabelState] = field(default_factory=dict)
262262+ """a mapping of label names to their current states'"""
263263+264264+265265+@dataclass
266266+class EntityLabelMutation:
267267+ """
268268+ a class that allows callers of LabelsProvider.apply_entity_label_mutations() to request how an
269269+ entity's labels should be mutated.
270270+271271+ mutations are not guaranteed to be written to the labels provider. see EntityLabelMutationsResult.dropped
272272+ """
273273+85274 label_name: str = ''
86275 reason_name: str = ''
87276 status: LabelStatus = LabelStatus.ADDED
88277 pending: bool = False
89278 description: str = ''
9090- features: Dict[str, str] = field(default_factory=dict)
9191- expires_at: Optional[datetime] = None
279279+ features: dict[str, str] = field(default_factory=dict)
280280+ expires_at: datetime | None = None
281281+282282+ def desired_state(self) -> LabelStateInner:
283283+ return LabelStateInner(
284284+ status=self.status,
285285+ reasons=LabelReasons({self.reason_name: self.reason}),
286286+ )
287287+288288+ @property
289289+ def reason(self) -> LabelReason:
290290+ return LabelReason(
291291+ pending=self.pending,
292292+ description=self.description,
293293+ features=self.features,
294294+ created_at=datetime.now(),
295295+ expires_at=self.expires_at,
296296+ )
922979329894299@dataclass
9595-class ApplyEntityMutationReply:
9696- added: List[str] = field(default_factory=list)
9797- removed: List[str] = field(default_factory=list)
9898- unchanged: List[str] = field(default_factory=list)
9999- dropped: List[EntityMutation] = field(default_factory=list)
300300+class DroppedEntityLabelMutation:
301301+ mutation: EntityLabelMutation
302302+ reason: MutationDropReason
100303101304102102-class EntityLabelDisagreeResponse(BaseModel):
103103- mutation_result: ApplyEntityMutationReply
104104- labels: Dict[str, LabelState]
105105- expires_at: Optional[datetime]
305305+@dataclass
306306+class EntityLabelMutationsResult:
307307+ new_entity_labels: EntityLabels
308308+ """
309309+ all of the entity's labels post-mutation
310310+ """
311311+312312+ old_entity_labels: Optional[EntityLabels] = None
313313+ """
314314+ all of the entity's labels pre-mutation
315315+ """
106316317317+ labels_added: list[str] = field(default_factory=list)
318318+ """
319319+ all (effective-status) label adds that occurred during this mutation
320320+ """
107321108108-def disagree_wth_label(
109109- endpoint: str, signer: 'Signer', entity_type: str, entity_id: str, label_disagreement: EntityLabelDisagreeRequest
110110-) -> EntityLabelDisagreeResponse:
111111- url = f'{endpoint}entity/{entity_type}/{entity_id}/labels/disagree'
322322+ labels_removed: list[str] = field(default_factory=list)
323323+ """
324324+ all (effective-status) label removes that occurred during this mutation
325325+ """
112326113113- label_disagreement_bytes = label_disagreement.json().encode()
114114- headers = signer.sign(label_disagreement_bytes)
327327+ labels_updated: list[str] = field(default_factory=list)
328328+ """
329329+ labels that had their state updated. this can include simply updating or
330330+ appending to the reason
331331+ """
115332116116- raw_resp = _session.post(url, headers=headers, data=label_disagreement_bytes, timeout=_REQUEST_TIMEOUT_SECS)
117117- raw_resp.raise_for_status()
118118- return EntityLabelDisagreeResponse.parse_obj(raw_resp.json())
333333+ dropped_mutations: list[DroppedEntityLabelMutation] = field(default_factory=list)
334334+ """
335335+ mutations that were dropped for one reason or another. each dropped mutation is
336336+ given a drop reason
337337+ """
···11-from __future__ import absolute_import
22-33-import logging
44-from random import random
55-from typing import Optional
66-77-from osprey.worker.sinks.sink.output_sink_utils.models import LabelStatus
88-from sqlalchemy import BigInteger, Column, DateTime, Integer, Text, and_, func, or_
99-from sqlalchemy.dialects.postgresql import INTERVAL, JSONB
1010-1111-from ..webhooks import WebhookStatus
1212-from .postgres import Model, scoped_session
1313-from .types import Enum
1414-1515-BASE_DELAY_SECONDS = 60
1616-MAX_ATTEMPTS = 3 # update table index in osprey/osprey_lib/schemas/osprey.sql if we change this value
1717-logger = logging.getLogger(__name__)
1818-1919-2020-class EntityLabelWebhook(Model):
2121- __tablename__ = 'entity_label_webhooks'
2222-2323- id = Column(BigInteger, primary_key=True, autoincrement=True)
2424-2525- entity_type = Column(Text, nullable=False)
2626- entity_id = Column(Text, nullable=False)
2727- label_name = Column(Text, nullable=False)
2828- label_status = Column(Enum(LabelStatus, name='label_status', create_type=False), nullable=False)
2929- webhook_name = Column(Text, nullable=False)
3030- arguments = Column(JSONB)
3131- features = Column(JSONB)
3232- status = Column(Enum(WebhookStatus, name='webhook_status', create_type=False))
3333- claim_until = Column(DateTime(timezone=True))
3434- result = Column(Text)
3535- attempts = Column(Integer, nullable=False, default=0)
3636- created_at = Column(DateTime(timezone=True), nullable=False)
3737- updated_at = Column(DateTime(timezone=True), nullable=False)
3838-3939- @classmethod
4040- def claim(cls) -> Optional['EntityLabelWebhook']:
4141- """Claim one webhook to send.
4242-4343- The claim duration is also used as the retry cooldown, since that's already longer than we expect sending a
4444- webhook to take. That way, if the process totally dies, the webhook can still be retried at the correct
4545- interval.
4646- """
4747- table = cls.__table__
4848- jitter_percent = 1 + random()
4949- lock_seconds = BASE_DELAY_SECONDS * func.power(2, table.c.attempts) * jitter_percent
5050-5151- # Selects the oldest claimable row's id in a subquery, because UPDATE doesn't support ORDER BY.
5252- # - oldest is based on claim_until (which is initially set to created_at, or some other time if it's a
5353- # delayed action).
5454- # - claimable means:
5555- # - the claim has expired
5656- # - status is one of the non-final statuses
5757- # - it hasn't already been attempted too many times
5858- order_subq = (
5959- table.select()
6060- .with_only_columns([table.c.id])
6161- .where(
6262- and_(
6363- table.c.claim_until < func.now(),
6464- or_(*(table.c.status == status for status in WebhookStatus.non_final_statuses())),
6565- table.c.attempts < MAX_ATTEMPTS,
6666- )
6767- )
6868- .with_for_update(skip_locked=True)
6969- .order_by(table.c.claim_until)
7070- .limit(1)
7171- .alias('order_subq')
7272- )
7373-7474- query = (
7575- table.update()
7676- .where(table.c.id.in_(order_subq))
7777- .values(
7878- claim_until=func.now() + func.cast(func.concat(lock_seconds, ' SECONDS'), INTERVAL),
7979- attempts=table.c.attempts + 1,
8080- status=WebhookStatus.RUNNING,
8181- updated_at=func.now(),
8282- )
8383- .returning(table)
8484- )
8585-8686- with scoped_session(commit=True) as session:
8787- cursor = session.execute(query)
8888- # We need to construct the ORM object in a way SQLAlchemy approves of so it can track state under the
8989- # hood correctly (namely know that this object represents an existing row in the database).
9090- rows = list(session.query(cls).instances(cursor))
9191- if len(rows) == 0:
9292- return None
9393- # We should only ever match up to one row
9494- (row,) = rows
9595- assert isinstance(row, EntityLabelWebhook)
9696- return row
9797-9898- def release(self, status: WebhookStatus, result: str) -> None:
9999- with scoped_session(commit=True) as session:
100100- session.add(self)
101101- self.status = status
102102- self.result = result
103103- self.updated_at = func.now()
···11+import copy
12from abc import ABC, abstractmethod
33+from collections import defaultdict
44+from contextlib import contextmanager
25from datetime import timedelta
33-from typing import Any, List, Optional, Sequence
66+from typing import Any, Generator, Optional, Sequence
4758from osprey.engine.executor.external_service_utils import ExternalService
69from osprey.engine.language_types.entities import EntityT
77-from osprey.worker.lib.osprey_shared.labels import ApplyEntityMutationReply, EntityMutation, Labels
88-from result import Result
1010+from osprey.worker.lib.osprey_shared.labels import (
1111+ DroppedEntityLabelMutation,
1212+ EntityLabelMutation,
1313+ EntityLabelMutationsResult,
1414+ EntityLabels,
1515+ LabelState,
1616+ LabelStateInner,
1717+ LabelStatus,
1818+ MutationDropReason,
1919+)
2020+from osprey.worker.lib.osprey_shared.logging import get_logger
2121+from result import Err, Ok, Result
2222+from tenacity import retry, stop_after_attempt, wait_exponential
9232424+logger = get_logger(__name__)
10251111-class LabelProvider(ExternalService[EntityT[Any], Labels], ABC):
1212- def cache_ttl(self) -> Optional[timedelta]:
1313- return timedelta(minutes=5)
14262727+class LabelsServiceBase(ABC):
1528 @abstractmethod
1616- def get_from_service(self, key: EntityT[Any]) -> Labels:
2929+ def write_labels(self, entity: EntityT[Any], labels: EntityLabels) -> None:
3030+ """
3131+ A standard write to the labels service that attempts to write the value to the primary key.
3232+3333+ This method may be retried upon exceptions, so keep that in mind when adding potentially
3434+ non-idempotent behaviour.
3535+ """
1736 raise NotImplementedError()
18371938 @abstractmethod
2020- def batch_get_from_service(self, keys: Sequence[EntityT[Any]]) -> Sequence[Result[Labels, Exception]]:
3939+ def read_labels(self, entity: EntityT[Any]) -> EntityLabels:
4040+ """
4141+ A standard read from the labels service. Keep in mind that if there is a cache_ttl greater than 0 seconds,
4242+ this method will not be called for every single label read.
4343+4444+ This method may be retried upon exceptions, so keep that in mind when adding potentially
4545+ non-idempotent behaviour.
4646+ """
2147 raise NotImplementedError()
22484949+ def batch_read_labels(self, entities: Sequence[EntityT[Any]]) -> Sequence[Result[EntityLabels, Exception]]:
5050+ """
5151+ Batching can optimize the number of RPCs that are sent out during executions,
5252+ which has been observed to provide noticeable performance benefits in a python/gevent world.
5353+5454+ The order that the entieties are supplied in the incoming sequence will match the order the results are returned.
5555+5656+ By default, this will just call read_labels in a for-loop, but it is encouraged to implemenent your own batch
5757+ endpoints and logic for the aforementioned performance benefits.
5858+ """
5959+ results: list[Result[EntityLabels, Exception]] = []
6060+ for entity in entities:
6161+ result: Result[EntityLabels, Exception] = Err(
6262+ Exception('invariant: label could not be retrieved but no error was caught')
6363+ )
6464+ try:
6565+ result = Ok(self.read_labels(entity))
6666+ except Exception as e:
6767+ result = Err(e)
6868+ finally:
6969+ results.append(result)
7070+ return results
7171+2372 @abstractmethod
2424- def apply_entity_mutation(
2525- self, entity_key: EntityT[Any], mutations: List[EntityMutation]
2626- ) -> ApplyEntityMutationReply:
2727- raise NotImplementedError()
7373+ @contextmanager
7474+ def get_labels_atomically(self, entity: EntityT[Any]) -> Generator[EntityLabels, None, None]:
7575+ """
7676+ Context manager for atomic read-modify-write operations.
7777+ Implementations should ensure the entity key is locked/in a transaction.
7878+ """
7979+ pass
8080+8181+8282+class LabelsProvider(ExternalService[EntityT[Any], EntityLabels]):
8383+ def __init__(self, labels_service: LabelsServiceBase):
8484+ self._labels_service = labels_service
8585+8686+ def _get_mutations_by_label_name_and_drop_conflicts(
8787+ self, mutations: Sequence[EntityLabelMutation]
8888+ ) -> tuple[dict[str, list[EntityLabelMutation]], list[DroppedEntityLabelMutation]]:
8989+ """
9090+ collect mutations based on the value of their status. this means if a higher status and a lower status label mutation
9191+ occur in the same mutations request, the lower status one(s) will be dropped.
9292+9393+ by the end of this method, the returned mutations will all be of the same label status for a given label.
9494+ """
9595+ # first, we collect all of the highest status mutations per label name. we collect a list because
9696+ # same status mutations will need to be merged into a single label state later to represent all
9797+ # applicable mutation reasons
9898+ mutations_by_label_name: dict[str, list[EntityLabelMutation]] = defaultdict(list)
9999+ dropped_mutations: list[DroppedEntityLabelMutation] = []
100100+ for mutation in mutations:
101101+ label_name = mutation.label_name
102102+ if label_name in mutations_by_label_name:
103103+ other_mutation = mutations_by_label_name[label_name][0]
104104+ if mutation.status.value > other_mutation.status.value:
105105+ for mut in mutations_by_label_name[label_name]:
106106+ # we may have a list of more than one mutation if the statuses are all the same
107107+ dropped_mutations.append(
108108+ DroppedEntityLabelMutation(mutation=mut, reason=MutationDropReason.CONFLICTING_MUTATION)
109109+ )
110110+ mutations_by_label_name[label_name] = [mutation]
111111+ continue
112112+ elif mutation.status.value < other_mutation.status.value:
113113+ dropped_mutations.append(
114114+ DroppedEntityLabelMutation(mutation=mutation, reason=MutationDropReason.CONFLICTING_MUTATION)
115115+ )
116116+ continue
117117+ # if the status weights are equal or if there is no previous statuses, append
118118+ mutations_by_label_name[label_name].append(mutation)
119119+120120+ return (mutations_by_label_name, dropped_mutations)
121121+122122+ def _get_desired_states_by_label_name(
123123+ self, mutations_by_label_name: dict[str, list[EntityLabelMutation]]
124124+ ) -> dict[str, LabelStateInner]:
125125+ """
126126+ given a dict of label names to entity label mutations, return the desired states that the mutations are seeking.
127127+ if there is more than one mutation for a given label, the resulting state should contain a merge of the mutation reasons.
128128+ """
129129+ desired_states_by_label_name: dict[str, LabelStateInner] = dict()
130130+131131+ for label_name, mutations in mutations_by_label_name.items():
132132+ assert len(mutations) > 0, 'invariant: mutations by label name should not be empty'
133133+ assert len({mutation.status for mutation in mutations}) == 1, (
134134+ 'invariant: more than one unique label status AFTER dropping conflicts'
135135+ )
136136+ desired_state = mutations[0].desired_state()
137137+ for i in range(1, len(mutations)):
138138+ desired_state.reasons.insert_or_update(mutations[i].reason_name, mutations[i].reason)
139139+ desired_states_by_label_name[label_name] = desired_state
140140+141141+ return desired_states_by_label_name
142142+143143+ def _compute_new_labels_from_mutations(
144144+ self, old_labels: EntityLabels, mutations: Sequence[EntityLabelMutation]
145145+ ) -> EntityLabelMutationsResult:
146146+ (mutations_by_label_name, dropped_mutations) = self._get_mutations_by_label_name_and_drop_conflicts(mutations)
147147+ desired_states_by_label_name: dict[str, LabelStateInner] = self._get_desired_states_by_label_name(
148148+ mutations_by_label_name
149149+ )
150150+151151+ # lets take desired states and try to apply them to the entity labels.
152152+ # for end-user convenience, we also track if labels are added, removed, updated, or if mutations are dropped entirely
153153+ added: list[str] = []
154154+ removed: list[str] = []
155155+ updated: list[str] = []
156156+ new_labels = copy.deepcopy(old_labels)
157157+ for label_name, desired_state in desired_states_by_label_name.items():
158158+ if label_name not in new_labels.labels:
159159+ new_labels.labels[label_name] = LabelState.from_inner(desired_state)
160160+ added.append(label_name)
161161+ continue
162162+ current_state = new_labels.labels[label_name]
163163+ prev_status = current_state.status
164164+ drop_reason = current_state.try_apply_desired_state(desired_state)
165165+ if drop_reason:
166166+ # if the current state rejected the desired state, we will drop the mutation(s) with the provided drop reason
167167+ for mutation in mutations_by_label_name[label_name]:
168168+ dropped_mutations.append(DroppedEntityLabelMutation(mutation=mutation, reason=drop_reason))
169169+ continue
170170+ # otherwise, let's compare the new status so we can add data to the EntityLabelMutationsResult c:
171171+ new_status = current_state.status
172172+ if prev_status == new_status:
173173+ updated.append(label_name)
174174+ continue
175175+ match new_status.effective_label_status():
176176+ case LabelStatus.ADDED:
177177+ added.append(label_name)
178178+ continue
179179+ case LabelStatus.REMOVED:
180180+ removed.append(label_name)
181181+ continue
182182+183183+ # finally, return the result! duhh :D
184184+ return EntityLabelMutationsResult(
185185+ new_entity_labels=new_labels,
186186+ old_entity_labels=old_labels,
187187+ labels_added=added,
188188+ labels_removed=removed,
189189+ labels_updated=updated,
190190+ dropped_mutations=dropped_mutations,
191191+ )
192192+193193+ @retry(wait=wait_exponential(min=0.5, max=5), stop=stop_after_attempt(3))
194194+ def apply_entity_label_mutations_with_retry(
195195+ self, entity: EntityT[Any], mutations: Sequence[EntityLabelMutation]
196196+ ) -> EntityLabelMutationsResult:
197197+ return self.apply_entity_label_mutations(entity=entity, mutations=mutations)
198198+199199+ def apply_entity_label_mutations(
200200+ self, entity: EntityT[Any], mutations: Sequence[EntityLabelMutation]
201201+ ) -> EntityLabelMutationsResult:
202202+ with self._labels_service.get_labels_atomically(entity) as old_labels:
203203+ result = self._compute_new_labels_from_mutations(old_labels, mutations)
204204+205205+ self._labels_service.write_labels(entity, result.new_entity_labels)
206206+207207+ return result
208208+209209+ def cache_ttl(self) -> Optional[timedelta]:
210210+ return timedelta(minutes=1)
211211+212212+ def get_from_service(self, key: EntityT[Any]) -> EntityLabels:
213213+ return self._labels_service.read_labels(entity=key)
214214+215215+ def batch_get_from_service(self, keys: Sequence[EntityT[Any]]) -> Sequence[Result[EntityLabels, Exception]]:
216216+ """
217217+ Note: By default, the labels service batch_read_labels calls read_labels in a for loop.
218218+ This is because the HasLabel UDF is batchable and requires batch support on the
219219+ provider.
220220+221221+ If you would like to reap the performance benefits of batching, please re-implement
222222+ the batch_read_labels to call a proper batch endpoint.
223223+224224+ See LabelsServiceBase.batch_read_labels for more information
225225+ """
226226+ return self._labels_service.batch_read_labels(entities=keys)
227227+228228+ def stop(self) -> None:
229229+ """
230230+ this method is called when the output sink receives a shutdown signal. if you would like to
231231+ add shutdown logic, override this~
232232+ """
233233+ pass
···1010 RULES_VISUALIZER_GEN_GRAPH = 'network_action_osprey_rules_visualizer_generate_graph'
111112121313-class MutationEventType(str, Enum):
1414- OSPREY_ACTION = 'osprey_action'
1515- BULK_ACTION = 'bulk_action'
1616- LABEL_DISAGREEMENT = 'label_disagreement'
1717- MANUAL_UPDATE = 'manual_update'
1818-1919-2013# There are more types, currently listing the ones we need to use in code
2114class EntityType(str, Enum):
2215 USER = 'User'