18
18
19
19
# This file is for step1: training a embedding model.
20
20
# This file is based on PaddleScience/ppsci API.
21
+ from os import path as osp
21
22
22
23
import hydra
23
24
import numpy as np
24
25
import paddle
25
26
from omegaconf import DictConfig
26
27
27
28
import ppsci
29
+ from ppsci .utils import logger
28
30
29
31
30
32
def get_mean_std (data : np .ndarray ):
@@ -38,6 +40,11 @@ def get_mean_std(data: np.ndarray):
38
40
39
41
40
42
def train (cfg : DictConfig ):
43
+ # set random seed for reproducibility
44
+ ppsci .utils .misc .set_random_seed (cfg .seed )
45
+ # initialize logger
46
+ logger .init_logger ("ppsci" , osp .join (cfg .output_dir , f"{ cfg .mode } .log" ), "info" )
47
+
41
48
weights = (1.0 * (cfg .TRAIN_BLOCK_SIZE - 1 ), 1.0e4 * cfg .TRAIN_BLOCK_SIZE )
42
49
regularization_key = "k_matrix"
43
50
# manually build constraint(s)
@@ -130,9 +137,12 @@ def train(cfg: DictConfig):
130
137
solver = ppsci .solver .Solver (
131
138
model ,
132
139
constraint ,
133
- optimizer = optimizer ,
140
+ cfg .output_dir ,
141
+ optimizer ,
142
+ epochs = cfg .TRAIN .epochs ,
143
+ iters_per_epoch = ITERS_PER_EPOCH ,
144
+ eval_during_train = True ,
134
145
validator = validator ,
135
- cfg = cfg ,
136
146
)
137
147
# train model
138
148
solver .train ()
@@ -141,6 +151,11 @@ def train(cfg: DictConfig):
141
151
142
152
143
153
def evaluate (cfg : DictConfig ):
154
+ # set random seed for reproducibility
155
+ ppsci .utils .misc .set_random_seed (cfg .seed )
156
+ # initialize logger
157
+ logger .init_logger ("ppsci" , osp .join (cfg .output_dir , f"{ cfg .mode } .log" ), "info" )
158
+
144
159
weights = (1.0 * (cfg .TRAIN_BLOCK_SIZE - 1 ), 1.0e4 * cfg .TRAIN_BLOCK_SIZE )
145
160
regularization_key = "k_matrix"
146
161
# manually build constraint(s)
0 commit comments