Newer
Older
TongueSegmentation / py / segm.py
import glob
import os

import cv2
import numpy as np
import onnxruntime


def main():
    session = onnxruntime.InferenceSession(
        "../TongueSegmentation/ResUNet.onnx",
        providers=["CPUExecutionProvider"],
    )

    for input_ in session.get_inputs():
        print(input_.name, input_.shape, input_.type)

    for output in session.get_outputs():
        print(output.name, output.shape, output.type)

    image_paths = glob.glob(os.path.join(".", "*.jpg"))

    for image_path in image_paths:
        x = cv2.imread(image_path)
        x = cv2.resize(x, (256, 256))
        cv2.imshow("input", x)
        x = x.astype(np.float32) / 255.0
        x = np.array([x.transpose(2, 0, 1)])  # HWC -> CHW
        print(x.shape)

        input_name = session.get_inputs()[0].name
        output_names = [output.name for output in session.get_outputs()]
        output = session.run(output_names, {input_name: x})
        output_array = np.array(output[0][0][0])

        probability_map = 1.0 / (1.0 + np.power(np.e, -output_array))
        _, tongue_mask = cv2.threshold(probability_map, 0.5, 1, cv2.THRESH_BINARY)
        tongue_mask = tongue_mask.astype(np.uint8)

        cv2.imshow("tongue_mask", tongue_mask * 255)
        cv2.waitKey(0)


if __name__ == "__main__":
    main()