question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Modeling multi-channel NCC-based registration

See original GitHub issue

Hi Adrian & co.,

For multi-channel registration (eg, RGB image registration or 4D registration of subject A with T1 and T2 <—> subject B with T1 and T2), vxm implements 4D windows for local NCC (e.g., with window size [9, 9, 9, 2] for T1+T2).

I wonder if this may be a problem when dealing with domain shifts (eg, scanner differences) in a heterogeneous dataset. Typically, 3D NCC handles this by standardizing local statistics and is mostly insensitive to domain shift. However, T1 and T2 intensities may not change with the same transformation and this impacts the statistics of the 4D window.

In practice, when training for multi-channel templates on a dataset with multiple centers, the NCC loss values had high variance and depended strongly on the center (which eventually lead to divergence). This effect goes away once I just used two separate 3D NCC terms for each modality (ANTs uses separate NCC terms as well). I imagine that if the batch size is high enough, this would not be an issue, but we’re stuck with a low number for 3D MRI. 😃

Here’s a minimal example demonstrating that 4D NCC is sensitive to domain shifts, whereas 3D NCC on each channel is relatively insensitive. The example uses ICBM 2009a Nonlinear Asymmetric T1+T2 as image 1 and NIH’s pediatric template as image 2.

import numpy as np
import SimpleITK as sitk
import tensorflow as tf

from voxelmorph.tf.losses import NCC

ncc_object = NCC(win=[9, 9, 9], eps=1e-3)

# -----------------------------------------------------------------------------
# Utility functions

def load_images(fpath):
    img = sitk.GetArrayFromImage(sitk.ReadImage(fpath))
    return img


def stack_to_tf_tensor(arr1, arr2):
    arr = np.stack((arr1, arr2), axis=-1)  # 4D concatenate T1 and T2
    arr = arr[np.newaxis, ...]  # add batch axis
    return tf.convert_to_tensor(arr)


def scale_shift_clamp(arr, scale, shift):
    arr = scale*arr + shift  # linearly transform image intensities
    return np.maximum(arr, 0)


def ch(tfarr, dim):
    """Extract a channel from a (bs, x, y, z, ch) array.""" 
    return tfarr[..., dim, tf.newaxis]


# -----------------------------------------------------------------------------
# Load images

# Multimodal image 1:
adult_t1 = load_images('./adult/mni_icbm152_t1_tal_nlin_asym_09a.nii')
adult_t2 = load_images('./adult/mni_icbm152_t2_tal_nlin_asym_09a.nii')

adult = stack_to_tf_tensor(adult_t1, adult_t2)

# Multimodal image 2:
pediatric_t1 = load_images('./pediatric/nihpd_asym_04.5-18.5_t1w.nii')
pediatric_t2 = load_images('./pediatric/nihpd_asym_04.5-18.5_t2w.nii')

pediatric = stack_to_tf_tensor(pediatric_t1, pediatric_t2)

# -----------------------------------------------------------------------------
# Initial NCC

print('Original 4D NCC: {}'.format(ncc_object.loss(adult, pediatric)))


# -----------------------------------------------------------------------------
# Domain shift images

# Simulate 3 different domains/scanner pairs with arbitrary transforms: 
# Adult images:
adult_transform1 = stack_to_tf_tensor(
    scale_shift_clamp(adult_t1, 0.5, 10), 
    scale_shift_clamp(adult_t2, 1.3, 47),
)
adult_transform2 = stack_to_tf_tensor(
    scale_shift_clamp(adult_t1, 1.2, 16), 
    scale_shift_clamp(adult_t2, 0.4, 0),
)
adult_transform3 = stack_to_tf_tensor(
    scale_shift_clamp(adult_t1, 1.0, 20), 
    scale_shift_clamp(adult_t2, 2.0, 60),
)

# Pediatric images:
pediatric_transform1 = stack_to_tf_tensor(
    scale_shift_clamp(pediatric_t1, 0.9, 30), 
    scale_shift_clamp(pediatric_t2, 1.4, 3),
)
pediatric_transform2 = stack_to_tf_tensor(
    scale_shift_clamp(pediatric_t1, 2.0, 12), 
    scale_shift_clamp(pediatric_t2, 0.9, 0),
)
pediatric_transform3 = stack_to_tf_tensor(
    scale_shift_clamp(pediatric_t1, 0.8, 0), 
    scale_shift_clamp(pediatric_t2, 1.1, 0),
)


# -----------------------------------------------------------------------------
# Calculate 4D NCC between original images with new domain shifts

print('4D NCC domain 1: {}'.format(
    ncc_object.loss(adult_transform1, pediatric_transform1),
))
print('4D NCC domain 2: {}'.format(
    ncc_object.loss(adult_transform2, pediatric_transform2),
))
print('4D NCC domain 3: {}'.format(
    ncc_object.loss(adult_transform3, pediatric_transform3),
))


# -----------------------------------------------------------------------------
# Calculate 3D NCC_T1 + NCC_T2 between original images with new domain shifts

print('Split 3D NCC: {}'.format(
    0.5*ncc_object.loss(ch(adult_transform1, 0), ch(pediatric_transform1, 0))
    + 0.5*ncc_object.loss(ch(adult_transform1, 1), ch(pediatric_transform1, 1)),
))
print('Split 3D NCC: {}'.format(
    0.5*ncc_object.loss(ch(adult_transform2, 0), ch(pediatric_transform2, 0))
    + 0.5*ncc_object.loss(ch(adult_transform2, 1), ch(pediatric_transform2, 1)),
))
print('Split 3D NCC: {}'.format(
    0.5*ncc_object.loss(ch(adult_transform3, 0), ch(pediatric_transform3, 0))
    + 0.5*ncc_object.loss(ch(adult_transform3, 1), ch(pediatric_transform3, 1)),
))

This yields output:

Original NCC: [-0.6175599]
4D NCC domain 1: [-0.57820976]
4D NCC domain 2: [-0.9330061]
4D NCC domain 3: [-0.6425893]
Split 3D NCC domain 1: [-0.54893446]
Split 3D NCC domain 2: [-0.54617786]
Split 3D NCC domain 3: [-0.53752065]

Do you have any thoughts on this phenomenon and if 4D NCC would be better than split 3D NCC in other applications?

Thanks!

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:5 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
brf2commented, May 24, 2021

Sure. I don’t think I have used multi-channel NCC yet, but certainly we will want it!

1reaction
neel-deycommented, May 23, 2021

Great, opened PR #314.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Similarity measures — mermaid 0.1 documentation
Similarity measures for the registration methods and factory to create ... Compute the multi-image multi-channel image similarity between two images of ...
Read more >
Track validation using gradient-based normalised cross ...
We develop a gradient-based normalised cross-correlation tracker that is as robust as brute-force template matching while being significantly more ...
Read more >
A Generative Model for Probabilistic Label Fusion of ... - NCBI
However, once the statistical atlas has been built, only one registration is required to propagate the label probabilities from the atlas to a...
Read more >
The evaluation of normalized cross correlations for defect ...
Yzuel et al. (1999) proposed a multichannel correlation process for pattern recognition in color images. The correlation is separately applied to each color ......
Read more >
Binary Quadratic Programing for Online Tracking of Hundreds ...
generative model such as template based tracking is used ... multi channel formulation of the above equation and use color features.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found