-
Notifications
You must be signed in to change notification settings - Fork 135
Add xtensor broadcast #1489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: labeled_tensors
Are you sure you want to change the base?
Add xtensor broadcast #1489
Conversation
@ricardoV94 Here's my attempt to rebase on the changes you just force pushed. Looks like mypy is unhappy -- is that something you expected? Other than that, I think this is ready for review. |
Yeah I didn't make mypy pass yet |
pytensor/xtensor/rewriting/shape.py
Outdated
x_tensor = x_tensor.dimshuffle(shuffle_pattern) | ||
|
||
# Now we are aligned with target dims and correct ndim | ||
x_tensor = broadcast_to(x_tensor, out.type.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't work when the output shape is not statically known. The target shape has to be computer symbolically from the symbolic input shapes.
You can test by having an xtensor with shape=(None) for a dim that only that tensor has
71bc4ef
to
41d9be4
Compare
@ricardoV94 I think I have symbolic dimensions working. My solution is more complicated than I think any of us would like, but I don't see a simpler solution. Maybe you will. Should we continue work on this PR, for now, and I will rebase later? |
Here is an idea: def lower_broadcast(fgraph, node):
excluded_dims = node.op.exclude
broadcast_dims = tuple(dim for dim in node.outputs[0].type.dims if dim not in excluded_dims)
all_dims = broadcast_dims + excluded_dims
# align inputs with all_dims like we do in other rewrites
# probably time to refactor this kind of logic into a helper
inp_tensors = []
for inp, out in zip(node.inputs, node.outputs, strict=True)
inp_dims = inp.type.dims
order = tuple(inp_dims.index(dim) if dim in inp_dims else "x" for dim in all_dims)
inp_tensors.append(inp.values.dimshuffle(order))
if not excluded_dims:
out_tensors = pt.broadcast_arrays(*inp_tensors)
else:
all_shape = tuple(pt.broadcast_shape(*inp_tensors))
assert len(all_shape) == len(all_dims)
for inp_tensor, out in zip(inp_tensors, node.outputs):
out_dims = out.type.dims
out_shape = tuple(length for length, dim in zip(all_shape, all_dims) if dim in out_dims)
out_tensors.append(pt.broadcast_to(inp_tensor, out_shape)
new_outs = [as_xtensor(out_tensor, dims=out.type.dims) for out_tensor, out in zip(out_tensors, node.outputs)]
return new_outs Btw the base branch is merged. You can rebase/ start from it. Note that you don't need to open a new PR. You can force-push your changes after cleaning up the branch to your current remote |
@ricardoV94 I've added |
Your version of I'll work on debugging it, but at the moment it's not clear to me whether these is a small error in your implementation or an actual problem with the logic. |
I suspect some wrong assumption on the excluded dims alignment but the general idea should work |
I think the incorrect assumption is that all outputs have the same shape. When exclude is not empty, they don't, in general. |
Actually there's a logical flaw. Two inputs could have an excluded dim with the same name but different length, in which case they shouldn't be aligned for the broadcast shape. We should add that as a test. Still the logic for each output should be something like |
I didn't assume that, the dimshuffle was supposed to take care of that so that things were put in different axis for broadcasting. Still as I just wrote there was a wrong assumption that you could align shared excluded dims. They don't even come out in a uniform order do they? |
I don't think this logical flaw is why the tests are failing though. We should test that case as well |
This replaces #1486. This one is based on a rebased
labeled_tensor
branch📚 Documentation preview 📚: https://pytensor--1489.org.readthedocs.build/en/1489/