{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Predict Sex\n", "\n", "This notebook goes through a simple binary classification example, explaining general library functionality along the way.\n", "Within this notebook we make use of data downloaded from Release 2.0.1 of the the ABCD Study (https://abcdstudy.org/).\n", "This dataset is openly available to researchers (after signing a data use agreement) and is particularly well suited\n", "towards performing neuroimaging based ML given the large sample size of the study.\n", "\n", "Within this notebook we will be performing binary classification predicting sex assigned at birth from tabular ROI structural MRI data.\n", "\n", "## Load Data" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import BPt as bp\n", "import pandas as pd\n", "import os\n", "\n", "from warnings import simplefilter\n", "from sklearn.exceptions import ConvergenceWarning\n", "simplefilter(\"ignore\", category=ConvergenceWarning)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def load_from_rds(names, eventname='baseline_year_1_arm_1'):\n", " \n", " data = pd.read_csv('data/nda_rds_201.csv',\n", " usecols=['src_subject_id', 'eventname'] + names,\n", " na_values=['777', 999, '999', 777])\n", " \n", " data = data.loc[data[data['eventname'] == eventname].index]\n", " data = data.set_index('src_subject_id')\n", " data = data.drop('eventname', axis=1)\n", " \n", " # Obsificate subject ID for public example\n", " data.index = list(range(len(data)))\n", " \n", " # Return as pandas DataFrame cast to BPt Dataset\n", " return bp.Dataset(data)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# This way we can look at all column available\n", "all_cols = list(pd.read_csv('data/nda_rds_201.csv', nrows=0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can search through all column to find which columns we actually want to load. We will start with the brain imaging features." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['smri_thick_cort.destrieux_g.and.s.frontomargin.lh',\n", " 'smri_thick_cort.destrieux_g.and.s.occipital.inf.lh',\n", " 'smri_thick_cort.destrieux_g.and.s.paracentral.lh',\n", " 'smri_thick_cort.destrieux_g.and.s.subcentral.lh',\n", " 'smri_thick_cort.destrieux_g.and.s.transv.frontopol.lh',\n", " 'smri_thick_cort.destrieux_g.and.s.cingul.ant.lh',\n", " 'smri_thick_cort.destrieux_g.and.s.cingul.mid.ant.lh',\n", " 'smri_thick_cort.destrieux_g.and.s.cingul.mid.post.lh',\n", " 'smri_thick_cort.destrieux_g.cingul.post.dorsal.lh',\n", " 'smri_thick_cort.destrieux_g.cingul.post.ventral.lh']" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "feat_keys = {'thick': 'smri_thick_cort.destrieux_g.',\n", " 'sulc': 'smri_sulc_cort.destrieux_g.',\n", " 'area': 'smri_area_cort.destrieux_g.',\n", " 'subcort': 'smri_vol_subcort.aseg_'}\n", "\n", "feat_cols = {key: [c for c in all_cols if feat_keys[key] in c] for key in feat_keys}\n", "all_cols = sum(feat_cols.values(), [])\n", "\n", "# For example\n", "feat_cols['thick'][:10]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also need our target variable, in this case sex.\n", "\n", "Let's load household income too as a non input, i.e., a variable we won't use directly as input." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "target = 'sex'\n", "non_inputs = ['household.income']" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/sage/anaconda3/envs/bpt/lib/python3.9/site-packages/IPython/core/interactiveshell.py:3338: DtypeWarning: Columns (63641) have mixed types.Specify dtype option on import or set low_memory=False.\n", " if (await self.run_code(code, result, async_=asy)):\n" ] }, { "data": { "text/html": [ "
\n", "

Data

\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
household.incomesexsmri_area_cort.destrieux_g.and.s.cingul.ant.lhsmri_area_cort.destrieux_g.and.s.cingul.ant.rhsmri_area_cort.destrieux_g.and.s.cingul.mid.ant.lhsmri_area_cort.destrieux_g.and.s.cingul.mid.ant.rhsmri_area_cort.destrieux_g.and.s.cingul.mid.post.lhsmri_area_cort.destrieux_g.and.s.cingul.mid.post.rhsmri_area_cort.destrieux_g.and.s.frontomargin.lhsmri_area_cort.destrieux_g.and.s.frontomargin.rh...smri_vol_subcort.aseg_subcorticalgrayvolumesmri_vol_subcort.aseg_supratentorialvolumesmri_vol_subcort.aseg_thalamus.proper.lhsmri_vol_subcort.aseg_thalamus.proper.rhsmri_vol_subcort.aseg_ventraldc.lhsmri_vol_subcort.aseg_ventraldc.rhsmri_vol_subcort.aseg_wholebrainsmri_vol_subcort.aseg_wm.hypointensitiessmri_vol_subcort.aseg_wm.hypointensities.lhsmri_vol_subcort.aseg_wm.hypointensities.rh
0[>=50K & <100K]F1540.01921.01237.01211.0939.01022.0872.0596.0...54112.09.738411e+056980.46806.63448.13372.71.099494e+062201.90.00.0
1NaNFNaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
2[>=100K]M2108.02583.01289.01295.01066.01328.0907.0843.0...71188.01.290405e+069091.38105.35058.55261.61.444690e+061254.80.00.0
3[>=100K]M2196.02266.01012.01459.01326.01398.0944.0924.0...61985.01.283405e+067470.77278.43924.83983.61.421171e+06950.90.00.0
4[<50K]M1732.01936.01024.0916.0900.01002.0863.0730.0...61855.01.072113e+068152.97436.84085.24129.31.186497e+06789.90.00.0
..................................................................
11870[>=100K]M1583.01821.0730.01040.0709.0872.0938.0745.0...59550.01.001272e+067993.57239.53899.44024.61.139532e+06651.70.00.0
11871[>=100K]F1603.01841.0899.01091.0990.0995.0809.0666.0...61090.09.897016e+057113.56835.34029.63826.01.134203e+062304.80.00.0
11872[>=100K]F1862.02245.01406.01502.0882.01279.01105.01015.0...64413.01.172208e+068123.07947.63893.04428.51.301402e+061654.00.00.0
11873[<50K]F1803.01888.0967.01101.0866.01128.01040.0642.0...55505.01.040864e+066923.86459.83502.53674.01.150473e+061209.50.00.0
11874[>=50K & <100K]F1957.01998.01142.01102.01144.01130.0884.0638.0...59816.01.079887e+067665.45959.43736.04060.71.214126e+061321.10.00.0
\n", "

11875 rows × 274 columns

\n", "
\n" ], "text/plain": [ " smri_thick_cort.destrieux_g.and.s.frontomargin.lh \\\n", "0 2.643 \n", "1 NaN \n", "2 2.798 \n", "3 2.570 \n", "4 2.589 \n", "... ... \n", "11870 2.604 \n", "11871 2.665 \n", "11872 2.517 \n", "11873 2.806 \n", "11874 2.817 \n", "\n", " smri_thick_cort.destrieux_g.and.s.occipital.inf.lh \\\n", "0 2.597 \n", "1 NaN \n", "2 2.635 \n", "3 3.008 \n", "4 2.495 \n", "... ... \n", "11870 2.839 \n", "11871 2.915 \n", "11872 2.743 \n", "11873 2.835 \n", "11874 2.267 \n", "\n", " smri_thick_cort.destrieux_g.and.s.paracentral.lh \\\n", "0 2.682 \n", "1 NaN \n", "2 2.620 \n", "3 2.771 \n", "4 2.732 \n", "... ... \n", "11870 2.642 \n", "11871 2.661 \n", "11872 2.607 \n", "11873 2.678 \n", "11874 2.639 \n", "\n", " smri_thick_cort.destrieux_g.and.s.subcentral.lh \\\n", "0 3.016 \n", "1 NaN \n", "2 2.963 \n", "3 3.116 \n", "4 2.982 \n", "... ... \n", "11870 3.017 \n", "11871 3.114 \n", "11872 3.210 \n", "11873 3.344 \n", "11874 2.805 \n", "\n", " smri_thick_cort.destrieux_g.and.s.transv.frontopol.lh \\\n", "0 2.776 \n", "1 NaN \n", "2 3.038 \n", "3 2.753 \n", "4 2.979 \n", "... ... \n", "11870 2.990 \n", "11871 2.968 \n", "11872 2.847 \n", "11873 2.975 \n", "11874 3.041 \n", "\n", " smri_thick_cort.destrieux_g.and.s.cingul.ant.lh \\\n", "0 3.012 \n", "1 NaN \n", "2 2.948 \n", "3 3.137 \n", "4 2.953 \n", "... ... \n", "11870 3.119 \n", "11871 3.167 \n", "11872 2.954 \n", "11873 3.134 \n", "11874 2.867 \n", "\n", " smri_thick_cort.destrieux_g.and.s.cingul.mid.ant.lh \\\n", "0 2.894 \n", "1 NaN \n", "2 2.966 \n", "3 3.222 \n", "4 2.732 \n", "... ... \n", "11870 3.014 \n", "11871 3.058 \n", "11872 2.965 \n", "11873 3.425 \n", "11874 2.906 \n", "\n", " smri_thick_cort.destrieux_g.and.s.cingul.mid.post.lh \\\n", "0 2.874 \n", "1 NaN \n", "2 2.728 \n", "3 3.062 \n", "4 2.819 \n", "... ... \n", "11870 2.871 \n", "11871 2.976 \n", "11872 2.846 \n", "11873 3.251 \n", "11874 3.049 \n", "\n", " smri_thick_cort.destrieux_g.cingul.post.dorsal.lh \\\n", "0 2.865 \n", "1 NaN \n", "2 3.263 \n", "3 3.315 \n", "4 2.908 \n", "... ... \n", "11870 3.240 \n", "11871 3.355 \n", "11872 3.211 \n", "11873 3.288 \n", "11874 3.440 \n", "\n", " smri_thick_cort.destrieux_g.cingul.post.ventral.lh ... \\\n", "0 2.350 ... \n", "1 NaN ... \n", "2 1.882 ... \n", "3 3.065 ... \n", "4 2.967 ... \n", "... ... ... \n", "11870 2.254 ... \n", "11871 2.168 ... \n", "11872 2.741 ... \n", "11873 2.535 ... \n", "11874 2.123 ... \n", "\n", " smri_vol_subcort.aseg_cc.mid.anterior \\\n", "0 396.9 \n", "1 NaN \n", "2 336.7 \n", "3 432.3 \n", "4 398.6 \n", "... ... \n", "11870 366.9 \n", "11871 367.8 \n", "11872 472.6 \n", "11873 424.1 \n", "11874 417.5 \n", "\n", " smri_vol_subcort.aseg_cc.anterior smri_vol_subcort.aseg_wholebrain \\\n", "0 546.9 1.099494e+06 \n", "1 NaN NaN \n", "2 684.0 1.444690e+06 \n", "3 720.6 1.421171e+06 \n", "4 824.5 1.186497e+06 \n", "... ... ... \n", "11870 761.9 1.139532e+06 \n", "11871 609.3 1.134203e+06 \n", "11872 855.6 1.301402e+06 \n", "11873 691.2 1.150473e+06 \n", "11874 863.2 1.214126e+06 \n", "\n", " smri_vol_subcort.aseg_latventricles \\\n", "0 4693.2 \n", "1 NaN \n", "2 13426.2 \n", "3 8375.3 \n", "4 19138.9 \n", "... ... \n", "11870 11129.1 \n", "11871 2855.1 \n", "11872 8278.4 \n", "11873 6483.5 \n", "11874 9234.5 \n", "\n", " smri_vol_subcort.aseg_allventricles \\\n", "0 6299.4 \n", "1 NaN \n", "2 18810.3 \n", "3 11828.6 \n", "4 21191.9 \n", "... ... \n", "11870 14259.9 \n", "11871 4925.1 \n", "11872 10434.1 \n", "11873 8978.0 \n", "11874 11169.3 \n", "\n", " smri_vol_subcort.aseg_intracranialvolume \\\n", "0 1.354788e+06 \n", "1 NaN \n", "2 1.703982e+06 \n", "3 1.679526e+06 \n", "4 1.561216e+06 \n", "... ... \n", "11870 1.480336e+06 \n", "11871 1.470497e+06 \n", "11872 1.455727e+06 \n", "11873 1.480286e+06 \n", "11874 1.500072e+06 \n", "\n", " smri_vol_subcort.aseg_supratentorialvolume \\\n", "0 9.738411e+05 \n", "1 NaN \n", "2 1.290405e+06 \n", "3 1.283405e+06 \n", "4 1.072113e+06 \n", "... ... \n", "11870 1.001272e+06 \n", "11871 9.897016e+05 \n", "11872 1.172208e+06 \n", "11873 1.040864e+06 \n", "11874 1.079887e+06 \n", "\n", " smri_vol_subcort.aseg_subcorticalgrayvolume sex household.income \n", "0 54112.0 F [>=50K & <100K] \n", "1 NaN F NaN \n", "2 71188.0 M [>=100K] \n", "3 61985.0 M [>=100K] \n", "4 61855.0 M [<50K] \n", "... ... ... ... \n", "11870 59550.0 M [>=100K] \n", "11871 61090.0 F [>=100K] \n", "11872 64413.0 F [>=100K] \n", "11873 55505.0 F [<50K] \n", "11874 59816.0 F [>=50K & <100K] \n", "\n", "[11875 rows x 274 columns]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = load_from_rds(all_cols + [target] + non_inputs )\n", "data.verbose = 1\n", "data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we need to tell the dataset a few things about sex, namely that it is a binary variable, and that it is our target variable." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "

Target

\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sex
00
10
21
31
41
......
118701
118710
118720
118730
118740
\n", "

11875 rows × 1 columns

\n", "
\n" ], "text/plain": [ " sex\n", "0 0\n", "1 0\n", "2 1\n", "3 1\n", "4 1\n", "... ..\n", "11870 1\n", "11871 0\n", "11872 0\n", "11873 0\n", "11874 0\n", "\n", "[11875 rows x 1 columns]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.to_binary('sex', inplace=True)\n", "data.set_target('sex', inplace=True)\n", "data['target']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to do something simillar for household income, tell it that it is a categorical variable, and has role non input" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dropped 1018 Rows\n" ] } ], "source": [ "data = data.ordinalize('household.income').set_role('household.income', 'non input')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's look at some NaN info" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loaded NaN Info:\n", "There are: 83896 total missing values\n", "224 columns found with 304 missing values (column name overlap: ['cort.', 'smri_'])\n", "9 columns found with 305 missing values (column name overlap: ['cort.', 'smri_'])\n", "3 columns found with 408 missing values (column name overlap: ['smri_vol_subcort.aseg_', 'le'])\n", "3 columns found with 349 missing values (column name overlap: ['smri_area_cort.destrieux_g.', 'ar', 'd.'])\n", "3 columns found with 306 missing values (column name overlap: ['smri_area_cort.destrieux_g.'])\n", "2 columns found with 340 missing values (column name overlap: ['smri_area_cort.destrieux_g.', '.lh', '.s', 'an', 'l.'])\n", "2 columns found with 324 missing values (column name overlap: ['smri_area_cort.destrieux_g.and.s.cingul.mid.', 't.rh'])\n", "2 columns found with 314 missing values (column name overlap: ['smri_area_cort.destrieux_g.', 'temp.', '.lat'])\n", "2 columns found with 309 missing values (column name overlap: ['smri_vol_subcort.aseg_c'])\n", "2 columns found with 307 missing values (column name overlap: ['smri_area_cort.destrieux_g.'])\n", "2 columns found with 345 missing values (column name overlap: ['smri_area_cort.destrieux_g.', 'ngul.', 't.'])\n", "2 columns found with 347 missing values (column name overlap: ['smri_area_cort.destrieux_g.', '.rh', 'ar', 'nt', 'er'])\n", "\n" ] } ], "source": [ "data.nan_info()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What happens now if we drop any subjects with more than 1% of their loaded columns with NaN values" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Setting NaN threshold to: 2.74\n", "Dropped 408 Rows\n", "Loaded NaN Info:\n", "There are: 885 total missing values\n", "9 columns found with 1 missing values (column name overlap: ['cort.', 'smri_'])\n", "3 columns found with 2 missing values (column name overlap: ['smri_area_cort.destrieux_g.'])\n", "3 columns found with 38 missing values (column name overlap: ['smri_area_cort.destrieux_g.'])\n", "3 columns found with 3 missing values (column name overlap: [])\n", "3 columns found with 45 missing values (column name overlap: ['smri_area_cort.destrieux_g.', 'ar', 'd.'])\n", "2 columns found with 43 missing values (column name overlap: ['smri_area_cort.destrieux_g.', '.rh', 'ar', 'nt', 'er'])\n", "2 columns found with 41 missing values (column name overlap: ['smri_area_cort.destrieux_g.', 'ngul.', 't.'])\n", "2 columns found with 35 missing values (column name overlap: ['smri_area_cort.destrieux_g.', '.lh', '.s', 'an'])\n", "2 columns found with 5 missing values (column name overlap: ['smri_vol_subcort.aseg_c'])\n", "2 columns found with 10 missing values (column name overlap: ['smri_area_cort.destrieux_g.', 'temp.', '.lat'])\n", "\n" ] } ], "source": [ "data = data.drop_subjects_by_nan(threshold=.01)\n", "data.nan_info()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That greatly reduces the number of remaining missing values we have. Next, let's consider outlier filtering as..." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "smri_thick_cort.destrieux_g.front.middle.rh -1.492846\n", "smri_thick_cort.destrieux_g.temporal.middle.lh -1.311317\n", "smri_thick_cort.destrieux_g.pariet.inf.angular.rh -1.284709\n", "smri_thick_cort.destrieux_g.precentral.rh -1.269309\n", "smri_thick_cort.destrieux_g.temporal.middle.rh -1.238609\n", " ... \n", "smri_vol_subcort.aseg_latventricles 4.022606\n", "smri_vol_subcort.aseg_lateral.ventricle.rh 4.504116\n", "smri_area_cort.destrieux_g.cingul.post.ventral.lh 5.093452\n", "smri_area_cort.destrieux_g.cingul.post.ventral.rh 5.106974\n", "smri_vol_subcort.aseg_wm.hypointensities 16.605102\n", "Length: 272, dtype: float64" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.skew().sort_values()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dropped 3 Columns\n" ] } ], "source": [ "# We don't even care about these measurements\n", "data = data.drop_cols(exclusions='aseg_wm.hypointensities')" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dropped 53 Rows\n" ] } ], "source": [ "data = data.filter_outliers_by_std(n_std=10)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "sex: 10393 rows (3 NaN)\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAF+CAYAAACidPAUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAVJUlEQVR4nO3de7RmdX3f8fcHhvsgNwURRC5ekFKkhNQbGrwQ0UaN1ioUI7Yag6auGtuFEBMiiVkN0Rq1sUVsjbagiE1YErwg4aKJqxmcMVyjhNsY5ZrRIBcRYfj2j/075OFwzpxnZJ5zzu+Z92uts87ev73P3t/fzDOf+Z3fs5+9U1VIkvqxxVIXIEnaOAa3JHXG4JakzhjcktQZg1uSOmNwS1JnDG5NvST7JLknyZab6HinJ/nttnxkku9viuO2470gybWb6niaTga3upbkzUnWt2C+J8lNSf4kydNn9qmqv6+qlVW1foxj/dVC56yqE6rq9zZR/ZXkqSPH/suqesamOLaml8GtafD/qmolsBPwUuA+YE2Sgzf1iTbVqF16LAxuTVSS9yS5OcndSa5N8pLWvkWSk5LckOQHSc5Jsmvb9j+S/OnIMU5LclGSbOhcVbW+qm6oqncAXwPe135+3zayXdHW35zkxlbTTUmOS/JM4HTguW3kfmfb91Otni8luRd4UWt7/6x+/maSdUnWJjlupP3SJG8dWX94VJ/k6635inbON8yeeknyzHaMO5Nck+RVI9s+leRjSb7Y+rIqyQHj/t2oXwa3JibJM4D/APx8Ve0IvAxY2za/E/hl4BeAJwH/CHysbftPwD9vIfcC4C3A8bVx92f4M+AFc9S0A/BR4OWtpucBl1fVt4ETaKP3qtp55Mf+LfD7wI7AXFMpTwQeD+wFHA+c0fq+QVX1wrb4rHbOz82qdSvgz4GvArsz/JmdNevYxwCnArsA17c6NeUMbk3SemAb4KAkW1XV2qq6oW07AXhvVX2/qu5nGB2/LsmKqvox8CvAh4AzgXdW1ca+AXgLsOs82x4CDk6yXVXdWlXXLHCsL1TVN6rqoar6yTz7/HZV3V9VXwO+CLx+I+udy3OAlcAfVNVPq+pi4Hzg2JF9zq2qy6rqQeAs4NBNcF4tcwa3JqaqrgfexRDKdyQ5O8mT2uanAOe2KYA7gW8zBP0e7WdXATcCAc75GU6/F/DDOWq6F3gDw38ct7ZphgMXONb3Ftj+j+24M77L8FvEY/Uk4HtV9dCsY+81sn7byPKPGYJeU87g1kRV1Weq6giGoC7gtLbpewzTFTuPfG1bVTcDJPl1htH6LcCJP8OpXwP85Tw1XVBVRwF7At8BPjGzab5uLHCuXdoUzIx9GOoGuBfYfmTbExc41qhbgCcnGf13ug9w80YcQ1PI4NbEJHlGkhcn2Qb4CcPVHjOjx9OB30/ylLbvE5K8ui0/HXg/8EaGKZMTkxw6xvm2TLJfkv8GHMkw9zt7nz2SvLoF7f3APSM13Q7snWTrn6G7pybZus3J/xLw+dZ+OfDaJNu3y/7eMuvnbgf2n+eYqxhG0Scm2SrJkcArgbN/hvo0RQxuTdI2wB8A6xh+pd8dOLlt+whwHvDVJHcDfw08u135cSZwWlVdUVXXAb8J/J/2H8BcnpvkHuAu4FLgcQxviF41x75bAO9mGM3+kOHN0be3bRcD1wC3JVm3Ef28jeHN1VsY5plPqKrvtG1/BPyUIaA/3baPeh/w6TZl9Ih58ar6KUNQv5zhz/C/A28aObY2U/FBCpLUF0fcktQZg1uSOmNwS1JnDG5J6syKpS5gLkcffXR95StfWeoyJGmpzXl/nmU54l63bmOuxJKkzcuyDG5J0vwMbknqjMEtSZ0xuCWpMwa3JHXG4JakzhjcktQZg1uSOmNwS1JnDG5J6ozBLUmdMbglqTMGtyR1xuCWpM4sy4cF7/DE/erAXzl1qcuQpMdkzQfe9FgP0c/9uCVJ8zO4JakzBrckdcbglqTOGNyS1BmDW5I6Y3BLUmcMbknqjMEtSZ0xuCWpMwa3JHXG4JakzhjcktQZg1uSOmNwS1JnDG5J6ozBLUmdMbglqTMGtyR1xuCWpM4Y3JLUGYNbkjpjcEtSZwxuSeqMwS1JnTG4JakzBrckdcbglqTOGNyS1BmDW5I6Y3BLUmcMbknqjMEtSZ0xuCWpMwa3JHXG4JakzhjcktQZg1uSOmNwS1JnDG5J6ozBLUmdMbglqTMGtyR1xuCWpM4Y3JLUGYNbkjpjcEtSZwxuSeqMwS1JnTG4JakzBrckdcbglqTOGNyS1BmDW5I6Y3BLUmcMbknqjMEtSZ0xuCWpMwa3JHXG4JakzhjcktSZiQV3kkpy5sj6iiT/kOT8SZ1TkjYHkxxx3wscnGS7tn4UcPMEzydJm4VJT5V8CfhXbflY4LMTPp8kTb1JB/fZwDFJtgUOAVbNt2OStyVZnWT1gz++e8JlSVK/JhrcVXUlsC/DaPtLC+x7RlUdXlWHr9h+x0mWJUldW7EI5zgP+CBwJLDbIpxPkqbaYgT3J4E7q+qqJEcuwvkkaapNPLir6vvARyd9HknaXEwsuKtq5RxtlwKXTuqckrQ58JOTktQZg1uSOmNwS1JnDG5J6ozBLUmdMbglqTMGtyR1xuCWpM4Y3JLUGYNbkjpjcEtSZwxuSeqMwS1JnTG4JakzBrckdcbglqTOGNyS1BmDW5I6Y3BLUmcMbknqjMEtSZ0xuCWpMwa3JHXG4JakzhjcktQZg1uSOmNwS1JnDG5J6ozBLUmdMbglqTMGtyR1xuCWpM4Y3JLUGYNbkjpjcEtSZwxuSeqMwS1JnTG4JakzBrckdcbglqTOGNyS1BmDW5I6Y3BLUmcMbknqjMEtSZ0xuCWpMwa3JHXG4JakzhjcktQZg1uSOrPRwZ1klySHTKIYSdLCxgruJJcmeVySXYFvAZ9I8qHJliZJmsu4I+6dquou4LXA/66qZwMvnVxZkqT5jBvcK5LsCbweOH+C9UiSFjBucP8ucAFwQ1V9M8n+wHWTK0uSNJ8V4+xUVZ8HPj+yfiPwrydVlCRpfuO+Ofn0JBclubqtH5LktyZbmiRpLuNOlXwCOBl4AKCqrgSOmVRRkqT5jTVVAmxfVZclGW17cAL1APDMvXdj9QfeNKnDS1LXxh1xr0tyAFAASV4H3DqxqiRJ8xp3xP3rwBnAgUluBm4C3jixqiRJ8xr3qpIbgZcm2QHYoqrunmxZkqT5bDC4k7yxqs5M8u5Z7QBUlR97l6RFttCIe4f2fcdJFyJJGs8Gg7uqPp5kS+CuqvqjRapJkrQBC15VUlXrgWMXoRZJ0hjGvarkG0n+GPgccO9MY1V9ayJVSZLmNW5wH9q+/+5IWwEv3qTVSJIWNO7lgC+adCGSpPGMe5OpnZJ8KMnq9vVfk+w06eIkSY827kfePwnczfAghdcDdwF/MqmiJEnzG3eO+4CqGr3/9qlJLp9APZKkBYw74r4vyREzK0meD9w3mZIkSRsy7oj77cCn27x2gB8Cb55UUZKk+Y17VcnlwLOSPK6t3zXJoiRJ8xsruOe5ydSPgDUt1CVJi2TcOe7DgROAvdrXrwFHA59IcuKEapMkzWHcOe69gcOq6h6AJL8DfBF4IbAG+MPJlCdJmm3cEffuwP0j6w8Ae1TVfbPaJUkTNu6I+yxgVZIvtPVXAp9pT8T524lUJkma07hXlfxeki8Dz29NJ1TV6rZ83EQqkyTNadypEoBtGR6o8BHgu0n2m1BNkqQNGPcmU78DvAc4uTVtBZw5qaIkSfMbd8T9GuBVtIcoVNUt+BxKSVoS4wb3T6uqGB6eQHtTUpK0BMYN7nOSfBzYOcmvAn8B/M/JlSVJms+4V5V8MMlRDPfhfgZwSlVdONHKJElzGvdeJadV1XuAC+dokyQtonGnSo6ao+3lm7IQSdJ4NjjiTvJ24B3A/kmuHNm0I/CNSRYmSZpbhotF5tk4PDhhF+C/ACeNbLq7qn44qaIO2Wu7Ov/Xnjqpw0vazO1zylVLXcK4MlfjBkfcVfUjhvtuHwuQZHeGT1CuTLKyqv5+U1cpSdqwcT85+cok1wE3AV8D1gJfnmBdkqR5jPvm5PuB5wB/V1X7AS8B/npiVUmS5jVucD9QVT8AtkiyRVVdwvBUHEnSIhv3ftx3JlkJfB04K8kdtPuWSJIW10KXAz4V2AN4NXAf8BsM999+CvDOiVcnSXqUhaZKPsxwD+57q+qhqnqwqj4NnAu8b9LFSZIebaHg3qOqHnXBY2vbdyIVSZI2aKHg3nkD27bbhHVIksa0UHCvbrdxfYQkbwXWTKYkSdKGLHRVybuAc5Mcxz8F9eHA1gxPxZEkLbKFPvJ+O/C8JC8CDm7NX6yqiydemSRpTuM+SOES4JIJ1yJJGsO4n5yUJC0TBrckdcbglqTOGNyS1BmDW5I6Y3BLUmcMbknqjMEtSZ0xuCWpMwa3JHXG4JakzhjcktQZg1uSOmNwS1JnDG5J6ozBLUmdMbglqTMGtyR1xuCWpM4Y3JLUGYNbkjpjcEtSZwxuSeqMwS1JnTG4JakzBrckdcbglqTOGNyS1BmDW5I6Y3BLUmcMbknqjMEtSZ0xuCWpMwa3JHXG4JakzhjcktQZg1uSOmNwS1JnDG5J6ozBLUmdMbglqTMGtyR1xuCWpM4Y3JLUGYNbkjpjcEtSZwxuSerMisU4SZL1wFUjTb9cVWsX49ySNG0WJbiB+6rq0EU6lyRNNadKJKkzizXi3i7J5W35pqp6zewdkrwNeBvAXjtttUhlSVJ/ls1USVWdAZwBcMhe29ViFCVJPXKqRJI6Y3BLUmcMbknqzKIEd1WtXIzzSNLmwBG3JHXG4JakzhjcktQZg1uSOmNwS1JnDG5J6ozBLUmdMbglqTMGtyR1xuCWpM4Y3JLUGYNbkjpjcEtSZwxuSeqMwS1JnTG4JakzBrckdcbglqTOGNyS1BmDW5I6Y3BLUmcMbknqjMEtSZ0xuCWpMwa3JHXG4JakzhjcktQZg1uSOmNwS1JnDG5J6ozBLUmdMbglqTMGtyR1xuCWpM4Y3JLUGYNbkjpjcEtSZwxuSeqMwS1JnTG4JakzBrckdcbglqTOGNyS1BmDW5I6Y3BLUmcMbknqjMEtSZ0xuCWpMwa3JHXG4JakzhjcktQZg1uSOmNwS1JnDG5J6ozBLUmdMbglqTMGtyR1xuCWpM4Y3JLUmRVLXcBctt7zn7HPKauXugxJWpYccUtSZwxuSeqMwS1JnTG4JakzBrckdcbglqTOGNyS1BmDW5I6Y3BLUmcMbknqjMEtSZ0xuCWpMwa3JHXG4JakzhjcktSZVNVS1/AoSe4Grl3qOibo8cC6pS5iwuxj/6a9f7D8+7iuqo6e3bgsH6QAXFtVhy91EZOSZPU09w/s4zSY9v5Bv310qkSSOmNwS1Jnlmtwn7HUBUzYtPcP7OM0mPb+Qad9XJZvTkqS5rdcR9ySpHkY3JLUmWUV3EmOTnJtkuuTnLTU9WyMJJ9MckeSq0fadk1yYZLr2vddWnuSfLT188okh438zPFt/+uSHL8UfZlLkicnuSTJ3ya5Jsl/bO3T1Mdtk1yW5IrWx1Nb+35JVrW+fC7J1q19m7Z+fdu+78ixTm7t1yZ52RJ1aU5JtkzyN0nOb+vT1r+1Sa5KcnmS1a1tal6nAFTVsvgCtgRuAPYHtgauAA5a6ro2ov4XAocBV4+0/SFwUls+CTitLb8C+DIQ4DnAqta+K3Bj+75LW95lqfvWatsTOKwt7wj8HXDQlPUxwMq2vBWwqtV+DnBMaz8deHtbfgdwels+BvhcWz6ovX63AfZrr+stl7p/I/18N/AZ4Py2Pm39Wws8flbb1LxOq2pZBfdzgQtG1k8GTl7qujayD/vOCu5rgT3b8p4MHywC+Dhw7Oz9gGOBj4+0P2K/5fQFfAE4alr7CGwPfAt4NsMn61a09odfp8AFwHPb8oq2X2a/dkf3W+ovYG/gIuDFwPmt3qnpX6tnruCeqtfpcpoq2Qv43sj691tbz/aoqlvb8m3AHm15vr528WfQfmX+Fwwj0qnqY5tGuBy4A7iQYTR5Z1U92HYZrffhvrTtPwJ2Y3n38cPAicBDbX03pqt/AAV8NcmaJG9rbVP1Ol2uH3mfOlVVSbq/9jLJSuBPgXdV1V1JHt42DX2sqvXAoUl2Bs4FDlzaijadJL8E3FFVa5IcucTlTNIRVXVzkt2BC5N8Z3TjNLxOl9OI+2bgySPre7e2nt2eZE+A9v2O1j5fX5f1n0GSrRhC+6yq+rPWPFV9nFFVdwKXMEwd7JxkZpAzWu/DfWnbdwJ+wPLt4/OBVyVZC5zNMF3yEaanfwBU1c3t+x0M//n+S6bsdbqcgvubwNPaO9xbM7wZct4S1/RYnQfMvBt9PMO88Ez7m9o72s8BftR+jbsA+MUku7R3vX+xtS25DEPr/wV8u6o+NLJpmvr4hDbSJsl2DHP432YI8Ne13Wb3cabvrwMurmFC9DzgmHZVxn7A04DLFqUTG1BVJ1fV3lW1L8O/r4ur6jimpH8ASXZIsuPMMsPr62qm6HUKLJ83J9sbAK9guFrhBuC9S13PRtb+WeBW4AGG+bC3MMwHXgRcB/wFsGvbN8DHWj+vAg4fOc6/B65vX/9uqfs1UtcRDHOHVwKXt69XTFkfDwH+pvXxauCU1r4/QzBdD3we2Ka1b9vWr2/b9x851ntb368FXr7UfZujr0fyT1eVTE3/Wl+uaF/XzOTINL1Oq8qPvEtSb5bTVIkkaQwGtyR1xuCWpM4Y3JLUGYNbkjpjcGsqJXlikrOT3NA++vylJE/fhMc/MsnzNtXxpI1hcGvqtA8LnQtcWlUHVNXPMdwYaY8N/+RGORIwuLUkDG5NoxcBD1TV6TMNVXUF8FdJPpDk6na/5jfAw6Pn82f2TfLHSd7cltcmOTXJt9rPHNhusnUC8Bvtns8vSPJv2nGvSPL1xeysNj/eZErT6GBgzRztrwUOBZ4FPB745pghu66qDkvyDuA/V9Vbk5wO3FNVHwRIchXwshpubrTzpuiENB9H3NqcHAF8tqrWV9XtwNeAnx/j52ZuqLWG4Z7rc/kG8Kkkv8rwUBBpYgxuTaNrgJ/biP0f5JH/Fradtf3+9n098/yWWlUnAL/FcEe5NUl224jzSxvF4NY0uhjYZuQm+iQ5BLgTeEN7WMITGB43dxnwXeCgdre7nYGXjHGOuxke4TZz/AOqalVVnQL8A4+8Jai0STnHralTVZXkNcCHk7wH+AnD46zeBaxkuHNcASdW1W0ASc5huCPgTQx3CFzInwP/N8mrgXcyvFH5NIa7zV3UziFNhHcHlKTOOFUiSZ0xuCWpMwa3JHXG4JakzhjcktQZg1uSOmNwS1Jn/j+M39/Arsnx1wAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "data.plot('target')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note we have some missing data in the target variable, we can drop these." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dropped 3 Rows\n" ] } ], "source": [ "data = data.drop_nan_subjects('target')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's lastly split our data in a train test split." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Performing test split on: 10393 subjects.\n", "random_state: 2\n", "Test split size: 0.2\n", "\n", "Performed train/test split\n", "Train size: 8314\n", "Test size: 2079\n" ] }, { "data": { "text/html": [ "
\n", "

Data

\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
smri_area_cort.destrieux_g.and.s.cingul.ant.lhsmri_area_cort.destrieux_g.and.s.cingul.ant.rhsmri_area_cort.destrieux_g.and.s.cingul.mid.ant.lhsmri_area_cort.destrieux_g.and.s.cingul.mid.ant.rhsmri_area_cort.destrieux_g.and.s.cingul.mid.post.lhsmri_area_cort.destrieux_g.and.s.cingul.mid.post.rhsmri_area_cort.destrieux_g.and.s.frontomargin.lhsmri_area_cort.destrieux_g.and.s.frontomargin.rhsmri_area_cort.destrieux_g.and.s.occipital.inf.lhsmri_area_cort.destrieux_g.and.s.occipital.inf.rh...smri_vol_subcort.aseg_pallidum.rhsmri_vol_subcort.aseg_putamen.lhsmri_vol_subcort.aseg_putamen.rhsmri_vol_subcort.aseg_subcorticalgrayvolumesmri_vol_subcort.aseg_supratentorialvolumesmri_vol_subcort.aseg_thalamus.proper.lhsmri_vol_subcort.aseg_thalamus.proper.rhsmri_vol_subcort.aseg_ventraldc.lhsmri_vol_subcort.aseg_ventraldc.rhsmri_vol_subcort.aseg_wholebrain
01540.01921.01237.01211.0939.01022.0872.0596.0820.0839.0...1392.55471.65002.954112.09.738411e+056980.46806.63448.13372.71.099494e+06
22108.02583.01289.01295.01066.01328.0907.0843.01571.01056.0...2102.46520.76929.871188.01.290405e+069091.38105.35058.55261.61.444690e+06
32196.02266.01012.01459.01326.01398.0944.0924.01209.01159.0...2030.06521.75647.161985.01.283405e+067470.77278.43924.83983.61.421171e+06
61537.01986.01151.01178.01182.01389.0839.0678.01250.01207.0...1859.06599.76317.065182.01.135326e+068437.88259.83734.34159.91.263524e+06
71824.02095.0893.01066.01067.01046.0974.0828.0890.0963.0...1456.75929.85642.260637.01.096084e+067891.97439.43945.93831.01.218476e+06
..................................................................
118701583.01821.0730.01040.0709.0872.0938.0745.0789.0825.0...1610.96090.95444.859550.01.001272e+067993.57239.53899.44024.61.139532e+06
118711603.01841.0899.01091.0990.0995.0809.0666.01155.0844.0...1664.67042.96654.061090.09.897016e+057113.56835.34029.63826.01.134203e+06
118721862.02245.01406.01502.0882.01279.01105.01015.01256.0960.0...1856.86331.36366.164413.01.172208e+068123.07947.63893.04428.51.301402e+06
118731803.01888.0967.01101.0866.01128.01040.0642.0939.0892.0...1470.95730.45469.355505.01.040864e+066923.86459.83502.53674.01.150473e+06
118741957.01998.01142.01102.01144.01130.0884.0638.01122.0819.0...1586.36171.05573.659816.01.079887e+067665.45959.43736.04060.71.214126e+06
\n", "

8314 rows × 269 columns

\n", "
\n", "
\n", "

Target

\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sex
00
21
31
61
71
......
118701
118710
118720
118730
118740
\n", "

8314 rows × 1 columns

\n", "
\n", "
\n", "

Non Input

\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
household.income
02
21
31
60
70
......
118701
118711
118721
118730
118742
\n", "

8314 rows × 1 columns

\n", "
\n" ], "text/plain": [ " smri_thick_cort.destrieux_g.and.s.frontomargin.lh \\\n", "0 2.643 \n", "2 2.798 \n", "3 2.570 \n", "6 2.720 \n", "7 2.561 \n", "... ... \n", "11870 2.604 \n", "11871 2.665 \n", "11872 2.517 \n", "11873 2.806 \n", "11874 2.817 \n", "\n", " smri_thick_cort.destrieux_g.and.s.occipital.inf.lh \\\n", "0 2.597 \n", "2 2.635 \n", "3 3.008 \n", "6 2.937 \n", "7 2.711 \n", "... ... \n", "11870 2.839 \n", "11871 2.915 \n", "11872 2.743 \n", "11873 2.835 \n", "11874 2.267 \n", "\n", " smri_thick_cort.destrieux_g.and.s.paracentral.lh \\\n", "0 2.682 \n", "2 2.620 \n", "3 2.771 \n", "6 2.678 \n", "7 2.869 \n", "... ... \n", "11870 2.642 \n", "11871 2.661 \n", "11872 2.607 \n", "11873 2.678 \n", "11874 2.639 \n", "\n", " smri_thick_cort.destrieux_g.and.s.subcentral.lh \\\n", "0 3.016 \n", "2 2.963 \n", "3 3.116 \n", "6 3.226 \n", "7 3.079 \n", "... ... \n", "11870 3.017 \n", "11871 3.114 \n", "11872 3.210 \n", "11873 3.344 \n", "11874 2.805 \n", "\n", " smri_thick_cort.destrieux_g.and.s.transv.frontopol.lh \\\n", "0 2.776 \n", "2 3.038 \n", "3 2.753 \n", "6 3.058 \n", "7 3.022 \n", "... ... \n", "11870 2.990 \n", "11871 2.968 \n", "11872 2.847 \n", "11873 2.975 \n", "11874 3.041 \n", "\n", " smri_thick_cort.destrieux_g.and.s.cingul.ant.lh \\\n", "0 3.012 \n", "2 2.948 \n", "3 3.137 \n", "6 3.185 \n", "7 3.059 \n", "... ... \n", "11870 3.119 \n", "11871 3.167 \n", "11872 2.954 \n", "11873 3.134 \n", "11874 2.867 \n", "\n", " smri_thick_cort.destrieux_g.and.s.cingul.mid.ant.lh \\\n", "0 2.894 \n", "2 2.966 \n", "3 3.222 \n", "6 3.198 \n", "7 3.108 \n", "... ... \n", "11870 3.014 \n", "11871 3.058 \n", "11872 2.965 \n", "11873 3.425 \n", "11874 2.906 \n", "\n", " smri_thick_cort.destrieux_g.and.s.cingul.mid.post.lh \\\n", "0 2.874 \n", "2 2.728 \n", "3 3.062 \n", "6 2.848 \n", "7 2.958 \n", "... ... \n", "11870 2.871 \n", "11871 2.976 \n", "11872 2.846 \n", "11873 3.251 \n", "11874 3.049 \n", "\n", " smri_thick_cort.destrieux_g.cingul.post.dorsal.lh \\\n", "0 2.865 \n", "2 3.263 \n", "3 3.315 \n", "6 3.283 \n", "7 3.114 \n", "... ... \n", "11870 3.240 \n", "11871 3.355 \n", "11872 3.211 \n", "11873 3.288 \n", "11874 3.440 \n", "\n", " smri_thick_cort.destrieux_g.cingul.post.ventral.lh ... \\\n", "0 2.350 ... \n", "2 1.882 ... \n", "3 3.065 ... \n", "6 2.422 ... \n", "7 2.751 ... \n", "... ... ... \n", "11870 2.254 ... \n", "11871 2.168 ... \n", "11872 2.741 ... \n", "11873 2.535 ... \n", "11874 2.123 ... \n", "\n", " smri_vol_subcort.aseg_cc.mid.anterior \\\n", "0 396.9 \n", "2 336.7 \n", "3 432.3 \n", "6 403.1 \n", "7 398.5 \n", "... ... \n", "11870 366.9 \n", "11871 367.8 \n", "11872 472.6 \n", "11873 424.1 \n", "11874 417.5 \n", "\n", " smri_vol_subcort.aseg_cc.anterior smri_vol_subcort.aseg_wholebrain \\\n", "0 546.9 1.099494e+06 \n", "2 684.0 1.444690e+06 \n", "3 720.6 1.421171e+06 \n", "6 873.3 1.263524e+06 \n", "7 786.8 1.218476e+06 \n", "... ... ... \n", "11870 761.9 1.139532e+06 \n", "11871 609.3 1.134203e+06 \n", "11872 855.6 1.301402e+06 \n", "11873 691.2 1.150473e+06 \n", "11874 863.2 1.214126e+06 \n", "\n", " smri_vol_subcort.aseg_latventricles \\\n", "0 4693.2 \n", "2 13426.2 \n", "3 8375.3 \n", "6 12905.8 \n", "7 12127.9 \n", "... ... \n", "11870 11129.1 \n", "11871 2855.1 \n", "11872 8278.4 \n", "11873 6483.5 \n", "11874 9234.5 \n", "\n", " smri_vol_subcort.aseg_allventricles \\\n", "0 6299.4 \n", "2 18810.3 \n", "3 11828.6 \n", "6 14878.8 \n", "7 14754.5 \n", "... ... \n", "11870 14259.9 \n", "11871 4925.1 \n", "11872 10434.1 \n", "11873 8978.0 \n", "11874 11169.3 \n", "\n", " smri_vol_subcort.aseg_intracranialvolume \\\n", "0 1.354788e+06 \n", "2 1.703982e+06 \n", "3 1.679526e+06 \n", "6 1.514361e+06 \n", "7 1.586405e+06 \n", "... ... \n", "11870 1.480336e+06 \n", "11871 1.470497e+06 \n", "11872 1.455727e+06 \n", "11873 1.480286e+06 \n", "11874 1.500072e+06 \n", "\n", " smri_vol_subcort.aseg_supratentorialvolume \\\n", "0 9.738411e+05 \n", "2 1.290405e+06 \n", "3 1.283405e+06 \n", "6 1.135326e+06 \n", "7 1.096084e+06 \n", "... ... \n", "11870 1.001272e+06 \n", "11871 9.897016e+05 \n", "11872 1.172208e+06 \n", "11873 1.040864e+06 \n", "11874 1.079887e+06 \n", "\n", " smri_vol_subcort.aseg_subcorticalgrayvolume sex household.income \n", "0 54112.0 0 2 \n", "2 71188.0 1 1 \n", "3 61985.0 1 1 \n", "6 65182.0 1 0 \n", "7 60637.0 1 0 \n", "... ... ... ... \n", "11870 59550.0 1 1 \n", "11871 61090.0 0 1 \n", "11872 64413.0 0 1 \n", "11873 55505.0 0 0 \n", "11874 59816.0 0 2 \n", "\n", "[8314 rows x 271 columns]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_data, test_data = data.test_split(size=.2, random_state=2)\n", "train_data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluating Models\n", "\n", "We will start by evaluating some different choices of pipelines / models on just our training data" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "ProblemSpec(n_jobs=16, scorer=['roc_auc'])" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ps = bp.ProblemSpec(scorer=['roc_auc'],\n", " n_jobs=16)\n", "ps" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ModelPipeline\n", "-------------\n", "imputers=\\\n", "Imputer(obj='default')\n", "\n", "scalers=\\\n", "Scaler(obj='standard')\n", "\n", "model=\\\n", "Model(obj='dt')\n", "\n", "param_search=\\\n", "None\n", "\n" ] } ], "source": [ "model_pipeline = bp.ModelPipeline(model=bp.Model('dt'))\n", "model_pipeline.print_all()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that the their are a few default values, specifically we have a set of default imputers, one for replacing all float variables with the mean value, and one for replacing all categorical / binary variables (if any, otherwise ignored) with the median values.\n", "\n", "Next, we have a just standard scaler, which scales all features to have mean 0, std of 1.\n", "\n", "Then, we have our decision tree.\n", "\n", "Lastly, we have no param_search specified." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we have an initial model, we are ready to use the Evaluate function" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "88ee7bcaba0d43f1867b79517de12ce0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Folds: 0%| | 0/5 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
predictpredict_proba_0predict_proba_1decision_functiony_true
00.00.9364330.063567-2.6899870.0
31.00.0447960.9552043.0598011.0
61.00.1141420.8858582.0491141.0
71.00.1222910.8777091.9709111.0
151.00.0258380.9741623.6297151.0
..................
118371.00.4121300.5878700.3551671.0
118410.00.8531390.146861-1.7594391.0
118491.00.4841630.5158370.0633700.0
118570.00.9884270.011573-4.4474420.0
118730.00.9893960.010604-4.5358590.0
\n", "

1663 rows × 5 columns

\n", "" ], "text/plain": [ " predict predict_proba_0 predict_proba_1 decision_function y_true\n", "0 0.0 0.936433 0.063567 -2.689987 0.0\n", "3 1.0 0.044796 0.955204 3.059801 1.0\n", "6 1.0 0.114142 0.885858 2.049114 1.0\n", "7 1.0 0.122291 0.877709 1.970911 1.0\n", "15 1.0 0.025838 0.974162 3.629715 1.0\n", "... ... ... ... ... ...\n", "11837 1.0 0.412130 0.587870 0.355167 1.0\n", "11841 0.0 0.853139 0.146861 -1.759439 1.0\n", "11849 1.0 0.484163 0.515837 0.063370 0.0\n", "11857 0.0 0.988427 0.011573 -4.447442 0.0\n", "11873 0.0 0.989396 0.010604 -4.535859 0.0\n", "\n", "[1663 rows x 5 columns]" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds = results.get_preds_dfs()\n", "\n", "# Just first fold\n", "preds[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sometimes it can be useful to look at predictions made as resitricted to only a group of subjects. Here's where we can use that household income information." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{0: '[<50K]', 1: '[>=100K]', 2: '[>=50K & <100K]'}" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# See how these values are coded\n", "train_data.encoders['household.income']" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
predictpredict_proba_0predict_proba_1decision_functiony_true
20491.00.0519970.9480032.9031641.0
61.00.1141420.8858582.0491141.0
71.00.1222910.8777091.9709111.0
61540.00.9391200.060880-2.7360430.0
102511.00.0114230.9885774.4606101.0
..................
81711.00.0203070.9796933.8762901.0
20320.00.9791600.020840-3.8498020.0
40821.00.0269870.9730133.5850601.0
40891.00.0426560.9573443.1110021.0
61401.00.3070290.6929710.8140461.0
\n", "

460 rows × 5 columns

\n", "
" ], "text/plain": [ " predict predict_proba_0 predict_proba_1 decision_function y_true\n", "2049 1.0 0.051997 0.948003 2.903164 1.0\n", "6 1.0 0.114142 0.885858 2.049114 1.0\n", "7 1.0 0.122291 0.877709 1.970911 1.0\n", "6154 0.0 0.939120 0.060880 -2.736043 0.0\n", "10251 1.0 0.011423 0.988577 4.460610 1.0\n", "... ... ... ... ... ...\n", "8171 1.0 0.020307 0.979693 3.876290 1.0\n", "2032 0.0 0.979160 0.020840 -3.849802 0.0\n", "4082 1.0 0.026987 0.973013 3.585060 1.0\n", "4089 1.0 0.042656 0.957344 3.111002 1.0\n", "6140 1.0 0.307029 0.692971 0.814046 1.0\n", "\n", "[460 rows x 5 columns]" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# First we want to get just the subset of subjects\n", "# from let's say the first fold and just under 50K\n", "fold_preds = preds[0]\n", "val_subjs = fold_preds.index\n", "vs = bp.ValueSubset('household.income', '[<50K]', decode_values=True)\n", "\n", "# Specify the intersection of those subsets of subjects\n", "subjs = bp.Intersection([val_subjs, vs])\n", "\n", "# Get the specific subject values\n", "subset_subjects = train_data.get_subjects(subjs)\n", "subset_preds = fold_preds.loc[subset_subjects]\n", "\n", "subset_preds" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's say we want to look at roc auc on just this subset" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8824987132808478" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.metrics import roc_auc_score\n", "roc_auc_score(subset_preds['y_true'], subset_preds['predict_proba_1'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One thing to note about post-stratify predictions by a group is that it is just a diagnostic tool. For example if we found that a sub group did much worse, it lets us know about the problem, but doesn't address it.\n", "\n", "That said, the above code may be useful for getting more famillar with the different internal saved attributes of the BPtEvaluator, but is it the easiest way to get this breakdown? No. Actually their is a dedicated function to breaking down results by a subset, let's check it out." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['[>=50K & <100K]', '[>=100K]', '[<50K]']" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "subsets = results.subset_by(group='household.income', dataset=train_data)\n", "list(subsets)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "BPtEvaluatorSubset(household.income=[<50K])\n", "------------\n", "mean_scores = {'roc_auc': 0.8927459075402938}\n", "std_scores = {'roc_auc': 0.01241167653252817}\n", "\n", "Saved Attributes: ['estimators', 'preds', 'timing', 'train_subjects', 'val_subjects', 'feat_names', 'ps', 'mean_scores', 'std_scores', 'weighted_mean_scores', 'scores', 'fis_', 'coef_']\n", "\n", "Available Methods: ['get_X_transform_df', 'get_preds_dfs', 'get_fis', 'get_coef_', 'permutation_importance']\n", "\n", "Evaluated with:\n", "ProblemSpec(n_jobs=16, problem_type='binary',\n", " scorer={'roc_auc': make_scorer(roc_auc_score, needs_threshold=True)},\n", " subjects='all', target='sex')" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "subsets['[<50K]']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Each of these objects can be treated the same as the main BPtEvaluator object, except with essentially a subset of validation subjects. I.e., let's look at the roc_auc we calculated vs. the saved one here for fold 0." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8824987132808478" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "subsets['[<50K]'].scores['roc_auc'][0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What if we wanted to say plot a confusion matrix? Well it seems like scikit-learn has a method dedicated to that, let's see if we can use it.\n", "\n", "https://scikit-learn.org/stable/modules/generated/sklearn.metrics.plot_confusion_matrix.html\n", "\n", "So for this function we need a trained estimator and then the validation X and y, let's grab those for just the first fold." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from sklearn.metrics import plot_confusion_matrix\n", "\n", "fold = 0\n", "\n", "estimator = results.estimators[fold]\n", "\n", "X, y = train_data.get_Xy(ps=results.ps,\n", " subjects=results.val_subjects[fold])\n", "\n", "plot_confusion_matrix(estimator, X, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "How would this change if we wanted to just plot the confusion matrix for that subset of subjects we looked at before? We just need to specify a different set of subjects, which we already calculated, so..." ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "X, y = train_data.get_Xy(ps=results.ps,\n", " subjects=subjs)\n", "\n", "plot_confusion_matrix(estimator, X, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Or of course we could just the the Subset evaluator." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also look at feature importances as averaged across all 5 folds." ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "smri_vol_subcort.aseg_caudate.rh -0.491889\n", "smri_vol_subcort.aseg_thalamus.proper.lh -0.469694\n", "smri_thick_cort.destrieux_g.postcentral.rh -0.393656\n", "smri_vol_subcort.aseg_cerebellum.white.matter.rh -0.352483\n", "smri_area_cort.destrieux_g.temp.sup.lateral.rh -0.345213\n", " ... \n", "smri_vol_subcort.aseg_putamen.rh 0.328204\n", "smri_vol_subcort.aseg_thalamus.proper.rh 0.364399\n", "smri_vol_subcort.aseg_cerebellum.cortex.rh 0.442560\n", "smri_vol_subcort.aseg_intracranialvolume 0.688781\n", "smri_vol_subcort.aseg_cerebral.white.matter.rh 0.703118\n", "Length: 269, dtype: float64" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "results.get_fis(mean=True).sort_values()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## LinearResidualizer\n", "\n", "What we find here is a bit trivial. Basically just boys have bigger brains than girls ... That said, this is just an example. What if we say residualize in a nested way for intracranial volume?\n" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ModelPipeline\n", "-------------\n", "imputers=\\\n", "Imputer(obj='default')\n", "\n", "scalers=\\\n", "[Scaler(obj='robust'),\n", " Scaler(obj=LinearResidualizer(to_resid_df= smri_vol_subcort.aseg_intracranialvolume\n", "0 1.354788e+06\n", "2 1.703982e+06\n", "3 1.679526e+06\n", "4 1.561216e+06\n", "6 1.514361e+06\n", "... ...\n", "11870 1.480336e+06\n", "11871 1.470497e+06\n", "11872 1.455727e+06\n", "11873 1.480286e+06\n", "11874 1.500072e+06\n", "\n", "[10393 rows x 1 columns]))]\n", "\n", "model=\\\n", "Model(obj='linear')\n", "\n", "param_search=\\\n", "None\n", "\n" ] } ], "source": [ "from BPt.extensions import LinearResidualizer\n", "\n", "resid = LinearResidualizer(to_resid_df=data[['smri_vol_subcort.aseg_intracranialvolume']])\n", "resid_scaler = bp.Scaler(resid, scope='float')\n", "\n", "\n", "resid_pipeline = bp.ModelPipeline(scalers=[bp.Scaler('robust'), resid_scaler],\n", " model=bp.Model('linear')) \n", "\n", "resid_pipeline.print_all()" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3dc9d98d8da5462a9e2bfc63b9a5608a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Folds: 0%| | 0/5 [00:00