Skip to content

Commit a25ab90

Browse files
update develop code
1 parent d319e80 commit a25ab90

14 files changed

+110
-28
lines changed

ppsci/arch/phycrnet.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def __init__(
147147
)
148148

149149
# ConvLSTM
150-
self.ConvLSTM = paddle.nn.LayerList(
150+
self.convlstm = paddle.nn.LayerList(
151151
[
152152
ConvLSTMCell(
153153
input_channels=self.input_channels[i],
@@ -194,16 +194,16 @@ def forward(self, x):
194194
x = encoder(x)
195195

196196
# convlstm
197-
for i, LSTM in enumerate(self.ConvLSTM):
197+
for i, lstm in enumerate(self.convlstm, self.num_encoder):
198198
if step == 0:
199-
(h, c) = LSTM.init_hidden_tensor(
199+
(h, c) = lstm.init_hidden_tensor(
200200
prev_state=self.initial_state[i - self.num_encoder]
201201
)
202202
internal_state.append((h, c))
203203

204204
# one-step forward
205205
(h, c) = internal_state[i - self.num_encoder]
206-
x, new_c = LSTM(x, h, c)
206+
x, new_c = lstm(x, h, c)
207207
internal_state[i - self.num_encoder] = (x, new_c)
208208

209209
# output

ppsci/equation/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def build_equation(cfg):
5454
"""Build equation(s)
5555
5656
Args:
57-
cfg (List[DictConfig]): Equation(s) config list.
57+
cfg (List[AttrDict]): Equation(s) config list.
5858
5959
Returns:
6060
Dict[str, Equation]: Equation(s) in dict.

ppsci/equation/pde/base.py

+73-15
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from typing import Union
2323

2424
import paddle
25-
import sympy
25+
import sympy as sp
2626
from paddle import nn
2727

2828
DETACH_FUNC_NAME = "detach"
@@ -33,7 +33,7 @@ class PDE:
3333

3434
def __init__(self):
3535
super().__init__()
36-
self.equations = {}
36+
self.equations: Dict[str, Union[Callable, sp.Basic]] = {}
3737
# for PDE which has learnable parameter(s)
3838
self.learnable_parameters = nn.ParameterList()
3939

@@ -42,7 +42,7 @@ def __init__(self):
4242
@staticmethod
4343
def create_symbols(
4444
symbol_str: str,
45-
) -> Union[sympy.Symbol, Tuple[sympy.Symbol, ...]]:
45+
) -> Union[sp.Symbol, Tuple[sp.Symbol, ...]]:
4646
"""create symbolic variables.
4747
4848
Args:
@@ -61,11 +61,9 @@ def create_symbols(
6161
>>> print(symbols_xyz)
6262
(x, y, z)
6363
"""
64-
return sympy.symbols(symbol_str)
64+
return sp.symbols(symbol_str)
6565

66-
def create_function(
67-
self, name: str, invars: Tuple[sympy.Symbol, ...]
68-
) -> sympy.Function:
66+
def create_function(self, name: str, invars: Tuple[sp.Symbol, ...]) -> sp.Function:
6967
"""Create named function depending on given invars.
7068
7169
Args:
@@ -86,14 +84,73 @@ def create_function(
8684
>>> print(f)
8785
f(x, y, z)
8886
"""
89-
expr = sympy.Function(name)(*invars)
87+
expr = sp.Function(name)(*invars)
9088

91-
# wrap `expression(...)` to `detach(expression(...))`
92-
# if name of expression is in given detach_keys
93-
if self.detach_keys and name in self.detach_keys:
94-
expr = sympy.Function(DETACH_FUNC_NAME)(expr)
9589
return expr
9690

91+
def _apply_detach(self):
92+
"""
93+
Wrap detached sub_expr into detach(sub_expr) to prevent gradient
94+
back-propagation, only for those items speicified in self.detach_keys.
95+
96+
NOTE: This function is expected to be called after self.equations is ready in PDE.__init__.
97+
98+
Examples:
99+
>>> import ppsci
100+
>>> ns = ppsci.equation.NavierStokes(1.0, 1.0, 2, False)
101+
>>> print(ns)
102+
NavierStokes
103+
continuity: Derivative(u(x, y), x) + Derivative(v(x, y), y)
104+
momentum_x: u(x, y)*Derivative(u(x, y), x) + v(x, y)*Derivative(u(x, y), y) + 1.0*Derivative(p(x, y), x) - 1.0*Derivative(u(x, y), (x, 2)) - 1.0*Derivative(u(x, y), (y, 2))
105+
momentum_y: u(x, y)*Derivative(v(x, y), x) + v(x, y)*Derivative(v(x, y), y) + 1.0*Derivative(p(x, y), y) - 1.0*Derivative(v(x, y), (x, 2)) - 1.0*Derivative(v(x, y), (y, 2))
106+
>>> detach_keys = ("u", "v__y")
107+
>>> ns = ppsci.equation.NavierStokes(1.0, 1.0, 2, False, detach_keys=detach_keys)
108+
>>> print(ns)
109+
NavierStokes
110+
continuity: detach(Derivative(v(x, y), y)) + Derivative(u(x, y), x)
111+
momentum_x: detach(u(x, y))*Derivative(u(x, y), x) + v(x, y)*Derivative(u(x, y), y) + 1.0*Derivative(p(x, y), x) - 1.0*Derivative(u(x, y), (x, 2)) - 1.0*Derivative(u(x, y), (y, 2))
112+
momentum_y: detach(u(x, y))*Derivative(v(x, y), x) + detach(Derivative(v(x, y), y))*v(x, y) + 1.0*Derivative(p(x, y), y) - 1.0*Derivative(v(x, y), (x, 2)) - 1.0*Derivative(v(x, y), (y, 2))
113+
"""
114+
if self.detach_keys is None:
115+
return
116+
117+
from copy import deepcopy
118+
119+
from sympy.core.traversal import postorder_traversal
120+
121+
from ppsci.utils.symbolic import _cvt_to_key
122+
123+
for name, expr in self.equations.items():
124+
if not isinstance(expr, sp.Basic):
125+
continue
126+
# only process sympy expression
127+
expr_ = deepcopy(expr)
128+
for item in postorder_traversal(expr):
129+
if _cvt_to_key(item) in self.detach_keys:
130+
# inplace all related sub_expr into detach(sub_expr)
131+
expr_ = expr_.replace(item, sp.Function(DETACH_FUNC_NAME)(item))
132+
133+
# remove all detach wrapper for more-than-once wrapped items to prevent duplicated wrapping
134+
expr_ = expr_.replace(
135+
sp.Function(DETACH_FUNC_NAME)(
136+
sp.Function(DETACH_FUNC_NAME)(item)
137+
),
138+
sp.Function(DETACH_FUNC_NAME)(item),
139+
)
140+
141+
# remove unccessary detach wrapping for the first arg of Derivative
142+
for item_ in list(postorder_traversal(expr_)):
143+
if isinstance(item_, sp.Derivative):
144+
if item_.args[0].name == DETACH_FUNC_NAME:
145+
expr_ = expr_.replace(
146+
item_,
147+
sp.Derivative(
148+
item_.args[0].args[0], *item_.args[1:]
149+
),
150+
)
151+
152+
self.equations[name] = expr_
153+
97154
def add_equation(self, name: str, equation: Callable):
98155
"""Add an equation.
99156
@@ -110,7 +167,8 @@ def add_equation(self, name: str, equation: Callable):
110167
>>> equation = sympy.diff(u, x) + sympy.diff(u, y)
111168
>>> pde.add_equation('linear_pde', equation)
112169
>>> print(pde)
113-
PDE, linear_pde: 2*x + 2*y
170+
PDE
171+
linear_pde: 2*x + 2*y
114172
"""
115173
self.equations.update({name: equation})
116174

@@ -181,7 +239,7 @@ def set_state_dict(
181239
return self.learnable_parameters.set_state_dict(state_dict)
182240

183241
def __str__(self):
184-
return ", ".join(
242+
return "\n".join(
185243
[self.__class__.__name__]
186-
+ [f"{name}: {eq}" for name, eq in self.equations.items()]
244+
+ [f" {name}: {eq}" for name, eq in self.equations.items()]
187245
)

ppsci/equation/pde/biharmonic.py

+2
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,5 @@ def __init__(
7070
biharmonic += u.diff(invar_i, 2).diff(invar_j, 2)
7171

7272
self.add_equation("biharmonic", biharmonic)
73+
74+
self._apply_detach()

ppsci/equation/pde/heat_exchanger.py

+2
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,5 @@ def __init__(
9090
self.add_equation("heat_boundary", heat_boundary)
9191
self.add_equation("cold_boundary", cold_boundary)
9292
self.add_equation("wall", wall)
93+
94+
self._apply_detach()

ppsci/equation/pde/laplace.py

+2
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,5 @@ def __init__(self, dim: int, detach_keys: Optional[Tuple[str, ...]] = None):
5151
laplace += u.diff(invar, 2)
5252

5353
self.add_equation("laplace", laplace)
54+
55+
self._apply_detach()

ppsci/equation/pde/linear_elasticity.py

+2
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,5 @@ def __init__(
179179
self.add_equation("traction_y", traction_y)
180180
if self.dim == 3:
181181
self.add_equation("traction_z", traction_z)
182+
183+
self._apply_detach()

ppsci/equation/pde/navier_stokes.py

+2
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,5 @@ def __init__(
147147
self.add_equation("momentum_y", momentum_y)
148148
if self.dim == 3:
149149
self.add_equation("momentum_z", momentum_z)
150+
151+
self._apply_detach()

ppsci/equation/pde/nls_m_b.py

+2
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,5 @@ def __init__(
9797
self.add_equation("Maxwell_1", Maxwell_1)
9898
self.add_equation("Maxwell_2", Maxwell_2)
9999
self.add_equation("Bloch", Bloch)
100+
101+
self._apply_detach()

ppsci/equation/pde/normal_dot_vec.py

+2
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,5 @@ def __init__(
5555
normal_dot_vec += normal * vec
5656

5757
self.add_equation("normal_dot_vec", normal_dot_vec)
58+
59+
self._apply_detach()

ppsci/equation/pde/poisson.py

+2
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,5 @@ def __init__(self, dim: int, detach_keys: Optional[Tuple[str, ...]] = None):
4949
poisson += p.diff(invar, 2)
5050

5151
self.add_equation("poisson", poisson)
52+
53+
self._apply_detach()

ppsci/equation/pde/viv.py

+2
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,5 @@ def __init__(self, rho: float, k1: float, k2: float):
6060
k2 = self.create_symbols(self.k2.name)
6161
f = self.rho * eta.diff(t_f, 2) + sp.exp(k1) * eta.diff(t_f) + sp.exp(k2) * eta
6262
self.add_equation("f", f)
63+
64+
self._apply_detach()

ppsci/utils/download.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def _download(url, path, md5sum=None):
157157
if chunk:
158158
f.write(chunk)
159159
shutil.move(tmp_fullname, fullname)
160-
logger.message(f"Finished downloading pretrained model and saved to {fullname}")
160+
logger.message(f"Finish downloading pretrained model and saved to {fullname}")
161161

162162
return fullname
163163

ppsci/utils/symbolic.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
__all__ = [
4242
"lambdify",
43+
"_cvt_to_key",
4344
]
4445

4546

@@ -116,14 +117,18 @@ def _cvt_to_key(expr: sp.Basic) -> str:
116117
Returns:
117118
str: Converted string key.
118119
"""
120+
if isinstance(expr, sp.Function) and str(expr.func) == equation.DETACH_FUNC_NAME:
121+
return f"{_cvt_to_key(expr.args[0])}_{equation.DETACH_FUNC_NAME}"
122+
119123
if isinstance(expr, (sp.Symbol, sp.core.function.UndefinedFunction, sp.Function)):
124+
# use name of custom function(e.g. "f") instead of itself(e.g. "f(x, y)")
125+
# for simplicity.
120126
if hasattr(expr, "name"):
121-
# use name of custom function instead of itself.
122127
return expr.name
123128
else:
124129
return str(expr)
125130
elif isinstance(expr, sp.Derivative):
126-
# convert Derivative(u(x,y),(x,2),(y,2)) to "u__x__x__y__y"
131+
# convert "Derivative(u(x,y),(x,2),(y,2))" to "u__x__x__y__y"
127132
expr_str = expr.args[0].name
128133
for symbol, order in expr.args[1:]:
129134
expr_str += f"__{symbol}" * order
@@ -813,12 +818,13 @@ def _expr_to_callable_nodes(
813818
else:
814819
callable_nodes.append(OperatorNode(node))
815820
elif isinstance(node, sp.Function):
816-
if node.name == equation.DETACH_FUNC_NAME:
821+
if str(node.func) == equation.DETACH_FUNC_NAME:
817822
callable_nodes.append(DetachNode(node))
823+
logger.debug(f"Detected detach node {node}")
818824
else:
819825
match_index = None
820826
for j, model in enumerate(models):
821-
if str(node.func.name) in model.output_keys:
827+
if str(node.func) in model.output_keys:
822828
callable_nodes.append(
823829
LayerNode(
824830
node,
@@ -828,13 +834,13 @@ def _expr_to_callable_nodes(
828834
if match_index is not None:
829835
raise ValueError(
830836
f"Name of function: '{node}' should be unique along given"
831-
f" models, but got same output_key: '{node.func.name}' "
837+
f" models, but got same output_key: '{str(node.func)}' "
832838
f"in given models[{match_index}] and models[{j}]."
833839
)
834840
match_index = j
835841
# NOTE: Skip 'sdf' function, which should be already generated in
836842
# given data_dict
837-
if match_index is None and node.name != "sdf":
843+
if match_index is None and str(node.func) != "sdf":
838844
raise ValueError(
839845
f"Node {node} can not match any model in given model(s)."
840846
)
@@ -925,7 +931,7 @@ def _expr_to_callable_nodes(
925931
logger.debug(
926932
f"Fused {len(candidate_pos)} derivatives nodes: "
927933
f"{[callable_nodes_group[i][j].expr for i, j in candidate_pos]} into"
928-
f" fuse node sequence: {fused_node_seq} at position: ([{gid0}][{nid0}])"
934+
f" {len(fused_node_seq)} fuse node sequence: {fused_node_seq} at position: ([{gid0}][{nid0}])"
929935
)
930936

931937
# mark merged node

0 commit comments

Comments
 (0)