Skip to content

Commit ee2d77c

Browse files
ADream-kiHydrogenSulfatewangguan1995
authored
【Hackathon 7th PPSCI No.7】No.7 AI-aided geometric design of anti-infection catheters 论文复现 (#986)
* 黑客松No.7 * update * NO7 * update * fix config * fix code * Update examples/catheter/conf/catheter.yaml Co-authored-by: HydrogenSulfate <490868991@qq.com> * update config * update code * update code * add inference * add inference * add inference * codestyle * codestyle * bugfix * fix codestyle * fix bug * codestyle * Complex GradNode bug fix * codestyle * add docs * add docs 2 * add docs 3 * add docs 4 * add docs 5 * add docs 6 * add docs 7 * add docs 7 * add docs 8 * add docs 9 * add docs 10 * add docs 10 * add docs 11 * add docs 11 * add docs 11 * add docs 12 * add docs 13 * add docs 13 * add docs 14 * add docs 14 * add docs 15 * docfix for transovler * finally * finally * fix codestyle * fix codestyle * fix codestyle * fix codestyle * fix codestyle * fix codestyle * fix * fix * fix * fix * fix * fix --------- Co-authored-by: HydrogenSulfate <490868991@qq.com> Co-authored-by: wangguan <772359200@qq.com>
1 parent c4979df commit ee2d77c

File tree

9 files changed

+1170
-0
lines changed

9 files changed

+1170
-0
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计
8585
| 流场高分辨率重构 | [2D 湍流流场重构](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/tempoGAN) | 数据驱动 | tempoGAN | 监督学习 | [Train Data](https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_train.mat)<br>[Eval Data](https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_valid.mat) | [Paper](https://dl.acm.org/doi/10.1145/3197517.3201304)|
8686
| 流场高分辨率重构 | [2D 湍流流场重构](https://aistudio.baidu.com/projectdetail/4493261?contributionType=1) | 数据驱动 | cycleGAN | 监督学习 | [Train Data](https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_train.mat)<br>[Eval Data](https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_valid.mat) | [Paper](https://arxiv.org/abs/2007.15324)|
8787
| 流场高分辨率重构 | [基于Voronoi嵌入辅助深度学习的稀疏传感器全局场重建](https://aistudio.baidu.com/projectdetail/5807904) | 数据驱动 | CNN | 监督学习 | [Data1](https://drive.google.com/drive/folders/1K7upSyHAIVtsyNAqe6P8TY1nS5WpxJ2c)<br>[Data2](https://drive.google.com/drive/folders/1pVW4epkeHkT2WHZB7Dym5IURcfOP4cXu)<br>[Data3](https://drive.google.com/drive/folders/1xIY_jIu-hNcRY-TTf4oYX1Xg4_fx8ZvD) | [Paper](https://arxiv.org/pdf/2202.11214.pdf) |
88+
| 流场预测 | [Catheter](https://aistudio.baidu.com/projectdetail/5379212) | 数据驱动 | FNO | 监督学习 | [Data](https://aistudio.baidu.com/datasetdetail/291940) | [Paper](https://www.science.org/doi/pdf/10.1126/sciadv.adj1741) |
8889
| 求解器耦合 | [CFD-GCN](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/cfdgcn) | 数据驱动 | GCN | 监督学习 | [Data](https://aistudio.baidu.com/aistudio/datasetdetail/184778)<br>[Mesh](https://paddle-org.bj.bcebos.com/paddlescience/datasets/CFDGCN/meshes.tar) | [Paper](https://arxiv.org/abs/2007.04439)|
8990
| 受力分析 | [1D 欧拉梁变形](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/euler_beam) | 机理驱动 | MLP | 无监督学习 | - | - |
9091
| 受力分析 | [2D 平板变形](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/biharmonic2d) | 机理驱动 | MLP | 无监督学习 | - | [Paper](https://arxiv.org/abs/2108.07243) |

docs/index.md

+1
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@
123123
| 流场高分辨率重构 | [2D 湍流流场重构](./zh/examples/tempoGAN.md) | 数据驱动 | tempoGAN | 监督学习 | [Train Data](https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_train.mat)<br>[Eval Data](https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_valid.mat) | [Paper](https://dl.acm.org/doi/10.1145/3197517.3201304)|
124124
| 流场高分辨率重构 | [2D 湍流流场重构](https://aistudio.baidu.com/projectdetail/4493261?contributionType=1) | 数据驱动 | cycleGAN | 监督学习 | [Train Data](https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_train.mat)<br>[Eval Data](https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_valid.mat) | [Paper](https://arxiv.org/abs/2007.15324)|
125125
| 流场高分辨率重构 | [基于Voronoi嵌入辅助深度学习的稀疏传感器全局场重建](https://aistudio.baidu.com/projectdetail/5807904) | 数据驱动 | CNN | 监督学习 | [Data1](https://drive.google.com/drive/folders/1K7upSyHAIVtsyNAqe6P8TY1nS5WpxJ2c)<br>[Data2](https://drive.google.com/drive/folders/1pVW4epkeHkT2WHZB7Dym5IURcfOP4cXu)<br>[Data3](https://drive.google.com/drive/folders/1xIY_jIu-hNcRY-TTf4oYX1Xg4_fx8ZvD) | [Paper](https://arxiv.org/pdf/2202.11214.pdf) |
126+
| 流场预测 | [Catheter](https://aistudio.baidu.com/projectdetail/5379212) | 数据驱动 | FNO | 监督学习 | [Data](https://aistudio.baidu.com/datasetdetail/291940) | [Paper](https://www.science.org/doi/pdf/10.1126/sciadv.adj1741) |
126127
| 求解器耦合 | [CFD-GCN](./zh/examples/cfdgcn.md) | 数据驱动 | GCN | 监督学习 | [Data](https://aistudio.baidu.com/aistudio/datasetdetail/184778)<br>[Mesh](https://paddle-org.bj.bcebos.com/paddlescience/datasets/CFDGCN/meshes.tar) | [Paper](https://arxiv.org/abs/2007.04439)|
127128
| 受力分析 | [1D 欧拉梁变形](./zh/examples/euler_beam.md) | 机理驱动 | MLP | 无监督学习 | - | - |
128129
| 受力分析 | [2D 平板变形](./zh/examples/biharmonic2d.md) | 机理驱动 | MLP | 无监督学习 | - | [Paper](https://arxiv.org/abs/2108.07243) |

docs/zh/api/arch.md

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
- DGMR
1919
- Discriminator
2020
- ExtFormerMoECuboid
21+
- FNO1d
2122
- Generator
2223
- HEDeepONets
2324
- LorenzEmbedding

docs/zh/examples/catheter.md

+563
Large diffs are not rendered by default.

examples/catheter/catheter.py

+303
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from os import path as osp
17+
18+
import hydra
19+
import matplotlib.pyplot as plt
20+
import numpy as np
21+
import paddle
22+
from omegaconf import DictConfig
23+
24+
import ppsci
25+
from ppsci.loss import L2RelLoss
26+
from ppsci.optimizer import Adam
27+
from ppsci.optimizer import lr_scheduler
28+
from ppsci.utils import logger
29+
30+
31+
# build data
32+
def getdata(
33+
x_path,
34+
y_path,
35+
para_path,
36+
output_path,
37+
n_data,
38+
n,
39+
s,
40+
is_train=True,
41+
is_inference=False,
42+
):
43+
# load data
44+
inputX_raw = np.load(x_path)[:, 0:n_data]
45+
inputY_raw = np.load(y_path)[:, 0:n_data]
46+
inputPara_raw = np.load(para_path)[:, 0:n_data]
47+
output_raw = np.load(output_path)[:, 0:n_data]
48+
49+
# preprocess data
50+
inputX = inputX_raw[:, 0::3]
51+
inputY = inputY_raw[:, 0::3]
52+
inputPara = inputPara_raw[:, 0::3]
53+
label = (output_raw[:, 0::3] + output_raw[:, 1::3] + output_raw[:, 2::3]) / 3.0
54+
55+
if is_inference:
56+
inputX = np.transpose(inputX, (1, 0))
57+
inputY = np.transpose(inputY, (1, 0))
58+
input = np.stack(arrays=[inputX, inputY], axis=-1).astype(np.float32)
59+
input = input.reshape(n, s, 2)
60+
return input
61+
62+
inputX = paddle.to_tensor(data=inputX, dtype="float32").transpose(perm=[1, 0])
63+
inputY = paddle.to_tensor(data=inputY, dtype="float32").transpose(perm=[1, 0])
64+
input = paddle.stack(x=[inputX, inputY], axis=-1)
65+
label = paddle.to_tensor(data=label, dtype="float32").transpose(perm=[1, 0])
66+
if is_train:
67+
index = paddle.randperm(n=n)
68+
index = index[:n]
69+
input = paddle.index_select(input, index)
70+
label = paddle.index_select(label, index)
71+
input = input.reshape([n, s, 2])
72+
else:
73+
input = input.reshape([n, s, 2])
74+
label = label.unsqueeze(axis=-1)
75+
return input, label, inputPara
76+
77+
78+
def plot(input: np.ndarray, out_pred: np.ndarray, output_dir: str):
79+
os.makedirs(output_dir, exist_ok=True)
80+
fig_path = osp.join(output_dir, "inference.png")
81+
82+
xx = np.linspace(-500, 0, 2001)
83+
fig = plt.figure(figsize=(5, 4))
84+
plt.plot(input[:, 0], input[:, 1], color="C1", label="Channel geometry")
85+
plt.plot(input[:, 0], 100 - input[:, 1], color="C1")
86+
plt.plot(
87+
xx,
88+
out_pred,
89+
"--*",
90+
color="C2",
91+
fillstyle="none",
92+
markevery=len(xx) // 10,
93+
label="Predicted bacteria distribution",
94+
)
95+
plt.xlabel(r"x")
96+
plt.legend()
97+
plt.tight_layout()
98+
fig.savefig(fig_path, bbox_inches="tight", dpi=400)
99+
plt.close()
100+
ppsci.utils.logger.info(f"Saving figure to {fig_path}")
101+
102+
103+
def train(cfg: DictConfig):
104+
# generate training dataset
105+
inputs_train, labels_train, _ = getdata(**cfg.TRAIN_DATA, is_train=True)
106+
107+
# set constraints
108+
sup_constraint = ppsci.constraint.SupervisedConstraint(
109+
{
110+
"dataset": {
111+
"name": "NamedArrayDataset",
112+
"input": {"input": inputs_train},
113+
"label": {"output": labels_train},
114+
},
115+
"batch_size": cfg.TRAIN.batch_size,
116+
"sampler": {
117+
"name": "BatchSampler",
118+
"drop_last": False,
119+
"shuffle": True,
120+
},
121+
},
122+
L2RelLoss(reduction="sum"),
123+
name="sup_constraint",
124+
)
125+
constraint = {sup_constraint.name: sup_constraint}
126+
127+
# set model
128+
model = ppsci.arch.FNO1d(**cfg.MODEL)
129+
if cfg.TRAIN.use_pretrained_model is True:
130+
logger.info(
131+
"Loading pretrained model from {}".format(cfg.TRAIN.pretrained_model_path)
132+
)
133+
model.set_state_dict(paddle.load(cfg.TRAIN.pretrained_model_path))
134+
135+
# set optimizer
136+
ITERS_PER_EPOCH = int(cfg.TRAIN_DATA.n / cfg.TRAIN.batch_size)
137+
scheduler = lr_scheduler.Step(
138+
**cfg.TRAIN.lr_scheduler, iters_per_epoch=ITERS_PER_EPOCH
139+
)
140+
optimizer = Adam(scheduler(), weight_decay=cfg.TRAIN.weight_decay)(model)
141+
142+
# generate test dataset
143+
inputs_test, labels_test, _ = getdata(**cfg.TEST_DATA, is_train=False)
144+
145+
# set validator
146+
l2rel_validator = {
147+
"validator1": ppsci.validate.SupervisedValidator(
148+
{
149+
"dataset": {
150+
"name": "NamedArrayDataset",
151+
"input": {"input": inputs_test},
152+
"label": {"output": labels_test},
153+
},
154+
"batch_size": cfg.TRAIN.batch_size,
155+
},
156+
L2RelLoss(reduction="sum"),
157+
metric={"L2Rel": ppsci.metric.L2Rel()},
158+
name="L2Rel_Validator",
159+
)
160+
}
161+
162+
# initialize solver
163+
solver = ppsci.solver.Solver(
164+
model,
165+
constraint,
166+
cfg.output_dir,
167+
optimizer,
168+
epochs=cfg.TRAIN.epochs,
169+
iters_per_epoch=ITERS_PER_EPOCH,
170+
eval_with_no_grad=True,
171+
eval_during_train=cfg.TRAIN.eval_during_train,
172+
validator=l2rel_validator,
173+
save_freq=cfg.TRAIN.save_freq,
174+
)
175+
176+
# train model
177+
solver.train()
178+
# plot losses
179+
solver.plot_loss_history(by_epoch=True, smooth_step=1)
180+
181+
182+
def evaluate(cfg: DictConfig):
183+
# set model
184+
model = ppsci.arch.FNO1d(**cfg.MODEL)
185+
ppsci.utils.save_load.load_pretrain(
186+
model,
187+
cfg.EVAL.pretrained_model_path,
188+
)
189+
190+
# set data
191+
x_test, y_test, para = getdata(**cfg.TEST_DATA, is_train=False)
192+
y_test = y_test.numpy()
193+
194+
for sample_id in [0, 8]:
195+
sample, uf, L_p, x1, x2, x3, h = para[:, sample_id]
196+
mesh = x_test[sample_id, :, :]
197+
mesh = mesh.numpy()
198+
199+
y_test_pred = (
200+
paddle.exp(
201+
model({"input": x_test[sample_id : sample_id + 1, :, :]})["output"]
202+
)
203+
.numpy()
204+
.flatten()
205+
)
206+
logger.info(
207+
"rel. error is ",
208+
np.linalg.norm(y_test_pred - y_test[sample_id, :].flatten())
209+
/ np.linalg.norm(y_test[sample_id, :].flatten()),
210+
)
211+
xx = np.linspace(-500, 0, 2001)
212+
plt.figure(figsize=(5, 4))
213+
214+
plt.plot(mesh[:, 0], mesh[:, 1], color="C1", label="Channel geometry")
215+
plt.plot(mesh[:, 0], 100 - mesh[:, 1], color="C1")
216+
217+
plt.plot(
218+
xx,
219+
y_test[sample_id, :],
220+
"--o",
221+
color="red",
222+
markevery=len(xx) // 10,
223+
label="Reference",
224+
)
225+
plt.plot(
226+
xx,
227+
y_test_pred,
228+
"--*",
229+
color="C2",
230+
fillstyle="none",
231+
markevery=len(xx) // 10,
232+
label="Predicted bacteria distribution",
233+
)
234+
235+
plt.xlabel(r"x")
236+
237+
plt.legend()
238+
plt.tight_layout()
239+
plt.savefig(f"Validation.{sample_id}.pdf")
240+
241+
242+
def export(cfg: DictConfig):
243+
# set model
244+
model = ppsci.arch.FNO1d(**cfg.MODEL)
245+
# initialize solver
246+
solver = ppsci.solver.Solver(
247+
model,
248+
pretrained_model_path=cfg.INFER.pretrained_model_path,
249+
)
250+
# export model
251+
from paddle.static import InputSpec
252+
253+
input_spec = [
254+
{
255+
key: InputSpec([None, 2001, 2], "float32", name=key)
256+
for key in model.input_keys
257+
},
258+
]
259+
solver.export(input_spec, cfg.INFER.export_path)
260+
261+
262+
def inference(cfg: DictConfig):
263+
from deploy import python_infer
264+
265+
predictor = python_infer.GeneralPredictor(cfg)
266+
267+
# evaluate
268+
input = getdata(**cfg.TEST_DATA, is_train=False, is_inference=True)
269+
input_dict = {"input": input}
270+
271+
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
272+
# mapping data to cfg.INFER.output_keys
273+
output_keys = ["output"]
274+
output_dict = {
275+
store_key: paddle.exp(paddle.to_tensor(output_dict[infer_key]))
276+
.numpy()
277+
.flatten()
278+
for store_key, infer_key in zip(output_keys, output_dict.keys())
279+
}
280+
281+
mesh = input_dict["input"][5, :, :]
282+
yy = output_dict["output"][5]
283+
plot(mesh, yy, cfg.output_dir)
284+
285+
286+
@hydra.main(version_base=None, config_path="./conf", config_name="catheter.yaml")
287+
def main(cfg: DictConfig):
288+
if cfg.mode == "train":
289+
train(cfg)
290+
elif cfg.mode == "eval":
291+
evaluate(cfg)
292+
elif cfg.mode == "export":
293+
export(cfg)
294+
elif cfg.mode == "infer":
295+
inference(cfg)
296+
else:
297+
raise ValueError(
298+
f"cfg.mode should in ['train', 'eval', 'export', 'infer], but got '{cfg.mode}'"
299+
)
300+
301+
302+
if __name__ == "__main__":
303+
main()

0 commit comments

Comments
 (0)