[XLA?] jnp.log1p for complex inputs looses accuracy
See original GitHub issueI think the complex implementation of jax.numpy.log1p
has issues.
Could someone point me to where it’s implemented? I find it hard to go through the tensorflow repository to find what i’m looking for…
Anyhow, with the latest releases of jax and jaxlib:
>>> import jax
>>> float(jax.numpy.log(1.00001)) # real-valued, log, inaccurate
1.0013530300057027e-05
>>> float(jax.numpy.log1p(0.00001)) # real-valued, log1p, accurate
9.999949725170154e-06
>>> complex(jax.numpy.log(1.00001+0.0j)) # complxe-valued, log, inaccurate
(1.0013530300057027e-05+0j)
>>> complex(jax.numpy.log1p(0.00001+0.0j)) # complxe-valued, log1p, also inaccurate
(1.001348027784843e-05+0j)
I would expect the same result for both real and complex valued cases.
For comparison, numpy does the correct thing:
>>> np.log1p(0.00001)
9.999950000333332e-06
>>> np.log1p(0.00001+0.0j)
(9.999950000398841e-06+0j)
I could not find the numpy implementation, however, in their repo. I will paste here the Julia implementation, which also gives correct results, and checks if the real part is finite and in that case uses a smart trick to increase accuracy.
function log1p(z::Complex{T}) where T
zr,zi = reim(z)
if isfinite(zr)
isinf(zi) && return log(z)
# This is based on a well-known trick for log1p of real z,
# allegedly due to Kahan, only modified to handle real(u) <= 0
# differently to avoid inaccuracy near z==-2 and for correct branch cut
u = one(float(T)) + z
u == 1 ? convert(typeof(u), z) : real(u) <= 0 ? log(u) : log(u)*z/(u-1)
elseif isnan(zr)
Complex(zr, zr)
elseif isfinite(zi)
Complex(T(Inf), copysign(zr > 0 ? zero(T) : convert(T, pi), zi))
else
Complex(T(Inf), T(NaN))
end
end
Issue Analytics
- State:
- Created 2 years ago
- Comments:7 (7 by maintainers)
Top Results From Across the Web
scipy.special.log1p fails with extended precision #9339 - GitHub
scipy.special.log1p fails on arrays that use numpy's extended precision types (np.longdouble and np.float128, for example) Reproducing code ...
Read more >numpy.log1p — NumPy v1.24 Manual
For real-valued input, log1p is accurate also for x so small that 1 + x == 1 in floating-point accuracy. Logarithm is a...
Read more >Purpose of `numpy.log1p( )`? - Stack Overflow
For real-valued input, log1p is accurate also for x so small that 1 + x == 1 in floating-point accuracy. So for example...
Read more >symjax - arXiv
accuracy = sj.losses.accuracy(outputs, layer[-1]) ... For complex-valued input, arccos is a complex analytic function that has branch cuts ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Let me have a go at implementing the Julia version.
This is now fixed at jaxlib head. Hope that helps!