question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[feature request] Image Histogram Transformation

See original GitHub issue

It is often useful (especially in the field of astronomy) to transform the histogram of images. I would like to suggest an image histogram transformation function (under torchvision.transforms) that transforms the histogram of an image to match that of a template image as closely as possible. For instance, consider the following function:

def match_histogram(source, template):

    source   = np.asanyarray(source)
    template = np.asanyarray(template)
    oldshape = source.shape
    source   = source.ravel()
    template = template.ravel()

    # get the set of unique pixel values and their corresponding indices and
    # counts
    s_values, bin_idx, s_counts = np.unique(source, return_inverse=True,
                                            return_counts=True)
    t_values, t_counts = np.unique(template, return_counts=True)

    # take the cumsum of the counts and normalize by the number of pixels to
    # get the empirical cumulative distribution functions for the source and
    # template images (maps pixel value --> quantile)
    s_quantiles  = np.cumsum(s_counts).astype(np.float32)
    s_quantiles /= s_quantiles[-1]
    t_quantiles  = np.cumsum(t_counts).astype(np.float32)
    t_quantiles /= t_quantiles[-1]

    # interpolate linearly to find the pixel values in the template image
    # that corresponds most closely to the quantiles in the source image
    interp_t_values = np.interp(s_quantiles, t_quantiles, t_values)

    return interp_t_values[bin_idx].reshape(oldshape)

The function above is not optimal since it has to recalculate template image information. It is not discretized for float type images. It only performs for highly discretized images such as png (0-255 bins). It also performs poorly when the number of diverse pixels is too low which might be fixed by adding small noise.

Issue Analytics

  • State:open
  • Created 5 years ago
  • Reactions:1
  • Comments:10 (5 by maintainers)

github_iconTop GitHub Comments

3reactions
gheaeckkseqrzcommented, Jul 14, 2019

Hey 😃

I was browsing throught the vison issues and found that one, turns out I actually did some work on histogram specification some time ago. Something like that :

D89b8U6XsAEVk4n

I wrote it as a cuda module as I was running the transform in an optimisation loop and needed it to be fast. The code is available over here if that can be useful : https://github.com/pierre-wilmot/NeuralTextureSynthesis/ Happy to help cleaning it up if you think it’s worth adding to the vision repo.

2reactions
ProGamerGovcommented, Dec 5, 2021

@Miladiouss So, I created this function that essentially matches the histogram of one image to another image, and it should hopefully help individuals with use cases like astronomy & neural style transfer.

I wrote the code for a different PyTorch project (pytorch/captum), but Torchvision is free to use it as well! @fmassa

