-
Notifications
You must be signed in to change notification settings - Fork 29
Prepending batch dimensions to match Keras interfaces #39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7b0bd95
0679395
1322eff
ca34638
630bb27
c61846d
0e8f716
749ac26
142031d
fb29cba
da42a9c
a54a9b7
0824d25
f569190
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,14 +54,22 @@ def _drop_input_dims(ds, input_dims, suffix='_input'): | |
return out | ||
|
||
|
||
def _maybe_stack_batch_dims(ds, input_dims, stacked_dim_name='sample'): | ||
def _maybe_stack_batch_dims( | ||
ds, input_dims, squeeze_batch_dim, stacked_dim_name='sample' | ||
): | ||
batch_dims = [d for d in ds.dims if d not in input_dims] | ||
if len(batch_dims) < 2: | ||
if len(batch_dims) == 0: | ||
if squeeze_batch_dim: | ||
return ds | ||
else: | ||
return ds.expand_dims(stacked_dim_name, 0) | ||
elif len(batch_dims) == 1: | ||
return ds | ||
ds_stack = ds.stack(**{stacked_dim_name: batch_dims}) | ||
# ensure correct order | ||
dim_order = (stacked_dim_name,) + tuple(input_dims) | ||
return ds_stack.transpose(*dim_order) | ||
else: | ||
ds_stack = ds.stack(**{stacked_dim_name: batch_dims}) | ||
# ensure correct order | ||
dim_order = (stacked_dim_name,) + tuple(input_dims) | ||
return ds_stack.transpose(*dim_order) | ||
|
||
|
||
class BatchGenerator: | ||
|
@@ -90,6 +98,10 @@ class BatchGenerator: | |
preload_batch : bool, optional | ||
If ``True``, each batch will be loaded into memory before reshaping / | ||
processing, triggering any dask arrays to be computed. | ||
squeeze_batch_dim : bool, optional | ||
If ``False`` and all dims are input dims, each batch's dataset will have a | ||
"batch" dimension of size 1 prepended to the array. This functionality is | ||
useful for interoperability with Keras / Tensorflow. | ||
|
||
Yields | ||
------ | ||
|
@@ -105,6 +117,7 @@ def __init__( | |
batch_dims={}, | ||
concat_input_dims=False, | ||
preload_batch=True, | ||
squeeze_batch_dim=True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it ok to set the default to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would definitely break existing behavior. I'm not really sure which would be more intuitive, I think that's kind of subjective, but personally I would agree with you. The problem is that the existing behavior already squeezes the array, so I think it would be kind of a pain to have to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yeah, I also find this unintuitive, the fact that xbatcher squeezes any non |
||
): | ||
|
||
self.ds = _as_xarray_dataset(ds) | ||
|
@@ -114,6 +127,7 @@ def __init__( | |
self.batch_dims = OrderedDict(batch_dims) | ||
self.concat_input_dims = concat_input_dims | ||
self.preload_batch = preload_batch | ||
self.squeeze_batch_dim = squeeze_batch_dim | ||
|
||
def __iter__(self): | ||
for ds_batch in self._iterate_batch_dims(self.ds): | ||
|
@@ -132,11 +146,13 @@ def __iter__(self): | |
new_input_dims = [ | ||
dim + new_dim_suffix for dim in self.input_dims | ||
] | ||
yield _maybe_stack_batch_dims(dsc, new_input_dims) | ||
yield _maybe_stack_batch_dims( | ||
dsc, new_input_dims, self.squeeze_batch_dim | ||
) | ||
else: | ||
for ds_input in input_generator: | ||
yield _maybe_stack_batch_dims( | ||
ds_input, list(self.input_dims) | ||
ds_input, list(self.input_dims), self.squeeze_batch_dim | ||
) | ||
|
||
def _iterate_batch_dims(self, ds): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some suggestions on the documentation. Maybe L86 and L87 could be edited to indicate the squeeze behavior is controllable with
squeeze_batch_dim
?Also, this sentence might be a bit confusing to read. So an extra dimension of size 1 is added/prepended only when
batch_dims is None
or unset. For cases wherelen(batch_dims) >= 1)
, the squeezing/collapsing of dimensions is still happening though. I'm wondering what's a good way to reword this to make it clearer on what is happening and why this option exists.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem would only appear in this one corner case, as far as I could tell. So, my solution was coded to only apply when there were no
batch_dims
(at least that was my intention). If you're changing the default behavior though, you can probably make this simpler, and therefore the docs would be less confusing too.