import SimpleITK as sitk
import numpy as np
from scipy import ndimage
from skimage import morphology
import argparse
import sys
import os


def read_nifti_file(file_path):
    """Read NIfTI file and return image data and metadata"""
    try:
        image = sitk.ReadImage(file_path)
        data = sitk.GetArrayFromImage(image)  # Shape: (z, y, x)
        return data, image
    except Exception as e:
        print(f"Error reading NIfTI file: {e}")
        sys.exit(1)


def save_nifti_file(data, reference_image, output_path):
    """Save NIfTI file while preserving original image geometry information"""
    try:
        output_image = sitk.GetImageFromArray(data)
        output_image.CopyInformation(reference_image)
        sitk.WriteImage(output_image, output_path)
        print(f"Result saved to: {output_path}")
    except Exception as e:
        print(f"Error saving NIfTI file: {e}")
        sys.exit(1)


def analyze_image_stats(data):
    """Analyze image statistical information"""
    print("=== Image Statistics ===")
    print(f"Data type: {data.dtype}")
    print(f"Image dimensions: {data.shape}")
    print(f"Value range: [{np.min(data):.2f}, {np.max(data):.2f}]")
    print(f"Mean: {np.mean(data):.2f}")
    print(f"Standard deviation: {np.std(data):.2f}")

    # Display percentiles
    percentiles = [0, 1, 5, 10, 25, 50, 75, 90, 95, 99, 100]
    percentile_values = np.percentile(data, percentiles)
    for p, val in zip(percentiles, percentile_values):
        print(f"{p}% percentile: {val:.2f}")


def brain_ct_threshold(data):
    """Threshold processing specifically for brain CT"""
    print("=== Brain CT Threshold Processing ===")

    # Analyze data range
    data_min, data_max = np.min(data), np.max(data)
    print(f"Data range: [{data_min:.2f}, {data_max:.2f}]")

    # Typical HU value ranges for brain CT:
    # - Air: -1000 HU
    # - Fat: -100 to -50 HU
    # - Water: 0 HU
    # - CSF: 15 HU
    # - Gray matter: 20-40 HU
    # - White matter: 20-40 HU
    # - Bone: 400-1000+ HU

    # Method 1: Use typical brain tissue HU value range
    if data_min < -100:  # Likely HU values
        print("CT HU values detected, using brain tissue threshold")
        # Typical HU value range for brain tissue (gray + white matter)
        brain_tissue_min = 20
        brain_tissue_max = 80
        print(f"Brain tissue HU range: [{brain_tissue_min}, {brain_tissue_max}]")

        thresholded = np.where((data >= brain_tissue_min) & (data <= brain_tissue_max), 1, 0)

    # Method 2: If data is already normalized, use adaptive threshold
    else:
        print("Using adaptive threshold")
        # Try multiple threshold strategies
        threshold_strategies = [
            np.percentile(data, 60),  # 60th percentile
            np.percentile(data, 70),  # 70th percentile
            np.mean(data) + 0.5 * np.std(data),  # Mean + 0.5 standard deviation
        ]

        best_threshold = None
        best_foreground_ratio = 0

        for threshold in threshold_strategies:
            temp_binary = (data >= threshold).astype(np.uint8)
            foreground_ratio = np.sum(temp_binary) / temp_binary.size

            # Ideal brain region should occupy 20-60% of the image
            if 0.2 <= foreground_ratio <= 0.6:
                best_threshold = threshold
                best_foreground_ratio = foreground_ratio
                break

        if best_threshold is None:
            # If no ideal threshold found, use 70th percentile
            best_threshold = np.percentile(data, 70)
            print(f"Using default threshold: {best_threshold:.2f} (70th percentile)")
        else:
            print(f"Selected best threshold: {best_threshold:.2f}, foreground ratio: {best_foreground_ratio:.2%}")

        thresholded = (data >= best_threshold).astype(np.uint8)

    print(f"Foreground voxels after thresholding: {np.sum(thresholded)}")
    print(f"Foreground percentage after thresholding: {np.sum(thresholded) / thresholded.size * 100:.2f}%")

    return thresholded


