question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[XLA?] jnp.log1p for complex inputs looses accuracy

See original GitHub issue

I think the complex implementation of jax.numpy.log1p has issues. Could someone point me to where it’s implemented? I find it hard to go through the tensorflow repository to find what i’m looking for…

Anyhow, with the latest releases of jax and jaxlib:

>>> import jax
>>> float(jax.numpy.log(1.00001))   # real-valued, log, inaccurate
1.0013530300057027e-05
>>> float(jax.numpy.log1p(0.00001))  # real-valued, log1p, accurate
9.999949725170154e-06
>>> complex(jax.numpy.log(1.00001+0.0j)) # complxe-valued, log, inaccurate
(1.0013530300057027e-05+0j)
>>> complex(jax.numpy.log1p(0.00001+0.0j)) # complxe-valued, log1p, also inaccurate
(1.001348027784843e-05+0j)

I would expect the same result for both real and complex valued cases.

For comparison, numpy does the correct thing:

>>> np.log1p(0.00001)
9.999950000333332e-06
>>> np.log1p(0.00001+0.0j)
(9.999950000398841e-06+0j)

I could not find the numpy implementation, however, in their repo. I will paste here the Julia implementation, which also gives correct results, and checks if the real part is finite and in that case uses a smart trick to increase accuracy.

function log1p(z::Complex{T}) where T
    zr,zi = reim(z)
    if isfinite(zr)
        isinf(zi) && return log(z)
        # This is based on a well-known trick for log1p of real z,
        # allegedly due to Kahan, only modified to handle real(u) <= 0
        # differently to avoid inaccuracy near z==-2 and for correct branch cut
        u = one(float(T)) + z
        u == 1 ? convert(typeof(u), z) : real(u) <= 0 ? log(u) : log(u)*z/(u-1)
    elseif isnan(zr)
        Complex(zr, zr)
    elseif isfinite(zi)
        Complex(T(Inf), copysign(zr > 0 ? zero(T) : convert(T, pi), zi))
    else
        Complex(T(Inf), T(NaN))
    end
end

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:7 (7 by maintainers)

github_iconTop GitHub Comments

1reaction
hawkinspcommented, Jun 17, 2021

Let me have a go at implementing the Julia version.

0reactions
hawkinspcommented, Jun 18, 2021

This is now fixed at jaxlib head. Hope that helps!

Read more comments on GitHub >

github_iconTop Results From Across the Web

scipy.special.log1p fails with extended precision #9339 - GitHub
scipy.special.log1p fails on arrays that use numpy's extended precision types (np.longdouble and np.float128, for example) Reproducing code ...
Read more >
numpy.log1p — NumPy v1.24 Manual
For real-valued input, log1p is accurate also for x so small that 1 + x == 1 in floating-point accuracy. Logarithm is a...
Read more >
tf.math.log1p | TensorFlow v2.11.0
Computes natural logarithm of (1 + x) element-wise.
Read more >
Purpose of `numpy.log1p( )`? - Stack Overflow
For real-valued input, log1p is accurate also for x so small that 1 + x == 1 in floating-point accuracy. So for example...
Read more >
symjax - arXiv
accuracy = sj.losses.accuracy(outputs, layer[-1]) ... For complex-valued input, arccos is a complex analytic function that has branch cuts ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found