diff --git a/main.py b/main.py index 71369b2..8939656 100644 --- a/main.py +++ b/main.py @@ -23,6 +23,11 @@ XGBOOST_COLOR = tuple(map(int, os.getenv("XGBOOST_COLOR", "255,0,0").split(","))) # Default: Red LIGHTGBM_COLOR = tuple(map(int, os.getenv("LIGHTGBM_COLOR", "0,0,255").split(","))) # Default: Blue +# Get model execution settings +CONV_ENABLED = os.getenv("CONV_ENABLED", "True").lower() == "true" +XGBOOST_ENABLED = os.getenv("XGBOOST_ENABLED", "True").lower() == "true" +LIGHTGBM_ENABLED = os.getenv("LIGHTGBM_ENABLED", "True").lower() == "true" + def load_model(model_path, model_type="lgb"): with open(model_path, "rb") as model_file: @@ -131,25 +136,32 @@ rows.append(row) if rows: - fieldnames = list(rows[0].keys()) + [ - "conv_stethoscope_x", - "conv_stethoscope_y", - "Xgboost_stethoscope_x", - "Xgboost_stethoscope_y", - "lightGBM_stethoscope_x", - "lightGBM_stethoscope_y", - ] + fieldnames = list(rows[0].keys()) + if CONV_ENABLED: + fieldnames.extend(["conv_stethoscope_x", "conv_stethoscope_y"]) + if XGBOOST_ENABLED: + fieldnames.extend(["Xgboost_stethoscope_x", "Xgboost_stethoscope_y"]) + if LIGHTGBM_ENABLED: + fieldnames.extend(["lightGBM_stethoscope_x", "lightGBM_stethoscope_y"]) - lgb_model_x = load_model("./models/lgb_stethoscope_calc_x_best_model-Fold4.pkl") - lgb_model_y = load_model("./models/lgb_stethoscope_calc_y_best_model-Fold4.pkl") - xg_model_x = load_model("./models/xg_stethoscope_calc_x_best_model-Fold4.pkl") - xg_model_y = load_model("./models/xg_stethoscope_calc_y_best_model-Fold4.pkl") + if LIGHTGBM_ENABLED: + lgb_model_x = load_model("./models/lgb_stethoscope_calc_x_best_model-Fold4.pkl") + lgb_model_y = load_model("./models/lgb_stethoscope_calc_y_best_model-Fold4.pkl") + if XGBOOST_ENABLED: + xg_model_x = load_model("./models/xg_stethoscope_calc_x_best_model-Fold4.pkl") + xg_model_y = load_model("./models/xg_stethoscope_calc_y_best_model-Fold4.pkl") with open(csv_path, "w", newline="") as csvfile: writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() - prev_values = {"conv": (180, 180), "lightGBM": (180, 180), "Xgboost": (180, 180)} + prev_values = {} + if CONV_ENABLED: + prev_values["conv"] = (180, 180) + if LIGHTGBM_ENABLED: + prev_values["lightGBM"] = (180, 180) + if XGBOOST_ENABLED: + prev_values["Xgboost"] = (180, 180) for row in rows: source_points = np.array( @@ -165,8 +177,9 @@ for key in prev_values: row[f"{key}_stethoscope_x"], row[f"{key}_stethoscope_y"] = prev_values[key] else: - conv_stethoscope = calc_position.calc_affine(source_points, *stethoscope_point) - row["conv_stethoscope_x"], row["conv_stethoscope_y"] = conv_stethoscope + if CONV_ENABLED: + conv_stethoscope = calc_position.calc_affine(source_points, *stethoscope_point) + row["conv_stethoscope_x"], row["conv_stethoscope_y"] = conv_stethoscope normalized_points = normalize_quadrilateral_with_point(source_points.flatten(), stethoscope_point) row_convert = { @@ -180,12 +193,13 @@ input_data = pd.DataFrame([row_convert]) input_columns = list(row_convert.keys()) - for model_name, models in [ - ("lightGBM", (lgb_model_x, lgb_model_y)), - ("Xgboost", (xg_model_x, xg_model_y)), - ]: - row[f"{model_name}_stethoscope_x"] = int(models[0].predict(input_data[input_columns])[0]) - row[f"{model_name}_stethoscope_y"] = int(models[1].predict(input_data[input_columns])[0]) + if LIGHTGBM_ENABLED: + row["lightGBM_stethoscope_x"] = int(lgb_model_x.predict(input_data[input_columns])[0]) + row["lightGBM_stethoscope_y"] = int(lgb_model_y.predict(input_data[input_columns])[0]) + + if XGBOOST_ENABLED: + row["Xgboost_stethoscope_x"] = int(xg_model_x.predict(input_data[input_columns])[0]) + row["Xgboost_stethoscope_y"] = int(xg_model_y.predict(input_data[input_columns])[0]) for key in prev_values: prev_values[key] = (row[f"{key}_stethoscope_x"], row[f"{key}_stethoscope_y"]) @@ -204,7 +218,13 @@ results_dir = "images/body/results" os.makedirs(results_dir, exist_ok=True) - dirs = {"marked": "marked_images", "conv": "conv", "Xgboost": "Xgboost", "lightGBM": "lightGBM"} + dirs = {"marked": "marked_images"} + if CONV_ENABLED: + dirs["conv"] = "conv" + if XGBOOST_ENABLED: + dirs["Xgboost"] = "Xgboost" + if LIGHTGBM_ENABLED: + dirs["lightGBM"] = "lightGBM" # Create all necessary directories for dir_name in dirs.values(): @@ -212,7 +232,7 @@ os.makedirs(os.path.join(results_dir, f"{dir_name}_with_trajectory"), exist_ok=True) os.makedirs(os.path.join(results_dir, f"{dir_name}_without_trajectory"), exist_ok=True) - points = {key: [] for key in ["conv", "Xgboost", "lightGBM"]} + points = {key: [] for key in dirs.keys() if key != "marked"} colors = {"conv": CONV_COLOR, "Xgboost": XGBOOST_COLOR, "lightGBM": LIGHTGBM_COLOR} for _, row in df.iterrows():