import SimpleITK as sitk
import numpy as np
from scipy import ndimage
from skimage import morphology
import argparse
import sys
import os
def read_nifti_file(file_path):
"""Read NIfTI file and return image data and metadata"""
try:
image = sitk.ReadImage(file_path)
data = sitk.GetArrayFromImage(image) # Shape: (z, y, x)
return data, image
except Exception as e:
print(f"Error reading NIfTI file: {e}")
sys.exit(1)
def save_nifti_file(data, reference_image, output_path):
"""Save NIfTI file while preserving original image geometry information"""
try:
output_image = sitk.GetImageFromArray(data)
output_image.CopyInformation(reference_image)
sitk.WriteImage(output_image, output_path)
print(f"Result saved to: {output_path}")
except Exception as e:
print(f"Error saving NIfTI file: {e}")
sys.exit(1)
def analyze_image_stats(data):
"""Analyze image statistical information"""
print("=== Image Statistics ===")
print(f"Data type: {data.dtype}")
print(f"Image dimensions: {data.shape}")
print(f"Value range: [{np.min(data):.2f}, {np.max(data):.2f}]")
print(f"Mean: {np.mean(data):.2f}")
print(f"Standard deviation: {np.std(data):.2f}")
# Display percentiles
percentiles = [0, 1, 5, 10, 25, 50, 75, 90, 95, 99, 100]
percentile_values = np.percentile(data, percentiles)
for p, val in zip(percentiles, percentile_values):
print(f"{p}% percentile: {val:.2f}")
def brain_ct_threshold(data):
"""Threshold processing specifically for brain CT"""
print("=== Brain CT Threshold Processing ===")
# Analyze data range
data_min, data_max = np.min(data), np.max(data)
print(f"Data range: [{data_min:.2f}, {data_max:.2f}]")
# Typical HU value ranges for brain CT:
# - Air: -1000 HU
# - Fat: -100 to -50 HU
# - Water: 0 HU
# - CSF: 15 HU
# - Gray matter: 20-40 HU
# - White matter: 20-40 HU
# - Bone: 400-1000+ HU
# Method 1: Use typical brain tissue HU value range
if data_min < -100: # Likely HU values
print("CT HU values detected, using brain tissue threshold")
# Typical HU value range for brain tissue (gray + white matter)
brain_tissue_min = 20
brain_tissue_max = 80
print(f"Brain tissue HU range: [{brain_tissue_min}, {brain_tissue_max}]")
thresholded = np.where((data >= brain_tissue_min) & (data <= brain_tissue_max), 1, 0)
# Method 2: If data is already normalized, use adaptive threshold
else:
print("Using adaptive threshold")
# Try multiple threshold strategies
threshold_strategies = [
np.percentile(data, 60), # 60th percentile
np.percentile(data, 70), # 70th percentile
np.mean(data) + 0.5 * np.std(data), # Mean + 0.5 standard deviation
]
best_threshold = None
best_foreground_ratio = 0
for threshold in threshold_strategies:
temp_binary = (data >= threshold).astype(np.uint8)
foreground_ratio = np.sum(temp_binary) / temp_binary.size
# Ideal brain region should occupy 20-60% of the image
if 0.2 <= foreground_ratio <= 0.6:
best_threshold = threshold
best_foreground_ratio = foreground_ratio
break
if best_threshold is None:
# If no ideal threshold found, use 70th percentile
best_threshold = np.percentile(data, 70)
print(f"Using default threshold: {best_threshold:.2f} (70th percentile)")
else:
print(f"Selected best threshold: {best_threshold:.2f}, foreground ratio: {best_foreground_ratio:.2%}")
thresholded = (data >= best_threshold).astype(np.uint8)
print(f"Foreground voxels after thresholding: {np.sum(thresholded)}")
print(f"Foreground percentage after thresholding: {np.sum(thresholded) / thresholded.size * 100:.2f}%")
return thresholded
def process_brain_slice(slice_data, slice_idx):
"""Process single brain CT slice"""
# 1. Morphological processing - remove small noise
structure = np.ones((3, 3), dtype=np.uint8)
binary_slice = ndimage.binary_opening(slice_data, structure=structure).astype(np.uint8)
# 2. Fill holes (especially in ventricular regions)
binary_slice = ndimage.binary_fill_holes(binary_slice).astype(np.uint8)
# 3. Connected component analysis - using 8-connectivity
labeled_slice, num_features = ndimage.label(binary_slice, structure=np.ones((3, 3)))
if num_features > 0:
# Calculate size of each component
component_sizes = []
for i in range(1, num_features + 1):
component_sizes.append(np.sum(labeled_slice == i))
# Select the largest components (brain may have left and right hemispheres)
if len(component_sizes) >= 2:
# Take the top 2 largest components (left and right brain hemispheres)
largest_indices = np.argsort(component_sizes)[-2:]
binary_slice = np.zeros_like(slice_data)
for idx in largest_indices:
binary_slice = np.logical_or(binary_slice, labeled_slice == (idx + 1))
elif len(component_sizes) >= 1:
# Only one large component
largest_component = np.argmax(component_sizes) + 1
binary_slice = (labeled_slice == largest_component).astype(np.uint8)
# 4. Final morphological cleanup
binary_slice = ndimage.binary_closing(binary_slice, structure=structure).astype(np.uint8)
binary_slice = ndimage.binary_fill_holes(binary_slice).astype(np.uint8)
return binary_slice.astype(np.uint8)
def process_brain_volume_3d(binary_data, voxel_spacing):
"""Brain 3D volume processing"""
# 1. 3D connected component analysis - using 6-connectivity
structure_3d = np.ones((3, 3, 3), dtype=np.uint8)
labeled_volume, num_features = ndimage.label(binary_data, structure=structure_3d)
if num_features > 0:
# Calculate size of each 3D component
component_sizes = []
for i in range(1, num_features + 1):
component_sizes.append(np.sum(labeled_volume == i))
# Select the largest component (brain)
if len(component_sizes) > 0:
largest_component = np.argmax(component_sizes) + 1
binary_data = (labeled_volume == largest_component).astype(np.uint8)
# 2. 3D morphological processing - considering anisotropy
z_ratio = voxel_spacing[2] / voxel_spacing[0] if voxel_spacing[0] > 0 else 1.0
kernel_size_z = max(1, int(round(2 * z_ratio))) # Use smaller kernel
structure_3d_aniso = np.ones((kernel_size_z, 3, 3), dtype=np.uint8)
# Gentle closing operation to fill small holes
binary_data = ndimage.binary_closing(binary_data, structure=structure_3d_aniso).astype(np.uint8)
return binary_data
def extract_brain_region(input_file, output_file):
"""Main function: Extract brain region"""
print(f"Input file: {input_file}")
# Check if input file exists
if not os.path.exists(input_file):
print(f"Error: Input file does not exist: {input_file}")
sys.exit(1)
# 1. Read image
data, image = read_nifti_file(input_file)
original_shape = data.shape
print(f"Image dimensions: {original_shape}")
# Get voxel spacing information
voxel_spacing = image.GetSpacing() # (x, y, z)
print(f"Voxel spacing: {voxel_spacing}")
# 2. Analyze image statistics
analyze_image_stats(data)
# 3. Brain CT threshold processing
print("Performing brain CT threshold processing...")
binary_data = brain_ct_threshold(data)
# Check threshold results
if np.sum(binary_data) == 0:
print("Warning: No foreground regions found after thresholding, trying more lenient threshold")
# Try more lenient threshold
if np.min(data) < -100: # HU values
binary_data = (data >= 0).astype(np.uint8) # All positive HU values
else:
binary_data = (data >= np.percentile(data, 40)).astype(np.uint8) # 40th percentile
print(f"Foreground voxels after lenient threshold: {np.sum(binary_data)}")
# 4. Slice-by-slice processing
print("Performing slice-by-slice processing...")
processed_slices = np.zeros_like(binary_data, dtype=np.uint8)
for z in range(binary_data.shape[0]):
if z % 20 == 0: # Show progress every 20 slices
print(f"Processing slice {z}/{binary_data.shape[0]}")
processed_slices[z] = process_brain_slice(binary_data[z], z)
# 5. 3D volume processing
print("Performing 3D volume processing...")
final_mask = process_brain_volume_3d(processed_slices, voxel_spacing)
# 6. Final statistics
print("=== Final Results ===")
print(f"Brain region voxel count: {np.sum(final_mask)}")
print(f"Brain region percentage: {np.sum(final_mask) / final_mask.size * 100:.2f}%")
# 7. Save results
save_nifti_file(final_mask, image, output_file)
print("Brain extraction completed!")
# Direct execution using provided paths
if __name__ == "__main__":
input_path = "/Users/liulaolao/Desktop/R2/HeadCtSample_2022_volume 2.nii"
output_path = "/Users/liulaolao/Desktop/R2/brain_extraction_result.nii.gz"
print("Starting brain CT image processing...")
print(f"Input file: {input_path}")
print(f"Output file: {output_path}")
extract_brain_region(input_path, output_path)