Skip to content

dataset's [ implementation has inconsistent output shapes when $.getitem() is implemented #1307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
sebffischer opened this issue Apr 17, 2025 · 4 comments

Comments

@sebffischer
Copy link
Collaborator

sebffischer commented Apr 17, 2025

In the example below, for the dataset with the $.getitem() implementation, the [ method returns an element without batch dimension for an index of length 1, and otherwise includes the batch dimension. I think it would be better to have this consistent and always return the batch dimension.

library(torch)

ds_batch = dataset("batch",
  initialize = function() {
    self$x = torch_randn(100, 10)
  },
  .getbatch = function(i) {
    self$x[i,.., drop = FALSE]
  },
  .length = function() nrow(self$x)
)()

print(ds_batch[1L]$shape)
#> [1]  1 10
print(ds_batch[1:2]$shape)
#> [1]  2 10

ds_item = dataset("batch",
  initialize = function() {
    self$x = torch_randn(100, 10)
  },
  .getitem = function(i) {
    self$x[i]
  },
  .length = function() nrow(self$x)
)()

print(ds_item[1L]$shape)
#> [1] 10
print(ds_item[1:2]$shape)
#> [1]  2 10

Created on 2025-04-17 with reprex v2.1.1

@sebffischer
Copy link
Collaborator Author

Ok, I realized that this is because [.dataset just calls into $.getitem() with whatever indices are provided.

I am not sure what the correct behavior here is, but I think the current implementation is somewhat inconsistent.

One suggestion would be to make [.dataset err when there is more than one index provided (for datsets that implement only $.getitem(). We can't just cat along the first dimension because the returned tensors might have varying shapes.

Also for consistency I think that [.dataset should include the batch dimension when called with a single index on a dataset that implements $.getitem.

@dfalbel
Copy link
Member

dfalbel commented Apr 17, 2025

I agree with your second suggestion [ should ibnclude the batch dimension when called with a single index on a dataset that only implements .getitem(). We could implement [[ to extract a single element by index, with .getitem().

@sebffischer
Copy link
Collaborator Author

But the question is still whether ds[1:2] should throw an error if ds implements $.getitem(). The different tensors might have varying shapes, so it's not always possible to torch_cat() them.

@dfalbel
Copy link
Member

dfalbel commented Apr 17, 2025

Yes, maybe a simpler solution is to error if it only implements .getitem() but then, I don't think we should include the batch dimension in this case. Maybe just allow [[ if .getitem is implemented. And make [ for .getbatch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants