Skip to content

Prepending batch dimensions to match Keras interfaces #39

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
32 changes: 24 additions & 8 deletions xbatcher/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,22 @@ def _drop_input_dims(ds, input_dims, suffix='_input'):
return out


def _maybe_stack_batch_dims(ds, input_dims, stacked_dim_name='sample'):
def _maybe_stack_batch_dims(
ds, input_dims, squeeze_batch_dim, stacked_dim_name='sample'
):
batch_dims = [d for d in ds.dims if d not in input_dims]
if len(batch_dims) < 2:
if len(batch_dims) == 0:
if squeeze_batch_dim:
return ds
else:
return ds.expand_dims(stacked_dim_name, 0)
elif len(batch_dims) == 1:
return ds
ds_stack = ds.stack(**{stacked_dim_name: batch_dims})
# ensure correct order
dim_order = (stacked_dim_name,) + tuple(input_dims)
return ds_stack.transpose(*dim_order)
else:
ds_stack = ds.stack(**{stacked_dim_name: batch_dims})
# ensure correct order
dim_order = (stacked_dim_name,) + tuple(input_dims)
return ds_stack.transpose(*dim_order)


class BatchGenerator:
Expand Down Expand Up @@ -90,6 +98,10 @@ class BatchGenerator:
preload_batch : bool, optional
If ``True``, each batch will be loaded into memory before reshaping /
processing, triggering any dask arrays to be computed.
squeeze_batch_dim : bool, optional
If ``False`` and all dims are input dims, each batch's dataset will have a
"batch" dimension of size 1 prepended to the array. This functionality is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some suggestions on the documentation. Maybe L86 and L87 could be edited to indicate the squeeze behavior is controllable with squeeze_batch_dim?

Also, this sentence might be a bit confusing to read. So an extra dimension of size 1 is added/prepended only when batch_dims is None or unset. For cases where len(batch_dims) >= 1), the squeezing/collapsing of dimensions is still happening though. I'm wondering what's a good way to reword this to make it clearer on what is happening and why this option exists.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem would only appear in this one corner case, as far as I could tell. So, my solution was coded to only apply when there were no batch_dims (at least that was my intention). If you're changing the default behavior though, you can probably make this simpler, and therefore the docs would be less confusing too.

useful for interoperability with Keras / Tensorflow.

Yields
------
Expand All @@ -105,6 +117,7 @@ def __init__(
batch_dims={},
concat_input_dims=False,
preload_batch=True,
squeeze_batch_dim=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it ok to set the default to True? I.e. squeeze the batch dimension if all dims are included in the input dims. This might break existing behaviour, so wondering if the default should be False instead for backward compatibility.

Copy link
Author

@cmdupuis3 cmdupuis3 Nov 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would definitely break existing behavior. I'm not really sure which would be more intuitive, I think that's kind of subjective, but personally I would agree with you.

The problem is that the existing behavior already squeezes the array, so I think it would be kind of a pain to have to use xr.DataArray.expand_dims, because you don't really know which dimension is being squeezed (because it's not there anymore). You'd probably end up digging through the code and iterating a few times, like I did in this scenario. Also, the fact that this behavior results in different-dimensional array results just from changing the batch dims, I think breaks the grammar that is being established here. I think your proposal to have False as default plus xr.DataArray.squeeze is much more logical here.

Copy link
Member

@weiji14 weiji14 Nov 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that the existing behavior already squeezes the array, so I think it would be kind of a pain to have to use xr.DataArray.expand_dims,

Yeah, I also find this unintuitive, the fact that xbatcher squeezes any non input_dims into a dim called sample (edit: there's actually a related issue at #127). So another way to think of this squeeze_batch_dim feature is that xbatcher will just return the cropped/sliced/chipped array without squeezing the dims. This is more important with higher dimensional (>3D) arrays because sometimes you do want to preserve the extra dims.

):

self.ds = _as_xarray_dataset(ds)
Expand All @@ -114,6 +127,7 @@ def __init__(
self.batch_dims = OrderedDict(batch_dims)
self.concat_input_dims = concat_input_dims
self.preload_batch = preload_batch
self.squeeze_batch_dim = squeeze_batch_dim

def __iter__(self):
for ds_batch in self._iterate_batch_dims(self.ds):
Expand All @@ -132,11 +146,13 @@ def __iter__(self):
new_input_dims = [
dim + new_dim_suffix for dim in self.input_dims
]
yield _maybe_stack_batch_dims(dsc, new_input_dims)
yield _maybe_stack_batch_dims(
dsc, new_input_dims, self.squeeze_batch_dim
)
else:
for ds_input in input_generator:
yield _maybe_stack_batch_dims(
ds_input, list(self.input_dims)
ds_input, list(self.input_dims), self.squeeze_batch_dim
)

def _iterate_batch_dims(self, ds):
Expand Down
57 changes: 57 additions & 0 deletions xbatcher/tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,63 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, bsize):
)


@pytest.mark.parametrize('bsize', [10, 20])
def test_batch_1d_squeeze_batch_dim(sample_ds_1d, bsize):
bg = BatchGenerator(
sample_ds_1d,
input_dims={'x': bsize},
squeeze_batch_dim=False,
)
for ds_batch in bg:
assert list(ds_batch['foo'].shape) == [1, bsize]

bg2 = BatchGenerator(
sample_ds_1d,
input_dims={'x': bsize},
squeeze_batch_dim=True,
)
for ds_batch in bg2:
assert list(ds_batch['foo'].shape) == [bsize]


@pytest.mark.parametrize('bsize', [5, 10])
def test_batch_3d_squeeze_batch_dim(sample_ds_3d, bsize):
bg = BatchGenerator(
sample_ds_3d,
input_dims={'y': bsize, 'x': bsize},
squeeze_batch_dim=False,
)
for ds_batch in bg:
assert list(ds_batch['foo'].shape) == [10, bsize, bsize]

bg2 = BatchGenerator(
sample_ds_3d,
input_dims={'y': bsize, 'x': bsize},
squeeze_batch_dim=True,
)
for ds_batch in bg2:
assert list(ds_batch['foo'].shape) == [10, bsize, bsize]


@pytest.mark.parametrize('bsize', [5, 10])
def test_batch_3d_squeeze_batch_dim2(sample_ds_3d, bsize):
bg = BatchGenerator(
sample_ds_3d,
input_dims={'x': bsize},
squeeze_batch_dim=False,
)
for ds_batch in bg:
assert list(ds_batch['foo'].shape) == [500, bsize]

bg2 = BatchGenerator(
sample_ds_3d,
input_dims={'x': bsize},
squeeze_batch_dim=True,
)
for ds_batch in bg2:
assert list(ds_batch['foo'].shape) == [500, bsize]


def test_preload_batch_false(sample_ds_1d):
sample_ds_1d_dask = sample_ds_1d.chunk({'x': 2})
bg = BatchGenerator(
Expand Down