Inputs to modeling not always broadcasted as expected
See original GitHub issueDescription
This might not actually be a bug and nobody but me expects it to work this way, but this worked when I ran the same code through astropy.wcs.WCS
and it fails with gwcs.WCS
because of the input behaviour in modeling.
Expected behavior
When calling a 2D model with sparse, broadcastable inputs I expect the inputs to be broadcasted together.
Actual behavior
Traceback saying ValueError: All inputs must have identical shapes or must be scalars.
Steps to Reproduce
import numpy as np
from astropy.modeling import models
import astropy.units as u
shape = (3, 3)
data = np.arange(np.product(shape)).reshape(shape) * u.m / u.s
points_unit = u.pix
points = [np.arange(size) * points_unit for size in shape]
kwargs = {
'bounds_error': False,
'fill_value': np.nan,
'method': 'nearest',
}
transform = models.Tabular2D(points, data, **kwargs)
print(transform)
Model: Tabular2D
N_inputs: 2
N_outputs: 1
Parameters:
points: [<Quantity [0., 1., 2.] pix>, <Quantity [0., 1., 2.] pix>]
lookup_table: [[0. 1. 2.]
[3. 4. 5.]
[6. 7. 8.]] m / s
method: nearest
fill_value: nan
bounds_error: False
points = np.meshgrid(np.arange(3), np.arange(3), indexing='ij', sparse=True)
print(points)
[array([[0],
[1],
[2]]), array([[0, 1, 2]])]
transform(*points)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-14-e007ddf28445> in <module>
----> 1 import codecs, os;__pyfile = codecs.open('''/tmp/pypuTmFL''', encoding='''utf-8''');__code = __pyfile.read().encode('''utf-8''');__pyfile.close();os.remove('''/tmp/pypuTmFL''');exec(compile(__code, '''/tmp/pypuTmFL''', 'exec'));
/tmp/pypuTmFL in <module>
~/.virtualenvs/ndcube-dev/lib/python3.8/site-packages/astropy/modeling/core.py in __call__(self, model_set_axis, with_bounding_box, fill_value, equivalencies, inputs_map, *inputs, **new_inputs)
398 ('inputs_map', None)])
399
--> 400 new_call = make_function_with_signature(
401 __call__, args, kwargs, varargs='inputs', varkwargs='new_inputs')
402
~/.virtualenvs/ndcube-dev/lib/python3.8/site-packages/astropy/modeling/core.py in __call__(self, *inputs, **kwargs)
377 def __call__(self, *inputs, **kwargs):
378 """Evaluate this model on the supplied inputs."""
--> 379 return super(cls, self).__call__(*inputs, **kwargs)
380
381 # When called, models can take two optional keyword arguments:
~/.virtualenvs/ndcube-dev/lib/python3.8/site-packages/astropy/modeling/core.py in __call__(self, *args, **kwargs)
911 new_args, kwargs = self._get_renamed_inputs_as_positional(*args, **kwargs)
912
--> 913 return generic_call(self, *new_args, **kwargs)
914
915 def _get_renamed_inputs_as_positional(self, *args, **kwargs):
~/.virtualenvs/ndcube-dev/lib/python3.8/site-packages/astropy/modeling/core.py in generic_call(self, *inputs, **kwargs)
3945 def generic_call(self, *inputs, **kwargs):
3946 """ The base ``Model. __call__`` method."""
-> 3947 inputs, format_info = self.prepare_inputs(*inputs, **kwargs)
3948 if isinstance(self, CompoundModel):
3949 # CompoundModels do not normally hold parameters at that level
~/.virtualenvs/ndcube-dev/lib/python3.8/site-packages/astropy/modeling/core.py in prepare_inputs(self, model_set_axis, equivalencies, *inputs, **kwargs)
1608 inputs = [np.asanyarray(_input, dtype=float) for _input in inputs]
1609
-> 1610 _validate_input_shapes(inputs, self.inputs, n_models,
1611 model_set_axis, self.standard_broadcasting)
1612
~/.virtualenvs/ndcube-dev/lib/python3.8/site-packages/astropy/modeling/core.py in _validate_input_shapes(inputs, argnames, n_models, model_set_axis, validate_broadcasting)
3885 input_shape = check_consistent_shapes(*all_shapes)
3886 if input_shape is None:
-> 3887 raise ValueError(
3888 "All inputs must have identical shapes or must be scalars.")
3889
ValueError: All inputs must have identical shapes or must be scalars.
What I expect the output to be is what happens if you set sparse=False
:
points = np.meshgrid(np.arange(3), np.arange(3), indexing='ij', sparse=False) * u.pix
print(transform(*points))
[[[0. 0. 0.]
[1. 1. 1.]
[2. 2. 2.]]
[[0. 1. 2.]
[0. 1. 2.]
[0. 1. 2.]]] pix
System Details
Astropy 4.1
Issue Analytics
- State:
- Created 3 years ago
- Comments:6 (6 by maintainers)
Top Results From Across the Web
Pipelines - Hugging Face
Pipelines are made of: A tokenizer in charge of mapping raw textual input to token. A model to make predictions from the inputs....
Read more >Specify Test Properties in the Test Manager - MathWorks
When running parallel execution in rapid accelerator mode, streamed signals do not show up in the Test Manager. The System Under Test cannot...
Read more >NVIDIA Broadcast Troubleshooting Guide
Check if you are downloading anything. Right click on the Windows Bar > Task Manager. Go to Performance, and navigate to your internet ......
Read more >Metrics - Captum · Model Interpretability for PyTorch
Explanation infidelity represents the expected mean-squared error between the explanation multiplied by a meaningful input perturbation and the differences ...
Read more >How (Not) to Tune Your Model With Hyperopt - Databricks
It's OK to let the objective function fail in a few cases if that's expected. It's also possible to simply return a very...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
This is a known issue. I attempted a solution in #10362 but it’s incomplete - handling of the bounding box needs to be fixed. Similar issues: #9953 , #10170
@WilliamJamieson 👍 The above numbers match the answer without units.