From 62cf433c1c8ca8dff574d347a02459eadbabf133 Mon Sep 17 00:00:00 2001 From: mmatera Date: Wed, 16 Nov 2022 09:52:53 -0300 Subject: [PATCH 1/2] Fix issue that prevented handling rule application for rules of the form pat->Condition[expr_,cond] --- CHANGES.rst | 3 +- mathics/builtin/assignments/assignment.py | 14 ++- mathics/builtin/patterns.py | 13 +- mathics/core/definitions.py | 45 +++++-- mathics/core/rules.py | 143 +++++++++++++++++++++- mathics/core/systemsymbols.py | 5 + test/builtin/test_patterns.py | 28 +++++ 7 files changed, 228 insertions(+), 23 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index d763134ae..af42c290b 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -50,7 +50,8 @@ Bugs # ``0`` with a given precision (like in ```0`3```) is now parsed as ``0``, an integer number. #. ``RandomSample`` with one list argument now returns a random ordering of the list items. Previously it would return just one item. - +#. Rules of the form ``pat->Condition[expr, cond]`` are handled as in WL. The same also works for nested `Condition` expressions. In particular, the comparison between two Rules with the same pattern but an iterated ``Condition`` expressionare considered equal if the conditions are the same. + Enhancements ++++++++++++ diff --git a/mathics/builtin/assignments/assignment.py b/mathics/builtin/assignments/assignment.py index 7f3008bf1..40b6af129 100644 --- a/mathics/builtin/assignments/assignment.py +++ b/mathics/builtin/assignments/assignment.py @@ -170,11 +170,21 @@ class SetDelayed(Set): 'Condition' ('/;') can be used with 'SetDelayed' to make an assignment that only holds if a condition is satisfied: >> f[x_] := p[x] /; x>0 + >> f[x_] := p[-x]/; x<-2 >> f[3] = p[3] >> f[-3] - = f[-3] - It also works if the condition is set in the LHS: + = p[3] + >> f[-1] + = f[-1] + Notice that the LHS is the same in both definitions, but the second + does not overwrite the first one. + + To overwrite one of these definitions, we have to assign using the same condition: + >> f[x_] := Sin[x] /; x>0 + >> f[3] + = Sin[3] + In a similar way, the condition can be set in the LHS: >> F[x_, y_] /; x < y /; x>0 := x / y; >> F[x_, y_] := y / x; >> F[2, 3] diff --git a/mathics/builtin/patterns.py b/mathics/builtin/patterns.py index 72fae8e61..b9df63681 100644 --- a/mathics/builtin/patterns.py +++ b/mathics/builtin/patterns.py @@ -172,7 +172,8 @@ def create_rules(rules_expr, expr, name, evaluation, extra_args=[]): else: result = [] for rule in rules: - if rule.get_head_name() not in ("System`Rule", "System`RuleDelayed"): + head_name = rule.get_head_name() + if head_name not in ("System`Rule", "System`RuleDelayed"): evaluation.message(name, "reps", rule) return None, True elif len(rule.elements) != 2: @@ -186,7 +187,13 @@ def create_rules(rules_expr, expr, name, evaluation, extra_args=[]): ) return None, True else: - result.append(Rule(rule.elements[0], rule.elements[1])) + result.append( + Rule( + rule.elements[0], + rule.elements[1], + delayed=(head_name == "System`RuleDelayed"), + ) + ) return result, False @@ -1690,7 +1697,7 @@ def __init__(self, rulelist, evaluation): self._elements = None self._head = SymbolDispatch - def get_sort_key(self) -> tuple: + def get_sort_key(self, pattern_sort=False) -> tuple: return self.src.get_sort_key() def get_atom_name(self): diff --git a/mathics/core/definitions.py b/mathics/core/definitions.py index def6b1a93..fe64e9d80 100644 --- a/mathics/core/definitions.py +++ b/mathics/core/definitions.py @@ -11,11 +11,13 @@ from typing import List, Optional -from mathics.core.atoms import String +from mathics.core.atoms import Integer, String from mathics.core.attributes import A_NO_ATTRIBUTES from mathics.core.convert.expression import to_mathics_list from mathics.core.element import fully_qualified_symbol_name from mathics.core.expression import Expression +from mathics.core.pattern import Pattern +from mathics.core.rules import Rule from mathics.core.symbols import ( Atom, Symbol, @@ -721,9 +723,6 @@ def get_ownvalue(self, name): return None def set_ownvalue(self, name, value) -> None: - from .expression import Symbol - from .rules import Rule - name = self.lookup_name(name) self.add_rule(name, Rule(Symbol(name), value)) self.clear_cache(name) @@ -759,8 +758,6 @@ def get_config_value(self, name, default=None): return default def set_config_value(self, name, new_value) -> None: - from mathics.core.expression import Integer - self.set_ownvalue(name, Integer(new_value)) def set_line_no(self, line_no) -> None: @@ -780,6 +777,25 @@ def get_history_length(self): def get_tag_position(pattern, name) -> Optional[str]: + # Strip first the pattern from HoldPattern, Pattern + # and Condition wrappings + while True: + # TODO: Not Atom/Expression, + # pattern -> pattern.to_expression() + if isinstance(pattern, Pattern): + pattern = pattern.expr + continue + if pattern.has_form("System`HoldPattern", 1): + pattern = pattern.elements[0] + continue + if pattern.has_form("System`Pattern", 2): + pattern = pattern.elements[1] + continue + if pattern.has_form("System`Condition", 2): + pattern = pattern.elements[0] + continue + break + if pattern.get_name() == name: return "own" elif isinstance(pattern, Atom): @@ -788,10 +804,8 @@ def get_tag_position(pattern, name) -> Optional[str]: head_name = pattern.get_head_name() if head_name == name: return "down" - elif head_name == "System`N" and len(pattern.elements) == 2: + elif pattern.has_form("System`N", 2): return "n" - elif head_name == "System`Condition" and len(pattern.elements) > 0: - return get_tag_position(pattern.elements[0], name) elif pattern.get_lookup_name() == name: return "sub" else: @@ -801,11 +815,18 @@ def get_tag_position(pattern, name) -> Optional[str]: return None -def insert_rule(values, rule) -> None: +def insert_rule(values: list, rule: Rule) -> None: + rhs_conds = getattr(rule, "rhs_conditions", []) for index, existing in enumerate(values): if existing.pattern.sameQ(rule.pattern): - del values[index] - break + # Check for coincidences in the replace conditions, + # it they are there. + # This ensures that the rules are equivalent even taking + # into accound the RHS conditions. + existing_rhs_conds = getattr(existing, "rhs_conditions", []) + if existing_rhs_conds == rhs_conds: + del values[index] + break # use insort_left to guarantee that if equal rules exist, newer rules will # get higher precedence by being inserted before them. see DownValues[]. bisect.insort_left(values, rule) diff --git a/mathics/core/rules.py b/mathics/core/rules.py index 717f9d64c..c4ba7a2a2 100644 --- a/mathics/core/rules.py +++ b/mathics/core/rules.py @@ -5,7 +5,7 @@ from mathics.core.element import KeyComparable from mathics.core.expression import Expression -from mathics.core.symbols import strip_context +from mathics.core.symbols import strip_context, SymbolTrue from mathics.core.pattern import Pattern, StopGenerator from itertools import chain @@ -19,6 +19,10 @@ def function_arguments(f): return _python_function_arguments(f) +class StopMatchConditionFailed(StopGenerator): + pass + + class StopGenerator_BaseRule(StopGenerator): pass @@ -59,7 +63,11 @@ def yield_match(vars, rest): if name.startswith("_option_"): options[name[len("_option_") :]] = value del vars[name] - new_expression = self.do_replace(expression, vars, options, evaluation) + try: + new_expression = self.do_replace(expression, vars, options, evaluation) + except StopMatchConditionFailed: + return + if new_expression is None: new_expression = expression if rest[0] or rest[1]: @@ -107,7 +115,7 @@ def yield_match(vars, rest): def do_replace(self): raise NotImplementedError - def get_sort_key(self) -> tuple: + def get_sort_key(self, pattern_sort=False) -> tuple: # FIXME: check if this makes sense: return tuple((self.system, self.pattern.get_sort_key(True))) @@ -131,12 +139,131 @@ class Rule(BaseRule): ``G[1.^2, a^2]`` """ - def __init__(self, pattern, replace, system=False) -> None: + def __ge__(self, other): + if isinstance(other, Rule): + sys, key, rhs_cond = self.get_sort_key() + sys_other, key_other, rhs_cond_other = other.get_sort_key() + if sys != sys_other: + return sys > sys_other + if key != key_other: + return key > key_other + + # larger and more complex conditions come first + len_cond, len_cond_other = len(rhs_cond), len(rhs_cond_other) + if len_cond != len_cond_other: + return len_cond_other > len_cond + if len_cond == 0: + return False + for me_cond, other_cond in zip(rhs_cond, rhs_cond_other): + me_sk = me_cond.get_sort_key(True) + o_sk = other_cond.get_sort_key(True) + if me_sk > o_sk: + return False + return True + # Follow the usual rule + return self.get_sort_key(True) >= other.get_sort_key(True) + + def __gt__(self, other): + if isinstance(other, Rule): + sys, key, rhs_cond = self.get_sort_key() + sys_other, key_other, rhs_cond_other = other.get_sort_key() + if sys != sys_other: + return sys > sys_other + if key != key_other: + return key > key_other + + # larger and more complex conditions come first + len_cond, len_cond_other = len(rhs_cond), len(rhs_cond_other) + if len_cond != len_cond_other: + return len_cond_other > len_cond + if len_cond == 0: + return False + + for me_cond, other_cond in zip(rhs_cond, rhs_cond_other): + me_sk = me_cond.get_sort_key(True) + o_sk = other_cond.get_sort_key(True) + if me_sk > o_sk: + return False + return me_sk > o_sk + # Follow the usual rule + return self.get_sort_key(True) > other.get_sort_key(True) + + def __le__(self, other): + if isinstance(other, Rule): + sys, key, rhs_cond = self.get_sort_key() + sys_other, key_other, rhs_cond_other = other.get_sort_key() + if sys != sys_other: + return sys < sys_other + if key != key_other: + return key < key_other + + # larger and more complex conditions come first + len_cond, len_cond_other = len(rhs_cond), len(rhs_cond_other) + if len_cond != len_cond_other: + return len_cond_other < len_cond + if len_cond == 0: + return False + for me_cond, other_cond in zip(rhs_cond, rhs_cond_other): + me_sk = me_cond.get_sort_key(True) + o_sk = other_cond.get_sort_key(True) + if me_sk < o_sk: + return False + return True + # Follow the usual rule + return self.get_sort_key(True) <= other.get_sort_key(True) + + def __lt__(self, other): + if isinstance(other, Rule): + sys, key, rhs_cond = self.get_sort_key() + sys_other, key_other, rhs_cond_other = other.get_sort_key() + if sys != sys_other: + return sys < sys_other + if key != key_other: + return key < key_other + + # larger and more complex conditions come first + len_cond, len_cond_other = len(rhs_cond), len(rhs_cond_other) + if len_cond != len_cond_other: + return len_cond_other < len_cond + if len_cond == 0: + return False + + for me_cond, other_cond in zip(rhs_cond, rhs_cond_other): + me_sk = me_cond.get_sort_key(True) + o_sk = other_cond.get_sort_key(True) + if me_sk < o_sk: + return False + return me_sk > o_sk + # Follow the usual rule + return self.get_sort_key(True) < other.get_sort_key(True) + + def __init__(self, pattern, replace, delayed=True, system=False) -> None: super(Rule, self).__init__(pattern, system=system) self.replace = replace + self.delayed = delayed + # If delayed is True, and replace is a nested + # Condition expression, stores the conditions and the + # remaining stripped expression. + # This is going to be used to compare and sort rules, + # and also to decide if the rule matches an expression. + conds = [] + if delayed: + while replace.has_form("System`Condition", 2): + replace, cond = replace.elements + conds.append(cond) + + self.rhs_conditions = sorted(conds) + self.strip_replace = replace def do_replace(self, expression, vars, options, evaluation): - new = self.replace.replace_vars(vars) + replace = self.replace if self.rhs_conditions == [] else self.strip_replace + for cond in self.rhs_conditions: + cond = cond.replace_vars(vars) + cond = cond.evaluate(evaluation) + if cond is not SymbolTrue: + raise StopMatchConditionFailed + + new = replace.replace_vars(vars) new.options = options # if options is a non-empty dict, we need to ensure reevaluation of the whole expression, since 'new' will @@ -159,6 +286,12 @@ def do_replace(self, expression, vars, options, evaluation): def __repr__(self) -> str: return " %s>" % (self.pattern, self.replace) + def get_sort_key(self, pattern_sort=False) -> tuple: + # FIXME: check if this makes sense: + return tuple( + (self.system, self.pattern.get_sort_key(True), self.rhs_conditions) + ) + class BuiltinRule(BaseRule): """ diff --git a/mathics/core/systemsymbols.py b/mathics/core/systemsymbols.py index a49c30a5d..dd77fac98 100644 --- a/mathics/core/systemsymbols.py +++ b/mathics/core/systemsymbols.py @@ -169,6 +169,7 @@ SymbolSeries = Symbol("System`Series") SymbolSeriesData = Symbol("System`SeriesData") SymbolSet = Symbol("System`Set") +SymbolSetDelayed = Symbol("System`SetDelayed") SymbolSign = Symbol("System`Sign") SymbolSimplify = Symbol("System`Simplify") SymbolSin = Symbol("System`Sin") @@ -186,6 +187,8 @@ SymbolSubsuperscriptBox = Symbol("System`SubsuperscriptBox") SymbolSuperscriptBox = Symbol("System`SuperscriptBox") SymbolTable = Symbol("System`Table") +SymbolTagSet = Symbol("System`TagSet") +SymbolTagSetDelayed = Symbol("System`TagSetDelayed") SymbolTeXForm = Symbol("System`TeXForm") SymbolThrow = Symbol("System`Throw") SymbolToString = Symbol("System`ToString") @@ -194,5 +197,7 @@ SymbolUndefined = Symbol("System`Undefined") SymbolUnequal = Symbol("System`Unequal") SymbolUnevaluated = Symbol("System`Unevaluated") +SymbolUpSet = Symbol("System`UpSet") +SymbolUpSetDelayed = Symbol("System`UpSetDelayed") SymbolUpValues = Symbol("System`UpValues") SymbolXor = Symbol("System`Xor") diff --git a/test/builtin/test_patterns.py b/test/builtin/test_patterns.py index 9fb3d8a0f..507c331f9 100644 --- a/test/builtin/test_patterns.py +++ b/test/builtin/test_patterns.py @@ -26,3 +26,31 @@ def test_replace_all(): ), ): check_evaluation(str_expr, str_expected, message) + + +def test_rule_repl_cond(): + for str_expr, str_expected, message in ( + # For Rules, replacement is not evaluated + ( + "f[x]/.(f[u_]->u^2/; u>3/; u>2)", + "x^2/; x>3/; x>2", + "conditions are not evaluated in Rule", + ), + ( + "f[4]/.(f[u_]->u^2/; u>3/; u>2)", + "16 /; 4 > 3 /; 4 > 2", + "still not evaluated, even if values are provided, due to the HoldAll attribute.", + ), + # However, for delayed rules, the behavior is different: + # Conditions defines if the rule is applied + # and do not appears in the result. + ("f[x]/.(f[u_]:>u^2/; u>3/; u>2)", "f[x]", "conditions are not evaluated"), + ("f[4]/.(f[u_]:>u^2/; u>3/; u>2)", "16", "both conditions are True"), + ( + "f[2.5]/.(f[u_]:>u^2/; u>3/; u>2)", + "f[2.5]", + "just the first condition is True", + ), + ("f[1.]/.(f[u_]:>u^2/; u>3/; u>2)", "f[1.]", "Both conditions are False"), + ): + check_evaluation(str_expr, str_expected, message) From 0888468c1a60c9fa4743e5beec556bbc6afcd069 Mon Sep 17 00:00:00 2001 From: mmatera Date: Sat, 31 Dec 2022 13:33:36 -0300 Subject: [PATCH 2/2] isort --- mathics/core/rules.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mathics/core/rules.py b/mathics/core/rules.py index c70f14e63..b9f10382c 100644 --- a/mathics/core/rules.py +++ b/mathics/core/rules.py @@ -6,9 +6,8 @@ from mathics.core.element import KeyComparable from mathics.core.expression import Expression -from mathics.core.symbols import strip_context, SymbolTrue from mathics.core.pattern import Pattern, StopGenerator -from mathics.core.symbols import strip_context +from mathics.core.symbols import SymbolTrue, strip_context def _python_function_arguments(f):