22
22
from typing import Union
23
23
24
24
import paddle
25
- import sympy
25
+ import sympy as sp
26
26
from paddle import nn
27
27
28
28
DETACH_FUNC_NAME = "detach"
@@ -33,7 +33,7 @@ class PDE:
33
33
34
34
def __init__ (self ):
35
35
super ().__init__ ()
36
- self .equations = {}
36
+ self .equations : Dict [ str , Union [ Callable , sp . Basic ]] = {}
37
37
# for PDE which has learnable parameter(s)
38
38
self .learnable_parameters = nn .ParameterList ()
39
39
@@ -42,7 +42,7 @@ def __init__(self):
42
42
@staticmethod
43
43
def create_symbols (
44
44
symbol_str : str ,
45
- ) -> Union [sympy .Symbol , Tuple [sympy .Symbol , ...]]:
45
+ ) -> Union [sp .Symbol , Tuple [sp .Symbol , ...]]:
46
46
"""create symbolic variables.
47
47
48
48
Args:
@@ -61,11 +61,9 @@ def create_symbols(
61
61
>>> print(symbols_xyz)
62
62
(x, y, z)
63
63
"""
64
- return sympy .symbols (symbol_str )
64
+ return sp .symbols (symbol_str )
65
65
66
- def create_function (
67
- self , name : str , invars : Tuple [sympy .Symbol , ...]
68
- ) -> sympy .Function :
66
+ def create_function (self , name : str , invars : Tuple [sp .Symbol , ...]) -> sp .Function :
69
67
"""Create named function depending on given invars.
70
68
71
69
Args:
@@ -86,14 +84,73 @@ def create_function(
86
84
>>> print(f)
87
85
f(x, y, z)
88
86
"""
89
- expr = sympy .Function (name )(* invars )
87
+ expr = sp .Function (name )(* invars )
90
88
91
- # wrap `expression(...)` to `detach(expression(...))`
92
- # if name of expression is in given detach_keys
93
- if self .detach_keys and name in self .detach_keys :
94
- expr = sympy .Function (DETACH_FUNC_NAME )(expr )
95
89
return expr
96
90
91
+ def _apply_detach (self ):
92
+ """
93
+ Wrap detached sub_expr into detach(sub_expr) to prevent gradient
94
+ back-propagation, only for those items speicified in self.detach_keys.
95
+
96
+ NOTE: This function is expected to be called after self.equations is ready in PDE.__init__.
97
+
98
+ Examples:
99
+ >>> import ppsci
100
+ >>> ns = ppsci.equation.NavierStokes(1.0, 1.0, 2, False)
101
+ >>> print(ns)
102
+ NavierStokes
103
+ continuity: Derivative(u(x, y), x) + Derivative(v(x, y), y)
104
+ momentum_x: u(x, y)*Derivative(u(x, y), x) + v(x, y)*Derivative(u(x, y), y) + 1.0*Derivative(p(x, y), x) - 1.0*Derivative(u(x, y), (x, 2)) - 1.0*Derivative(u(x, y), (y, 2))
105
+ momentum_y: u(x, y)*Derivative(v(x, y), x) + v(x, y)*Derivative(v(x, y), y) + 1.0*Derivative(p(x, y), y) - 1.0*Derivative(v(x, y), (x, 2)) - 1.0*Derivative(v(x, y), (y, 2))
106
+ >>> detach_keys = ("u", "v__y")
107
+ >>> ns = ppsci.equation.NavierStokes(1.0, 1.0, 2, False, detach_keys=detach_keys)
108
+ >>> print(ns)
109
+ NavierStokes
110
+ continuity: detach(Derivative(v(x, y), y)) + Derivative(u(x, y), x)
111
+ momentum_x: detach(u(x, y))*Derivative(u(x, y), x) + v(x, y)*Derivative(u(x, y), y) + 1.0*Derivative(p(x, y), x) - 1.0*Derivative(u(x, y), (x, 2)) - 1.0*Derivative(u(x, y), (y, 2))
112
+ momentum_y: detach(u(x, y))*Derivative(v(x, y), x) + detach(Derivative(v(x, y), y))*v(x, y) + 1.0*Derivative(p(x, y), y) - 1.0*Derivative(v(x, y), (x, 2)) - 1.0*Derivative(v(x, y), (y, 2))
113
+ """
114
+ if self .detach_keys is None :
115
+ return
116
+
117
+ from copy import deepcopy
118
+
119
+ from sympy .core .traversal import postorder_traversal
120
+
121
+ from ppsci .utils .symbolic import _cvt_to_key
122
+
123
+ for name , expr in self .equations .items ():
124
+ if not isinstance (expr , sp .Basic ):
125
+ continue
126
+ # only process sympy expression
127
+ expr_ = deepcopy (expr )
128
+ for item in postorder_traversal (expr ):
129
+ if _cvt_to_key (item ) in self .detach_keys :
130
+ # inplace all related sub_expr into detach(sub_expr)
131
+ expr_ = expr_ .replace (item , sp .Function (DETACH_FUNC_NAME )(item ))
132
+
133
+ # remove all detach wrapper for more-than-once wrapped items to prevent duplicated wrapping
134
+ expr_ = expr_ .replace (
135
+ sp .Function (DETACH_FUNC_NAME )(
136
+ sp .Function (DETACH_FUNC_NAME )(item )
137
+ ),
138
+ sp .Function (DETACH_FUNC_NAME )(item ),
139
+ )
140
+
141
+ # remove unccessary detach wrapping for the first arg of Derivative
142
+ for item_ in list (postorder_traversal (expr_ )):
143
+ if isinstance (item_ , sp .Derivative ):
144
+ if item_ .args [0 ].name == DETACH_FUNC_NAME :
145
+ expr_ = expr_ .replace (
146
+ item_ ,
147
+ sp .Derivative (
148
+ item_ .args [0 ].args [0 ], * item_ .args [1 :]
149
+ ),
150
+ )
151
+
152
+ self .equations [name ] = expr_
153
+
97
154
def add_equation (self , name : str , equation : Callable ):
98
155
"""Add an equation.
99
156
@@ -110,7 +167,8 @@ def add_equation(self, name: str, equation: Callable):
110
167
>>> equation = sympy.diff(u, x) + sympy.diff(u, y)
111
168
>>> pde.add_equation('linear_pde', equation)
112
169
>>> print(pde)
113
- PDE, linear_pde: 2*x + 2*y
170
+ PDE
171
+ linear_pde: 2*x + 2*y
114
172
"""
115
173
self .equations .update ({name : equation })
116
174
@@ -181,7 +239,7 @@ def set_state_dict(
181
239
return self .learnable_parameters .set_state_dict (state_dict )
182
240
183
241
def __str__ (self ):
184
- return ", " .join (
242
+ return "\n " .join (
185
243
[self .__class__ .__name__ ]
186
- + [f"{ name } : { eq } " for name , eq in self .equations .items ()]
244
+ + [f" { name } : { eq } " for name , eq in self .equations .items ()]
187
245
)
0 commit comments