Modeling multi-channel NCC-based registration
See original GitHub issueHi Adrian & co.,
For multi-channel registration (eg, RGB image registration or 4D registration of subject A with T1 and T2 <—> subject B with T1 and T2), vxm implements 4D windows for local NCC (e.g., with window size [9, 9, 9, 2] for T1+T2).
I wonder if this may be a problem when dealing with domain shifts (eg, scanner differences) in a heterogeneous dataset. Typically, 3D NCC handles this by standardizing local statistics and is mostly insensitive to domain shift. However, T1 and T2 intensities may not change with the same transformation and this impacts the statistics of the 4D window.
In practice, when training for multi-channel templates on a dataset with multiple centers, the NCC loss values had high variance and depended strongly on the center (which eventually lead to divergence). This effect goes away once I just used two separate 3D NCC terms for each modality (ANTs uses separate NCC terms as well). I imagine that if the batch size is high enough, this would not be an issue, but we’re stuck with a low number for 3D MRI. 😃
Here’s a minimal example demonstrating that 4D NCC is sensitive to domain shifts, whereas 3D NCC on each channel is relatively insensitive. The example uses ICBM 2009a Nonlinear Asymmetric T1+T2 as image 1 and NIH’s pediatric template as image 2.
import numpy as np
import SimpleITK as sitk
import tensorflow as tf
from voxelmorph.tf.losses import NCC
ncc_object = NCC(win=[9, 9, 9], eps=1e-3)
# -----------------------------------------------------------------------------
# Utility functions
def load_images(fpath):
img = sitk.GetArrayFromImage(sitk.ReadImage(fpath))
return img
def stack_to_tf_tensor(arr1, arr2):
arr = np.stack((arr1, arr2), axis=-1) # 4D concatenate T1 and T2
arr = arr[np.newaxis, ...] # add batch axis
return tf.convert_to_tensor(arr)
def scale_shift_clamp(arr, scale, shift):
arr = scale*arr + shift # linearly transform image intensities
return np.maximum(arr, 0)
def ch(tfarr, dim):
"""Extract a channel from a (bs, x, y, z, ch) array."""
return tfarr[..., dim, tf.newaxis]
# -----------------------------------------------------------------------------
# Load images
# Multimodal image 1:
adult_t1 = load_images('./adult/mni_icbm152_t1_tal_nlin_asym_09a.nii')
adult_t2 = load_images('./adult/mni_icbm152_t2_tal_nlin_asym_09a.nii')
adult = stack_to_tf_tensor(adult_t1, adult_t2)
# Multimodal image 2:
pediatric_t1 = load_images('./pediatric/nihpd_asym_04.5-18.5_t1w.nii')
pediatric_t2 = load_images('./pediatric/nihpd_asym_04.5-18.5_t2w.nii')
pediatric = stack_to_tf_tensor(pediatric_t1, pediatric_t2)
# -----------------------------------------------------------------------------
# Initial NCC
print('Original 4D NCC: {}'.format(ncc_object.loss(adult, pediatric)))
# -----------------------------------------------------------------------------
# Domain shift images
# Simulate 3 different domains/scanner pairs with arbitrary transforms:
# Adult images:
adult_transform1 = stack_to_tf_tensor(
scale_shift_clamp(adult_t1, 0.5, 10),
scale_shift_clamp(adult_t2, 1.3, 47),
)
adult_transform2 = stack_to_tf_tensor(
scale_shift_clamp(adult_t1, 1.2, 16),
scale_shift_clamp(adult_t2, 0.4, 0),
)
adult_transform3 = stack_to_tf_tensor(
scale_shift_clamp(adult_t1, 1.0, 20),
scale_shift_clamp(adult_t2, 2.0, 60),
)
# Pediatric images:
pediatric_transform1 = stack_to_tf_tensor(
scale_shift_clamp(pediatric_t1, 0.9, 30),
scale_shift_clamp(pediatric_t2, 1.4, 3),
)
pediatric_transform2 = stack_to_tf_tensor(
scale_shift_clamp(pediatric_t1, 2.0, 12),
scale_shift_clamp(pediatric_t2, 0.9, 0),
)
pediatric_transform3 = stack_to_tf_tensor(
scale_shift_clamp(pediatric_t1, 0.8, 0),
scale_shift_clamp(pediatric_t2, 1.1, 0),
)
# -----------------------------------------------------------------------------
# Calculate 4D NCC between original images with new domain shifts
print('4D NCC domain 1: {}'.format(
ncc_object.loss(adult_transform1, pediatric_transform1),
))
print('4D NCC domain 2: {}'.format(
ncc_object.loss(adult_transform2, pediatric_transform2),
))
print('4D NCC domain 3: {}'.format(
ncc_object.loss(adult_transform3, pediatric_transform3),
))
# -----------------------------------------------------------------------------
# Calculate 3D NCC_T1 + NCC_T2 between original images with new domain shifts
print('Split 3D NCC: {}'.format(
0.5*ncc_object.loss(ch(adult_transform1, 0), ch(pediatric_transform1, 0))
+ 0.5*ncc_object.loss(ch(adult_transform1, 1), ch(pediatric_transform1, 1)),
))
print('Split 3D NCC: {}'.format(
0.5*ncc_object.loss(ch(adult_transform2, 0), ch(pediatric_transform2, 0))
+ 0.5*ncc_object.loss(ch(adult_transform2, 1), ch(pediatric_transform2, 1)),
))
print('Split 3D NCC: {}'.format(
0.5*ncc_object.loss(ch(adult_transform3, 0), ch(pediatric_transform3, 0))
+ 0.5*ncc_object.loss(ch(adult_transform3, 1), ch(pediatric_transform3, 1)),
))
This yields output:
Original NCC: [-0.6175599]
4D NCC domain 1: [-0.57820976]
4D NCC domain 2: [-0.9330061]
4D NCC domain 3: [-0.6425893]
Split 3D NCC domain 1: [-0.54893446]
Split 3D NCC domain 2: [-0.54617786]
Split 3D NCC domain 3: [-0.53752065]
Do you have any thoughts on this phenomenon and if 4D NCC would be better than split 3D NCC in other applications?
Thanks!
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (2 by maintainers)
Top GitHub Comments
Sure. I don’t think I have used multi-channel NCC yet, but certainly we will want it!
Great, opened PR #314.