Skip to content

Commit c4979df

Browse files
[Fix&Refine] Fix PirateBlock and refine save_load log (#1054)
* fix stan act bug in PirateNetBlock * refine load printing log
1 parent 8cea0c0 commit c4979df

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

ppsci/arch/mlp.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -595,9 +595,21 @@ def __init__(
595595
],
596596
default_initializer=nn.initializer.Constant(0),
597597
)
598-
self.act1 = act_mod.get_activation(activation)
599-
self.act2 = act_mod.get_activation(activation)
600-
self.act3 = act_mod.get_activation(activation)
598+
self.act1 = (
599+
act_mod.get_activation(activation)
600+
if activation != "stan"
601+
else act_mod.get_activation(activation)(embed_dim)
602+
)
603+
self.act2 = (
604+
act_mod.get_activation(activation)
605+
if activation != "stan"
606+
else act_mod.get_activation(activation)(embed_dim)
607+
)
608+
self.act3 = (
609+
act_mod.get_activation(activation)
610+
if activation != "stan"
611+
else act_mod.get_activation(activation)(embed_dim)
612+
)
601613

602614
def forward(self, x, u, v):
603615
f = self.act1(self.linear1(x))

ppsci/utils/save_load.py

+6
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def load_checkpoint(
173173
equation_dict = paddle.load(f"{path}.pdeqn")
174174

175175
# set state dict
176+
logger.message(f"* Loading model checkpoint from {path}.pdparams")
176177
missing_keys, unexpected_keys = model.set_state_dict(param_dict)
177178
if missing_keys:
178179
logger.warning(
@@ -185,18 +186,23 @@ def load_checkpoint(
185186
"and corresponding weights will be ignored."
186187
)
187188

189+
logger.message(f"* Loading optimizer checkpoint from {path}.pdopt")
188190
optimizer.set_state_dict(optim_dict)
189191
if grad_scaler is not None:
192+
logger.message(f"* Loading grad scaler checkpoint from {path}.pdscaler")
190193
grad_scaler.load_state_dict(scaler_dict)
191194
if equation is not None and equation_dict is not None:
195+
logger.message(f"* Loading equation checkpoint from {path}.pdeqn")
192196
for name, _equation in equation.items():
193197
_equation.set_state_dict(equation_dict[name])
194198

195199
if ema_model:
200+
logger.message(f"* Loading EMA checkpoint from {path}_ema.pdparams")
196201
avg_param_dict = paddle.load(f"{path}_ema.pdparams")
197202
ema_model.set_state_dict(avg_param_dict)
198203

199204
if aggregator is not None and aggregator.should_persist:
205+
logger.message(f"* Loading loss aggregator checkpoint from {path}.pdagg")
200206
aggregator_dict = paddle.load(f"{path}.pdagg")
201207
aggregator.set_state_dict(aggregator_dict)
202208

0 commit comments

Comments
 (0)