@@ -184,14 +184,19 @@ def train_step(self, data):
184
184
metrics = self .get_metrics_result ()
185
185
return metrics
186
186
187
- def train_step_ga (self , data_buffer ):
188
- first_data , * rest_data = data_buffer
189
- gradients = self ._train_step (first_data )
190
- for data in rest_data :
191
- next_gradients = self ._train_step (data )
192
- gradients = self .ga .accumulate (gradients , next_gradients )
193
- self ._apply_gradients (gradients )
187
+ def train_step_ga (self , data ):
188
+ gradients = self ._train_step (data )
194
189
metrics = self .get_metrics_result ()
190
+ return metrics , gradients
191
+
192
+ def train_step_ga_next (self , data , prev_gradients ):
193
+ metrics , gradients = self .train_step_ga (data )
194
+ gradients = self .ga .accumulate (prev_gradients , gradients )
195
+ return metrics , gradients
196
+
197
+ def train_step_ga_last (self , data , prev_gradients ):
198
+ metrics , gradients = self .train_step_ga_next (data , prev_gradients )
199
+ self ._apply_gradients (gradients )
195
200
return metrics
196
201
197
202
def _test_step (self , data : schemas .TrainData ):
@@ -244,29 +249,17 @@ def make_train_function(self, force=False):
244
249
return self .train_function
245
250
246
251
@tf .autograph .experimental .do_not_convert
247
- def one_ga_step_on_data (data_buffer ):
248
- outputs = self .distribute_strategy .run (self .train_step_ga , args = (data_buffer ,))
252
+ def one_ga_step_on_data (iterator ):
253
+ outputs , gradients = self .distribute_strategy .run (self .train_step_ga , args = (next (iterator ),))
254
+ for i , data in zip (range (1 , self .ga .total_steps - 1 ), iterator ):
255
+ outputs , gradients = self .distribute_strategy .run (self .train_step_ga_next , args = (data , gradients ))
256
+ outputs = self .distribute_strategy .run (self .train_step_ga_last , args = (next (iterator ), gradients ))
249
257
outputs = reduce_per_replica (
250
258
outputs ,
251
259
self .distribute_strategy ,
252
260
reduction = "auto" ,
253
261
)
254
262
return outputs
255
- # data = next(iterator)
256
- # outputs, gradients = self.distribute_strategy.run(self.train_step_ga, args=(data, None))
257
- # for _ in range(1, self.ga.total_steps):
258
- # try:
259
- # data = next(iterator)
260
- # outputs, gradients = self.distribute_strategy.run(self.train_step_ga, args=(data, gradients))
261
- # except StopIteration:
262
- # break
263
- # self.distribute_strategy.run(self._apply_gradients, args=(gradients,))
264
- # outputs = reduce_per_replica(
265
- # outputs,
266
- # self.distribute_strategy,
267
- # reduction="auto",
268
- # )
269
- # return outputs
270
263
271
264
if not self .run_eagerly :
272
265
one_ga_step_on_data = tf .function (
@@ -276,10 +269,8 @@ def one_ga_step_on_data(data_buffer):
276
269
)
277
270
278
271
def function (iterator ):
279
- data_buffer = []
280
- for _ , data in zip (range (self .ga .total_steps ), iterator ):
281
- data_buffer .append (data )
282
- return one_ga_step_on_data (data_buffer )
272
+ outputs = one_ga_step_on_data (iterator )
273
+ return outputs
283
274
284
275
self .train_function = function
285
276
return self .train_function
0 commit comments