-
-
Notifications
You must be signed in to change notification settings - Fork 62
/
Copy pathrules.py
372 lines (297 loc) · 14 KB
/
rules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
# cython: language_level=3
# -*- coding: utf-8 -*-
from inspect import signature
from itertools import chain
from mathics.core.element import KeyComparable
from mathics.core.expression import Expression
from mathics.core.pattern import Pattern, StopGenerator
from mathics.core.symbols import SymbolTrue, strip_context
def _python_function_arguments(f):
return signature(f).parameters.keys()
def function_arguments(f):
return _python_function_arguments(f)
class StopMatchConditionFailed(StopGenerator):
pass
class StopGenerator_BaseRule(StopGenerator):
pass
class BaseRule(KeyComparable):
"""
This is the base class from which all other Rules are derived from.
Rules are part of the rewriting system of Mathics. See https://en.wikipedia.org/wiki/Rewriting
This class is not complete in of itself and subclasses should adapt or fill in
what is needed. In particular ``do_replace()`` needs to be implemented.
Important subclasses: BuiltinRule and Rule.
"""
def __init__(self, pattern, system=False) -> None:
self.pattern = Pattern.create(pattern)
self.system = system
def apply(
self, expression, evaluation, fully=True, return_list=False, max_list=None
):
result_list = []
# count = 0
if return_list and max_list is not None and max_list <= 0:
return []
def yield_match(vars, rest):
if rest is None:
rest = ([], [])
if 0 < len(rest[0]) + len(rest[1]) == len(expression.get_elements()):
# continue
return
options = {}
for name, value in list(vars.items()):
if name.startswith("_option_"):
options[name[len("_option_") :]] = value
del vars[name]
try:
new_expression = self.do_replace(expression, vars, options, evaluation)
except StopMatchConditionFailed:
return
if new_expression is None:
new_expression = expression
if rest[0] or rest[1]:
result = Expression(
expression.get_head(),
*list(chain(rest[0], [new_expression], rest[1]))
)
else:
result = new_expression
if isinstance(result, Expression):
if result.elements_properties is None:
result._build_elements_properties()
# Flatten out sequences (important for Rule itself!)
result = result.flatten_pattern_sequence(evaluation)
if return_list:
result_list.append(result)
# count += 1
if max_list is not None and len(result_list) >= max_list:
# return result_list
raise StopGenerator_BaseRule(result_list)
else:
raise StopGenerator_BaseRule(result)
# only first possibility counts
try:
self.pattern.match(yield_match, expression, {}, evaluation, fully=fully)
except StopGenerator_BaseRule as exc:
# FIXME: figure where these values are not getting set or updated properly.
# For now we have to take a pessimistic view
expr = exc.value
# FIXME: expr is sometimes a list - why the changing types
if hasattr(expr, "_elements_fully_evaluated"):
expr._elements_fully_evaluated = False
expr._is_flat = False # I think this is fully updated
expr._is_ordered = False
return expr
if return_list:
return result_list
else:
return None
def do_replace(self):
raise NotImplementedError
def get_sort_key(self, pattern_sort=False) -> tuple:
# FIXME: check if this makes sense:
return tuple((self.system, self.pattern.get_sort_key(True)))
class Rule(BaseRule):
"""
There are two kinds of Rules. This kind of Rule transforms an
Expression into another Expression based on the pattern and a
replacement term and doesn't involve function application.
Also, in contrast to BuiltinRule[], rule application cannot force
a reevaluation of the expression when the rewrite/apply/eval step
finishes.
Here is an example of a Rule::
F[x_] -> x^2 (* The same thing as: Rule[x_, x^2] *)
``F[x_]`` is a pattern and ``x^2`` is the replacement term. When
applied to the expression ``G[F[1.], F[a]]`` the result is
``G[1.^2, a^2]``
"""
def __ge__(self, other):
if isinstance(other, Rule):
sys, key, rhs_cond = self.get_sort_key()
sys_other, key_other, rhs_cond_other = other.get_sort_key()
if sys != sys_other:
return sys > sys_other
if key != key_other:
return key > key_other
# larger and more complex conditions come first
len_cond, len_cond_other = len(rhs_cond), len(rhs_cond_other)
if len_cond != len_cond_other:
return len_cond_other > len_cond
if len_cond == 0:
return False
for me_cond, other_cond in zip(rhs_cond, rhs_cond_other):
me_sk = me_cond.get_sort_key(True)
o_sk = other_cond.get_sort_key(True)
if me_sk > o_sk:
return False
return True
# Follow the usual rule
return self.get_sort_key(True) >= other.get_sort_key(True)
def __gt__(self, other):
if isinstance(other, Rule):
sys, key, rhs_cond = self.get_sort_key()
sys_other, key_other, rhs_cond_other = other.get_sort_key()
if sys != sys_other:
return sys > sys_other
if key != key_other:
return key > key_other
# larger and more complex conditions come first
len_cond, len_cond_other = len(rhs_cond), len(rhs_cond_other)
if len_cond != len_cond_other:
return len_cond_other > len_cond
if len_cond == 0:
return False
for me_cond, other_cond in zip(rhs_cond, rhs_cond_other):
me_sk = me_cond.get_sort_key(True)
o_sk = other_cond.get_sort_key(True)
if me_sk > o_sk:
return False
return me_sk > o_sk
# Follow the usual rule
return self.get_sort_key(True) > other.get_sort_key(True)
def __le__(self, other):
if isinstance(other, Rule):
sys, key, rhs_cond = self.get_sort_key()
sys_other, key_other, rhs_cond_other = other.get_sort_key()
if sys != sys_other:
return sys < sys_other
if key != key_other:
return key < key_other
# larger and more complex conditions come first
len_cond, len_cond_other = len(rhs_cond), len(rhs_cond_other)
if len_cond != len_cond_other:
return len_cond_other < len_cond
if len_cond == 0:
return False
for me_cond, other_cond in zip(rhs_cond, rhs_cond_other):
me_sk = me_cond.get_sort_key(True)
o_sk = other_cond.get_sort_key(True)
if me_sk < o_sk:
return False
return True
# Follow the usual rule
return self.get_sort_key(True) <= other.get_sort_key(True)
def __lt__(self, other):
if isinstance(other, Rule):
sys, key, rhs_cond = self.get_sort_key()
sys_other, key_other, rhs_cond_other = other.get_sort_key()
if sys != sys_other:
return sys < sys_other
if key != key_other:
return key < key_other
# larger and more complex conditions come first
len_cond, len_cond_other = len(rhs_cond), len(rhs_cond_other)
if len_cond != len_cond_other:
return len_cond_other < len_cond
if len_cond == 0:
return False
for me_cond, other_cond in zip(rhs_cond, rhs_cond_other):
me_sk = me_cond.get_sort_key(True)
o_sk = other_cond.get_sort_key(True)
if me_sk < o_sk:
return False
return me_sk > o_sk
# Follow the usual rule
return self.get_sort_key(True) < other.get_sort_key(True)
def __init__(self, pattern, replace, delayed=True, system=False) -> None:
super(Rule, self).__init__(pattern, system=system)
self.replace = replace
self.delayed = delayed
# If delayed is True, and replace is a nested
# Condition expression, stores the conditions and the
# remaining stripped expression.
# This is going to be used to compare and sort rules,
# and also to decide if the rule matches an expression.
conds = []
if delayed:
while replace.has_form("System`Condition", 2):
replace, cond = replace.elements
conds.append(cond)
self.rhs_conditions = sorted(conds)
self.strip_replace = replace
def do_replace(self, expression, vars, options, evaluation):
replace = self.replace if self.rhs_conditions == [] else self.strip_replace
for cond in self.rhs_conditions:
cond = cond.replace_vars(vars)
cond = cond.evaluate(evaluation)
if cond is not SymbolTrue:
raise StopMatchConditionFailed
new = replace.replace_vars(vars)
new.options = options
# if options is a non-empty dict, we need to ensure reevaluation of the whole expression, since 'new' will
# usually contain one or more matching OptionValue[symbol_] patterns that need to get replaced with the
# options' values. this is achieved through Expression.evaluate(), which then triggers OptionValue.apply,
# which in turn consults evaluation.options to return an option value.
# in order to get there, we copy 'new' using copy(reevaluate=True), as this will ensure that the whole thing
# will get reevaluated.
# if the expression contains OptionValue[] patterns, but options is empty here, we don't need to act, as the
# expression won't change in that case. the Expression.options would be None anyway, so OptionValue.apply
# would just return the unchanged expression (which is what we have already).
if options:
new = new.copy(reevaluate=True)
return new
def __repr__(self) -> str:
return "<Rule: %s -> %s>" % (self.pattern, self.replace)
def get_sort_key(self, pattern_sort=False) -> tuple:
# FIXME: check if this makes sense:
return tuple(
(self.system, self.pattern.get_sort_key(True), self.rhs_conditions)
)
class BuiltinRule(BaseRule):
"""
A BuiltinRule is a rule that has a replacement term that is associated
a Python function rather than a Mathics Expression as happens in a Rule.
Each time the Pattern part of the Rule matches an Expression, the
matching subexpression is replaced by the expression returned
by application of that function to the remaining terms.
Parameters for the function are bound to parameters matched by the pattern.
Here is an example taken from the symbol ``System`Plus``.
It has has associated a BuiltinRule::
Plus[items___] -> mathics.builtin.arithfns.basic.Plus.apply
The pattern ``items___`` matches a list of Expressions.
When applied to the expression ``F[a+a]`` the method ``mathics.builtin.arithfns.basic.Plus.apply`` is called
binding the parameter ``items`` to the value ``Sequence[a,a]``.
The return value of this function is ``Times[2, a]`` (or more compactly: ``2*a``).
When replaced in the original expression, the result is: ``F[2*a]``.
In contrast to Rule, BuiltinRules can change the state of definitions
in the the system.
For example, the rule::
SetAttributes[a_,b_] -> mathics.builtin.attributes.SetAttributes.apply
when applied to the expression ``SetAttributes[F, NumericFunction]``
sets the attribute ``NumericFunction`` in the definition of the symbol ``F`` and returns Null (``SymbolNull`)`.
This will cause `Expression.evalate() to perform an additional ``rewrite_apply_eval()`` step.
"""
def __init__(self, name, pattern, function, check_options, system=False) -> None:
super(BuiltinRule, self).__init__(pattern, system=system)
self.name = name
self.function = function
self.check_options = check_options
self.pass_expression = "expression" in function_arguments(function)
# If you update this, you must also update traced_do_replace
# (that's in the same file TraceBuiltins is)
def do_replace(self, expression, vars, options, evaluation):
if options and self.check_options:
if not self.check_options(options, evaluation):
return None
# The Python function implementing this builtin expects
# argument names corresponding to the symbol names without
# context marks.
vars_noctx = dict(((strip_context(s), vars[s]) for s in vars))
if self.pass_expression:
vars_noctx["expression"] = expression
if options:
return self.function(evaluation=evaluation, options=options, **vars_noctx)
else:
return self.function(evaluation=evaluation, **vars_noctx)
def __repr__(self) -> str:
return "<BuiltinRule: %s -> %s>" % (self.pattern, self.function)
def __getstate__(self):
odict = self.__dict__.copy()
del odict["function"]
odict["function_"] = (self.function.__self__.get_name(), self.function.__name__)
return odict
def __setstate__(self, dict):
from mathics.builtin import _builtins
self.__dict__.update(dict) # update attributes
class_name, name = dict["function_"]
self.function = getattr(_builtins[class_name], name)