def color_transfer(
    input: torch.Tensor,
    source: torch.Tensor,
    mode: str = "pca",
    eps: float = 1e-5,
) -> torch.Tensor:
    """
    Transfer the colors from one image tensor to another, so that the target image's
    histogram matches the source image's histogram. Applications for image histogram
    matching includes neural style transfer and astronomy.

    The source image is not required to have the same height and width as the target
    image. Batch and channel dimensions are required to be the same for both inputs.

    Gatys, et al., "Controlling Perceptual Factors in Neural Style Transfer", arXiv, 2017.
    https://arxiv.org/abs/1611.07865

    Args:

        input (torch.Tensor): The NCHW or CHW image to transfer colors from source
            image to from the source image.
        source (torch.Tensor): The NCHW or CHW image to transfer colors from to the
            input image.
        mode (str): The color transfer mode to use. One of 'pca', 'cholesky', or 'sym'.
            Default: "pca"
        eps (float): The desired epsilon value to use.
            Default: 1e-5

    Returns:
        matched_image (torch.tensor): The NCHW input image with the colors of source
            image. Outputs should ideally be clamped to the desired value range to
            avoid artifacts.
    """

    assert input.dim() == 3 or input.dim() == 4
    assert source.dim() == 3 or source.dim() == 4
    input = input.unsqueeze(0) if input.dim() == 3 else input
    source = source.unsqueeze(0) if source.dim() == 3 else source
    assert input.shape[:2] == source.shape[:2]

    # Handle older versions of PyTorch
    torch_cholesky = (
        torch.linalg.cholesky if torch.__version__ >= "1.9.0" else torch.cholesky
    )

    def torch_symeig_eigh(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        torch.symeig() was deprecated in favor of torch.linalg.eigh()
        """
        if torch.__version__ >= "1.9.0":
            L, V = torch.linalg.eigh(x, UPLO="U")
        else:
            L, V = torch.symeig(x, eigenvectors=True, upper=True)
        return L, V

    def get_mean_vec_and_cov(
        x_input: torch.Tensor, eps: float
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Convert input images into a vector, subtract the mean, and calculate the
        covariance matrix of colors.
        """
        x_mean = x_input.mean(3).mean(2)[:, :, None, None]

        # Subtract the color mean and convert to a vector
        B, C = x_input.shape[:2]
        x_vec = (x_input - x_mean).reshape(B, C, -1)

        # Calculate covariance matrix
        x_cov = torch.bmm(x_vec, x_vec.permute(0, 2, 1)) / x_vec.shape[2]

        # This line is only important if you get artifacts in the output image
        x_cov = x_cov + (eps * torch.eye(C, device=x_input.device)[None, :])
        return x_mean, x_vec, x_cov

    def pca(x: torch.Tensor) -> torch.Tensor:
        """Perform principal component analysis"""
        eigenvalues, eigenvectors = torch_symeig_eigh(x)
        e = torch.sqrt(torch.diag_embed(eigenvalues.reshape(eigenvalues.size(0), -1)))
        # Remove any NaN values if they occur
        if torch.isnan(e).any():
            e = torch.where(torch.isnan(e), torch.zeros_like(e), e)
        return torch.bmm(torch.bmm(eigenvectors, e), eigenvectors.permute(0, 2, 1))

    # Collect & calculate required values
    _, input_vec, input_cov = get_mean_vec_and_cov(input, eps)
    source_mean, _, source_cov = get_mean_vec_and_cov(source, eps)

    # Calculate new cov matrix for input
    if mode == "pca":
        new_cov = torch.bmm(pca(source_cov), torch.inverse(pca(input_cov)))
    elif mode == "cholesky":
        new_cov = torch.bmm(
            torch_cholesky(source_cov), torch.inverse(torch_cholesky(input_cov))
        )
    elif mode == "sym":
        p = pca(input_cov)
        pca_out = pca(torch.bmm(torch.bmm(p, source_cov), p))
        new_cov = torch.bmm(torch.bmm(torch.inverse(p), pca_out), torch.inverse(p))
    else:
        raise ValueError(
            "mode has to be one of 'pca', 'cholesky', or 'sym'."
            + " Received '{}'.".format(mode)
        )

    # Multiply input vector by new cov matrix
    new_vec = torch.bmm(new_cov, input_vec)

    # Reshape output vector back to input's shape &
    # add the source mean to our output vector
    return new_vec.reshape(input.shape) + source_mean


# Example for standard PyTorch images with value ranges of [0-1]
matched_image = color_transfer(target_image, source_image).clamp(0, 1)

The inner functions can be eliminated easily for TorchScript / JIT compatibility, and it’s fully autograd compatible.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Histogram transformations — Basics of Image Processing
An histogram transformation consists in applying a mathematical function to the intensity distribution. Generally, the transformations are useful to improve ...
Read more >
Image Processing
Histogram equalization is a method to process images in order to adjust the contrast of an image by modifying the intensity distribution of...
Read more >
Histogram Equalization | by Shreenidhi Sudhakar
Image Contrast Enhancement ... Histogram Equalization is a computer image processing technique used to improve contrast in images . ... Get this newsletter ......
Read more >
A Tutorial to Histogram Equalization | by Kyaw Saw Htoon
Histogram Equalization is an image processing technique that adjusts the contrast of an image by using its histogram. To enhance the image's ......
Read more >
Image histogram—ArcGIS Pro | Documentation
Transformation. Some analytical methods require that data be normally distributed. When the data is skewed (the distribution is disproportionate), you might ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found