Mirror of https://github.com/roostorg/osprey github.com/roostorg/osprey
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

fix: engine tests for labels (#27)

authored by

Caidan and committed by
GitHub
dca6bada f76102dd

+98 -85
+3 -9
osprey_worker/src/osprey/engine/executor/tests/test_render_graph.py
··· 4 4 5 5 import pytest 6 6 from osprey.engine.ast_validator.validation_context import ValidatedSources 7 - from osprey.engine.ast_validator.validator_registry import ValidatorRegistry 8 7 from osprey.engine.ast_validator.validators.unique_stored_names import UniqueStoredNames 9 8 from osprey.engine.ast_validator.validators.validate_call_kwargs import ValidateCallKwargs 10 - from osprey.engine.ast_validator.validators.validate_dynamic_calls_have_annotated_rvalue import ( 11 - ValidateDynamicCallsHaveAnnotatedRValue, 12 - ) 13 9 from osprey.engine.conftest import RunValidationFunction 14 10 from osprey.engine.executor.execution_graph import ExecutionGraph, compile_execution_graph 15 11 from osprey.engine.executor.execution_visualizer import _render_graph 16 - from osprey.engine.stdlib import get_config_registry 17 12 18 13 pytestmark = [ 19 - pytest.mark.use_validators([ValidateCallKwargs, ValidateDynamicCallsHaveAnnotatedRValue, UniqueStoredNames]), 14 + pytest.mark.use_validators([ValidateCallKwargs, UniqueStoredNames]), 15 + pytest.mark.use_standard_rules_validators, 20 16 pytest.mark.use_osprey_stdlib, 21 17 ] 22 18 ··· 147 143 """ 148 144 Compiles an ExecutionGraph based on the above Osprey Rules configs 149 145 """ 150 - config_validator = get_config_registry().get_validator() 151 - validator_registry = ValidatorRegistry.get_instance().instance_with_additional_validators(config_validator) 152 - validated_sources: ValidatedSources = run_validation(config, validator_registry=validator_registry) 146 + validated_sources: ValidatedSources = run_validation(config) 153 147 execution_graph = compile_execution_graph(validated_sources) 154 148 return execution_graph 155 149
+3 -4
osprey_worker/src/osprey/engine/query_language/__init__.py
··· 3 3 from osprey.engine.ast_validator.validators.unique_stored_names import UniqueStoredNames 4 4 from osprey.engine.ast_validator.validators.validate_static_types import ValidateStaticTypes 5 5 from osprey.engine.ast_validator.validators.variables_must_be_defined import VariablesMustBeDefined 6 + from osprey.engine.query_language import udfs 7 + from osprey.engine.query_language.ast_validator import REGISTRY 8 + from osprey.engine.query_language.udfs.registry import UDF_REGISTRY 6 9 from osprey.engine.utils.imports import import_all_direct_children 7 - 8 - from . import udfs 9 - from .ast_validator import REGISTRY 10 - from .udfs.registry import UDF_REGISTRY 11 10 12 11 13 12 def parse_query_to_validated_ast(query: str, rules_sources: ValidatedSources) -> ValidatedSources:
+2 -3
osprey_worker/src/osprey/engine/query_language/udfs/did_declare_verdict.py
··· 1 1 from typing import Dict 2 2 3 + from osprey.engine import shared_constants 3 4 from osprey.engine.ast_validator.validation_context import ValidationContext 5 + from osprey.engine.query_language.udfs.registry import register 4 6 from osprey.engine.udf.arguments import ArgumentsBase, ConstExpr 5 7 from osprey.engine.udf.base import QueryUdfBase 6 - 7 - from ... import shared_constants 8 - from .registry import register 9 8 10 9 11 10 class Arguments(ArgumentsBase):
+1 -2
osprey_worker/src/osprey/engine/query_language/udfs/regex_match.py
··· 3 3 4 4 from osprey.engine.ast import grammar 5 5 from osprey.engine.ast_validator.validation_context import ValidationContext 6 + from osprey.engine.query_language.udfs.registry import register 6 7 from osprey.engine.udf.arguments import ArgumentsBase, ConstExpr 7 8 from osprey.engine.udf.base import QueryUdfBase 8 - 9 - from .registry import register 10 9 11 10 12 11 class Arguments(ArgumentsBase):
+78 -61
osprey_worker/src/osprey/engine/stdlib/udfs/tests/test_labels.py
··· 1 1 import json 2 2 from datetime import datetime, timedelta 3 - from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set 3 + from typing import Any, Callable, Dict, List, Optional, Sequence, Set 4 4 5 5 import gevent 6 6 import pytest ··· 20 20 ) 21 21 from osprey.engine.executor.udf_execution_helpers import UDFHelpers 22 22 from osprey.engine.language_types.entities import EntityT 23 + from osprey.engine.language_types.labels import LabelStatus 23 24 from osprey.engine.stdlib import get_config_registry 24 25 from osprey.engine.stdlib.udfs.entity import Entity 25 26 from osprey.engine.stdlib.udfs.labels import HasLabel, LabelAdd, LabelRemove 26 27 from osprey.engine.stdlib.udfs.rules import Rule, WhenRules 27 28 from osprey.engine.stdlib.udfs.time_delta import TimeDelta 28 29 from osprey.engine.udf.registry import UDFRegistry 29 - from osprey.engine.utils.proto_utils import datetime_to_timestamp 30 - from osprey.rpc.labels.v1.service_pb2 import LabelReason, Labels, LabelState, LabelStatus 30 + from osprey.worker.lib.osprey_shared.labels import ( 31 + EntityLabelMutation, 32 + EntityLabelMutationsResult, 33 + EntityLabels, 34 + LabelReason, 35 + LabelReasons, 36 + LabelState, 37 + ) 31 38 from osprey.worker.lib.storage.labels import LabelsProvider 32 - 33 - if TYPE_CHECKING: 34 - from osprey.rpc.labels.v1.service_pb2 import LabelStatusValue 39 + from result import Result 35 40 36 41 pytestmark: List[Callable[[Any], Any]] = [ 37 42 pytest.mark.use_validators( ··· 50 55 51 56 52 57 class StaticLabelProvider(LabelsProvider): 53 - def __init__(self, entity_labels: Dict[EntityT[Any], Labels]) -> None: 58 + def __init__(self, entity_labels: Dict[EntityT[Any], EntityLabels]) -> None: 54 59 self._entity_labels = entity_labels 55 60 56 - def get_from_service(self, key: EntityT[Any]) -> Labels: 61 + def get_from_service(self, key: EntityT[Any]) -> EntityLabels: 57 62 return self._entity_labels[key] 58 63 64 + def batch_get_from_service(self, keys: Sequence[EntityT[Any]]) -> Sequence[Result[EntityLabels, Exception]]: 65 + return [Result.Ok(self.get_from_service(key)) for key in keys] 66 + 67 + def apply_entity_mutation( 68 + self, entity_key: EntityT[Any], mutations: List[EntityLabelMutation] 69 + ) -> EntityLabelMutationsResult: 70 + return self.apply_entity_label_mutations(entity_key, mutations) 71 + 59 72 60 73 class BlockingLabelProvider(StaticLabelProvider): 61 - def __init__(self, entity_labels: Dict[EntityT[Any], Labels]) -> None: 74 + def __init__(self, entity_labels: Dict[EntityT[Any], EntityLabels]) -> None: 62 75 super().__init__(entity_labels) 63 76 self.blocking_events: List[Event] = [] 64 77 self.calls: List[EntityT[Any]] = [] 65 78 66 - def get_from_service(self, key: EntityT[Any]) -> Labels: 79 + def get_from_service(self, key: EntityT[Any]) -> EntityLabels: 67 80 event = Event() 68 81 self.blocking_events.append(event) 69 82 event.wait() ··· 81 94 @pytest.mark.parametrize( 82 95 'checking_status, manual, actual_status, reasons, result', 83 96 ( 84 - ('added', None, LabelStatus.ADDED, {'TestReason': LabelReason()}, True), 97 + ('added', None, LabelStatus.ADDED, LabelReasons({'TestReason': LabelReason()}), True), 85 98 ( 86 99 'added', 87 100 None, 88 101 LabelStatus.ADDED, 89 - {'ExpiredReason': LabelReason(expires_at=datetime_to_timestamp(datetime.now() - timedelta(hours=1)))}, 102 + LabelReasons({'ExpiredReason': LabelReason(expires_at=(datetime.now() - timedelta(hours=1)))}), 90 103 False, 91 104 ), 92 105 ( 93 106 'added', 94 107 None, 95 108 LabelStatus.ADDED, 96 - { 97 - 'ExpiredReason': LabelReason(expires_at=datetime_to_timestamp(datetime.now() - timedelta(hours=1))), 98 - 'TestReason': LabelReason(), 99 - }, 109 + LabelReasons( 110 + { 111 + 'ExpiredReason': LabelReason(expires_at=(datetime.now() - timedelta(hours=1))), 112 + 'TestReason': LabelReason(), 113 + } 114 + ), 100 115 True, 101 116 ), 102 117 ( 103 118 'added', 104 119 None, 105 120 LabelStatus.ADDED, 106 - { 107 - 'ExpiredReason': LabelReason(expires_at=datetime_to_timestamp(datetime.now() - timedelta(hours=1))), 108 - 'ExpiringReason': LabelReason(expires_at=datetime_to_timestamp(datetime.now() + timedelta(hours=1))), 109 - }, 121 + LabelReasons( 122 + { 123 + 'ExpiredReason': LabelReason(expires_at=(datetime.now() - timedelta(hours=1))), 124 + 'ExpiringReason': LabelReason(expires_at=(datetime.now() + timedelta(hours=1))), 125 + } 126 + ), 110 127 True, 111 128 ), 112 129 ( 113 130 'added', 114 131 None, 115 132 LabelStatus.ADDED, 116 - {'ExpiringReason': LabelReason(expires_at=datetime_to_timestamp(datetime.now() + timedelta(hours=1)))}, 133 + LabelReasons({'ExpiringReason': LabelReason(expires_at=(datetime.now() + timedelta(hours=1)))}), 117 134 True, 118 135 ), 119 - ('added', None, LabelStatus.MANUALLY_ADDED, {'TestReason': LabelReason()}, True), 120 - ('added', None, LabelStatus.REMOVED, {'TestReason': LabelReason()}, False), 121 - ('added', None, LabelStatus.MANUALLY_REMOVED, {'TestReason': LabelReason()}, False), 122 - ('added', None, None, {'TestReason': LabelReason()}, False), 123 - ('added', True, LabelStatus.ADDED, {'TestReason': LabelReason()}, False), 124 - ('added', True, LabelStatus.MANUALLY_ADDED, {'TestReason': LabelReason()}, True), 125 - ('added', True, LabelStatus.REMOVED, {'TestReason': LabelReason()}, False), 126 - ('added', True, LabelStatus.MANUALLY_REMOVED, {'TestReason': LabelReason()}, False), 127 - ('added', True, None, {'TestReason': LabelReason()}, False), 128 - ('added', False, LabelStatus.ADDED, {'TestReason': LabelReason()}, True), 129 - ('added', False, LabelStatus.MANUALLY_ADDED, {'TestReason': LabelReason()}, False), 130 - ('added', False, LabelStatus.REMOVED, {'TestReason': LabelReason()}, False), 131 - ('added', False, LabelStatus.MANUALLY_REMOVED, {'TestReason': LabelReason()}, False), 132 - ('added', False, None, {'TestReason': LabelReason()}, False), 133 - ('removed', None, LabelStatus.ADDED, {'TestReason': LabelReason()}, False), 134 - ('removed', None, LabelStatus.MANUALLY_ADDED, {'TestReason': LabelReason()}, False), 135 - ('removed', None, LabelStatus.REMOVED, {'TestReason': LabelReason()}, True), 136 - ('removed', None, LabelStatus.MANUALLY_REMOVED, {'TestReason': LabelReason()}, True), 137 - ('removed', None, None, {'TestReason': LabelReason()}, True), 138 - ('removed', True, LabelStatus.ADDED, {'TestReason': LabelReason()}, False), 139 - ('removed', True, LabelStatus.MANUALLY_ADDED, {'TestReason': LabelReason()}, False), 140 - ('removed', True, LabelStatus.REMOVED, {'TestReason': LabelReason()}, False), 141 - ('removed', True, LabelStatus.MANUALLY_REMOVED, {'TestReason': LabelReason()}, True), 142 - ('removed', True, None, {'TestReason': LabelReason()}, False), 143 - ('removed', False, LabelStatus.ADDED, {'TestReason': LabelReason()}, False), 144 - ('removed', False, LabelStatus.MANUALLY_ADDED, {'TestReason': LabelReason()}, False), 145 - ('removed', False, LabelStatus.REMOVED, {'TestReason': LabelReason()}, True), 146 - ('removed', False, LabelStatus.MANUALLY_REMOVED, {'TestReason': LabelReason()}, False), 147 - ('removed', False, None, {'TestReason': LabelReason()}, True), 136 + ('added', None, LabelStatus.MANUALLY_ADDED, LabelReasons({'TestReason': LabelReason()}), True), 137 + ('added', None, LabelStatus.REMOVED, LabelReasons({'TestReason': LabelReason()}), False), 138 + ('added', None, LabelStatus.MANUALLY_REMOVED, LabelReasons({'TestReason': LabelReason()}), False), 139 + ('added', None, None, LabelReasons({'TestReason': LabelReason()}), False), 140 + ('added', True, LabelStatus.ADDED, LabelReasons({'TestReason': LabelReason()}), False), 141 + ('added', True, LabelStatus.MANUALLY_ADDED, LabelReasons({'TestReason': LabelReason()}), True), 142 + ('added', True, LabelStatus.REMOVED, LabelReasons({'TestReason': LabelReason()}), False), 143 + ('added', True, LabelStatus.MANUALLY_REMOVED, LabelReasons({'TestReason': LabelReason()}), False), 144 + ('added', True, None, LabelReasons({'TestReason': LabelReason()}), False), 145 + ('added', False, LabelStatus.ADDED, LabelReasons({'TestReason': LabelReason()}), True), 146 + ('added', False, LabelStatus.MANUALLY_ADDED, LabelReasons({'TestReason': LabelReason()}), False), 147 + ('added', False, LabelStatus.REMOVED, LabelReasons({'TestReason': LabelReason()}), False), 148 + ('added', False, LabelStatus.MANUALLY_REMOVED, LabelReasons({'TestReason': LabelReason()}), False), 149 + ('added', False, None, LabelReasons({'TestReason': LabelReason()}), False), 150 + ('removed', None, LabelStatus.ADDED, LabelReasons({'TestReason': LabelReason()}), False), 151 + ('removed', None, LabelStatus.MANUALLY_ADDED, LabelReasons({'TestReason': LabelReason()}), False), 152 + ('removed', None, LabelStatus.REMOVED, LabelReasons({'TestReason': LabelReason()}), True), 153 + ('removed', None, LabelStatus.MANUALLY_REMOVED, LabelReasons({'TestReason': LabelReason()}), True), 154 + ('removed', None, None, LabelReasons({'TestReason': LabelReason()}), True), 155 + ('removed', True, LabelStatus.ADDED, LabelReasons({'TestReason': LabelReason()}), False), 156 + ('removed', True, LabelStatus.MANUALLY_ADDED, LabelReasons({'TestReason': LabelReason()}), False), 157 + ('removed', True, LabelStatus.REMOVED, LabelReasons({'TestReason': LabelReason()}), False), 158 + ('removed', True, LabelStatus.MANUALLY_REMOVED, LabelReasons({'TestReason': LabelReason()}), True), 159 + ('removed', True, None, LabelReasons({'TestReason': LabelReason()}), False), 160 + ('removed', False, LabelStatus.ADDED, LabelReasons({'TestReason': LabelReason()}), False), 161 + ('removed', False, LabelStatus.MANUALLY_ADDED, LabelReasons({'TestReason': LabelReason()}), False), 162 + ('removed', False, LabelStatus.REMOVED, LabelReasons({'TestReason': LabelReason()}), True), 163 + ('removed', False, LabelStatus.MANUALLY_REMOVED, LabelReasons({'TestReason': LabelReason()}), False), 164 + ('removed', False, None, LabelReasons({'TestReason': LabelReason()}), True), 148 165 ), 149 166 ) 150 167 def test_get_labels_retrieves_data( 151 168 execute: ExecuteFunction, 152 169 checking_status: str, 153 170 manual: Optional[bool], 154 - actual_status: Optional['LabelStatusValue'], 155 - reasons: Dict[str, LabelReason], 171 + actual_status: Optional[LabelStatus], 172 + reasons: LabelReasons, 156 173 result: bool, 157 174 ) -> None: 158 175 if actual_status is None: 159 - labels = Labels(labels={}) 176 + labels = EntityLabels(labels={}) 160 177 else: 161 - labels = Labels(labels={'my_label': LabelState(status=actual_status, reasons=reasons)}) 178 + labels = EntityLabels(labels={'my_label': LabelState(status=actual_status, reasons=reasons)}) 162 179 label_provider = StaticLabelProvider({EntityT('MyEntity', 'my_id'): labels}) 163 180 data = execute( 164 181 source_with_labels_config( ··· 184 201 'added', 185 202 None, 186 203 LabelStatus.ADDED, 187 - {'TestReason': LabelReason(created_at=datetime_to_timestamp(datetime.now() - timedelta(days=1)))}, 204 + LabelReasons({'TestReason': LabelReason(created_at=(datetime.now() - timedelta(days=1)))}), 188 205 timedelta(days=1), 189 206 True, 190 207 ), ··· 192 209 'added', 193 210 None, 194 211 LabelStatus.ADDED, 195 - {'TestReason': LabelReason(created_at=datetime_to_timestamp(datetime.now()))}, 212 + LabelReasons({'TestReason': LabelReason(created_at=(datetime.now()))}), 196 213 timedelta(days=1), 197 214 False, 198 215 ), ··· 202 219 execute: ExecuteFunction, 203 220 checking_status: str, 204 221 manual: Optional[bool], 205 - actual_status: Optional['LabelStatusValue'], 206 - reasons: Dict[str, LabelReason], 222 + actual_status: Optional[LabelStatus], 223 + reasons: LabelReasons, 207 224 min_label_age: timedelta, 208 225 result: bool, 209 226 ) -> None: 210 227 if actual_status is None: 211 - labels = Labels(labels={}) 228 + labels = EntityLabels(labels={}) 212 229 else: 213 - labels = Labels(labels={'my_label': LabelState(status=actual_status, reasons=reasons)}) 230 + labels = EntityLabels(labels={'my_label': LabelState(status=actual_status, reasons=reasons)}) 214 231 label_provider = StaticLabelProvider({EntityT('MyEntity', 'my_id'): labels}) 215 232 data = execute( 216 233 source_with_labels_config( ··· 268 285 269 286 270 287 def test_gets_only_debounces_single_execution(execute: ExecuteFunction) -> None: 271 - label_provider = BlockingLabelProvider({EntityT('User', 123): Labels(labels={})}) 288 + label_provider = BlockingLabelProvider({EntityT('User', 123): EntityLabels(labels={})}) 272 289 273 290 def do_execute() -> None: 274 291 execute(
+1 -4
osprey_worker/src/osprey/engine/stdlib/udfs/tests/test_rules.py
··· 13 13 RunValidationFunction, 14 14 ) 15 15 from osprey.engine.executor.execution_context import ( 16 - EntityLabelMutation, 17 16 ExecutionContext, 18 17 ) 19 18 from osprey.engine.language_types.entities import EntityT ··· 24 23 from osprey.engine.udf.arguments import ArgumentsBase 25 24 from osprey.engine.udf.base import UDFBase 26 25 from osprey.engine.udf.registry import UDFRegistry 27 - from osprey.rpc.labels.v1.service_pb2 import LabelStatus 26 + from osprey.worker.lib.osprey_shared.labels import EntityLabelMutation, LabelStatus 28 27 from osprey.worker.sinks.sink.output_sink import _get_label_effects_from_result 29 - 30 - # Moved here because WhenRules is not included in the MVP yet 31 28 32 29 33 30 class FailingUdf(UDFBase[ArgumentsBase, bool]):
+6 -2
osprey_worker/src/osprey/engine/udf/registry.py
··· 21 21 return instance 22 22 23 23 def register(self, func: Type[UDFBase[Any, Any]]) -> Type[UDFBase[Any, Any]]: 24 - if func.__name__ in self._functions: 25 - raise Exception(f'A function with the name {func.__name__} is already registered.') 24 + # Allow idempotent re-registration of the exact same class. 25 + existing = self.get(func.__name__) 26 + if existing is not None: 27 + if existing is func: 28 + return existing 29 + raise Exception(f'A function with the name {func.__name__} is already registered with {func}.') 26 30 27 31 try: 28 32 rvalue_type = func.get_rvalue_type()
+4
osprey_worker/src/osprey/worker/_stdlibplugin/udf_register.py
··· 15 15 from osprey.engine.stdlib.udfs.import_ import Import 16 16 from osprey.engine.stdlib.udfs.ip_network import IpNetwork 17 17 from osprey.engine.stdlib.udfs.json_data import JsonData 18 + from osprey.engine.stdlib.udfs.labels import HasLabel, LabelAdd, LabelRemove 18 19 from osprey.engine.stdlib.udfs.list_length import ListLength 19 20 from osprey.engine.stdlib.udfs.list_read import ListRead 20 21 from osprey.engine.stdlib.udfs.list_sort import ListSort ··· 82 83 ExperimentsBucketAssignment, 83 84 ExtractCookie, 84 85 GetActionName, 86 + HasLabel, 85 87 Import, 86 88 IpNetwork, 87 89 JsonData, 90 + LabelAdd, 91 + LabelRemove, 88 92 ListLength, 89 93 ListRead, 90 94 ListSort,