@@ -317,13 +317,21 @@ def validate(self, valid_iter, moving_average=None):
317
317
tgt = batch .tgt
318
318
319
319
# F-prop through the model.
320
- outputs , attns , enc_src , enc_tgt = valid_model (
321
- src , tgt , src_lengths ,
322
- with_align = self .with_align )
320
+ if self .encode_tgt :
321
+ outputs , attns , enc_src , enc_tgt = valid_model (
322
+ src , tgt , src_lengths ,
323
+ with_align = self .with_align ,
324
+ encode_tgt = self .encode_tgt )
325
+ else :
326
+ output , attns = valid_model (
327
+ src , tgt , src_lengths ,
328
+ with_align = self .with_align )
329
+ enc_src , enc_tgt = None , None
323
330
324
331
# Compute loss.
325
332
_ , batch_stats = self .valid_loss (
326
- batch , outputs , attns , enc_src , enc_tgt )
333
+ batch , outputs , attns ,
334
+ enc_src = enc_src , enc_tgt = enc_tgt )
327
335
328
336
# Update statistics.
329
337
stats .update (batch_stats )
@@ -366,9 +374,16 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats,
366
374
if self .accum_count == 1 :
367
375
self .optim .zero_grad ()
368
376
369
- outputs , attns , enc_src , enc_tgt = self .model (
370
- src , tgt , src_lengths , bptt = bptt ,
371
- with_align = self .with_align , encode_tgt = self .encode_tgt )
377
+ if self .encode_tgt :
378
+ outputs , attns , enc_src , enc_tgt = self .model (
379
+ src , tgt , src_lengths , bptt = bptt ,
380
+ with_align = self .with_align , encode_tgt = self .encode_tgt )
381
+ else :
382
+ output , attns = self .model (
383
+ src , tgt , src_lengths , bptt = bptt ,
384
+ with_align = self .with_align )
385
+ enc_src , enc_tgt = None , None
386
+
372
387
bptt = True
373
388
374
389
# 3. Compute loss.
@@ -377,8 +392,8 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats,
377
392
batch ,
378
393
outputs ,
379
394
attns ,
380
- enc_src ,
381
- enc_tgt ,
395
+ enc_src = enc_src ,
396
+ enc_tgt = enc_tgt ,
382
397
normalization = normalization ,
383
398
shard_size = self .shard_size ,
384
399
trunc_start = j ,
0 commit comments