-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patharchitect.py
34 lines (25 loc) · 1.13 KB
/
architect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import numpy as np
import torch.nn as nn
from SSIM import SSIM
from percep_loss import networks
from percep_loss import vgg
mse = nn.MSELoss().cuda()
ssim = SSIM().cuda()
self_device = torch.device('cuda:{}'.format('0'))
self_vgg = vgg.Vgg19(requires_grad=False).to(self_device)
criterionVgg = networks.VGGLoss1(self_device, vgg=self_vgg, normalize=False)
class Architect () :
def __init__(self, model, args):
self.model = model
self.optimizer = torch.optim.Adam(self.model.arch_parameters(),
lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay)
def step (self, output_valid, target_valid, blended_valid) :
self.optimizer.zero_grad ()
self._backward_step(output_valid, target_valid, blended_valid)
self.optimizer.step()
def _backward_step (self, output_valid, target_valid, blended_valid) :
loss = 0.1 * mse(output_valid, target_valid) + (1-ssim(output_valid, target_valid))
vgg_loss = criterionVgg(target_valid, output_valid) / criterionVgg(blended_valid, output_valid)
loss += 0.1 * vgg_loss
loss.backward ()