Skip to content

Commit 8616b99

Browse files
return encoder representations only if necessary
1 parent 306b2e5 commit 8616b99

File tree

3 files changed

+35
-19
lines changed

3 files changed

+35
-19
lines changed

onmt/models/model.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -46,22 +46,23 @@ def forward(self, src, tgt, lengths, bptt=False,
4646

4747
enc_state, memory_bank, lengths = self.encoder(src, lengths)
4848

49+
if bptt is False:
50+
self.decoder.init_state(src, memory_bank, enc_state)
51+
52+
dec_out, attns = self.decoder(dec_in, memory_bank,
53+
memory_lengths=lengths,
54+
with_align=with_align)
55+
4956
if encode_tgt:
5057
# tgt for zero shot alignment loss
5158
tgt_lengths = torch.Tensor(tgt.size(1))\
5259
.type_as(memory_bank) \
5360
.long() \
5461
.fill_(tgt.size(0))
5562
embs_tgt, memory_bank_tgt, ltgt = self.encoder(tgt, tgt_lengths)
56-
else:
57-
memory_bank_tgt = None
63+
return dec_out, attns, memory_bank, memory_bank_tgt
5864

59-
if bptt is False:
60-
self.decoder.init_state(src, memory_bank, enc_state)
61-
dec_out, attns = self.decoder(dec_in, memory_bank,
62-
memory_lengths=lengths,
63-
with_align=with_align)
64-
return dec_out, attns, memory_bank, memory_bank_tgt
65+
return dec_out, attns
6566

6667
def update_dropout(self, dropout):
6768
self.encoder.update_dropout(dropout)

onmt/trainer.py

+24-9
Original file line numberDiff line numberDiff line change
@@ -317,13 +317,21 @@ def validate(self, valid_iter, moving_average=None):
317317
tgt = batch.tgt
318318

319319
# F-prop through the model.
320-
outputs, attns, enc_src, enc_tgt = valid_model(
321-
src, tgt, src_lengths,
322-
with_align=self.with_align)
320+
if self.encode_tgt:
321+
outputs, attns, enc_src, enc_tgt = valid_model(
322+
src, tgt, src_lengths,
323+
with_align=self.with_align,
324+
encode_tgt=self.encode_tgt)
325+
else:
326+
output, attns = valid_model(
327+
src, tgt, src_lengths,
328+
with_align=self.with_align)
329+
enc_src, enc_tgt = None, None
323330

324331
# Compute loss.
325332
_, batch_stats = self.valid_loss(
326-
batch, outputs, attns, enc_src, enc_tgt)
333+
batch, outputs, attns,
334+
enc_src=enc_src, enc_tgt=enc_tgt)
327335

328336
# Update statistics.
329337
stats.update(batch_stats)
@@ -366,9 +374,16 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats,
366374
if self.accum_count == 1:
367375
self.optim.zero_grad()
368376

369-
outputs, attns, enc_src, enc_tgt = self.model(
370-
src, tgt, src_lengths, bptt=bptt,
371-
with_align=self.with_align, encode_tgt=self.encode_tgt)
377+
if self.encode_tgt:
378+
outputs, attns, enc_src, enc_tgt = self.model(
379+
src, tgt, src_lengths, bptt=bptt,
380+
with_align=self.with_align, encode_tgt=self.encode_tgt)
381+
else:
382+
output, attns = self.model(
383+
src, tgt, src_lengths, bptt=bptt,
384+
with_align=self.with_align)
385+
enc_src, enc_tgt = None, None
386+
372387
bptt = True
373388

374389
# 3. Compute loss.
@@ -377,8 +392,8 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats,
377392
batch,
378393
outputs,
379394
attns,
380-
enc_src,
381-
enc_tgt,
395+
enc_src=enc_src,
396+
enc_tgt=enc_tgt,
382397
normalization=normalization,
383398
shard_size=self.shard_size,
384399
trunc_start=j,

onmt/utils/loss.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ def __call__(self,
124124
batch,
125125
output,
126126
attns,
127-
enc_src,
128-
enc_tgt,
127+
enc_src=None,
128+
enc_tgt=None,
129129
normalization=1.0,
130130
shard_size=0,
131131
trunc_start=0,

0 commit comments

Comments
 (0)