{ "cells": [ { "cell_type": "markdown", "id": "understanding-messaging", "metadata": {}, "source": [ "# Predict BMI\n", "\n", "This script shows a real world example using BPt to study the relationship between BMI and the brain. The data used in this notebook cannot be made public as it is from the ABCD Study, which requires a data use agreement in order to use the data.\n", "\n", "This notebook covers a number of different topics:\n", "\n", "- Preparing Data\n", "- Evaluating a single pipeline\n", "- Considering different options for how to use a test set\n", "- Introduce and use the Evaluate input option" ] }, { "cell_type": "code", "execution_count": 1, "id": "illegal-account", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import BPt as bp\n", "import numpy as np\n", "\n", "# Don't show sklearn convergence warnings\n", "from warnings import simplefilter\n", "from sklearn.exceptions import ConvergenceWarning\n", "simplefilter(\"ignore\", category=ConvergenceWarning)\n", "\n", "# Display tables up to five decimals\n", "pd.options.display.float_format = \"{:,.5f}\".format" ] }, { "cell_type": "markdown", "id": "approximate-swift", "metadata": {}, "source": [ "## Preparing Data" ] }, { "cell_type": "markdown", "id": "swiss-external", "metadata": {}, "source": [ "We will first load in the underlying dataset for this project which has been saved as a csv. It contains multi-modal change in ROI data from two timepoints of the ABCD Study (difference from follow up and baseline).\n", "\n", "This saved dataset doesn't include the real family ids, but an interesting piece of the ABCD study derived data is that there are a number of subjects from the same family. We will handle that in this example (granted with a fake family structure which we will generate below) by ensuring that for any cross-validation split, members of the same family stay in the same training or testing fold." ] }, { "cell_type": "code", "execution_count": 2, "id": "smart-defensive", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['Unnamed: 0',\n", " 'src_subject_id',\n", " 'b_averaged_puberty',\n", " 'b_agemos',\n", " 'b_sex',\n", " 'b_race_ethnicity_categories',\n", " 'b_demo_highest_education_categories',\n", " 'b_site_id_l',\n", " 'b_subjects_no_missing_data',\n", " 'bmi_keep']" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = pd.read_excel('data/structure_base.xlsx')\n", "list(data)[:10]" ] }, { "cell_type": "markdown", "id": "chinese-thirty", "metadata": {}, "source": [ "This dataset contains a number of columns we don't need. We will use the next cell to both group variables of interest together, and then select only the relvant columns to keep." ] }, { "cell_type": "code", "execution_count": 3, "id": "chubby-length", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(9724, 668)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Our target variable\n", "targets = ['b_bmi']\n", "\n", "# Columns with different traditional 'co-variates'.\n", "covars = ['b_sex', 'b_demo_highest_education_categories','b_race_ethnicity_categories',\n", " 'b_agemos', 'b_mri_info_deviceserialnumber']\n", "\n", "# Let's also note which of these are categorical\n", "cat_covars = ['b_mri_info_deviceserialnumber',\n", " 'b_demo_highest_education_categories',\n", " 'b_race_ethnicity_categories',\n", " 'b_sex']\n", "\n", "# These variables are any which we might want to use\n", "# but not directly as input features! E.g., we\n", "# might want to use them to inform choice of cross-validation.\n", "non_input = ['b_rel_family_id']\n", "\n", "# The different imaging features\n", "thick = [d for d in list(data) if 'thick' in d]\n", "area = [d for d in list(data) if 'smri_area_cort' in d]\n", "subcort = [d for d in list(data) if 'smri_vol' in d]\n", "dti_fa = [d for d in list(data) if 'dmri_dti_full_fa_' in d]\n", "dti_md = [d for d in list(data) if 'dmri_dti_full_md_' in d]\n", "brain = thick + area + subcort + dti_fa + dti_md\n", "\n", "# All to keep\n", "to_keep = brain + targets + covars + non_input\n", "\n", "data = data[to_keep]\n", "data.shape" ] }, { "cell_type": "markdown", "id": "editorial-vehicle", "metadata": {}, "source": [ "Now let's convert from a pandas DataFrame to a BPt Dataset." ] }, { "cell_type": "code", "execution_count": 4, "id": "mobile-berlin", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(9724, 668)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = bp.Dataset(data)\n", "\n", "# This is optional, to print some extra statements.\n", "data.verbose = 1\n", "data.shape" ] }, { "cell_type": "markdown", "id": "macro-matrix", "metadata": {}, "source": [ "Next, we perform some actions specific to the Dataset class. These include specifying which columns are 'target' and 'non input', with any we don't set to one these roles treated as the default role, 'data'." ] }, { "cell_type": "code", "execution_count": 5, "id": "random-commissioner", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Setting NaN threshold to: 0.5\n", "Dropped 8 Rows\n" ] }, { "data": { "text/html": [ "
\n", " | b_agemos | \n", "b_demo_highest_education_categories | \n", "b_mri_info_deviceserialnumber | \n", "b_race_ethnicity_categories | \n", "b_sex | \n", "dmri_dti_full_fa_subcort_aseg_accumbens_area_lh | \n", "dmri_dti_full_fa_subcort_aseg_accumbens_area_rh | \n", "dmri_dti_full_fa_subcort_aseg_amygdala_lh | \n", "dmri_dti_full_fa_subcort_aseg_amygdala_rh | \n", "dmri_dti_full_fa_subcort_aseg_caudate_lh | \n", "... | \n", "smri_vol_scs_subcorticalgv | \n", "smri_vol_scs_suprateialv | \n", "smri_vol_scs_tplh | \n", "smri_vol_scs_tprh | \n", "smri_vol_scs_vedclh | \n", "smri_vol_scs_vedcrh | \n", "smri_vol_scs_wholeb | \n", "smri_vol_scs_wmhint | \n", "smri_vol_scs_wmhintlh | \n", "smri_vol_scs_wmhintrh | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "124 | \n", "2 | \n", "HASHe4f6957a | \n", "2 | \n", "1 | \n", "0.17195 | \n", "0.17389 | \n", "0.19923 | \n", "0.16654 | \n", "0.15518 | \n", "... | \n", "52,440.00000 | \n", "935,475.83514 | \n", "6,220.90000 | \n", "5,787.40000 | \n", "3,554.80000 | \n", "3,427.90000 | \n", "1,045,923.63514 | \n", "478.40000 | \n", "0.00000 | \n", "0.00000 | \n", "
1 | \n", "122 | \n", "5 | \n", "HASH1314a204 | \n", "1 | \n", "2 | \n", "0.16897 | \n", "0.16477 | \n", "0.20000 | \n", "0.17713 | \n", "0.15848 | \n", "... | \n", "62,550.00000 | \n", "1,078,644.86257 | \n", "8,222.60000 | \n", "7,571.40000 | \n", "3,872.50000 | \n", "3,837.40000 | \n", "1,197,394.06258 | \n", "769.70000 | \n", "0.00000 | \n", "0.00000 | \n", "
2 | \n", "114 | \n", "3 | \n", "HASH69f406fa | \n", "1 | \n", "1 | \n", "0.22881 | \n", "0.25250 | \n", "0.23221 | \n", "0.22298 | \n", "0.22545 | \n", "... | \n", "60,695.00000 | \n", "1,133,471.50460 | \n", "7,040.50000 | \n", "6,752.90000 | \n", "4,588.80000 | \n", "4,631.50000 | \n", "1,291,126.70460 | \n", "1,114.70000 | \n", "0.00000 | \n", "0.00000 | \n", "
3 | \n", "130 | \n", "5 | \n", "HASH1314a204 | \n", "1 | \n", "2 | \n", "0.16737 | \n", "0.18840 | \n", "0.15135 | \n", "0.16215 | \n", "0.18848 | \n", "... | \n", "65,614.00000 | \n", "1,055,580.45187 | \n", "8,365.70000 | \n", "7,656.40000 | \n", "4,600.10000 | \n", "4,920.60000 | \n", "1,189,841.05187 | \n", "1,788.50000 | \n", "0.00000 | \n", "0.00000 | \n", "
4 | \n", "115 | \n", "5 | \n", "HASHc3bf3d9c | \n", "1 | \n", "2 | \n", "0.18902 | \n", "0.21375 | \n", "0.23945 | \n", "0.20581 | \n", "0.20100 | \n", "... | \n", "59,174.00000 | \n", "1,010,567.33270 | \n", "6,577.90000 | \n", "6,612.70000 | \n", "3,434.30000 | \n", "3,942.60000 | \n", "1,144,069.73270 | \n", "1,036.30000 | \n", "0.00000 | \n", "0.00000 | \n", "
5 rows × 666 columns
\n", "\n", " | b_bmi | \n", "
---|---|
0 | \n", "15.17507 | \n", "
1 | \n", "16.45090 | \n", "
2 | \n", "24.43703 | \n", "
3 | \n", "17.38701 | \n", "
4 | \n", "17.59670 | \n", "
\n", " | b_rel_family_id | \n", "
---|---|
0 | \n", "2257 | \n", "
1 | \n", "11328 | \n", "
2 | \n", "7607 | \n", "
3 | \n", "11324 | \n", "
4 | \n", "7608 | \n", "
\n", " | b_agemos | \n", "b_demo_highest_education_categories | \n", "b_mri_info_deviceserialnumber | \n", "b_race_ethnicity_categories | \n", "b_sex | \n", "dmri_dti_full_fa_subcort_aseg_accumbens_area_lh | \n", "dmri_dti_full_fa_subcort_aseg_accumbens_area_rh | \n", "dmri_dti_full_fa_subcort_aseg_amygdala_lh | \n", "dmri_dti_full_fa_subcort_aseg_amygdala_rh | \n", "dmri_dti_full_fa_subcort_aseg_caudate_lh | \n", "... | \n", "smri_vol_scs_subcorticalgv | \n", "smri_vol_scs_suprateialv | \n", "smri_vol_scs_tplh | \n", "smri_vol_scs_tprh | \n", "smri_vol_scs_vedclh | \n", "smri_vol_scs_vedcrh | \n", "smri_vol_scs_wholeb | \n", "smri_vol_scs_wmhint | \n", "smri_vol_scs_wmhintlh | \n", "smri_vol_scs_wmhintrh | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "124 | \n", "1 | \n", "26 | \n", "1 | \n", "0 | \n", "0.17195 | \n", "0.17389 | \n", "0.19923 | \n", "0.16654 | \n", "0.15518 | \n", "... | \n", "52,440.00000 | \n", "935,475.83514 | \n", "6,220.90000 | \n", "5,787.40000 | \n", "3,554.80000 | \n", "3,427.90000 | \n", "1,045,923.63514 | \n", "478.40000 | \n", "0.00000 | \n", "0.00000 | \n", "
1 | \n", "122 | \n", "4 | \n", "2 | \n", "0 | \n", "1 | \n", "0.16897 | \n", "0.16477 | \n", "0.20000 | \n", "0.17713 | \n", "0.15848 | \n", "... | \n", "62,550.00000 | \n", "1,078,644.86257 | \n", "8,222.60000 | \n", "7,571.40000 | \n", "3,872.50000 | \n", "3,837.40000 | \n", "1,197,394.06258 | \n", "769.70000 | \n", "0.00000 | \n", "0.00000 | \n", "
2 | \n", "114 | \n", "2 | \n", "13 | \n", "0 | \n", "0 | \n", "0.22881 | \n", "0.25250 | \n", "0.23221 | \n", "0.22298 | \n", "0.22545 | \n", "... | \n", "60,695.00000 | \n", "1,133,471.50460 | \n", "7,040.50000 | \n", "6,752.90000 | \n", "4,588.80000 | \n", "4,631.50000 | \n", "1,291,126.70460 | \n", "1,114.70000 | \n", "0.00000 | \n", "0.00000 | \n", "
3 | \n", "130 | \n", "4 | \n", "2 | \n", "0 | \n", "1 | \n", "0.16737 | \n", "0.18840 | \n", "0.15135 | \n", "0.16215 | \n", "0.18848 | \n", "... | \n", "65,614.00000 | \n", "1,055,580.45187 | \n", "8,365.70000 | \n", "7,656.40000 | \n", "4,600.10000 | \n", "4,920.60000 | \n", "1,189,841.05187 | \n", "1,788.50000 | \n", "0.00000 | \n", "0.00000 | \n", "
4 | \n", "115 | \n", "4 | \n", "20 | \n", "0 | \n", "1 | \n", "0.18902 | \n", "0.21375 | \n", "0.23945 | \n", "0.20581 | \n", "0.20100 | \n", "... | \n", "59,174.00000 | \n", "1,010,567.33270 | \n", "6,577.90000 | \n", "6,612.70000 | \n", "3,434.30000 | \n", "3,942.60000 | \n", "1,144,069.73270 | \n", "1,036.30000 | \n", "0.00000 | \n", "0.00000 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
9717 | \n", "119 | \n", "1 | \n", "2 | \n", "2 | \n", "1 | \n", "0.16392 | \n", "0.17222 | \n", "0.19906 | \n", "0.17829 | \n", "0.15375 | \n", "... | \n", "47,422.00000 | \n", "893,537.35606 | \n", "6,066.30000 | \n", "5,342.00000 | \n", "2,808.40000 | \n", "3,469.60000 | \n", "999,770.35606 | \n", "1,152.40000 | \n", "0.00000 | \n", "0.00000 | \n", "
9720 | \n", "129 | \n", "4 | \n", "7 | \n", "0 | \n", "1 | \n", "0.23469 | \n", "0.25338 | \n", "0.25971 | \n", "0.26321 | \n", "0.22641 | \n", "... | \n", "59,348.00000 | \n", "1,070,071.00812 | \n", "7,036.60000 | \n", "6,669.10000 | \n", "3,817.40000 | \n", "4,103.40000 | \n", "1,198,435.10812 | \n", "794.90000 | \n", "0.00000 | \n", "0.00000 | \n", "
9721 | \n", "108 | \n", "2 | \n", "0 | \n", "1 | \n", "0 | \n", "0.15469 | \n", "0.20223 | \n", "0.17145 | \n", "0.15300 | \n", "0.15900 | \n", "... | \n", "63,328.00000 | \n", "1,093,051.50897 | \n", "8,123.20000 | \n", "7,346.20000 | \n", "3,992.80000 | \n", "4,219.80000 | \n", "1,213,030.80897 | \n", "805.50000 | \n", "0.00000 | \n", "0.00000 | \n", "
9722 | \n", "110 | \n", "2 | \n", "17 | \n", "0 | \n", "1 | \n", "0.16530 | \n", "0.22484 | \n", "0.19524 | \n", "0.16084 | \n", "0.16671 | \n", "... | \n", "57,037.00000 | \n", "997,648.58273 | \n", "6,692.50000 | \n", "6,437.90000 | \n", "3,833.40000 | \n", "3,887.70000 | \n", "1,127,133.48273 | \n", "1,477.90000 | \n", "0.00000 | \n", "0.00000 | \n", "
9723 | \n", "113 | \n", "4 | \n", "26 | \n", "0 | \n", "1 | \n", "0.15050 | \n", "0.16066 | \n", "0.16982 | \n", "0.16484 | \n", "0.15355 | \n", "... | \n", "61,090.00000 | \n", "989,701.56266 | \n", "7,113.50000 | \n", "6,835.30000 | \n", "4,029.60000 | \n", "3,826.00000 | \n", "1,134,202.76266 | \n", "2,304.80000 | \n", "0.00000 | \n", "0.00000 | \n", "
9401 rows × 666 columns
\n", "7481 rows × 666 columns - Train Set
1920 rows × 666 columns - Test Set
\n", " | b_bmi | \n", "
---|---|
0 | \n", "15.17507 | \n", "
1 | \n", "16.45090 | \n", "
2 | \n", "24.43703 | \n", "
3 | \n", "17.38701 | \n", "
4 | \n", "17.59670 | \n", "
... | \n", "... | \n", "
9717 | \n", "16.26895 | \n", "
9720 | \n", "31.95732 | \n", "
9721 | \n", "14.19771 | \n", "
9722 | \n", "24.42694 | \n", "
9723 | \n", "18.25486 | \n", "
9401 rows × 1 columns
\n", "7481 rows × 1 columns - Train Set
1920 rows × 1 columns - Test Set
\n", " | b_rel_family_id | \n", "
---|---|
0 | \n", "1589 | \n", "
1 | \n", "7867 | \n", "
2 | \n", "5174 | \n", "
3 | \n", "7864 | \n", "
4 | \n", "5175 | \n", "
... | \n", "... | \n", "
9717 | \n", "8070 | \n", "
9720 | \n", "1814 | \n", "
9721 | \n", "5719 | \n", "
9722 | \n", "6412 | \n", "
9723 | \n", "1740 | \n", "
9401 rows × 1 columns
\n", "7481 rows × 1 columns - Train Set
1920 rows × 1 columns - Test Set
\n", " | mean_scores_explained_variance | \n", "mean_scores_neg_mean_squared_error | \n", "std_scores_explained_variance | \n", "std_scores_neg_mean_squared_error | \n", "mean_timing_fit | \n", "mean_timing_score | \n", "
---|---|---|---|---|---|---|
scope | \n", "\n", " | \n", " | \n", " | \n", " | \n", " | \n", " |
covars | \n", "0.11317 | \n", "-14.51761 | \n", "0.01192 | \n", "1.45243 | \n", "4.73712 | \n", "0.03311 | \n", "
brain | \n", "0.25656 | \n", "-12.15231 | \n", "0.02394 | \n", "1.10996 | \n", "15.58119 | \n", "0.05231 | \n", "
all | \n", "0.27908 | \n", "-11.78751 | \n", "0.02294 | \n", "1.10537 | \n", "15.45347 | \n", "0.07103 | \n", "
thick | \n", "0.10118 | \n", "-14.69564 | \n", "0.02711 | \n", "1.38971 | \n", "8.06942 | \n", "0.01529 | \n", "
area | \n", "0.02774 | \n", "-15.89653 | \n", "0.00991 | \n", "1.46288 | \n", "7.71240 | \n", "0.01417 | \n", "
subcort | \n", "0.05189 | \n", "-15.50615 | \n", "0.00831 | \n", "1.39405 | \n", "4.38906 | \n", "0.02439 | \n", "
dti_fa | \n", "0.10302 | \n", "-14.66791 | \n", "0.00659 | \n", "1.39366 | \n", "8.24234 | \n", "0.01827 | \n", "
dti_md | \n", "0.14541 | \n", "-13.96297 | \n", "0.01127 | \n", "1.21864 | \n", "8.31585 | \n", "0.01754 | \n", "
\n", " | scope (1) | \n", "scope (2) | \n", "t_stat | \n", "p_val | \n", "
---|---|---|---|---|
0 | \n", "covars | \n", "brain | \n", "-11.36825 | \n", "0.00478 | \n", "
1 | \n", "covars | \n", "all | \n", "-14.32368 | \n", "0.00193 | \n", "
2 | \n", "covars | \n", "thick | \n", "0.86450 | \n", "1.00000 | \n", "
3 | \n", "covars | \n", "area | \n", "7.66506 | \n", "0.02180 | \n", "
4 | \n", "covars | \n", "subcort | \n", "9.07301 | \n", "0.01145 | \n", "
5 | \n", "covars | \n", "dti_fa | \n", "1.97369 | \n", "1.00000 | \n", "
6 | \n", "covars | \n", "dti_md | \n", "-4.56856 | \n", "0.14381 | \n", "
7 | \n", "brain | \n", "all | \n", "-15.67812 | \n", "0.00135 | \n", "
8 | \n", "brain | \n", "thick | \n", "30.46195 | \n", "0.00010 | \n", "
9 | \n", "brain | \n", "area | \n", "17.32117 | \n", "0.00091 | \n", "
10 | \n", "brain | \n", "subcort | \n", "17.40024 | \n", "0.00090 | \n", "
11 | \n", "brain | \n", "dti_fa | \n", "12.94696 | \n", "0.00287 | \n", "
12 | \n", "brain | \n", "dti_md | \n", "11.98864 | \n", "0.00388 | \n", "
13 | \n", "all | \n", "thick | \n", "31.58027 | \n", "0.00008 | \n", "
14 | \n", "all | \n", "area | \n", "19.16281 | \n", "0.00061 | \n", "
15 | \n", "all | \n", "subcort | \n", "20.49440 | \n", "0.00047 | \n", "
16 | \n", "all | \n", "dti_fa | \n", "15.58172 | \n", "0.00139 | \n", "
17 | \n", "all | \n", "dti_md | \n", "15.37322 | \n", "0.00146 | \n", "
18 | \n", "thick | \n", "area | \n", "5.10927 | \n", "0.09713 | \n", "
19 | \n", "thick | \n", "subcort | \n", "3.32372 | \n", "0.40987 | \n", "
20 | \n", "thick | \n", "dti_fa | \n", "-0.14599 | \n", "1.00000 | \n", "
21 | \n", "thick | \n", "dti_md | \n", "-3.56252 | \n", "0.32946 | \n", "
22 | \n", "area | \n", "subcort | \n", "-2.76953 | \n", "0.70498 | \n", "
23 | \n", "area | \n", "dti_fa | \n", "-11.38268 | \n", "0.00476 | \n", "
24 | \n", "area | \n", "dti_md | \n", "-12.59938 | \n", "0.00320 | \n", "
25 | \n", "subcort | \n", "dti_fa | \n", "-7.99938 | \n", "0.01854 | \n", "
26 | \n", "subcort | \n", "dti_md | \n", "-31.20696 | \n", "0.00009 | \n", "
27 | \n", "dti_fa | \n", "dti_md | \n", "-6.47580 | \n", "0.04102 | \n", "
\n", " | mean_diff | \n", "std_diff | \n", "t_stat | \n", "p_val | \n", "better_prob | \n", "worse_prob | \n", "rope_prob | \n", "
---|---|---|---|---|---|---|---|
explained_variance | \n", "-0.14339 | \n", "-0.01202 | \n", "-11.36825 | \n", "0.00017 | \n", "0.00013 | \n", "0.99977 | \n", "0.00009 | \n", "
neg_mean_squared_error | \n", "-2.36530 | \n", "0.34247 | \n", "-8.09697 | \n", "0.00063 | \n", "0.00062 | \n", "0.99936 | \n", "0.00002 | \n", "