Skip to content

Commit 0d6d3bf

Browse files
committed
fix: ga step
1 parent 6be5428 commit 0d6d3bf

File tree

1 file changed

+19
-28
lines changed

1 file changed

+19
-28
lines changed

tensorflow_asr/models/base_model.py

+19-28
Original file line numberDiff line numberDiff line change
@@ -184,14 +184,19 @@ def train_step(self, data):
184184
metrics = self.get_metrics_result()
185185
return metrics
186186

187-
def train_step_ga(self, data_buffer):
188-
first_data, *rest_data = data_buffer
189-
gradients = self._train_step(first_data)
190-
for data in rest_data:
191-
next_gradients = self._train_step(data)
192-
gradients = self.ga.accumulate(gradients, next_gradients)
193-
self._apply_gradients(gradients)
187+
def train_step_ga(self, data):
188+
gradients = self._train_step(data)
194189
metrics = self.get_metrics_result()
190+
return metrics, gradients
191+
192+
def train_step_ga_next(self, data, prev_gradients):
193+
metrics, gradients = self.train_step_ga(data)
194+
gradients = self.ga.accumulate(prev_gradients, gradients)
195+
return metrics, gradients
196+
197+
def train_step_ga_last(self, data, prev_gradients):
198+
metrics, gradients = self.train_step_ga_next(data, prev_gradients)
199+
self._apply_gradients(gradients)
195200
return metrics
196201

197202
def _test_step(self, data: schemas.TrainData):
@@ -244,29 +249,17 @@ def make_train_function(self, force=False):
244249
return self.train_function
245250

246251
@tf.autograph.experimental.do_not_convert
247-
def one_ga_step_on_data(data_buffer):
248-
outputs = self.distribute_strategy.run(self.train_step_ga, args=(data_buffer,))
252+
def one_ga_step_on_data(iterator):
253+
outputs, gradients = self.distribute_strategy.run(self.train_step_ga, args=(next(iterator),))
254+
for i, data in zip(range(1, self.ga.total_steps - 1), iterator):
255+
outputs, gradients = self.distribute_strategy.run(self.train_step_ga_next, args=(data, gradients))
256+
outputs = self.distribute_strategy.run(self.train_step_ga_last, args=(next(iterator), gradients))
249257
outputs = reduce_per_replica(
250258
outputs,
251259
self.distribute_strategy,
252260
reduction="auto",
253261
)
254262
return outputs
255-
# data = next(iterator)
256-
# outputs, gradients = self.distribute_strategy.run(self.train_step_ga, args=(data, None))
257-
# for _ in range(1, self.ga.total_steps):
258-
# try:
259-
# data = next(iterator)
260-
# outputs, gradients = self.distribute_strategy.run(self.train_step_ga, args=(data, gradients))
261-
# except StopIteration:
262-
# break
263-
# self.distribute_strategy.run(self._apply_gradients, args=(gradients,))
264-
# outputs = reduce_per_replica(
265-
# outputs,
266-
# self.distribute_strategy,
267-
# reduction="auto",
268-
# )
269-
# return outputs
270263

271264
if not self.run_eagerly:
272265
one_ga_step_on_data = tf.function(
@@ -276,10 +269,8 @@ def one_ga_step_on_data(data_buffer):
276269
)
277270

278271
def function(iterator):
279-
data_buffer = []
280-
for _, data in zip(range(self.ga.total_steps), iterator):
281-
data_buffer.append(data)
282-
return one_ga_step_on_data(data_buffer)
272+
outputs = one_ga_step_on_data(iterator)
273+
return outputs
283274

284275
self.train_function = function
285276
return self.train_function

0 commit comments

Comments
 (0)