File tree 2 files changed +21
-3
lines changed
2 files changed +21
-3
lines changed Original file line number Diff line number Diff line change @@ -595,9 +595,21 @@ def __init__(
595
595
],
596
596
default_initializer = nn .initializer .Constant (0 ),
597
597
)
598
- self .act1 = act_mod .get_activation (activation )
599
- self .act2 = act_mod .get_activation (activation )
600
- self .act3 = act_mod .get_activation (activation )
598
+ self .act1 = (
599
+ act_mod .get_activation (activation )
600
+ if activation != "stan"
601
+ else act_mod .get_activation (activation )(embed_dim )
602
+ )
603
+ self .act2 = (
604
+ act_mod .get_activation (activation )
605
+ if activation != "stan"
606
+ else act_mod .get_activation (activation )(embed_dim )
607
+ )
608
+ self .act3 = (
609
+ act_mod .get_activation (activation )
610
+ if activation != "stan"
611
+ else act_mod .get_activation (activation )(embed_dim )
612
+ )
601
613
602
614
def forward (self , x , u , v ):
603
615
f = self .act1 (self .linear1 (x ))
Original file line number Diff line number Diff line change @@ -173,6 +173,7 @@ def load_checkpoint(
173
173
equation_dict = paddle .load (f"{ path } .pdeqn" )
174
174
175
175
# set state dict
176
+ logger .message (f"* Loading model checkpoint from { path } .pdparams" )
176
177
missing_keys , unexpected_keys = model .set_state_dict (param_dict )
177
178
if missing_keys :
178
179
logger .warning (
@@ -185,18 +186,23 @@ def load_checkpoint(
185
186
"and corresponding weights will be ignored."
186
187
)
187
188
189
+ logger .message (f"* Loading optimizer checkpoint from { path } .pdopt" )
188
190
optimizer .set_state_dict (optim_dict )
189
191
if grad_scaler is not None :
192
+ logger .message (f"* Loading grad scaler checkpoint from { path } .pdscaler" )
190
193
grad_scaler .load_state_dict (scaler_dict )
191
194
if equation is not None and equation_dict is not None :
195
+ logger .message (f"* Loading equation checkpoint from { path } .pdeqn" )
192
196
for name , _equation in equation .items ():
193
197
_equation .set_state_dict (equation_dict [name ])
194
198
195
199
if ema_model :
200
+ logger .message (f"* Loading EMA checkpoint from { path } _ema.pdparams" )
196
201
avg_param_dict = paddle .load (f"{ path } _ema.pdparams" )
197
202
ema_model .set_state_dict (avg_param_dict )
198
203
199
204
if aggregator is not None and aggregator .should_persist :
205
+ logger .message (f"* Loading loss aggregator checkpoint from { path } .pdagg" )
200
206
aggregator_dict = paddle .load (f"{ path } .pdagg" )
201
207
aggregator .set_state_dict (aggregator_dict )
202
208
You can’t perform that action at this time.
0 commit comments