-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathfull_mnist.py
65 lines (35 loc) · 2.29 KB
/
full_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from tensorflow_generator import TensorflowGenerator
from products_tree import ProductSet
import multiprocessing
import json
def run_tensorflow(product, url, index,datasets=[], epochs=12, depth=1, data_augmentation=False):
for dataset in datasets:
logpath = "report_all_{2}epochs_{3}depth_{0}_{1}.txt".format(url,dataset, epochs, depth)
tensorflow = TensorflowGenerator(product,epochs, dataset, depth=depth, data_augmentation=data_augmentation)
f2 = open(logpath,"a")
history = "{accuracy}|{validation_accuracy}".format(accuracy="#".join(map(str, tensorflow.history[0])), validation_accuracy="#".join(map(str, tensorflow.history[1])))
f2.write("\r\n{0}: {1} {2} {3} {4} {5} {6}".format(index, tensorflow.accuracy, tensorflow.stop_training, tensorflow.training_time, tensorflow.params, tensorflow.flops, history))
f2.close()
def main(target, min_index=0, max_index=0, filter_indices=[], datasets=None,epochs=12, depth=1, data_augmentation=False):
baseurl = "./"
productSet = ProductSet(baseurl+target+".pdt")
if not datasets:
datasets = ["mnist"]
for index,product in enumerate(productSet.format_products()):
print("product {0}".format(index))
if index >= min_index and (len(filter_indices)==0 or index in filter_indices):
f = open("{0}products/{1}_{2}.json".format(baseurl, target, index), "w")
str_ = json.dumps(product)
f.write(str_)
f.close()
run_tensorflow(product, target, index, datasets, epochs, depth, data_augmentation=data_augmentation)
#p = multiprocessing.Process(target=run_tensorflow, args=(product,))
#p.start()
#p.join()
if max_index!= 0 and index ==max_index:
break
if __name__ == "__main__":
# execute only if run as a script
top_cifar = [59, 63, 143, 161, 203, 444, 477, 595, 634, 936]
top_cifar = []
main("./datasets/1000Products", datasets=["mnist"], epochs=300, depth=1, data_augmentation=False)