Mirror of https://github.com/roostorg/osprey github.com/roostorg/osprey
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

Support Optional types in comparison operators (<, >, <=, >=) (#146)

Co-authored-by: Cursor <cursoragent@cursor.com>

authored by

Alphonso Crawford
Cursor
and committed by
GitHub
63abe402 e794aa81

+186 -14
+113 -14
osprey_worker/src/osprey/engine/ast_validator/validators/validate_static_types.py
··· 457 457 458 458 return False 459 459 460 + def _get_non_none_type(self, t: type) -> Optional[type]: 461 + """Extract the non-None type from an Optional[T] type. 462 + 463 + Returns T if t is Optional[T], otherwise returns None. 464 + """ 465 + if not self._is_optional_type(t): 466 + return None 467 + 468 + origin = get_origin(t) 469 + if origin is EntityT: 470 + # EntityT is treated as optional but we don't narrow it 471 + return None 472 + 473 + # For Union types (which includes Optional), get the non-None type 474 + if hasattr(t, '__args__'): 475 + non_none_args = [arg for arg in t.__args__ if arg is not type(None)] 476 + if len(non_none_args) == 1: 477 + return non_none_args[0] 478 + elif len(non_none_args) > 1: 479 + # Multiple non-None types, return a Union of them 480 + return cast(type, Union[tuple(non_none_args)]) 481 + 482 + return None 483 + 484 + def _get_type_narrowing_from_expression( 485 + self, expr: grammar.Expression, operand: grammar.BooleanOperand 486 + ) -> Dict[str, type]: 487 + """Detect type narrowing from null-check patterns. 488 + 489 + For 'and' operations: X != None narrows X from Optional[T] to T 490 + For 'or' operations: X == None narrows X from Optional[T] to T (for subsequent expressions) 491 + 492 + Returns a dict of identifier_key -> narrowed type 493 + """ 494 + narrowed_types: Dict[str, type] = {} 495 + 496 + if not isinstance(expr, grammar.BinaryComparison): 497 + return narrowed_types 498 + 499 + is_and_operation = isinstance(operand, grammar.And) 500 + is_not_equals = isinstance(expr.comparator, grammar.NotEquals) 501 + is_equals = isinstance(expr.comparator, grammar.Equals) 502 + 503 + # For 'and': X != None narrows X 504 + # For 'or': X == None narrows X (because if X == None is false, X is not None) 505 + should_narrow = (is_and_operation and is_not_equals) or (not is_and_operation and is_equals) 506 + 507 + if not should_narrow: 508 + return narrowed_types 509 + 510 + # Check if one side is None and the other is a Name 511 + left_is_none = isinstance(expr.left, grammar.None_) 512 + right_is_none = isinstance(expr.right, grammar.None_) 513 + 514 + if left_is_none and isinstance(expr.right, grammar.Name): 515 + name = expr.right 516 + elif right_is_none and isinstance(expr.left, grammar.Name): 517 + name = expr.left 518 + else: 519 + return narrowed_types 520 + 521 + # Get the current type of the name 522 + if name.identifier_key not in self._name_type_and_span_cache: 523 + return narrowed_types 524 + 525 + current_type = self._name_type_and_span_cache[name.identifier_key].type 526 + non_none_type = self._get_non_none_type(current_type) 527 + 528 + if non_none_type is not None: 529 + narrowed_types[name.identifier_key] = non_none_type 530 + 531 + return narrowed_types 532 + 460 533 def _validate_binary_comparison(self, binary_comparison: grammar.BinaryComparison) -> type: 461 534 def valid_transition_hook(left_type: type, right_type: type) -> None: 462 535 # Some extra warnings for certain cases ··· 585 658 def _validate_boolean_operation(self, boolean_operation: grammar.BooleanOperation) -> type: 586 659 # Type check left and right sides, but return bool regardless because the underlying `any` and `all` can 587 660 # handle arbitrary types and will always return bools. 661 + # 662 + # Apply type narrowing: for 'and' operations, X != None narrows X from Optional[T] to T 663 + # for subsequent expressions. For 'or' operations, X == None narrows X. 664 + narrowed_types: Dict[str, type] = {} 665 + 588 666 for value in boolean_operation.values: 589 - value_type = self._validate_expression(value) 590 - value_type_str = to_display_str(value_type) 591 - self._check_compatible_type( 592 - type_t=value_type, 593 - accepted_by_t=bool, 594 - message=f'unsupported operand type for `{boolean_operation.operand.original_operand}`', 595 - node=value, 596 - hint=f'has type {value_type_str}, expected `bool`', 597 - additional_spans=self._maybe_get_additional_span_for_identifier_definition(value, value_type_str), 598 - ) 667 + # Temporarily apply any accumulated type narrowings for this expression 668 + original_types: Dict[str, _TypeAndSpan] = {} 669 + for identifier_key, narrowed_type in narrowed_types.items(): 670 + if identifier_key in self._name_type_and_span_cache: 671 + original_types[identifier_key] = self._name_type_and_span_cache[identifier_key] 672 + self._name_type_and_span_cache[identifier_key] = original_types[identifier_key].copy( 673 + type=narrowed_type 674 + ) 675 + 676 + try: 677 + value_type = self._validate_expression(value) 678 + value_type_str = to_display_str(value_type) 679 + self._check_compatible_type( 680 + type_t=value_type, 681 + accepted_by_t=bool, 682 + message=f'unsupported operand type for `{boolean_operation.operand.original_operand}`', 683 + node=value, 684 + hint=f'has type {value_type_str}, expected `bool`', 685 + additional_spans=self._maybe_get_additional_span_for_identifier_definition(value, value_type_str), 686 + ) 687 + 688 + # Detect any new type narrowings from this expression 689 + new_narrowings = self._get_type_narrowing_from_expression(value, boolean_operation.operand) 690 + narrowed_types.update(new_narrowings) 691 + finally: 692 + # Restore original types 693 + for identifier_key, original_type_and_span in original_types.items(): 694 + self._name_type_and_span_cache[identifier_key] = original_type_and_span 599 695 600 696 return bool 601 697 ··· 752 848 number_to_bool_transition = _ValidTwoArgTypeTransition( 753 849 valid_left_type=_INT_OR_FLOAT_T, valid_right_type=_INT_OR_FLOAT_T, resulting_type=bool 754 850 ) 851 + # Note: Optional types are not directly supported for comparison operators. 852 + # Use type narrowing with a null check first: X != None and X >= 90 853 + number_comparison_transitions = [number_to_bool_transition] 755 854 # For "in"/"not in" 756 855 in_transitions = [ 757 856 _ValidTwoArgTypeTransition(valid_left_type=str, valid_right_type=str, resulting_type=bool), ··· 763 862 comparators_to_transitions = { 764 863 grammar.Equals: [any_to_bool_transition], 765 864 grammar.NotEquals: [any_to_bool_transition], 766 - grammar.LessThan: [number_to_bool_transition], 767 - grammar.LessThanEquals: [number_to_bool_transition], 768 - grammar.GreaterThan: [number_to_bool_transition], 769 - grammar.GreaterThanEquals: [number_to_bool_transition], 865 + grammar.LessThan: number_comparison_transitions, 866 + grammar.LessThanEquals: number_comparison_transitions, 867 + grammar.GreaterThan: number_comparison_transitions, 868 + grammar.GreaterThanEquals: number_comparison_transitions, 770 869 grammar.In: in_transitions, 771 870 grammar.NotIn: in_transitions, 772 871 }
+17
osprey_worker/src/osprey/engine/executor/node_executor/binary_comparison_executor.py
··· 41 41 _COMPARATORS[Equals], 42 42 _COMPARATORS[NotEquals], 43 43 ) 44 + # For numerical comparisons (<, <=, >, >=), we need to handle None at runtime 45 + # because the executor resolves all boolean operation dependencies before 46 + # short-circuiting. The static validator enforces null-check patterns, but 47 + # at runtime the comparison may still be evaluated with None values. 48 + self.handles_none_comparison = self.comparator in ( 49 + _COMPARATORS[LessThan], 50 + _COMPARATORS[LessThanEquals], 51 + _COMPARATORS[GreaterThan], 52 + _COMPARATORS[GreaterThanEquals], 53 + ) 44 54 45 55 def execute(self, execution_context: 'ExecutionContext') -> bool: 46 56 left = execution_context.resolved(self._node.left, return_none_for_failed_values=self.left_can_be_none) 47 57 right = execution_context.resolved(self._node.right, return_none_for_failed_values=self.right_can_be_none) 58 + 59 + # Handle None values for numerical comparisons at runtime. 60 + # Even with null-check patterns like "X != None and X >= 90", the executor 61 + # resolves all dependencies before the boolean operation short-circuits. 62 + if self.handles_none_comparison and (left is None or right is None): 63 + return False 64 + 48 65 return bool(self.comparator(left, right)) 49 66 50 67 def get_dependent_nodes(self) -> List[ASTNode]:
+56
osprey_worker/src/osprey/engine/executor/tests/test_binary_comparison.py
··· 167 167 ) 168 168 169 169 assert data == {'Ret': expected} 170 + 171 + 172 + @pytest.mark.parametrize( 173 + 'opt_val, expected', 174 + [ 175 + (90, True), 176 + (80, False), 177 + ('None', False), # None value - condition short-circuits 178 + ], 179 + ) 180 + def test_optional_null_check_before_comparison(execute: ExecuteFunction, opt_val: object, expected: bool) -> None: 181 + """Test that type narrowing works for X != None and X >= 90 pattern.""" 182 + data = execute( 183 + f""" 184 + OptVal: Optional[int] = {opt_val} 185 + # Type narrowing: after OptVal != None, OptVal is narrowed from Optional[int] to int 186 + Ret = OptVal != None and OptVal >= 90 187 + """ 188 + ) 189 + 190 + assert data == {'Ret': expected} 191 + 192 + 193 + def test_optional_null_check_chained_narrowing(execute: ExecuteFunction) -> None: 194 + """Test chained type narrowing with multiple optional values.""" 195 + data = execute( 196 + """ 197 + A: Optional[int] = 100 198 + B: Optional[int] = 50 199 + # Both A and B get narrowed after their respective null checks 200 + Ret = A != None and B != None and A >= 90 and B >= 40 201 + """ 202 + ) 203 + 204 + assert data == {'Ret': True} 205 + 206 + 207 + @pytest.mark.parametrize( 208 + 'opt_val, expected', 209 + [ 210 + (90, True), 211 + (80, False), 212 + ('None', True), # None value - first condition is True, so result is True 213 + ], 214 + ) 215 + def test_optional_or_pattern_null_check(execute: ExecuteFunction, opt_val: object, expected: bool) -> None: 216 + """Test that type narrowing works for X == None or X >= 90 pattern.""" 217 + data = execute( 218 + f""" 219 + OptVal: Optional[int] = {opt_val} 220 + # Type narrowing: after OptVal == None is false, OptVal is narrowed from Optional[int] to int 221 + Ret = OptVal == None or OptVal >= 90 222 + """ 223 + ) 224 + 225 + assert data == {'Ret': expected}