{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Predicting Substance Dependence from Multi-Site Data\n", "\n", "This notebook explores an example using data from the ENIGMA Addiction Consortium. Within this notebook we will be trying to predict between participants with any drug dependence (alcohol, cocaine, etc...), vs. healthy controls. The data for this is sources from a number of individual studies from all around the world and with different scanners etc... making this a challenging problem with its own unique considerations. Structural FreeSurfer ROIs are used. The raw data cannot be made available due to data use agreements.\n", "\n", "The key idea explored in this notebook is a particular tricky problem introduced by case-only sites, which are subject's data from site's with only case's. This introduces a confound where you cannot easily tell if the classifier is learning to predict site or the dependence status of interest.\n", "\n", "Featured in this notebook as well are some helpful\n", "code snippets for converting from BPt versions earlier than BPt 2.0 to valid BPt 2.0+ code." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/sage/anaconda3/envs/bpt/lib/python3.9/site-packages/nilearn/datasets/__init__.py:93: FutureWarning: Fetchers from the nilearn.datasets module will be updated in version 0.9 to return python strings instead of bytes and Pandas dataframes instead of Numpy arrays.\n", " warn(\"Fetchers from the nilearn.datasets module will be \"\n" ] } ], "source": [ "import pandas as pd\n", "import BPt as bp\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "from warnings import simplefilter\n", "from sklearn.exceptions import ConvergenceWarning\n", "simplefilter(\"ignore\", category=ConvergenceWarning)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Loading / Preparing Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As a general tip it can be useful to wrap something like a series of steps that load multiple DataFrames and merge them into a function. That said, it is often useful when first writing the function to try it taking advantage of the interactive-ness of the jupyter-notebook." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def load_base_df():\n", " '''Loads and merges a DataFrame from multiple raw files'''\n", " \n", " na_vals = [' ', ' ', 'nan', 'NaN']\n", " \n", " # Load first part of data\n", " d1 = pd.read_excel('/home/sage/Downloads/e1.xlsx', na_values=na_vals)\n", " d2 = pd.read_excel('/home/sage/Downloads/e2.xlsx', na_values=na_vals)\n", " df = pd.concat([d1, d2])\n", "\n", " df['Subject'] = df['Subject'].astype('str')\n", " df.rename({'Subject': 'subject'}, axis=1, inplace=True)\n", " df.set_index('subject', inplace=True)\n", "\n", " # Load second part\n", " df2 = pd.read_excel('/home/sage/Downloads/e3.xlsx', na_values=na_vals)\n", " df2['Subject ID'] = df2['Subject ID'].astype('str')\n", " df2.rename({'Subject ID': 'subject'}, axis=1, inplace=True)\n", " df2.set_index('subject', inplace=True)\n", "\n", " # Merge\n", " data = df2.merge(df, on='subject', how='outer')\n", " \n", " # Rename age and sex\n", " data = data.rename({'Sex_y': 'Sex', 'Age_y': 'Age'}, axis=1)\n", " \n", " # Remove subject name to obsficate\n", " data.index = list(range(len(data.index)))\n", " \n", " return data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(3525, 224)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = load_base_df()\n", "df.shape" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Setting NaN threshold to: 1762.5\n", "Dropped 20 Columns\n", "Dropped 3 Columns\n", "Dropped 36 Columns\n", "Num. categorical variables in dataset: 3\n", "to binary cols: ['Dependent any drug', 'Handedness', 'Sex']\n", "Dropped 479 Rows\n", "scope covars = ['Age', 'Education', 'Handedness', 'ICV', 'Sex']\n", "Setting NaN threshold to: 82.5\n", "Dropped 38 Rows\n", "Dropped 5 Rows\n" ] } ], "source": [ "# Cast to dataset\n", "data = bp.Dataset(df)\n", "data.verbose = 1\n", "\n", "# Drop non relevant columns\n", "data.drop_cols_by_nan(threshold=.5, inplace=True)\n", "data.drop_cols(scope='Dependent', inclusions='any drug', inplace=True)\n", "data.drop_cols(exclusions=['Half', '30 days', 'Site ',\n", " 'Sex_', 'Age_', 'Primary Drug', 'ICV.'], inplace=True)\n", "\n", "# Set binary vars as categorical\n", "data.auto_detect_categorical(inplace=True)\n", "data.to_binary(scope='category', inplace=True)\n", "print('to binary cols:', data.get_cols('category'))\n", "\n", "# Set target and drop any NaNs\n", "data.set_role('Dependent any drug', 'target', inplace=True)\n", "data.drop_nan_subjects('target', inplace=True)\n", "\n", "# Save this set of vars under scope covars\n", "data = data.add_scope(['ICV', 'Sex', 'Age', 'Education', 'Handedness'], 'covars')\n", "print('scope covars = ', data.get_cols('covars'))\n", "\n", "# Set site as non input\n", "data = data.set_role('Site', 'non input')\n", "data = data.ordinalize(scope='non input')\n", "\n", "# Drop subjects with too many NaN's and big outliers\n", "data.drop_subjects_by_nan(threshold=.5, scope='all', inplace=True)\n", "data.filter_outliers_by_std(n_std=10, scope='float', inplace=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Legacy Equivilent pre BPt 2.0 loading code**\n", "\n", "```\n", "ML = BPt_ML('Enigma_Alc',\n", " log_dr = None,\n", " n_jobs = 8)\n", " \n", "ML.Set_Default_Load_Params(subject_id = 'subject',\n", " na_values = [' ', ' ', 'nan', 'NaN'],\n", " drop_na = .5)\n", " \n", "ML.Load_Data(df=df,\n", " drop_keys = ['Unnamed:', 'Site', 'Half', 'PI', 'Dependent',\n", " 'Surface Area', 'Thickness', 'ICV', 'Subcortical',\n", " 'Sex', 'Age', 'Primary Drug', 'Education', 'Handedness'],\n", " inclusion_keys=None,\n", " unique_val_warn=None,\n", " clear_existing=True)\n", "\n", "ML.Load_Targets(df=df,\n", " col_name = 'Dependent any drug',\n", " data_type = 'b')\n", "\n", "ML.Load_Covars(df=df,\n", " col_name = ['ICV', 'Sex', 'Age'],\n", " drop_na = False,\n", " data_type = ['f', 'b', 'f'])\n", "\n", "ML.Load_Covars(df = df,\n", " col_name = ['Education', 'Handedness'],\n", " data_type = ['f', 'b'],\n", " drop_na = False,\n", " filter_outlier_std = 10)\n", "\n", "ML.Load_Strat(df=df,\n", " col_name=['Sex', 'Site'],\n", " binary_col=[True, False]\n", " )\n", "\n", "ML.Prepare_All_Data()\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's take a look at what we prepared. We can see that visually the Dataset is grouped by role." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's plot some variables of interest" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dependent any drug: 3003 rows\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "data.plot('target')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Site: 3003 rows\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "data.plot('Site')" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Site: 3003 rows\n", "Dependent any drug: 3003 rows\n", "Plotting 3003 overlap valid subjects.\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "data.plot_bivar('Site', 'Dependent any drug')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That bi-variate plot is a little hard to read... we can make it bigger though easy enough. Let's say we also wanted to save it (we need to add show=False)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Site: 3003 rows\n", "Dependent any drug: 3003 rows\n", "Plotting 3003 overlap valid subjects.\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(10, 10))\n", "data.plot_bivar('Site', 'Dependent any drug', show=False)\n", "plt.savefig('site_by_drug.png', dpi=200)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Okay next, we are going to define a custom validation strategy to use, which is essentially going to preserve subjects within the same train and test fold based on site.\n", "\n", "\n", "**Legacy Code**\n", "```\n", "from BPt import CV\n", "group_site = CV(groups='Site')\n", "```" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CVStrategy(groups='Site')" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "group_site = bp.CVStrategy(groups='Site')\n", "group_site" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We then can use this CV strategy as an argument when defining the train test split.\n", "\n", "**Legacy Code**\n", "```\n", "ML.Train_Test_Split(test_size =.2, cv=group_site)\n", "```" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Performing test split on: 3003 subjects.\n", "random_state: 5\n", "Test split size: 0.2\n", "\n", "Performed train/test split\n", "Train size: 2379\n", "Test size: 624\n" ] } ], "source": [ "data = data.set_test_split(size=.2, cv_strategy=group_site, random_state=5)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dependent any drug: 2379 rows\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Dependent any drug: 624 rows\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFwCAYAAACGt6HXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAZUklEQVR4nO3deZxkZX3v8c9XBlABQUAQQWURQVyDGJeoF8VlIHJxQcAVjcvFJTeaGHFJEMw1iUvU+NJI8CqgIKJGIgqoKCiRRGAwgEMiOirKJjggw4CIgL/8cU5L0fRSM071U939eb9e/eqqc5469aunqr596jmnn0pVIUmae3drXYAkLVYGsCQ1YgBLUiMGsCQ1YgBLUiMGsCQ1YgBraEleluTbresYtSSV5EGt65iQ5ElJLlmH2zstycH95XX6nCZ5UZKvravtLXQGcGNJLk1yc5LVSa5P8u9JDkmyoJ+bcQu5VpIcnuTW/vlfneQHST6cZJuJNlX1b1W1y5DbOm62dlW1d1Uduw5q375/HpcMbPv4qnrG77vtxWJBv8nnkX2rahPggcDfA4cCH29bkqaSZL0RbPbE/vnfHHgOcF/g/MEQXhfS8T0/RnwyxkhVraqqk4EDgYOTPAwgyYZJ3pfkZ0muTnJkknv06/ZMcnmStyVZ2e9Rv2him0Pe9i+SXJPkqiQvH7jtFklOTnJDknOBnQbrTbJrktOTXJfkkiQHDKw7JslHkpzS79mdk2Snft1ZfbMLk9yY5MDJfZFkpyRnJLm2f1zHJ9lsYP2lSd6U5KIkq5KcmOTu/brlSfYdaLt+v40/mKrfk/xl/9ivTPInk9Ydk+SjSU5NchPwlCTfTPLKgTZ3+hif5Bl9f6xK8k9JvjXYfjpVdWtVXUz3/P8C+It+e3smuXxg+4cmuaLv10uS7JVkKfA24MC+Ty/s234zybuSnA38Cthxcv1ds3y4r/f7Sfaa1M9PG7g+uJc98Txe39/n46foiyckOa/f9nlJnjCw7ptJ/ibJ2f1j+VqSLWfrp4XEAB5DVXUucDnwpH7R3wMPBh4FPAjYFjhs4Cb3Bbbslx8MHJVklzW47ab98lcAH0ly737dR4BfA9sAf9L/AJBkI+B04NPAVsBBwD8l2W1g2wcBRwD3BlYA7+of35P79Y+sqo2r6sQpuiHA3wH3Ax4C3B84fFKbA4ClwA7AI4CX9cs/Cbx4oN0+wFVV9Z93uZMuuN4EPB3YGXja5DbAC/vaNwFmHC/tA+TzwFuBLYBLgCfMdJvJqup24Ivc8fwPbn8X4PXAY/q95mcCl1bVV4C/pdub3riqHjlws5cAr+7r/+kUd/lY4Ed0r6F3AF9IsvkQpU48j5v19/kfk2rdHDgF+BBdX7wfOCXJFgPNXgi8nO41tAHdc7FoGMDj60pg8yShe/O8saquq6rVdG+0gya1/+uquqWqvkX3oj9gyNveCryz3/s6FbgR2CXdR+3nAYdV1U1VtRwYHDd8Ft0b/+iquq0Pt38Bnj/Q5qSqOreqbgOOp/sjMJSqWlFVp/eP6Rd0b97/NanZh6rqyqq6DvjSwPaPA/ZJcq/++kuAT01zVwcAR1fV8qq6ibuGPMAXq+rsqvptVf16ltL3AS6uqi/0j/tDwM9nuc1UrqQbkpjsdmBDYLck61fVpVX1o1m2dUxVXdw/T7dOsf4a4IP9a+BEuj8af7wWNU/2x8APq+pT/X2fAHwf2HegzdFV9YOquhn4LGvwGlkIlszeRI1sC1wH3Ae4J92Y4MS6AINjkb/sw2PCT+n2HIe57bV9UEz4FbBxf9slwGWTtjvhgcBjk1w/sGwJdw66weCZ2O5QkmwN/CPdXuAmdDsLv5zUbPL27wdQVVf2H7mfl+QkYG/gz6a5q/sB5w9cn2oP8bIplk3nfoPtq6oGhw/WwMTzfydVtSLJG+j+UDw0yVeBP6+qK2fY1mz1X1F3npVr4vXz+7ofd+3Pn9I9tglr/RpZCNwDHkNJHkP3Iv02sBK4GXhoVW3W/2xaVYMv1Hv3QwITHkC3BzXMbafzC+A2uo/+g9udcBnwrYHtTnwMfc0aP+Cp/S1QwMOr6l50QwqZ+SZ3cmx/m+cD/1FVV0zT7iqmf4wTJk8ZeBPdH7YJ9520ve0mrvSfQrZjDaQ7ULYv8G9Tra+qT1fVE+n+CBbw7mnqZJblE7bNwF9o7nj9wMyPdbbtXtnXOOgBwHTPxaJjAI+RJPdK8izgM8BxVfW9qvot8DHgA0m26tttm+SZk25+RJINkjyJbnjgc2tw27voxyG/ABye5J792O7BA02+DDw4yUv6g1zrJ3lMkocM+XCvBnacYf0mdMMhq5JsC/zlkNud8K/A7nR7vp+cod1ngZcl2S3JPenGQGdzAfDcvl8eRDd2PuEU4OFJnp3u9KzXcefQmlaSJX3/ndDf5v1TtNklyVOTbEg3Pn8z8Nt+9dXA9lnzMx22Av5v/xw+n27M/dSBx3pQv24PYP+B2/2iv+/pnsdT6V4jL+wf24HAbnSvHWEAj4svJVlNt1f5dro33ssH1h9KdxDrO0luAL4ODJ4X+nO6j+dX0o21HlJV3x/ytjN5Pd1Hwp8DxwBHT6zox5OfQTeefGXf5t1045PDOBw4Nt25zwdMsf4IugBdRRdqXxhyuxP13Uw3Jr3DTLetqtOADwJn0PXTGUNs/gPAb+gC71i6Pp/Y3kq6ve73ANfSBc4y4JYZtndgkhvpHuvJ/e0ePc2wwoZ0B1ZX0vX5VnQH/AA+1/++Nsl3h3gcE86hOwC5ku5g4/5VdW2/7q/pzn75Jd1z8umBx/qrvv3Z/fP4uMGN9tt4Ft3ZHNcCbwae1feRgDgh+/yWZE+6veU1+pi7GCQ5DHhwVb141sajq+FudGe0vKiqzmxVh8aTe8BakPpToF4BHNXgvp+ZZLN+mOBtdGPX35nrOjT+DGAtOEleRTecc1pVnTVb+xF4PN15tSvpDqY9ux8Ske7EIQhJasQ9YElqZN79I8bSpUvrK1/5SusyJGlNTHkO+7zbA1650jNYJC0M8y6AJWmhMIAlqREDWJIaMYAlqREDWJIaMYAlqREDWJIaMYAlqREDWJIaMYAlqREDWJIaMYAlqREDWJIaMYAlqZF5940YG913h9r1JUe0LkMaC+e/96WtS9BwFsZ8wJK0UBjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjYw0gJMsTXJJkhVJ3jLF+g2TnNivPyfJ9qOsR5LGycgCOMl6wEeAvYHdgBck2W1Ss1cAv6yqBwEfAN49qnokadyMcg/4D4EVVfXjqvoN8Blgv0lt9gOO7S9/HtgrSUZYkySNjVEG8LbAZQPXL++XTdmmqm4DVgFbTN5QklcnWZZk2W2/Wj2iciVpbs2Lg3BVdVRV7VFVeyy55yaty5GkdWKUAXwFcP+B69v1y6Zsk2QJsClw7QhrkqSxMcoAPg/YOckOSTYADgJOntTmZODg/vL+wBlVVSOsSZLGxpJRbbiqbkvyeuCrwHrAJ6rq4iTvBJZV1cnAx4FPJVkBXEcX0pK0KIwsgAGq6lTg1EnLDhu4/Gvg+aOsQZLG1bw4CCdJC5EBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1MgaB3CSeyd5xCiKkaTFZKgATvLNJPdKsjnwXeBjSd4/2tIkaWEbdg9406q6AXgu8MmqeizwtNGVJUkL37ABvCTJNsABwJdHWI8kLRrDBvA7ga8CP6qq85LsCPxwdGVJ0sK3ZJhGVfU54HMD138MPG9URUnSYjDsQbgHJ/lGkuX99Uck+avRliZJC9uwQxAfA94K3ApQVRcBB42qKElaDIYN4HtW1bmTlt22rouRpMVkqDFgYGWSnYACSLI/cNXIqprBQ7bbgmXvfWmLu5akdWrYAH4dcBSwa5IrgJ8ALx5ZVZK0CAx7FsSPgacl2Qi4W1WtHm1ZkrTwzRjASV5cVccl+fNJywGoKv8dWZLW0mx7wBv1vzcZdSGStNjMGMBV9c9J1gNuqKoPzFFNkrQozHoaWlXdDrxgDmqRpEVl2LMgzk7yYeBE4KaJhVX13ZFUJUmLwLAB/Kj+9zsHlhXw1HVajSQtIsOehvaUURciSYvNsJPxbJrk/UmW9T//kGTTURcnSQvZsHNBfAJYTTch+wHADcDRoypKkhaDYceAd6qqwfl/j0hywQjqkaRFY9g94JuTPHHiSpI/Am4eTUmStDgMuwf8GuDYftw3wHXAy0ZVlCQtBsOeBXEB8Mgk9+qv3zDKoiRpMRgqgKeZjGcVcH4fzpKkNTTsGPAewCHAtv3P/wGWAh9L8uYR1SZJC9qwY8DbAbtX1Y0ASd4BnAI8GTgfeM9oypOkhWvYPeCtgFsGrt8KbF1VN09aLkka0rB7wMcD5yT5Yn99X+DT/Tdk/NdIKpOkBW7YsyD+JslpwB/1iw6pqmX95ReNpDJJWuCGHYIAuDvdxOz/CPw0yQ4jqkmSFoVhJ+N5B3Ao8NZ+0frAcaMqSpIWg2H3gJ8D/G/6ydir6kr8njhJ+r0MG8C/qaqim4Sd/uCbJOn3MGwAfzbJPwObJXkV8HXg/4+uLEla+IY9C+J9SZ5ONw/wLsBhVXX6SCuTpAVu2Lkg3l1VhwKnT7FMkrQWhh2CePoUy/Zel4VI0mIz4x5wktcArwV2THLRwKpNgLNHWdh0fnPVxfzsnQ9vcdeSFrEHHPa9db7N2YYgPg2cBvwd8JaB5aur6rp1Xo0kLSIzBnBVraKb9/cFAEm2ovuPuI2TbFxVPxt9iZK0MA37n3D7Jvkh8BPgW8CldHvGkqS1NOxBuP8HPA74QVXtAOwFfGdkVUnSIjBsAN9aVdcCd0tyt6o6k+5bMiRJa2nY+YCvT7IxcBZwfJJr6OeFkCStndlOQ3sQsDWwH3Az8Ea6+X8fCPzpyKuTpAVstiGID9LNAXxTVf22qm6rqmOBk4DDR12cJC1kswXw1lV1l7OP+2Xbj6QiSVokZgvgzWZYd491WIckLTqzBfCyfvrJO0nySrqvo5ckraXZzoJ4A3BSkhdxR+DuAWxA9y0ZkqS1NNu/Il8NPCHJU4CH9YtPqaozRl6ZJC1ww07IfiZw5ohrkaRFZU2+ll6StA4ZwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUyMgCOMknklyTZPk065PkQ0lWJLkoye6jqkWSxtEo94CPAZbOsH5vYOf+59XAR0dYiySNnZEFcFWdBVw3Q5P9gE9W5zvAZkm2GVU9kjRuWo4BbwtcNnD98n7ZXSR5dZJlSZZdd9Ptc1KcJI3avDgIV1VHVdUeVbXH5hut17ocSVonWgbwFcD9B65v1y+TpEWhZQCfDLy0PxviccCqqrqqYT2SNKeWjGrDSU4A9gS2THI58A5gfYCqOhI4FdgHWAH8Cnj5qGqRpHE0sgCuqhfMsr6A143q/iVp3M2Lg3CStBAZwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0saV3Amtpgm4fygMOWtS5Dkn5v7gFLUiMGsCQ1YgBLUiMGsCQ1YgBLUiMGsCQ1YgBLUiMGsCQ1YgBLUiMGsCQ1YgBLUiMGsCQ1YgBLUiMGsCQ1YgBLUiOpqtY1rJEkq4FLWtcxjS2Bla2LmMK41gXWtrbGtbZxrQva1rayqpZOXjjvJmQHLqmqPVoXMZUky8axtnGtC6xtbY1rbeNaF4xnbQ5BSFIjBrAkNTIfA/io1gXMYFxrG9e6wNrW1rjWNq51wRjWNu8OwknSQjEf94AlaUEwgCWpkXkVwEmWJrkkyYokb2lcy6VJvpfkgiTL+mWbJzk9yQ/73/eeo1o+keSaJMsHlk1ZSzof6vvwoiS7N6jt8CRX9H13QZJ9Bta9ta/tkiTPHGFd909yZpL/SnJxkj/rlzfvtxlqG4d+u3uSc5Nc2Nd2RL98hyTn9DWcmGSDfvmG/fUV/frt57iuY5L8ZKDPHtUvn9P3wbSqal78AOsBPwJ2BDYALgR2a1jPpcCWk5a9B3hLf/ktwLvnqJYnA7sDy2erBdgHOA0I8DjgnAa1HQ68aYq2u/XP64bADv3zvd6I6toG2L2/vAnwg/7+m/fbDLWNQ78F2Li/vD5wTt8fnwUO6pcfCbymv/xa4Mj+8kHAiXNc1zHA/lO0n9P3wXQ/82kP+A+BFVX146r6DfAZYL/GNU22H3Bsf/lY4NlzcadVdRZw3ZC17Ad8sjrfATZLss0c1zad/YDPVNUtVfUTYAXd8z6Kuq6qqu/2l1cD/w1syxj02wy1TWcu+62q6sb+6vr9TwFPBT7fL5/cbxP9+XlgrySZw7qmM6fvg+nMpwDeFrhs4PrlzPyiHLUCvpbk/CSv7pdtXVVX9Zd/DmzdprQZaxmXfnx9/9HvEwNDNU1q6z8W/wHdXtNY9duk2mAM+i3JekkuAK4BTqfb476+qm6b4v5/V1u/fhWwxVzUVVUTffauvs8+kGTDyXVNUfOcmU8BPG6eWFW7A3sDr0vy5MGV1X3OGYtz/Maplt5HgZ2ARwFXAf/QqpAkGwP/Aryhqm4YXNe636aobSz6rapur6pHAdvR7Wnv2qKOySbXleRhwFvp6nsMsDlwaLsK72o+BfAVwP0Hrm/XL2uiqq7of18DnET3Qrx64mNM//uaVvXNUEvzfqyqq/s3y2+Bj3HHx+U5rS3J+nQBd3xVfaFfPBb9NlVt49JvE6rqeuBM4PF0H+En5pYZvP/f1dav3xS4do7qWtoP51RV3QIcTeM+m2w+BfB5wM790dYN6Ab0T25RSJKNkmwycRl4BrC8r+fgvtnBwBdb1NebrpaTgZf2R4EfB6wa+Mg9JyaNtT2Hru8majuoP3K+A7AzcO6IagjwceC/q+r9A6ua99t0tY1Jv90nyWb95XsAT6cboz4T2L9vNrnfJvpzf+CM/pPFXNT1/YE/pqEblx7ss6bvA2D+nAVRdxy5/AHdmNPbG9axI91R5wuBiydqoRvb+gbwQ+DrwOZzVM8JdB9Jb6Uby3rFdLXQHfX9SN+H3wP2aFDbp/r7vojujbDNQPu397VdAuw9wrqeSDe8cBFwQf+zzzj02wy1jUO/PQL4z76G5cBhA++Jc+kOAH4O2LBffvf++op+/Y5zXNcZfZ8tB47jjjMl5vR9MN2P/4osSY3MpyEISVpQDGBJasQAlqRGDGBJasQAlqRGDGAtGEnum+QzSX7U/4v4qUkevA63v2eSJ6yr7UkGsBaE/kT7k4BvVtVOVfVoun9DXZfzcewJGMBaZwxgLRRPAW6tqiMnFlTVhcC3k7w3yfJ08zcfCL/bm/3yRNskH07ysv7ypUmOSPLd/ja79pPiHAK8sZ9X9klJnt9v98IkZ83lg9XCsGT2JtK88DDg/CmWP5du8ppHAlsC5w0Zliuravckr6Wbg/eVSY4Ebqyq9wEk+R7wzKq6YuLfYKU14R6wFronAidUN4nN1cC36GbGms3E5DznA9tP0+Zs4Jgkr6L7wgBpjRjAWiguBh69Bu1v486v/7tPWn9L//t2pvmkWFWHAH9FN6vW+UlGMs+tFi4DWAvFGcCGA5Pjk+QRwPXAgf1k3feh+4qkc4GfArv1M4htBuw1xH2spvuKoInt71RV51TVYcAvuPP0htKsHAPWglBVleQ5wAeTHAr8mu57+94AbEw3c10Bb66qnwMk+SzdLFk/oZtJazZfAj6fZD/gT+kOyO1MN7PWN/r7kIbmbGiS1IhDEJLUiAEsSY0YwJLUiAEsSY0YwJLUiAEsSY0YwJLUyP8Ae+Voo7ONYdoAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "data.plot('target', subjects='train')\n", "data.plot('target', subjects='test')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Running ML\n", "\n", "Next, we are going to evaluate some different machine learning models on the problem we have defined. Notably not much has changed here with respect to the old version, except some cosmetic changes like Problem_Spec to ProblemSpec. Also we have a new `evaluate` function instead of calling Evaluate via the ML object (`ML.Evaluate`)." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "ProblemSpec(n_jobs=16, scorer=['matthews', 'roc_auc', 'balanced_accuracy'],\n", " subjects='train')" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# This just holds some commonly used values\n", "ps = bp.ProblemSpec(subjects='train',\n", " scorer=['matthews', 'roc_auc', 'balanced_accuracy'],\n", " n_jobs=16)\n", "ps" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ModelPipeline\n", "-------------\n", "imputers=\\\n", "[Imputer(obj='mean', scope='float'),\n", " Imputer(obj='median', scope='category')]\n", "\n", "scalers=\\\n", "Scaler(obj='standard')\n", "\n", "model=\\\n", "Model(obj='elastic', params=1)\n", "\n", "param_search=\\\n", "ParamSearch(cv=CV(cv_strategy=CVStrategy()), n_iter=64,\n", " search_type='DiscreteOnePlusOne')\n", "\n" ] } ], "source": [ "# Define a ModelPipeline to use with imputation, scaling and an elastic net\n", "pipe = bp.ModelPipeline(imputers=[bp.Imputer(obj='mean', scope='float'),\n", " bp.Imputer(obj='median', scope='category')],\n", " scalers=bp.Scaler('standard'), \n", " model=bp.Model('elastic', params=1),\n", " param_search=bp.ParamSearch(\n", " search_type='DiscreteOnePlusOne', n_iter=64))\n", "pipe.print_all()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3-fold CV random splits\n", "\n", "Let's start by evaluating this model is a fairly naive way, just 3 folds of CV." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicting target = Dependent any drug\n", "Using problem_type = binary\n", "Using scope = all (defining a total of 163 features).\n", "Evaluating 2379 total data points.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "096378b8b9d045239c0e5fce543c3999", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Folds: 0%| | 0/3 [00:00" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Site: 793 rows\n", "Dependent any drug: 793 rows\n", "Plotting 793 overlap valid subjects.\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "data.plot_bivar('Site', 'Dependent any drug', subjects=results.train_subjects[0])\n", "data.plot_bivar('Site', 'Dependent any drug', subjects=results.val_subjects[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Preserving Groups by Site\n", "\n", "We can see that for example the 149 subjects from site 1 in the validation set had 269 subjects also from site 1 to potentially memorize site effects from! The confound is a site with only cases. One way we can account for this is through something we've already hinted at, using the group site cv from earlier. We can create a new CV object with this attribute." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CV(cv_strategy=CVStrategy(groups='Site'))" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "site_cv = bp.CV(splits=3, cv_strategy=group_site)\n", "site_cv" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5dbfc07f21b24324a748ba0e1998d195", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Folds: 0%| | 0/3 [00:00" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Site: 908 rows\n", "Dependent any drug: 908 rows\n", "Plotting 908 overlap valid subjects.\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "data.plot_bivar('Site', 'Dependent any drug', subjects=results.train_subjects[0])\n", "data.plot_bivar('Site', 'Dependent any drug', subjects=results.val_subjects[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another piece we might want to look at in this case is the weighted_mean_scores. That is, each of the splits has train and test sets of slightly different sizes, so how does the metric change if we weight it by the number of validation subjects in that fold" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'matthews': 0.2962912181056522,\n", " 'roc_auc': 0.731185680809635,\n", " 'balanced_accuracy': 0.6545156761514691}" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "results.weighted_mean_scores" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c1125e8677974cdb9d85d73fd28d5bea", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Folds: 0%| | 0/3 [00:00" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Site: 30 rows\n", "Dependent any drug: 30 rows\n", "Plotting 30 overlap valid subjects.\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWgAAAEGCAYAAABIGw//AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAUDUlEQVR4nO3de5SlVX3m8e9T3eAwQAgXaToggREI0VxaBGIiQ0gYO2o0IF4SEgUdls3KBLzGBAwJoEPihBEnJgZskAVe0qJBFj3gIKSFkCZBukWUa5A4LYE0tIkEaUOEht/8cV7wTFld51R1VZ3dXd/PWu+qU+9ln11Q6+ldv3e/+6SqkCS1Z2zUHZAkTcyAlqRGGdCS1CgDWpIaZUBLUqMWjroDm7PDvsc7vUQ/4PH7zx51F9Skg7KlLUwlcx6/f8UWv98wHEFLUqOaHUFL0lxK2huvGtCSBIylvThsr0eSNAKOoCWpUcmc3PebEgNakoAW50wY0JKEJQ5JapYBLUmNchaHJDXKEbQkNcqAlqRGBafZSVKTHEFLUqPGxtqLw/Z6JEkj4QhakppkiUOSGmVAS1KjYolDktrkCFqSGjU2tmDUXfgBBrQkYYlDkppliUOSGmVAS1KjLHFIUqPio96S1CY/NFaSGtViiaO9HknSCCRjQ2+Tt5PnJbk+yV1J7kzy9m7/WUkeTHJbt71yUJ8cQUsSwMyVODYB766qW5PsDHw5yXXdsQ9V1f8ctiEDWpJgxuoJVbUeWN+9fizJ3cDeI+ySJG3lxsaG3pIsS7K2b1s2UZNJ9gNeBHyp23VKkq8luTjJrgO7NHM/nSRtxcaG36pqeVUd2rctH99ckp2Ay4F3VNV3gPOB5wNL6I2wPzioS5Y4JAmoGZxml2Q7euH8qar6HEBVPdx3/ELgqkHtOIKWJIBMYZusmd6E6o8Bd1fVeX37F/ed9hrgjkFdcgQtSQBjMzaCfinwJuD2JLd1+94LHJ9kCVDAOuDkQQ0Z0JIEMzbNrqpWM/E4+/NTbcuAliSABT7qLUltci0OSWpUe/lsQEsSMJM3CWeMAS1J4AhaklpVC9p7LMSAliRwBC1JzXIWhyQ1ypuEktSo9vLZgJYkwBKHJDXLR70lqVGOoCWpUe3lswEtSQDlLA5JapQlDklqVHv5bEBLEgCuxSFJjXIELUmN8iahJDXKgJakNlV7+WxASxLgTUJJapYlDklqVHsDaANakgCfJJSkZlnikKQ2lSNoSWrUQgNaktrU4Ai6wfuWkjQCYxl+m0SS5yW5PsldSe5M8vZu/25Jrkvy9e7rrgO7NEM/miRt3TKFbXKbgHdX1QuAlwC/leQFwGnAqqo6EFjVfT8pA1qS6H2iyrDbpO1Ura+qW7vXjwF3A3sDxwCXdqddChw7qE8GtCTBlEocSZYlWdu3LZuoyST7AS8CvgQsqqr13aGHgEWDuuRNQkkCWDD8TcKqWg4sn+ycJDsBlwPvqKrvpO8mZFVVkhr0Po6gJQl6sziG3QY2le3ohfOnqupz3e6Hkyzuji8GNgxqx4CWJJjJWRwBPgbcXVXn9R1aCZzYvT4RuHJQlyxxSBLM5KPeLwXeBNye5LZu33uBDwCfSXIS8E3gDYMaMqAliZl71LuqVrP5yXhHT6UtA1qSYEo3CeeKAS1J4Gp2ktQsA1qSGtVePhvQkgQMfIR7FAxoSYImlxs1oCUJnMUhSa0aa/C5agNakmiywmFASxIY0JLUrDSY0Aa0JGENWpKaFQNaktrUYIXDgJYkaHIpDgNaksARtCQ1y4CWpEaN+ai3JLXJEbQkNcqAlqRGGdCS1Cin2UlSoxxBS1KjnMUhSY1yBC1JjWoxoIdavyk9b0zyB933+yY5fHa7JklzJxl+myvDLrD358DPAsd33z8GfGRWeiRJIzCW4be5MmyJ42eq6pAkXwGoqkeSbD+L/ZKkOTW2YNQ9+EHDBvSTSRYABZDkucDTs9areWyfxbtx0Yf+G3s+dxeq4OK/WMVHLr6Gn/zxffnTPzyJHXf8D3zzgW/xlrd9hMc2Pj7q7moETj/9T7jhhjXsvvsuXHWVf8jOlK22Bg18GLgC2DPJOcBq4I9mrVfz2Kannua0//5JDjn6Pfz8Mb/PyScs5eAD9+b8P17GGR/4NIct/V1WXrOWd578qlF3VSNy3HFHc9FFZ426G9ucJENvQ7R1cZINSe7o23dWkgeT3NZtrxzUzlABXVWfAn6HXiivB46tqs8Mc62m5qEN/8ptd6wDYON3/5177nuQH9lrNw7YfzGrv3Q3AF/8m69x7Cu9RztfHXbYT7DLLjuPuhvbnBm+SXgJ8PIJ9n+oqpZ02+cHNTJUiSPJJ6rqTcA9E+zb3DUHA8cAe3e7HgRWVtXdw7ynYN999mDJC/djzVfu4+57H+DVSw/lf1+7luN++SXss3j3UXdP2qbMZImjqm5Mst+WtjNsieOF/d909egXb+7kJL8LfBoIcEu3BViR5LRJrluWZG2StZs23jdk17ZNO/7H57Dio+/kPWd/nMc2Ps7J7/koy054GTddfQ477bQDTzy5adRdlLYpUxlB92dVty0b8m1OSfK1rgSy66CTJx1BJzkdeC+wQ5Lv0AtZgCeA5ZNcehLwwqp6clx75wF3Ah+Y6KKqWv5Muzvse3wN6vy2auHCBaz46Du57IqbuPKaNQDc+w//xKvf2Cv7H7D/XrziF5eMsIfStmfhFD7Vuz+rpuB84P30Jlu8H/gg8F8nu2DSLlXVH1XVzsC5VfVDVbVzt+1eVadPcunTwI9MsH8xzv4Y6IJzl/H39/0TH77o+yWq5+7+Q0DvRsZpb3sNF35y1ai6J22TxlJDb9NRVQ9X1VNV9TRwITDwRtKgEfTBVXUP8Nkkh0zwhrdu5tJ3AKuSfB34x27fvsABwCmDOjWf/dxhP8ZvvPZIbr/7fm7+P70R85l/fBkH7L8XJ5+wFIArr7mFj3/mhhH2UqP0rnedyy233M4jj3yHI498M6ee+uu8/vVLR92trd5sP4CSZHFVre++fQ1wx2TnA6Rq8/8aJFleVcuSXN+3+9kLquoXJ7l2jN6/EP03CddU1VODOgXzu8ShzXv8/rNH3QU16aAtjtdfvnb10Jlz9dIjJn2/JCuAo4A9gIeBM7vvl9DL0HXAyX2BPaFBszguSrJXVf1C96YnAq/tGj9rsgu7YfzNA9qXpCZMt3Qxkao6foLdH5tqO4PK4hfQuyFIkiPpzYO+FHiUqRfI6dq5ajrXSdJs2hrX4lhQVd/uXv8qsLyqLgcuT3LbNN/zrdO8TpJmzcKt8FHvBUmeCfGjgS/2HZvWWtKDai6SNApJDb3NlUEBvQL46yRXAo8DfwOQ5AB6ZY4JJdkpyfuS3Jnk0STfSnJzkjfPVMclaSZtdSWOqjonySp685evre9P+RgDTp3k0k/RW1zpl4A3ADvSe7LwjCQHVdV7t7jnkjSDpvCcypwZWKaoqh+YiVFV9w64bL+quqR7fV6SNVX1/iRvAe6i93SiJDVjJmdxzJTZ+kfju0mOAEjyK8C34dmpdw2W4iXNdwsz/DZnfZqldn8TuDDJgfTW3jgJnl3o3xXGJTVnLmvLw5qVgK6qr3YPtewN3FxVG7v930oyqDwiSXNu3pQ4kryN3k3CU4A7khzTd/gPZ+M9JWlLbHWzOLbAW4FDq2pjt2j1XybZr6r+BGvQkhq0Vc7imKaxvrLGuiRH0QvpH8WAltSgeVPiAB5OsuSZb7qwfhW9lZ1+cpbeU5KmbeHY8Ntcma23OgF4qH9HVW2qqhOAI2fpPSVp2samsM2V2ZrF8cAkx26ajfeUpC3RYoljtmrQkrRVmTfzoCVpazOfZnFI0lbFEbQkNWrBmDVoSWqSJQ5JapSzOCSpUdagJalRBrQkNWo7SxyS1CZH0JLUKANakhq1wICWpDY5gpakRjkPWpIatV2DI+gWn26UpDk3kx8am+TiJBuS3NG3b7ck1yX5evd114F92rIfSZK2DWOpobchXAK8fNy+04BVVXUgsKr7fvI+TfWHkKRt0YIMvw1SVTcC3x63+xjg0u71pcCxg9oxoCWJqZU4kixLsrZvWzbEWyyqqvXd64eARYMu8CahJDG1T+uuquXA8um+V1VVMrhWYkBLErBg9qfZPZxkcVWtT7IY2DDoAksckkQvDIfdpmklcGL3+kTgykEXOIKWJGb2ScIkK4CjgD2SPACcCXwA+EySk4BvAm8Y1I4BLUnMbEBX1fGbOXT0VNoxoCWJOalBT5kBLUlMbRbHXDGgJQlXs5OkZrketCQ1yuVGJalRDZagDWhJAmvQktSs7cYscUhSkxxBS1KjDGhJapQ3CSWpUXEELUltssQhSY2yxCFJjRriE6jmnAEtSUCDFQ4DWpLAm4SS1KwG89mAliRwuVFJapYlDklqVIP5bEBLEhjQktQsnySUpEY1mM8GtCSBn0koSc1yFockNcrFkiSpUY6gJalRDeazAS1J4DQ7SWrWTAZ0knXAY8BTwKaqOnQ67RjQksSslDh+oar+eUsaMKAliTY/UaXFmSWSNOcylS1ZlmRt37ZsXHMFXJvkyxMcG5ojaEliatPsqmo5sHySU46oqgeT7Alcl+Seqrpxqn1yBC1JwIIpbINU1YPd1w3AFcDh0+mTAS1J9EbQw26Tt5Mdk+z8zGtgKXDHdPpkiUOSgBmcx7EIuCK9JF8I/EVVXTOdhgxoSQIyQwFdVd8Afnom2jKgJQlI2qv4GtCSBLS4GocBLUlAGpwzYUBLEpY4JKlhljgkqUkzNYtjJhnQkoQBLUnNSoZ5iHtuGdCSBFiDlqRGWeKQpGY5zU6SmuQIWpIalams2D9HDGhJAjLUUvxzy4CWJMBZHJLUKEscktQsA1qSmuRyo5LULEfQktSkMdeDlqRWGdCS1CSfJJSkZhnQktQk50FLUqNafNQ7VTXqPmiAJMuqavmo+6G2+Hux7WvvtqUmsmzUHVCT/L3YxhnQktQoA1qSGmVAbx2sM2oi/l5s47xJKEmNcgQtSY0yoCWpUQZ0Q5K8PMnfJ7kvyWkTHH9Oksu6419Kst8Iuqk5lOTiJBuS3LGZ40ny4e534mtJDpnrPmr2GNCNSLIA+AjwCuAFwPFJXjDutJOAR6rqAOBDwP+Y215qBC4BXj7J8VcAB3bbMuD8OeiT5ogB3Y7Dgfuq6htV9QTwaeCYceccA1zavf5L4Oi0uICAZkxV3Qh8e5JTjgE+Xj03Az+cZPHc9E6zzYBux97AP/Z9/0C3b8JzqmoT8Ciw+5z0Tq0a5vdGWykDWpIaZUC340HgeX3f79Ptm/CcJAuBXYB/mZPeqVXD/N5oK2VAt2MNcGCS/ZNsD/wasHLcOSuBE7vXrwO+WD5pNN+tBE7oZnO8BHi0qtaPulOaGa4H3Yiq2pTkFOALwALg4qq6M8n7gLVVtRL4GPCJJPfRu3H0a6PrseZCkhXAUcAeSR4AzgS2A6iqC4DPA68E7gP+DXjLaHqq2eCj3pLUKEscktQoA1qSGmVAS1KjDGhJapQBLUmNMqDngSRPJbktyZ1Jvprk3UlG9v8+yboke0zz2mMnWERqziQ5K8lvj+r9Nb8Y0PPD41W1pKpeCLyM3gpoZ464T9N1LL3V/prSPdkpzSgDep6pqg30lqU8pXv6bEGSc5Os6dYTPhkgyVFJbkxydbdG9QXPjLqTLE3yd0luTfLZJDt1+9clObvbf3uSg7v9uye5thvBXwQ8uwJfkjcmuaUb4X+0W3aVJBuTnNON+G9OsijJzwG/Apzbnf/8/p8tyau7dbK/kuSvkizq9p/Vrat8Q5JvJHlbt/99Sd7Rd/05Sd4+/r9Zkt9Lcm+S1cCP9e2/Icn/SrIWeHuSS5K8ru/4xu7rWJI/T3JPkuuSfL7/PGlzDOh5qKq+Qe9pxT3prTH9aFUdBhwGvDXJ/t2phwOn0huxPh84ritNnAH8l6o6BFgLvKuv+X/u9p8PPFMKOBNY3Y3grwD2BUjy48CvAi+tqiXAU8BvdNfsCNxcVT8N3Ai8tar+lt6jze/p/iL4h3E/2mrgJVX1InrLtf5O37GDgV/qfqYzk2wHXAyc0PVljN6TmZ/sbzDJi7v9S+g9sXfYuPfcvqoOraoPsnnHAfvR++/4JuBnJzlXepZ/lmkp8FN9I7pd6C3+/gRwSxfmzzxyfATw7/SC5qZuKertgb/ra+9z3dcv0wsmgCOfeV1VVyd5pNt/NPBiYE3X1g7Ahu7YE8BVfW29bIifZR/gsm495O2B/9t37Oqq+h7wvSQbgEVVtS7JvyR5EbAI+EpVjV986j8DV1TVv3X/Hcavj3LZEP06AvhsVT0NPJTk+iGukQzo+SjJf6I3Wt1Ar9xwalV9Ydw5RwHj1wGo7vzrqur4zTT/ve7rUwz+/QpwaVWdPsGxJ/sWghqmLYA/Bc6rqpVd/8+aoF/j27sIeDOwF70R9VR9t+/1Jrq/SrsR+fbTaE96liWOeSbJc4ELgD/rAvALwG92f/KT5KAkO3anH96trjdGrxSxGrgZeGmSA7rzd0xy0IC3vRH49e78VwC7dvtXAa9Lsmd3bLckPzqgrceAnTdzbBe+v9TmiZs5Z7wr6H2k1GH0/ltM1Pdjk+yQZGfg1ZO0tY7eXwTQq5Vv172+CXhtV4teRG/xI2kgR9Dzww5JbqMXGJuATwDndccuolcfvTW9OsO36M2UgN4SqH8GHABcT+9P/aeTvBlYkeQ53XlnAPdO8v5nd+ffCfwtcD9AVd2V5Azg2u4fgSeB3wK+OUlbnwYu7G70vW5cHfos4LNdCeWLwP4TXP//qaonupLDv1bVUxMcvzXJZcBX6f3FsWaS5i4ErkzyVeAavj+6vpxeOecuep9+ciu9T8ORJuVqdppQVyL47ap61Yi7Mqu6fxhuBV5fVV+fxffZqao2JtkduIXejdGHZuv9tG1wBK15K70HXq6i95fBrIVz56okP0yvLv1+w1nDcAQtSY3yJqEkNcqAlqRGGdCS1CgDWpIaZUBLUqP+H2nw8thw4q0+AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
predictpredict_probadecision_functiony_true
34580.00.810241-1.4515800.0
34590.00.902262-2.2226180.0
34620.00.976899-3.7445220.0
34700.00.954009-3.0322170.0
34780.00.964846-3.3122430.0
34790.00.900526-2.2030870.0
34820.00.946305-2.8692530.0
34840.00.681535-0.7608370.0
34850.00.924836-2.5099490.0
34860.00.957875-3.1240760.0
34940.00.870815-1.9081850.0
34970.00.959444-3.1636630.0
34980.00.877308-1.9671850.0
35020.00.892309-2.1145500.0
35030.00.946427-2.8716501.0
35070.00.968260-3.4179210.0
35090.00.961510-3.2181030.0
35120.00.907378-2.2820280.0
35170.00.976704-3.7359070.0
35210.00.665222-0.6866550.0
35230.00.699328-0.8441000.0
34230.00.951770-2.9823370.0
34290.00.603810-0.4213640.0
34340.00.955454-3.0656610.0
34370.00.934065-2.6508730.0
34420.00.912519-2.3447820.0
34430.00.927844-2.5540320.0
34470.00.916637-2.3975060.0
34530.00.945561-2.8546910.0
34550.00.745451-1.0744960.0
\n", "
" ], "text/plain": [ " predict predict_proba decision_function y_true\n", "3458 0.0 0.810241 -1.451580 0.0\n", "3459 0.0 0.902262 -2.222618 0.0\n", "3462 0.0 0.976899 -3.744522 0.0\n", "3470 0.0 0.954009 -3.032217 0.0\n", "3478 0.0 0.964846 -3.312243 0.0\n", "3479 0.0 0.900526 -2.203087 0.0\n", "3482 0.0 0.946305 -2.869253 0.0\n", "3484 0.0 0.681535 -0.760837 0.0\n", "3485 0.0 0.924836 -2.509949 0.0\n", "3486 0.0 0.957875 -3.124076 0.0\n", "3494 0.0 0.870815 -1.908185 0.0\n", "3497 0.0 0.959444 -3.163663 0.0\n", "3498 0.0 0.877308 -1.967185 0.0\n", "3502 0.0 0.892309 -2.114550 0.0\n", "3503 0.0 0.946427 -2.871650 1.0\n", "3507 0.0 0.968260 -3.417921 0.0\n", "3509 0.0 0.961510 -3.218103 0.0\n", "3512 0.0 0.907378 -2.282028 0.0\n", "3517 0.0 0.976704 -3.735907 0.0\n", "3521 0.0 0.665222 -0.686655 0.0\n", "3523 0.0 0.699328 -0.844100 0.0\n", "3423 0.0 0.951770 -2.982337 0.0\n", "3429 0.0 0.603810 -0.421364 0.0\n", "3434 0.0 0.955454 -3.065661 0.0\n", "3437 0.0 0.934065 -2.650873 0.0\n", "3442 0.0 0.912519 -2.344782 0.0\n", "3443 0.0 0.927844 -2.554032 0.0\n", "3447 0.0 0.916637 -2.397506 0.0\n", "3453 0.0 0.945561 -2.854691 0.0\n", "3455 0.0 0.745451 -1.074496 0.0" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The train and val subjects from fold 0\n", "tr_subjs = results.train_subjects[0]\n", "val_subjs = results.val_subjects[0]\n", "\n", "# An object specifying just site 29's subjects\n", "site29 = bp.ValueSubset('Site', 29, decode_values=True)\n", "\n", "# Get the intersection w/ intersection obj\n", "site_29_tr = bp.Intersection([tr_subjs, site29])\n", "site_29_val = bp.Intersection([val_subjs, site29])\n", "\n", "# Plot\n", "data.plot_bivar('Site', 'Dependent any drug', subjects=bp.Intersection([tr_subjs, site29]))\n", "data.plot_bivar('Site', 'Dependent any drug', subjects=bp.Intersection([val_subjs, site29]))\n", "\n", "# Grab just the subjects as actual index\n", "val_subjs = data.get_subjects(site_29_val)\n", "\n", "# Get a dataframe with the predictions made for just fold0\n", "preds_df_fold0 = results.get_preds_dfs()[0]\n", "\n", "# Let's see predictions for just these validation subjects\n", "preds_df_fold0.loc[val_subjs]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that it has learned to just predict 0 for every subject, which because of the imbalance is right most of the time, but really doesn't tell us if it has learned to predict site or substance dependence." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Drop imbalanced subjects\n", "The last option was just don't use the imbalanced subjects at all. We will do this by indexing only the balanced sites" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "# Note the ~ before the previously defined is_imbalanced\n", "balanced_sites = np.flatnonzero(~is_imbalanced)\n", "\n", "# We can specify just these subjects using the ValueSubset wrapper\n", "balanced_subjs = bp.ValueSubset('Site', values=balanced_sites, decode_values=False)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8d17d0d434dd452eb75b6b7019897c19", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Folds: 0%| | 0/3 [00:00