I find myself in need of a filter that takes a set of binary masks-real valued image pairs and returns a single mask where the value is the mask index with the largest real value. I use this constantly for flattening logits or label interpolation.
Currently, I do this in numpy with the argmax
features.
# Let `logits` come from a neural network I have
# It is [512 x 512 x 128] with 4 channels
# Each channel is a semantic class - say, grey matter, white matter, skull, skin
logits.shape() # [512, 512, 128, 4]
# Get mask, which is a semantic segmentation
mask = np.argmax(images)
mask.shape() # [512, 512, 128]
np.min(max) # 0
np.max(max) # 3
I can then do some additional filtering to map [0, 1, 2, 3]
to my class labels and also suppress values below a threshold. Say, for instance, if the logit is less than 0 for all labels, I usually assign that background. I do the exact same operation in different situations, like label correcting.
I imagine a C++ interface something like this:
// Types
using Dimension = 3;
using MaskPixelType = short;
using IntensityPixelType = float;
using MaskImageType = itk::Image<MaskPixelType, Dimension>;
using IntensityImageType = itk::Image<IntensityPixelType, Dimension>;
// Basic Interface using pairs
using ImagePair = pair<MaskImageType, IntensityImageType>;
vector<ImagePair> image_pairs = get_image_pairs(); # Not defined here
auto filter = itk::FlattenLabelsImageFilter<MaskImageType, IntensityImageType>();
filter->SetPairs(pairs);
segmentation = filt->GetOutput();
Additional interfaces (e.g., .add_pair
) as needed.
The real benefit of this filter would be if the filter works for images that align in physical space but do not have the same extent. The issue with these filters is that they are very memory intensive. Say my network outputs images at [128, 128, 128] and I have 8 classes to segment. That’s a 128 x 128 x 128 x 8 x sizeof(float) memory requirement (I could compress to int, yes). However, these images might be subsamples of a larger image. So I actually have to resample the logits to their native spacing first, which might be [512, 512, 1028], before joining, which has an extreme memory burden.
I think the basic “label flattening map” is not so difficult to implement. Is there already a solution for this? Is my approach wrong? The second resampling method might take more effort and defining.