diff --git a/config.py b/config.py index c6ad6fe..e20cb0a 100644 --- a/config.py +++ b/config.py @@ -11,8 +11,10 @@ # Model execution settings CONV_ENABLED = True -XGBOOST_ENABLED = False +XGBOOST_ENABLED = True LIGHTGBM_ENABLED = True +CATBOOST_ENABLED = True +NGBOOST_ENABLED = True NORMALIZE_ENABLED = False POSENET_ENABLED = False RTMPOSE_ENABLED = True diff --git a/main.py b/main.py index 62173aa..5e5f2d6 100644 --- a/main.py +++ b/main.py @@ -493,22 +493,23 @@ ) if XGBOOST_ENABLED: xg_model_x = load_model( - "./models/xg_stethoscope_calc_x_best_model-Fold4.pkl" + "./models/XGBoost/stethoscope_calc_x_best_model.pkl" ) xg_model_y = load_model( - "./models/xg_stethoscope_calc_y_best_model-Fold4.pkl" + "./models/XGBoost/stethoscope_calc_y_best_model.pkl" ) # Load models if EARSNET_ENABLED: earsnet_predictor = load_earsnet_model( - model_path="models/EARSNet/best_model.pth", + model_path="models/EARSNet/best_model-50-F2.pth", model_type="resnet", - model_version="18", + model_version="50", ) - with open(csv_path, "w", newline="") as csvfile, open( - normalized_csv_path, "w", newline="" - ) as norm_csvfile: + with ( + open(csv_path, "w", newline="") as csvfile, + open(normalized_csv_path, "w", newline="") as norm_csvfile, + ): writer = csv.DictWriter(csvfile, fieldnames=fieldnames) norm_writer = csv.DictWriter(norm_csvfile, fieldnames=fieldnames) writer.writeheader() diff --git a/requirements.txt b/requirements.txt index 9420f80..43eb7ba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,9 @@ joblib==1.4.2 lightgbm==4.5.0 xgboost==2.1.1 +catboost==1.2.7 +ngboost==0.5.1 scipy==1.9.3 numpy==1.24.0 -scikit-learn == 1.5.1 +scikit-learn==1.5.1 matplotlib==3.9.2 \ No newline at end of file