numpy.roots type error - casting from complex128 to float64
See original GitHub issueHey Numba team, is casting from complex128 to float64 supported? If not, it’d be a nice to have. Thanks.
This runs:
import numpy as np
def np_findroot(coefficients):
root_values = np.roots(coefficients)
return(root_values)
np_findroot(np.array([1.1, 2.2, 3.3, 4.4]))
array([-1.65062919+0.j , -0.1746854 +1.54686889j,
-0.1746854 -1.54686889j])
However this does not:
from numba import njit
@njit
def findroot(coefficients):
root_values = np.roots(coefficients)
return(root_values)
findroot(np.array([1.1, 2.2, 3.3, 4.4]))
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
in
7 return(root_values)
8
----> 9 findroot(np.array([1.1, 2.2, 3.3, 4.4]))
/opt/conda/envs/linfit/lib/python3.7/site-packages/numba/np/linalg.py in real_eigvals_impl()
1124 if np.any(wi):
1125 raise ValueError(
-> 1126 "eigvals() argument must not cause a domain change.")
1127
1128 # put these in to help with liveness analysis,
ValueError: eigvals() argument must not cause a domain change.
And this does not (trying to cast the complex numbers to their imaginary and real parts which are each floats):
@njit
def findroot(coefficients):
root_values = np.roots(coefficients).real + np.roots(coefficients).imag
return(root_values)
findroot(np.array([1.1, 2.2, 3.3, 4.4]))
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
in
4 return(root_values)
5
----> 6 findroot(np.array([1.1, 2.2, 3.3, 4.4]))
/opt/conda/envs/linfit/lib/python3.7/site-packages/numba/np/linalg.py in real_eigvals_impl()
1124 if np.any(wi):
1125 raise ValueError(
-> 1126 "eigvals() argument must not cause a domain change.")
1127
1128 # put these in to help with liveness analysis,
ValueError: eigvals() argument must not cause a domain change.
Nor does this:
@njit('complex128[:](float64[:])')
def findroot(coefficients):
root_values = np.roots(coefficients)
return(root_values)
findroot(np.array([1.1, 2.2, 3.3, 4.4]))
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
in
----> 1 @njit('complex128[:](float64[:])')
2 def findroot(coefficients):
3 root_values = np.roots(coefficients)
4 return(root_values)
5
/opt/conda/envs/linfit/lib/python3.7/site-packages/numba/core/decorators.py in wrapper(func)
216 with typeinfer.register_dispatcher(disp):
217 for sig in sigs:
--> 218 disp.compile(sig)
219 disp.disable_compile()
220 return disp
/opt/conda/envs/linfit/lib/python3.7/site-packages/numba/core/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
30 def _acquire_compile_lock(*args, **kwargs):
31 with self:
---> 32 return func(*args, **kwargs)
33 return _acquire_compile_lock
34
/opt/conda/envs/linfit/lib/python3.7/site-packages/numba/core/dispatcher.py in compile(self, sig)
817 self._cache_misses[sig] += 1
818 try:
--> 819 cres = self._compiler.compile(args, return_type)
820 except errors.ForceLiteralArg as e:
821 def folded(args, kws):
/opt/conda/envs/linfit/lib/python3.7/site-packages/numba/core/dispatcher.py in compile(self, args, return_type)
80 return retval
81 else:
---> 82 raise retval
83
84 def _compile_cached(self, args, return_type):
/opt/conda/envs/linfit/lib/python3.7/site-packages/numba/core/dispatcher.py in _compile_cached(self, args, return_type)
90
91 try:
---> 92 retval = self._compile_core(args, return_type)
93 except errors.TypingError as e:
94 self._failed_cache[key] = e
/opt/conda/envs/linfit/lib/python3.7/site-packages/numba/core/dispatcher.py in _compile_core(self, args, return_type)
108 args=args, return_type=return_type,
109 flags=flags, locals=self.locals,
--> 110 pipeline_class=self.pipeline_class)
111 # Check typing error if object mode is used
112 if cres.typing_error is not None and not flags.enable_pyobject:
/opt/conda/envs/linfit/lib/python3.7/site-packages/numba/core/compiler.py in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
623 pipeline = pipeline_class(typingctx, targetctx, library,
624 args, return_type, flags, locals)
--> 625 return pipeline.compile_extra(func)
626
627
/opt/conda/envs/linfit/lib/python3.7/site-packages/numba/core/compiler.py in compile_extra(self, func)
359 self.state.lifted = ()
360 self.state.lifted_from = None
--> 361 return self._compile_bytecode()
362
363 def compile_ir(self, func_ir, lifted=(), lifted_from=None):
/opt/conda/envs/linfit/lib/python3.7/site-packages/numba/core/compiler.py in _compile_bytecode(self)
421 """
422 assert self.state.func_ir is None
--> 423 return self._compile_core()
424
425 def _compile_ir(self):
/opt/conda/envs/linfit/lib/python3.7/site-packages/numba/core/compiler.py in _compile_core(self)
401 self.state.status.fail_reason = e
402 if is_final_pipeline:
--> 403 raise e
404 else:
405 raise CompilerError("All available pipelines exhausted")
/opt/conda/envs/linfit/lib/python3.7/site-packages/numba/core/compiler.py in _compile_core(self)
392 res = None
393 try:
--> 394 pm.run(self.state)
395 if self.state.cr is not None:
396 break
/opt/conda/envs/linfit/lib/python3.7/site-packages/numba/core/compiler_machinery.py in run(self, state)
339 (self.pipeline_name, pass_desc)
340 patched_exception = self._patch_error(msg, e)
--> 341 raise patched_exception
342
343 def dependency_analysis(self):
/opt/conda/envs/linfit/lib/python3.7/site-packages/numba/core/compiler_machinery.py in run(self, state)
330 pass_inst = _pass_registry.get(pss).pass_inst
331 if isinstance(pass_inst, CompilerPass):
--> 332 self._runPass(idx, pass_inst, state)
333 else:
334 raise BaseException("Legacy pass in use")
/opt/conda/envs/linfit/lib/python3.7/site-packages/numba/core/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
30 def _acquire_compile_lock(*args, **kwargs):
31 with self:
---> 32 return func(*args, **kwargs)
33 return _acquire_compile_lock
34
/opt/conda/envs/linfit/lib/python3.7/site-packages/numba/core/compiler_machinery.py in _runPass(self, index, pss, internal_state)
289 mutated |= check(pss.run_initialization, internal_state)
290 with SimpleTimer() as pass_time:
--> 291 mutated |= check(pss.run_pass, internal_state)
292 with SimpleTimer() as finalize_time:
293 mutated |= check(pss.run_finalizer, internal_state)
/opt/conda/envs/linfit/lib/python3.7/site-packages/numba/core/compiler_machinery.py in check(func, compiler_state)
262
263 def check(func, compiler_state):
--> 264 mangled = func(compiler_state)
265 if mangled not in (True, False):
266 msg = ("CompilerPass implementations should return True/False. "
/opt/conda/envs/linfit/lib/python3.7/site-packages/numba/core/typed_passes.py in run_pass(self, state)
96 state.return_type,
97 state.locals,
---> 98 raise_errors=self._raise_errors)
99 state.typemap = typemap
100 if self._raise_errors:
/opt/conda/envs/linfit/lib/python3.7/site-packages/numba/core/typed_passes.py in type_inference_stage(typingctx, interp, args, return_type, locals, raise_errors)
68
69 infer.build_constraint()
---> 70 infer.propagate(raise_errors=raise_errors)
71 typemap, restype, calltypes = infer.unify(raise_errors=raise_errors)
72
/opt/conda/envs/linfit/lib/python3.7/site-packages/numba/core/typeinfer.py in propagate(self, raise_errors)
1069 if isinstance(e, ForceLiteralArg)]
1070 if not force_lit_args:
-> 1071 raise errors[0]
1072 else:
1073 raise reduce(operator.or_, force_lit_args)
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No conversion from array(float64, 1d, C) to array(complex128, 1d, A) for '$14return_value.5', defined at None
But this does!
@njit('complex128[:](complex128[:])')
def findroot(coefficients):
root_values = np.roots(coefficients)
return(root_values)
np_findroot(np.array([1.1, 2.2, 3.3, 4.4], dtype=np.complex128))
array([-0.1746854 +1.54686889e+00j, -1.65062919+6.17561557e-16j,
-0.1746854 -1.54686889e+00j])
My Numba version is 0.51.0. Python 3.7.8. Running on an Ubuntu 18.04.4 Docker container.
BR, Ryan
Issue Analytics
- State:
- Created 3 years ago
- Comments:5 (3 by maintainers)
Top Results From Across the Web
Cannot cast array data from dtype('complex128') to ...
But every time I get an error Cannot cast array data from dtype('complex128') to dtype('float64') according to the rule 'safe' . Can anyone...
Read more >Iterating Over Arrays — NumPy v1.24 Manual
The iterator uses NumPy's casting rules to determine whether a specific conversion is permitted.
Read more >[SciPy-User] Broyden with complex numbers [Was Re: ANN
TypeError : Cannot cast ufunc add output from dtype('complex128') to dtype('float64') with casting rule 'same_kind'. I get "nonlin.py:314: ComplexWarning: ...
Read more >NumPy Reference
NumPy provides an N-dimensional array type, the ndarray, ... means only safe casts or casts within a kind, like float64 to float32, are....
Read more >4. NumPy Basics: Arrays and Vectorized Computation
If casting were to fail for some reason (like a string that cannot be converted to float64 ), a TypeError will be raised....
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Closing this issue as it seems to be resolved. If this is not the case please re-open with a comment about any item that appears to be unresolved. Many thanks.
@sklam your interpretation is correct. There’s no alternative to the domain change error, the OP is requesting value based dispatch.
@ryan-chien I think your approach may well work, albeit approaching the complexity of just solving for the roots in the first place given a fixed cubic polynomial. You may also wish to look at the formula for computing a discriminant of a cubic again as I am not sure that it’s quite right, for example:
x = np.array([-100, 2, 3, 4])
causes problems in thefindroots_wrapper
function as it reports a positive discriminant when the answer is -4352492.