def process_brain_slice(slice_data, slice_idx):
    """Process single brain CT slice"""
    # 1. Morphological processing - remove small noise
    structure = np.ones((3, 3), dtype=np.uint8)
    binary_slice = ndimage.binary_opening(slice_data, structure=structure).astype(np.uint8)

    # 2. Fill holes (especially in ventricular regions)
    binary_slice = ndimage.binary_fill_holes(binary_slice).astype(np.uint8)

    # 3. Connected component analysis - using 8-connectivity
    labeled_slice, num_features = ndimage.label(binary_slice, structure=np.ones((3, 3)))

    if num_features > 0:
        # Calculate size of each component
        component_sizes = []
        for i in range(1, num_features + 1):
            component_sizes.append(np.sum(labeled_slice == i))

        # Select the largest components (brain may have left and right hemispheres)
        if len(component_sizes) >= 2:
            # Take the top 2 largest components (left and right brain hemispheres)
            largest_indices = np.argsort(component_sizes)[-2:]
            binary_slice = np.zeros_like(slice_data)
            for idx in largest_indices:
                binary_slice = np.logical_or(binary_slice, labeled_slice == (idx + 1))
        elif len(component_sizes) >= 1:
            # Only one large component
            largest_component = np.argmax(component_sizes) + 1
            binary_slice = (labeled_slice == largest_component).astype(np.uint8)

    # 4. Final morphological cleanup
    binary_slice = ndimage.binary_closing(binary_slice, structure=structure).astype(np.uint8)
    binary_slice = ndimage.binary_fill_holes(binary_slice).astype(np.uint8)

    return binary_slice.astype(np.uint8)


def process_brain_volume_3d(binary_data, voxel_spacing):
    """Brain 3D volume processing"""
    # 1. 3D connected component analysis - using 6-connectivity
    structure_3d = np.ones((3, 3, 3), dtype=np.uint8)
    labeled_volume, num_features = ndimage.label(binary_data, structure=structure_3d)

    if num_features > 0:
        # Calculate size of each 3D component
        component_sizes = []
        for i in range(1, num_features + 1):
            component_sizes.append(np.sum(labeled_volume == i))

        # Select the largest component (brain)
        if len(component_sizes) > 0:
            largest_component = np.argmax(component_sizes) + 1
            binary_data = (labeled_volume == largest_component).astype(np.uint8)

    # 2. 3D morphological processing - considering anisotropy
    z_ratio = voxel_spacing[2] / voxel_spacing[0] if voxel_spacing[0] > 0 else 1.0
    kernel_size_z = max(1, int(round(2 * z_ratio)))  # Use smaller kernel
    structure_3d_aniso = np.ones((kernel_size_z, 3, 3), dtype=np.uint8)

    # Gentle closing operation to fill small holes
    binary_data = ndimage.binary_closing(binary_data, structure=structure_3d_aniso).astype(np.uint8)

    return binary_data


def extract_brain_region(input_file, output_file):
    """Main function: Extract brain region"""
    print(f"Input file: {input_file}")

    # Check if input file exists
    if not os.path.exists(input_file):
        print(f"Error: Input file does not exist: {input_file}")
        sys.exit(1)

    # 1. Read image
    data, image = read_nifti_file(input_file)
    original_shape = data.shape
    print(f"Image dimensions: {original_shape}")

    # Get voxel spacing information
    voxel_spacing = image.GetSpacing()  # (x, y, z)
    print(f"Voxel spacing: {voxel_spacing}")

    # 2. Analyze image statistics
    analyze_image_stats(data)

    # 3. Brain CT threshold processing
    print("Performing brain CT threshold processing...")
    binary_data = brain_ct_threshold(data)

    # Check threshold results
    if np.sum(binary_data) == 0:
        print("Warning: No foreground regions found after thresholding, trying more lenient threshold")
        # Try more lenient threshold
        if np.min(data) < -100:  # HU values
            binary_data = (data >= 0).astype(np.uint8)  # All positive HU values
        else:
            binary_data = (data >= np.percentile(data, 40)).astype(np.uint8)  # 40th percentile
        print(f"Foreground voxels after lenient threshold: {np.sum(binary_data)}")

    # 4. Slice-by-slice processing
    print("Performing slice-by-slice processing...")
    processed_slices = np.zeros_like(binary_data, dtype=np.uint8)

    for z in range(binary_data.shape[0]):
        if z % 20 == 0:  # Show progress every 20 slices
            print(f"Processing slice {z}/{binary_data.shape[0]}")
        processed_slices[z] = process_brain_slice(binary_data[z], z)

    # 5. 3D volume processing
    print("Performing 3D volume processing...")
    final_mask = process_brain_volume_3d(processed_slices, voxel_spacing)

    # 6. Final statistics
    print("=== Final Results ===")
    print(f"Brain region voxel count: {np.sum(final_mask)}")
    print(f"Brain region percentage: {np.sum(final_mask) / final_mask.size * 100:.2f}%")

    # 7. Save results
    save_nifti_file(final_mask, image, output_file)
    print("Brain extraction completed!")


# Direct execution using provided paths
if __name__ == "__main__":
    input_path = "/Users/liulaolao/Desktop/R2/HeadCtSample_2022_volume 2.nii"
    output_path = "/Users/liulaolao/Desktop/R2/brain_extraction_result.nii.gz"

    print("Starting brain CT image processing...")
    print(f"Input file: {input_path}")
    print(f"Output file: {output_path}")

    extract_brain_region(input_path, output_path)