-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_sp.py
50 lines (38 loc) · 1.48 KB
/
train_sp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#test
import argparse
import autotracker
import matplotlib.pyplot as plt
import deeptrack as dt
import numpy as np
parser = argparse.ArgumentParser(
description="Train a label-free single-particle tracker",
)
parser.add_argument("filename", metavar="d", type=str)
parser.add_argument("--batch_size", dest="batch_size", type=int, default=8)
parser.add_argument("--epochs", dest="epochs", type=int, default=20)
parser.add_argument("--trainframes", dest="train_frames", type=str, default=":")
parser.add_argument("--lossfn", dest="lossfn", type=str, default="mae")
parser.add_argument("--prefix", dest="prefix", type=str, default="")
parser.add_argument("--radius", dest="radius", type=int, default=2)
parser.add_argument("--rotate", dest="rotate", type=int, default=1)
parser.add_argument("--sigma", dest="sigma", type=int, default=0.01)
def main():
args = parser.parse_args()
frames, _ = autotracker.load(args.filename)
training_set = eval(f"frames[{args.train_frames}]")
plt.imshow(training_set[0])
plt.savefig("trainim.png")
# print(training_set.shape)
dataloader = autotracker.dataloader(training_set)
model = autotracker.single_particle_model(
input_shape=frames.shape[1:], loss=args.lossfn,
)
model.fit(
dataloader,
epochs=args.epochs,
batch_size=args.batch_size,
generator_kwargs={"radius": args.radius, "rotate": args.rotate},
)
autotracker.save(model, args)
if __name__ == "__main__":
main()