{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "from matplotlib.pylab import rcParams\n", "from sklearn.datasets import make_blobs\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.metrics import accuracy_score, confusion_matrix\n", "from sklearn.svm import SVC\n", "\n", "%matplotlib inline\n", "rcParams['figure.figsize'] = 9, 6" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Generate Imbalanced Datasets" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# make 3-class dataset for classification\n", "center1 = [[0, 0]]\n", "X1, y1 = make_blobs(n_samples=800, centers=center1, random_state=42)\n", "center2 = [[0, 0], [1, 1]]\n", "X2, y2 = make_blobs(n_samples=200,\n", " centers=center2,\n", " random_state=42,\n", " cluster_std=0.3)\n", "X = np.vstack((X1, X2))\n", "y = np.hstack((y1, y2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* 在圖上紅色點表示為壞人(1),藍色點是好人(0)\n", "* 今天身為一位警察該如何透過兩個 feature 從好人中去抓出壞人呢?\n", "* 這邊產生的資料只有 20% 是壞人,80% 是好人" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "color = \"br\"\n", "color = [color[y[i]] for i in range(len(y))]\n", "plt.scatter(X[:, 0], X[:, 1], c=color, alpha=.5)\n", "plt.xlabel('Feature1')\n", "plt.ylabel('Feature2')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Predict with Logistic Regression" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = LogisticRegression()\n", "model.fit(X, y)\n", "prediction = model.predict(X)\n", "print('Accuracy:%.2f' % model.score(X, y))\n", "color = \"br\"\n", "label = ['Good Sample', 'Bad Sample']\n", "# color = [color[prediction[i]] for i in range(len(y))]\n", "# label = [label[prediction[i]] for i in range(len(y))]\n", "for i in range(2):\n", " plt.scatter(X[y == i][:, 0],\n", " X[y == i][:, 1],\n", " c=color[i],\n", " label=label[i],\n", " alpha=.5)\n", "plt.legend()\n", "plt.xlabel('Feature1')\n", "plt.ylabel('Feature2')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 利用預測機率代替預測結果\n", "* ### 在分類問題中,不管 logistic regression 還是 svm 都是在求解 maxmum likelyhood,因此最後算出來的邊界都是一個軟性邊界(機率分佈) \n", "### 例如在上面的 logistic regression 模型預測機率如下 model.predict_proba(X) 所示 \n", "### 第一欄代表被預測為 0(好人)的機率,第二欄代表被預測為 1(壞人)的機率。\n", "### 在一般 model.predict() 給出的預測值是直接比較第一欄和第二欄的大小決定的,好人的機率大於壞人的機率,就會預測這筆資料為 0。\n", "### 但是在 imbalanced data 的情況下,我們會去調整 threshold 以達到目的。 \n", "### 假設我們在這個 case 中想要抓出大部分的壞人,那也許我們可以設定被預測為 1 的機率大於 20% 那我們就認定他是壞人,如此的話就可以抓出大部分的壞人,那 20% 就是我們所設定的 threshold。\n", "### 之前有提到一般以 F1_score 來權衡 precision 和 recall,因此在這樣的 case 下我們會去調整 threshold 來求得最大的 F1_score。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(model.predict_proba(X))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(model.predict_proba(X)[0], '--->', model.predict(X)[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 定義 function \n", "# 比較各個不同 threshold 所形成的分類圖。\n", "* ### 這邊主要是要將分類問圖形化給大家看,如果看不懂的話只要知道fucntion的輸入為\n", " * ### 1. model_type: 可選擇 Logistic 或 SVC\n", " * ### 2. plot_dict: key 值為設定的 threshold,value 代表的是 subplot 的 index" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def auc(model_type, plot_dict):\n", " model_dict = {\n", " 'Logistic': LogisticRegression(),\n", " 'SVC': SVC(kernel='rbf', probability=True)\n", " }\n", "\n", " model = model_dict[model_type]\n", " model.fit(X, y)\n", " tpr_temp = []\n", " fpr_temp = []\n", " precision_temp = []\n", " recall_temp = []\n", " f1_score_temp = []\n", " for threshold in plot_dict:\n", " # 將預測機率大於 threshold 的分類為 1\n", " pred = (model.predict_proba(X)[:, 1] >= threshold).astype(int)\n", "\n", " # 先算出 confusion matrix 再算出其他評估值\n", " cm = confusion_matrix(y, pred)\n", " tpr = cm[1][1] / cm.sum(axis=1)[1]\n", " fpr = cm[0][1] / cm.sum(axis=1)[0]\n", " pre = cm[1][1] / cm.sum(axis=0)[1] # precision\n", " re = cm[1][1] / cm.sum(axis=1)[1] # recall\n", " f1 = 2 * (pre * re) / (pre + re) # f1-score\n", "\n", " # 畫圖\n", " color = \"br\"\n", " color = [color[pred[j]] for j in range(len(pred))]\n", " plt.subplot(plot_dict[threshold])\n", " plt.tight_layout()\n", " plt.scatter(X[:, 0], X[:, 1], c=color, alpha=.5)\n", " plt.title('threshold:%.2f' % threshold + '\\n' +\n", " 'accuracy:%.2f' % accuracy_score(y, pred) + '\\n' +\n", " 'precision:%.2f' % pre + '\\n' + 'recall:%.2f' % re + '\\n' +\n", " 'f1-score:%.2f' % f1)\n", "\n", " # 將各種評估值存取成 list 回傳\n", " tpr_temp.append(tpr)\n", " fpr_temp.append(fpr)\n", " precision_temp.append(pre)\n", " recall_temp.append(re)\n", " f1_score_temp.append('nan' if f1 != f1 else round(f1, 3))\n", "\n", " return f1_score_temp, tpr_temp, fpr_temp, precision_temp, recall_temp" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Logistic Regression with varied threshold" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(10.5, 7))\n", "\n", "plot_dict = {\n", " 0.85: 241,\n", " 0.8: 242,\n", " 0.6: 243,\n", " 0.5: 244,\n", " 0.4: 245,\n", " 0.3: 246,\n", " 0.2: 247,\n", " 0.1: 248\n", "}\n", "# 自訂的 function 會回傳不同 threshold 的 f1_score, tpr, fpr, precision 和 recall\n", "f1_0, tpr0, fpr0, precision0, recall0 = auc('Logistic', plot_dict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Precision & Recall & F1_Score\n", "* 圖形化 Precsion Recall curve,並在個點旁邊標出對應的 f1_score " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(7, 7))\n", "plt.scatter(recall0, precision0, s=100)\n", "plt.xlabel('Recall')\n", "plt.ylabel('Precision')\n", "plt.title('Precision-Recall Curve')\n", "# Show F1 Score on each point\n", "for i, txt in enumerate(f1_0):\n", " plt.annotate(txt, (recall0[i] - 0.05, precision0[i] + 0.015), fontsize=14)\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# SVM with Varied threshold" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(10.5, 7))\n", "plot_dict = {\n", " 0.1: 241,\n", " 0.2: 242,\n", " 0.3: 243,\n", " 0.4: 244,\n", " 0.5: 245,\n", " 0.6: 246,\n", " 0.8: 247,\n", " 1: 248\n", "}\n", "f1_1, tpr1, fpr1, precision1, recall1 = auc('SVC', plot_dict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Compare Performance between Logistic Regression & SVM\n", "* 比較 logistic regression 和 SVM 的 auc score,大家應該可以清楚地看出來哪個模型在這個問題中表現比較好。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(7, 7))\n", "plt.plot(recall0, precision0, '.', label='Logistic Regression', markersize=20)\n", "plt.plot(recall1, precision1, '.', label='SVM', markersize=20)\n", "plt.xlabel('Recall')\n", "plt.ylabel('Precision')\n", "plt.legend()\n", "plt.title('Precision-Recall Curve')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(7, 7))\n", "plt.plot(fpr0, tpr0, '.', label='Logistic Regression', markersize=20)\n", "plt.plot(fpr1, tpr1, '.', label='SVM', markersize=20)\n", "plt.xlabel('FPR')\n", "plt.ylabel('TPR')\n", "plt.legend()\n", "plt.title('ROC')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* 有興趣的同學可以查詢 ROC 曲線\n", "* roc curve 底下的面積 AUC 是很常見的一種 evaluation 方式" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.2" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }