{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Predict Sex\n", "\n", "This notebook goes through a simple binary classification example, explaining general library functionality along the way.\n", "Within this notebook we make use of data downloaded from Release 2.0.1 of the the ABCD Study (https://abcdstudy.org/).\n", "This dataset is openly available to researchers (after signing a data use agreement) and is particularly well suited\n", "towards performing neuroimaging based ML given the large sample size of the study.\n", "\n", "Within this notebook we will be performing binary classification predicting sex assigned at birth from tabular ROI structural MRI data.\n", "\n", "## Load Data" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import BPt as bp\n", "import pandas as pd\n", "import os\n", "\n", "from warnings import simplefilter\n", "from sklearn.exceptions import ConvergenceWarning\n", "simplefilter(\"ignore\", category=ConvergenceWarning)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def load_from_rds(names, eventname='baseline_year_1_arm_1'):\n", " \n", " data = pd.read_csv('data/nda_rds_201.csv',\n", " usecols=['src_subject_id', 'eventname'] + names,\n", " na_values=['777', 999, '999', 777])\n", " \n", " data = data.loc[data[data['eventname'] == eventname].index]\n", " data = data.set_index('src_subject_id')\n", " data = data.drop('eventname', axis=1)\n", " \n", " # Obsificate subject ID for public example\n", " data.index = list(range(len(data)))\n", " \n", " # Return as pandas DataFrame cast to BPt Dataset\n", " return bp.Dataset(data)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# This way we can look at all column available\n", "all_cols = list(pd.read_csv('data/nda_rds_201.csv', nrows=0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can search through all column to find which columns we actually want to load. We will start with the brain imaging features." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['smri_thick_cort.destrieux_g.and.s.frontomargin.lh',\n", " 'smri_thick_cort.destrieux_g.and.s.occipital.inf.lh',\n", " 'smri_thick_cort.destrieux_g.and.s.paracentral.lh',\n", " 'smri_thick_cort.destrieux_g.and.s.subcentral.lh',\n", " 'smri_thick_cort.destrieux_g.and.s.transv.frontopol.lh',\n", " 'smri_thick_cort.destrieux_g.and.s.cingul.ant.lh',\n", " 'smri_thick_cort.destrieux_g.and.s.cingul.mid.ant.lh',\n", " 'smri_thick_cort.destrieux_g.and.s.cingul.mid.post.lh',\n", " 'smri_thick_cort.destrieux_g.cingul.post.dorsal.lh',\n", " 'smri_thick_cort.destrieux_g.cingul.post.ventral.lh']" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "feat_keys = {'thick': 'smri_thick_cort.destrieux_g.',\n", " 'sulc': 'smri_sulc_cort.destrieux_g.',\n", " 'area': 'smri_area_cort.destrieux_g.',\n", " 'subcort': 'smri_vol_subcort.aseg_'}\n", "\n", "feat_cols = {key: [c for c in all_cols if feat_keys[key] in c] for key in feat_keys}\n", "all_cols = sum(feat_cols.values(), [])\n", "\n", "# For example\n", "feat_cols['thick'][:10]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also need our target variable, in this case sex.\n", "\n", "Let's load household income too as a non input, i.e., a variable we won't use directly as input." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "target = 'sex'\n", "non_inputs = ['household.income']" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/sage/anaconda3/envs/bpt/lib/python3.9/site-packages/IPython/core/interactiveshell.py:3338: DtypeWarning: Columns (63641) have mixed types.Specify dtype option on import or set low_memory=False.\n", " if (await self.run_code(code, result, async_=asy)):\n" ] }, { "data": { "text/html": [ "
\n", " | household.income | \n", "sex | \n", "smri_area_cort.destrieux_g.and.s.cingul.ant.lh | \n", "smri_area_cort.destrieux_g.and.s.cingul.ant.rh | \n", "smri_area_cort.destrieux_g.and.s.cingul.mid.ant.lh | \n", "smri_area_cort.destrieux_g.and.s.cingul.mid.ant.rh | \n", "smri_area_cort.destrieux_g.and.s.cingul.mid.post.lh | \n", "smri_area_cort.destrieux_g.and.s.cingul.mid.post.rh | \n", "smri_area_cort.destrieux_g.and.s.frontomargin.lh | \n", "smri_area_cort.destrieux_g.and.s.frontomargin.rh | \n", "... | \n", "smri_vol_subcort.aseg_subcorticalgrayvolume | \n", "smri_vol_subcort.aseg_supratentorialvolume | \n", "smri_vol_subcort.aseg_thalamus.proper.lh | \n", "smri_vol_subcort.aseg_thalamus.proper.rh | \n", "smri_vol_subcort.aseg_ventraldc.lh | \n", "smri_vol_subcort.aseg_ventraldc.rh | \n", "smri_vol_subcort.aseg_wholebrain | \n", "smri_vol_subcort.aseg_wm.hypointensities | \n", "smri_vol_subcort.aseg_wm.hypointensities.lh | \n", "smri_vol_subcort.aseg_wm.hypointensities.rh | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "[>=50K & <100K] | \n", "F | \n", "1540.0 | \n", "1921.0 | \n", "1237.0 | \n", "1211.0 | \n", "939.0 | \n", "1022.0 | \n", "872.0 | \n", "596.0 | \n", "... | \n", "54112.0 | \n", "9.738411e+05 | \n", "6980.4 | \n", "6806.6 | \n", "3448.1 | \n", "3372.7 | \n", "1.099494e+06 | \n", "2201.9 | \n", "0.0 | \n", "0.0 | \n", "
1 | \n", "NaN | \n", "F | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
2 | \n", "[>=100K] | \n", "M | \n", "2108.0 | \n", "2583.0 | \n", "1289.0 | \n", "1295.0 | \n", "1066.0 | \n", "1328.0 | \n", "907.0 | \n", "843.0 | \n", "... | \n", "71188.0 | \n", "1.290405e+06 | \n", "9091.3 | \n", "8105.3 | \n", "5058.5 | \n", "5261.6 | \n", "1.444690e+06 | \n", "1254.8 | \n", "0.0 | \n", "0.0 | \n", "
3 | \n", "[>=100K] | \n", "M | \n", "2196.0 | \n", "2266.0 | \n", "1012.0 | \n", "1459.0 | \n", "1326.0 | \n", "1398.0 | \n", "944.0 | \n", "924.0 | \n", "... | \n", "61985.0 | \n", "1.283405e+06 | \n", "7470.7 | \n", "7278.4 | \n", "3924.8 | \n", "3983.6 | \n", "1.421171e+06 | \n", "950.9 | \n", "0.0 | \n", "0.0 | \n", "
4 | \n", "[<50K] | \n", "M | \n", "1732.0 | \n", "1936.0 | \n", "1024.0 | \n", "916.0 | \n", "900.0 | \n", "1002.0 | \n", "863.0 | \n", "730.0 | \n", "... | \n", "61855.0 | \n", "1.072113e+06 | \n", "8152.9 | \n", "7436.8 | \n", "4085.2 | \n", "4129.3 | \n", "1.186497e+06 | \n", "789.9 | \n", "0.0 | \n", "0.0 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
11870 | \n", "[>=100K] | \n", "M | \n", "1583.0 | \n", "1821.0 | \n", "730.0 | \n", "1040.0 | \n", "709.0 | \n", "872.0 | \n", "938.0 | \n", "745.0 | \n", "... | \n", "59550.0 | \n", "1.001272e+06 | \n", "7993.5 | \n", "7239.5 | \n", "3899.4 | \n", "4024.6 | \n", "1.139532e+06 | \n", "651.7 | \n", "0.0 | \n", "0.0 | \n", "
11871 | \n", "[>=100K] | \n", "F | \n", "1603.0 | \n", "1841.0 | \n", "899.0 | \n", "1091.0 | \n", "990.0 | \n", "995.0 | \n", "809.0 | \n", "666.0 | \n", "... | \n", "61090.0 | \n", "9.897016e+05 | \n", "7113.5 | \n", "6835.3 | \n", "4029.6 | \n", "3826.0 | \n", "1.134203e+06 | \n", "2304.8 | \n", "0.0 | \n", "0.0 | \n", "
11872 | \n", "[>=100K] | \n", "F | \n", "1862.0 | \n", "2245.0 | \n", "1406.0 | \n", "1502.0 | \n", "882.0 | \n", "1279.0 | \n", "1105.0 | \n", "1015.0 | \n", "... | \n", "64413.0 | \n", "1.172208e+06 | \n", "8123.0 | \n", "7947.6 | \n", "3893.0 | \n", "4428.5 | \n", "1.301402e+06 | \n", "1654.0 | \n", "0.0 | \n", "0.0 | \n", "
11873 | \n", "[<50K] | \n", "F | \n", "1803.0 | \n", "1888.0 | \n", "967.0 | \n", "1101.0 | \n", "866.0 | \n", "1128.0 | \n", "1040.0 | \n", "642.0 | \n", "... | \n", "55505.0 | \n", "1.040864e+06 | \n", "6923.8 | \n", "6459.8 | \n", "3502.5 | \n", "3674.0 | \n", "1.150473e+06 | \n", "1209.5 | \n", "0.0 | \n", "0.0 | \n", "
11874 | \n", "[>=50K & <100K] | \n", "F | \n", "1957.0 | \n", "1998.0 | \n", "1142.0 | \n", "1102.0 | \n", "1144.0 | \n", "1130.0 | \n", "884.0 | \n", "638.0 | \n", "... | \n", "59816.0 | \n", "1.079887e+06 | \n", "7665.4 | \n", "5959.4 | \n", "3736.0 | \n", "4060.7 | \n", "1.214126e+06 | \n", "1321.1 | \n", "0.0 | \n", "0.0 | \n", "
11875 rows × 274 columns
\n", "\n", " | sex | \n", "
---|---|
0 | \n", "0 | \n", "
1 | \n", "0 | \n", "
2 | \n", "1 | \n", "
3 | \n", "1 | \n", "
4 | \n", "1 | \n", "
... | \n", "... | \n", "
11870 | \n", "1 | \n", "
11871 | \n", "0 | \n", "
11872 | \n", "0 | \n", "
11873 | \n", "0 | \n", "
11874 | \n", "0 | \n", "
11875 rows × 1 columns
\n", "\n", " | smri_area_cort.destrieux_g.and.s.cingul.ant.lh | \n", "smri_area_cort.destrieux_g.and.s.cingul.ant.rh | \n", "smri_area_cort.destrieux_g.and.s.cingul.mid.ant.lh | \n", "smri_area_cort.destrieux_g.and.s.cingul.mid.ant.rh | \n", "smri_area_cort.destrieux_g.and.s.cingul.mid.post.lh | \n", "smri_area_cort.destrieux_g.and.s.cingul.mid.post.rh | \n", "smri_area_cort.destrieux_g.and.s.frontomargin.lh | \n", "smri_area_cort.destrieux_g.and.s.frontomargin.rh | \n", "smri_area_cort.destrieux_g.and.s.occipital.inf.lh | \n", "smri_area_cort.destrieux_g.and.s.occipital.inf.rh | \n", "... | \n", "smri_vol_subcort.aseg_pallidum.rh | \n", "smri_vol_subcort.aseg_putamen.lh | \n", "smri_vol_subcort.aseg_putamen.rh | \n", "smri_vol_subcort.aseg_subcorticalgrayvolume | \n", "smri_vol_subcort.aseg_supratentorialvolume | \n", "smri_vol_subcort.aseg_thalamus.proper.lh | \n", "smri_vol_subcort.aseg_thalamus.proper.rh | \n", "smri_vol_subcort.aseg_ventraldc.lh | \n", "smri_vol_subcort.aseg_ventraldc.rh | \n", "smri_vol_subcort.aseg_wholebrain | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "1540.0 | \n", "1921.0 | \n", "1237.0 | \n", "1211.0 | \n", "939.0 | \n", "1022.0 | \n", "872.0 | \n", "596.0 | \n", "820.0 | \n", "839.0 | \n", "... | \n", "1392.5 | \n", "5471.6 | \n", "5002.9 | \n", "54112.0 | \n", "9.738411e+05 | \n", "6980.4 | \n", "6806.6 | \n", "3448.1 | \n", "3372.7 | \n", "1.099494e+06 | \n", "
2 | \n", "2108.0 | \n", "2583.0 | \n", "1289.0 | \n", "1295.0 | \n", "1066.0 | \n", "1328.0 | \n", "907.0 | \n", "843.0 | \n", "1571.0 | \n", "1056.0 | \n", "... | \n", "2102.4 | \n", "6520.7 | \n", "6929.8 | \n", "71188.0 | \n", "1.290405e+06 | \n", "9091.3 | \n", "8105.3 | \n", "5058.5 | \n", "5261.6 | \n", "1.444690e+06 | \n", "
3 | \n", "2196.0 | \n", "2266.0 | \n", "1012.0 | \n", "1459.0 | \n", "1326.0 | \n", "1398.0 | \n", "944.0 | \n", "924.0 | \n", "1209.0 | \n", "1159.0 | \n", "... | \n", "2030.0 | \n", "6521.7 | \n", "5647.1 | \n", "61985.0 | \n", "1.283405e+06 | \n", "7470.7 | \n", "7278.4 | \n", "3924.8 | \n", "3983.6 | \n", "1.421171e+06 | \n", "
6 | \n", "1537.0 | \n", "1986.0 | \n", "1151.0 | \n", "1178.0 | \n", "1182.0 | \n", "1389.0 | \n", "839.0 | \n", "678.0 | \n", "1250.0 | \n", "1207.0 | \n", "... | \n", "1859.0 | \n", "6599.7 | \n", "6317.0 | \n", "65182.0 | \n", "1.135326e+06 | \n", "8437.8 | \n", "8259.8 | \n", "3734.3 | \n", "4159.9 | \n", "1.263524e+06 | \n", "
7 | \n", "1824.0 | \n", "2095.0 | \n", "893.0 | \n", "1066.0 | \n", "1067.0 | \n", "1046.0 | \n", "974.0 | \n", "828.0 | \n", "890.0 | \n", "963.0 | \n", "... | \n", "1456.7 | \n", "5929.8 | \n", "5642.2 | \n", "60637.0 | \n", "1.096084e+06 | \n", "7891.9 | \n", "7439.4 | \n", "3945.9 | \n", "3831.0 | \n", "1.218476e+06 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
11870 | \n", "1583.0 | \n", "1821.0 | \n", "730.0 | \n", "1040.0 | \n", "709.0 | \n", "872.0 | \n", "938.0 | \n", "745.0 | \n", "789.0 | \n", "825.0 | \n", "... | \n", "1610.9 | \n", "6090.9 | \n", "5444.8 | \n", "59550.0 | \n", "1.001272e+06 | \n", "7993.5 | \n", "7239.5 | \n", "3899.4 | \n", "4024.6 | \n", "1.139532e+06 | \n", "
11871 | \n", "1603.0 | \n", "1841.0 | \n", "899.0 | \n", "1091.0 | \n", "990.0 | \n", "995.0 | \n", "809.0 | \n", "666.0 | \n", "1155.0 | \n", "844.0 | \n", "... | \n", "1664.6 | \n", "7042.9 | \n", "6654.0 | \n", "61090.0 | \n", "9.897016e+05 | \n", "7113.5 | \n", "6835.3 | \n", "4029.6 | \n", "3826.0 | \n", "1.134203e+06 | \n", "
11872 | \n", "1862.0 | \n", "2245.0 | \n", "1406.0 | \n", "1502.0 | \n", "882.0 | \n", "1279.0 | \n", "1105.0 | \n", "1015.0 | \n", "1256.0 | \n", "960.0 | \n", "... | \n", "1856.8 | \n", "6331.3 | \n", "6366.1 | \n", "64413.0 | \n", "1.172208e+06 | \n", "8123.0 | \n", "7947.6 | \n", "3893.0 | \n", "4428.5 | \n", "1.301402e+06 | \n", "
11873 | \n", "1803.0 | \n", "1888.0 | \n", "967.0 | \n", "1101.0 | \n", "866.0 | \n", "1128.0 | \n", "1040.0 | \n", "642.0 | \n", "939.0 | \n", "892.0 | \n", "... | \n", "1470.9 | \n", "5730.4 | \n", "5469.3 | \n", "55505.0 | \n", "1.040864e+06 | \n", "6923.8 | \n", "6459.8 | \n", "3502.5 | \n", "3674.0 | \n", "1.150473e+06 | \n", "
11874 | \n", "1957.0 | \n", "1998.0 | \n", "1142.0 | \n", "1102.0 | \n", "1144.0 | \n", "1130.0 | \n", "884.0 | \n", "638.0 | \n", "1122.0 | \n", "819.0 | \n", "... | \n", "1586.3 | \n", "6171.0 | \n", "5573.6 | \n", "59816.0 | \n", "1.079887e+06 | \n", "7665.4 | \n", "5959.4 | \n", "3736.0 | \n", "4060.7 | \n", "1.214126e+06 | \n", "
8314 rows × 269 columns
\n", "\n", " | sex | \n", "
---|---|
0 | \n", "0 | \n", "
2 | \n", "1 | \n", "
3 | \n", "1 | \n", "
6 | \n", "1 | \n", "
7 | \n", "1 | \n", "
... | \n", "... | \n", "
11870 | \n", "1 | \n", "
11871 | \n", "0 | \n", "
11872 | \n", "0 | \n", "
11873 | \n", "0 | \n", "
11874 | \n", "0 | \n", "
8314 rows × 1 columns
\n", "\n", " | household.income | \n", "
---|---|
0 | \n", "2 | \n", "
2 | \n", "1 | \n", "
3 | \n", "1 | \n", "
6 | \n", "0 | \n", "
7 | \n", "0 | \n", "
... | \n", "... | \n", "
11870 | \n", "1 | \n", "
11871 | \n", "1 | \n", "
11872 | \n", "1 | \n", "
11873 | \n", "0 | \n", "
11874 | \n", "2 | \n", "
8314 rows × 1 columns
\n", "\n", " | predict | \n", "predict_proba_0 | \n", "predict_proba_1 | \n", "decision_function | \n", "y_true | \n", "
---|---|---|---|---|---|
0 | \n", "0.0 | \n", "0.936433 | \n", "0.063567 | \n", "-2.689987 | \n", "0.0 | \n", "
3 | \n", "1.0 | \n", "0.044796 | \n", "0.955204 | \n", "3.059801 | \n", "1.0 | \n", "
6 | \n", "1.0 | \n", "0.114142 | \n", "0.885858 | \n", "2.049114 | \n", "1.0 | \n", "
7 | \n", "1.0 | \n", "0.122291 | \n", "0.877709 | \n", "1.970911 | \n", "1.0 | \n", "
15 | \n", "1.0 | \n", "0.025838 | \n", "0.974162 | \n", "3.629715 | \n", "1.0 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
11837 | \n", "1.0 | \n", "0.412130 | \n", "0.587870 | \n", "0.355167 | \n", "1.0 | \n", "
11841 | \n", "0.0 | \n", "0.853139 | \n", "0.146861 | \n", "-1.759439 | \n", "1.0 | \n", "
11849 | \n", "1.0 | \n", "0.484163 | \n", "0.515837 | \n", "0.063370 | \n", "0.0 | \n", "
11857 | \n", "0.0 | \n", "0.988427 | \n", "0.011573 | \n", "-4.447442 | \n", "0.0 | \n", "
11873 | \n", "0.0 | \n", "0.989396 | \n", "0.010604 | \n", "-4.535859 | \n", "0.0 | \n", "
1663 rows × 5 columns
\n", "\n", " | predict | \n", "predict_proba_0 | \n", "predict_proba_1 | \n", "decision_function | \n", "y_true | \n", "
---|---|---|---|---|---|
2049 | \n", "1.0 | \n", "0.051997 | \n", "0.948003 | \n", "2.903164 | \n", "1.0 | \n", "
6 | \n", "1.0 | \n", "0.114142 | \n", "0.885858 | \n", "2.049114 | \n", "1.0 | \n", "
7 | \n", "1.0 | \n", "0.122291 | \n", "0.877709 | \n", "1.970911 | \n", "1.0 | \n", "
6154 | \n", "0.0 | \n", "0.939120 | \n", "0.060880 | \n", "-2.736043 | \n", "0.0 | \n", "
10251 | \n", "1.0 | \n", "0.011423 | \n", "0.988577 | \n", "4.460610 | \n", "1.0 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
8171 | \n", "1.0 | \n", "0.020307 | \n", "0.979693 | \n", "3.876290 | \n", "1.0 | \n", "
2032 | \n", "0.0 | \n", "0.979160 | \n", "0.020840 | \n", "-3.849802 | \n", "0.0 | \n", "
4082 | \n", "1.0 | \n", "0.026987 | \n", "0.973013 | \n", "3.585060 | \n", "1.0 | \n", "
4089 | \n", "1.0 | \n", "0.042656 | \n", "0.957344 | \n", "3.111002 | \n", "1.0 | \n", "
6140 | \n", "1.0 | \n", "0.307029 | \n", "0.692971 | \n", "0.814046 | \n", "1.0 | \n", "
460 rows × 5 columns
\n", "