diff --git a/analysis.py b/analysis.py index 470a567..720bf55 100644 --- a/analysis.py +++ b/analysis.py @@ -39,7 +39,8 @@ "HandyTCC_Web", ] # データセットの種類 input_list = ["shape", "color", "texture", "all"] # 入力の種類 -target_list = ["A", "B", "C", "Total"] # ターゲットスコア +target_list = ["A", "B", "C", "Total", "Raw", "Conv"] # ターゲットスコア +# target_list = ["Raw", "Conv"] # ターゲットスコア追加解析 model_list = [ RandomForestClassifier(), AdaBoostClassifier(), @@ -85,6 +86,7 @@ # print("Using all features...") x = all_data.loc[:, "shape-width":"fiveClick-tip-b-kurtosis"] # 全データ scores = all_data.loc[:, "A01":"C08"] # スコア + scores["id"] = all_data["ID"] # IDを追加 invert_list = ["A01", "A02", "A03"] for invert in invert_list: scores[invert] = 5 - scores[invert] # 逆転スコア @@ -93,10 +95,43 @@ scores["C"] = scores.loc[:, "C01":"C08"].sum(axis=1) # C群 scores["Total"] = scores.loc[:, "A":"C"].sum(axis=1) # 総合スコア target = df.loc[df_index, "Target"] # ターゲットスコア - # target = 'Total' - threshold = scores[target].median() # 中央値を閾値とする - # print(f'Threshold for {target}: {threshold}') + # 高ストレス ラベル決定 + scores["label"] = 0 + if target == "Raw": + for idx, row in scores.iterrows(): + if row["B"] >= 31 or (row["B"] >= 23 and (row["A"] + row["C"]) >= 39): + scores.loc[idx, "label"] = 1 # ラベル付け + elif target == "Conv": + for idx, row in scores.iterrows(): + A1 = row["A01"] + row["A02"] +row["A03"] # A群の1-3 + A2 = 15 - (row["A08"] + row["A09"] +row["A10"]) + B1 = row["B07"] + row["B08"] +row["B09"] # B群の7-9 + B2 = row["B10"] + row["B11"] +row["B12"] + B3 = row["B13"] + row["B14"] +row["B16"] + B4 = row["B27"] + B5 = row["B29"] + C1 = 15 - (row["C01"] + row["C04"] +row["C07"]) + C2 = 15 - (row["C02"] + row["C05"] +row["C08"]) + ConvA1 = 1 if A1 >= 12 else 2 if A1 >= 10 else 3 if A1 >= 8 else 4 if A1 >= 6 else 5 + ConvA2 = 1 if A2 <= 4 else 2 if A2 <= 6 else 3 if A2 <= 8 else 4 if A2 <= 10 else 5 + ConvB1 = 1 if B1 >= 11 else 2 if B1 >= 8 else 3 if B1 >= 5 else 4 if B1 >= 4 else 5 + ConvB2 = 1 if B2 >= 10 else 2 if B2 >= 8 else 3 if B2 >= 5 else 4 if B2 >= 4 else 5 + ConvB3 = 1 if B3 >= 10 else 2 if B3 >= 7 else 3 if B3 >= 5 else 4 if B3 >= 4 else 5 + ConvB4 = 1 if B4 >= 4 else 2 if B4 >= 3 else 3 if B4 >= 2 else 5 + ConvB5 = 1 if B5 >= 4 else 2 if B5 >= 3 else 3 if B5 >= 2 else 5 + ConvC1 = 1 if C1 <= 4 else 2 if C1 <= 6 else 3 if C1 <= 8 else 4 if C1 <= 10 else 5 + ConvC2 = 1 if C2 <= 5 else 2 if C2 <= 7 else 3 if C2 <= 9 else 4 if C2 <= 11 else 5 + ConvAC = ConvA1 + ConvA2 + ConvC1 + ConvC2 + ConvB = ConvB1 + ConvB2 + ConvB3 + ConvB4 + ConvB5 + if ConvB <= 11 or (ConvB <= 16 and ConvAC <= 8): + scores.loc[idx, "label"] = 1 + else: + threshold = scores[target].median() # 中央値を閾値とする + # print(f'Threshold for {target}: {threshold}') + scores.loc[scores[target] >= threshold, "label"] = 1 # ラベル付け + + # target = 'Total' # scores[target].plot.hist(bins=20, edgecolor='black') # ヒストグラムを描画 # import matplotlib.pyplot as plt # plt.title(f'{target} Score Distribution') @@ -105,9 +140,10 @@ # plt.show() # ヒストグラムを表示 # exit() - scores["label"] = 0 - scores.loc[scores[target] >= threshold, "label"] = 1 # ラベル付け - # print(scores.head(3)) + # print(scores[["id", "label"]]) + # scores[["id", "label"]].to_csv("labels.csv", index=False) # ラベルをCSVファイルに保存 + # print("num positive labels:", scores["label"].sum()) + # exit() return x, scores["label"]