diff --git a/.gitignore b/.gitignore index d4150f0..fb7775f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,5 @@ .vs/ *.onnx +venv/ +.vscode/ +*.jpg diff --git a/py/segm.py b/py/segm.py new file mode 100644 index 0000000..a3d9af9 --- /dev/null +++ b/py/segm.py @@ -0,0 +1,45 @@ +import glob +import os + +import cv2 +import numpy as np +import onnxruntime + + +def main(): + session = onnxruntime.InferenceSession( + "../TongueSegmentation/ResUNet.onnx", + providers=["CPUExecutionProvider"], + ) + + for input_ in session.get_inputs(): + print(input_.name, input_.shape, input_.type) + + for output in session.get_outputs(): + print(output.name, output.shape, output.type) + + image_paths = glob.glob(os.path.join(".", "*.jpg")) + + for image_path in image_paths: + x = cv2.imread(image_path) + x = cv2.resize(x, (256, 256)) + cv2.imshow("input", x) + x = x.astype(np.float32) / 255.0 + x = np.array([x.transpose(2, 0, 1)]) # HWC -> CHW + print(x.shape) + + input_name = session.get_inputs()[0].name + output_names = [output.name for output in session.get_outputs()] + output = session.run(output_names, {input_name: x}) + output_array = np.array(output[0][0][0]) + + probability_map = 1.0 / (1.0 + np.power(np.e, -output_array)) + _, tongue_mask = cv2.threshold(probability_map, 0.5, 1, cv2.THRESH_BINARY) + tongue_mask = tongue_mask.astype(np.uint8) + + cv2.imshow("tongue_mask", tongue_mask * 255) + cv2.waitKey(0) + + +if __name__ == "__main__": + main()