{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Predict Waist Circumference with Diffusion Weighted Imaging\n", "\n", "This notebook using diffusion weighted imaging data, and subjects waist circumference in cm from the ABCD Study.\n", "We will use as input feature derived Restriction spectrum imaging (RSI) from diffusion weighted images. This notebook\n", "covers data loading as well as evaluation across a large number of different ML Pipelines. This notebook may be useful\n", "for people looking for more examples on what different Pipelines to try." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import BPt as bp\n", "import pandas as pd\n", "import os\n", "\n", "from warnings import simplefilter\n", "from sklearn.exceptions import ConvergenceWarning\n", "simplefilter(\"ignore\", category=ConvergenceWarning)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load the data needed\n", "\n", "Data is loaded from a large csv file with all of the features from release 2 of the ABCD study." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def load_from_rds(names, eventname='baseline_year_1_arm_1'):\n", " \n", " data = pd.read_csv('data/nda_rds_201.csv',\n", " usecols=['src_subject_id', 'eventname'] + names,\n", " na_values=['777', 999, '999', 777])\n", " \n", " data = data.loc[data[data['eventname'] == eventname].index]\n", " data = data.set_index('src_subject_id')\n", " data = data.drop('eventname', axis=1)\n", " \n", " # Obsificate subject ID for public example\n", " data.index = list(range(len(data)))\n", " \n", " # Return as pandas DataFrame cast to BPt Dataset\n", " return bp.Dataset(data)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['subjectid',\n", " 'src_subject_id',\n", " 'eventname',\n", " 'anthro_1_height_in',\n", " 'anthro_2_height_in',\n", " 'anthro_3_height_in',\n", " 'anthro_height_calc',\n", " 'anthro_weight_cast',\n", " 'anthro_weight_a_location',\n", " 'anthro_weight1_lb']" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# This way we can look at all column available\n", "all_cols = list(pd.read_csv('data/nda_rds_201.csv', nrows=0))\n", "all_cols[:10]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "294" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The target variable\n", "target_cols = ['anthro_waist_cm']\n", "\n", "# non input feature - i.e., those that inform \n", "non_input_cols = ['sex', 'rel_family_id']\n", "\n", "# We will use the fiber at dti measures\n", "dti_cols = [c for c in all_cols if '_fiber.at' in c and 'rsi.' in c]\n", "len(dti_cols)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can use the helper function defined at the start to load these features in as a Dataset" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(11875, 297)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = load_from_rds(target_cols + non_input_cols + dti_cols)\n", "data.shape" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# This is optional, but will print out some extra verbosity when using the dataset operations\n", "data.verbose = 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The first step we will do is tell the dataset what roles the different columns are. See: https://sahahn.github.io/BPt/user_guide/role.html" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dropped 2 Rows\n", "Dropped 6 Rows\n" ] }, { "data": { "text/html": [ "
\n", "

Data

\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
dmri_rsi.n0_fiber.at_allfib.lhdmri_rsi.n0_fiber.at_allfib.rhdmri_rsi.n0_fiber.at_allfibersdmri_rsi.n0_fiber.at_allfibnocc.lhdmri_rsi.n0_fiber.at_allfibnocc.rhdmri_rsi.n0_fiber.at_atr.lhdmri_rsi.n0_fiber.at_atr.rhdmri_rsi.n0_fiber.at_ccdmri_rsi.n0_fiber.at_cgc.lhdmri_rsi.n0_fiber.at_cgc.rh...dmri_rsi.vol_fiber.at_scs.lhdmri_rsi.vol_fiber.at_scs.rhdmri_rsi.vol_fiber.at_sifc.lhdmri_rsi.vol_fiber.at_sifc.rhdmri_rsi.vol_fiber.at_slf.lhdmri_rsi.vol_fiber.at_slf.rhdmri_rsi.vol_fiber.at_tslf.lhdmri_rsi.vol_fiber.at_tslf.rhdmri_rsi.vol_fiber.at_unc.lhdmri_rsi.vol_fiber.at_unc.rh
00.3276230.3234200.3259570.3405590.3323640.3478370.3360720.3068030.3113470.304854...23672.013056.09648.09528.010152.011504.08384.08024.04968.07176.0
1NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
20.3253740.3114650.3190270.3412130.3263340.3466510.3353620.2881240.3264160.300990...33112.019256.011928.08688.013144.015344.010488.010936.06904.09480.0
30.3050950.3043570.3051700.3154770.3128660.3139720.3167290.2887420.2891660.290347...28480.016016.013024.011960.013600.014880.011416.010592.06952.08736.0
40.3168600.3152380.3163990.3282510.3272590.3339980.3181620.2940080.2978000.299230...29904.017968.012720.011336.013528.015672.011096.011816.05912.07336.0
..................................................................
118700.3357410.3360480.3363720.3498060.3477320.3499660.3456920.3120560.3240540.336676...28328.015400.09656.010080.011312.013496.08728.09176.04960.07392.0
118710.3205630.3175250.3194290.3273020.3221610.3330860.3154820.3085540.2999380.298093...23792.013632.09928.08912.09152.012288.07128.08912.05744.07376.0
118720.3270510.3253860.3265220.3409180.3348540.3454350.3356100.3057200.3086300.330612...28640.016384.09496.011216.012168.012312.09520.08952.04568.09056.0
118730.3235790.3193770.3218050.3349450.3294330.3322000.3340170.3043990.3038310.307037...26216.014672.09408.08872.010960.012584.08880.09176.03696.06168.0
118740.3835370.3714830.3778220.3944130.3734860.4046840.3744350.3662700.4153420.418280...26544.015624.09904.010360.09904.012216.07712.08000.05208.08816.0
\n", "

11867 rows × 294 columns

\n", "
\n", "
\n", "

Target

\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
anthro_waist_cm
031.00
130.50
226.75
323.50
430.00
......
1187026.00
1187130.00
1187219.00
1187325.00
1187432.00
\n", "

11867 rows × 1 columns

\n", "
\n", "
\n", "

Non Input

\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
rel_family_idsex
08780.0F
110207.0F
24720.0M
33804.0M
45358.0M
.........
118703791.0M
118712441.0F
118727036.0F
118736681.0F
118747588.0F
\n", "

11867 rows × 2 columns

\n", "
\n" ], "text/plain": [ " anthro_waist_cm dmri_rsi.n0_fiber.at_fx.rh \\\n", "0 31.00 0.246540 \n", "1 30.50 NaN \n", "2 26.75 0.146416 \n", "3 23.50 0.229894 \n", "4 30.00 0.192228 \n", "... ... ... \n", "11870 26.00 0.236385 \n", "11871 30.00 0.247628 \n", "11872 19.00 0.224581 \n", "11873 25.00 0.212500 \n", "11874 32.00 0.237271 \n", "\n", " dmri_rsi.n0_fiber.at_fx.lh dmri_rsi.n0_fiber.at_cgc.rh \\\n", "0 0.240964 0.304854 \n", "1 NaN NaN \n", "2 0.241515 0.300990 \n", "3 0.225981 0.290347 \n", "4 0.201559 0.299230 \n", "... ... ... \n", "11870 0.233723 0.336676 \n", "11871 0.244926 0.298093 \n", "11872 0.181725 0.330612 \n", "11873 0.225045 0.307037 \n", "11874 0.232131 0.418280 \n", "\n", " dmri_rsi.n0_fiber.at_cgc.lh dmri_rsi.n0_fiber.at_cgh.rh \\\n", "0 0.311347 0.255081 \n", "1 NaN NaN \n", "2 0.326416 0.230479 \n", "3 0.289166 0.204329 \n", "4 0.297800 0.249119 \n", "... ... ... \n", "11870 0.324054 0.245577 \n", "11871 0.299938 0.201506 \n", "11872 0.308630 0.251494 \n", "11873 0.303831 0.254441 \n", "11874 0.415342 0.241162 \n", "\n", " dmri_rsi.n0_fiber.at_cgh.lh dmri_rsi.n0_fiber.at_cst.rh \\\n", "0 0.244332 0.378148 \n", "1 NaN NaN \n", "2 0.232169 0.386119 \n", "3 0.200217 0.364850 \n", "4 0.248674 0.365131 \n", "... ... ... \n", "11870 0.259280 0.388121 \n", "11871 0.192919 0.386154 \n", "11872 0.265312 0.382624 \n", "11873 0.257420 0.380283 \n", "11874 0.263211 0.394936 \n", "\n", " dmri_rsi.n0_fiber.at_cst.lh dmri_rsi.n0_fiber.at_atr.rh ... \\\n", "0 0.388728 0.336072 ... \n", "1 NaN NaN ... \n", "2 0.400546 0.335362 ... \n", "3 0.365397 0.316729 ... \n", "4 0.359856 0.318162 ... \n", "... ... ... ... \n", "11870 0.388715 0.345692 ... \n", "11871 0.390666 0.315482 ... \n", "11872 0.379022 0.335610 ... \n", "11873 0.377752 0.334017 ... \n", "11874 0.414786 0.374435 ... \n", "\n", " dmri_rsi.vol_fiber.at_ifsfc.lh dmri_rsi.vol_fiber.at_fxcut.rh \\\n", "0 13224.0 2720.0 \n", "1 NaN NaN \n", "2 18080.0 2728.0 \n", "3 18768.0 2776.0 \n", "4 18112.0 2528.0 \n", "... ... ... \n", "11870 15848.0 2760.0 \n", "11871 14368.0 2728.0 \n", "11872 17616.0 2928.0 \n", "11873 16024.0 2728.0 \n", "11874 17096.0 1976.0 \n", "\n", " dmri_rsi.vol_fiber.at_fxcut.lh dmri_rsi.vol_fiber.at_allfibers \\\n", "0 1928.0 264000.0 \n", "1 NaN NaN \n", "2 2264.0 339840.0 \n", "3 1784.0 331024.0 \n", "4 2408.0 327192.0 \n", "... ... ... \n", "11870 2192.0 292304.0 \n", "11871 2072.0 271624.0 \n", "11872 2072.0 317816.0 \n", "11873 2032.0 286832.0 \n", "11874 1776.0 298280.0 \n", "\n", " dmri_rsi.vol_fiber.at_allfibnocc.rh \\\n", "0 91368.0 \n", "1 NaN \n", "2 119280.0 \n", "3 114912.0 \n", "4 115360.0 \n", "... ... \n", "11870 101352.0 \n", "11871 98328.0 \n", "11872 113016.0 \n", "11873 97896.0 \n", "11874 105096.0 \n", "\n", " dmri_rsi.vol_fiber.at_allfibnocc.lh dmri_rsi.vol_fiber.at_allfib.rh \\\n", "0 92144.0 133816.0 \n", "1 NaN NaN \n", "2 123784.0 166808.0 \n", "3 115472.0 167400.0 \n", "4 114592.0 165224.0 \n", "... ... ... \n", "11870 103848.0 145272.0 \n", "11871 94352.0 138504.0 \n", "11872 106280.0 162352.0 \n", "11873 97496.0 143704.0 \n", "11874 101088.0 152848.0 \n", "\n", " dmri_rsi.vol_fiber.at_allfib.lh sex rel_family_id \n", "0 131856.0 F 8780.0 \n", "1 NaN F 10207.0 \n", "2 174480.0 M 4720.0 \n", "3 165336.0 M 3804.0 \n", "4 164176.0 M 5358.0 \n", "... ... ... ... \n", "11870 148776.0 M 3791.0 \n", "11871 134832.0 F 2441.0 \n", "11872 157344.0 F 7036.0 \n", "11873 144768.0 F 6681.0 \n", "11874 147488.0 F 7588.0 \n", "\n", "[11867 rows x 297 columns]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = data.set_target(target_cols) # Note we doing data = data.func()\n", "data = data.set_non_input(non_input_cols)\n", "data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A few things to note right off the bat.\n", "\n", "1. The verbosity printed us out two statements, about dropping rows. This is due to a constraint on columns of role 'non input' that there cannot be any NaN / missing data, so those lines just say 2 NaN's were found when loading the first non input column and 6 when loading the next.\n", "\n", "2. The values for sex are still 'F' and 'M', we will handle that next.\n", "\n", "3. Some columns with role data are missing values. We will handle that as well." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "

Non Input

\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
rel_family_idsex
073210
186340
239711
331391
445431
.........
1187031281
1187121110
1187259070
1187355940
1187462380
\n", "

11867 rows × 2 columns

\n", "
\n" ], "text/plain": [ " rel_family_id sex\n", "0 7321 0\n", "1 8634 0\n", "2 3971 1\n", "3 3139 1\n", "4 4543 1\n", "... ... ..\n", "11870 3128 1\n", "11871 2111 0\n", "11872 5907 0\n", "11873 5594 0\n", "11874 6238 0\n", "\n", "[11867 rows x 2 columns]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# We explicitly say this variable should be binary\n", "data.to_binary('sex', inplace=True)\n", "\n", "# We will ordinalize rel_family_id too\n", "data = data.ordinalize(scope='rel_family_id')\n", "\n", "data['non input']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next let's look into that NaN problem we saw before." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loaded NaN Info:\n", "There are: 332348 total missing values\n", "180 columns found with 1131 missing values (column name overlap: ['dmri_rsi.n', '_fiber.at_'])\n", "66 columns found with 1130 missing values (column name overlap: ['dmri_rsi.n', '_fiber.at_'])\n", "42 columns found with 1128 missing values (column name overlap: ['dmri_rsi.vol_fiber.at_'])\n", "6 columns found with 1133 missing values (column name overlap: ['_fiber.at_cgh.lh', 'dmri_rsi.n'])\n", "\n" ] } ], "source": [ "data.nan_info()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Seems like most of the missing data is missing for everyone, i.e., if the above info founds columns with only a few missing values, we might want to do something different, but this tells us that when data is missing it is missing for all columns.\n", "\n", "We just drop any subjects with any NaN data below across the target variable and the Data" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dropped 1145 Rows\n" ] } ], "source": [ "data = data.drop_nan_subjects(scope='all')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another thing we need to worry about with data like this is corrupted data, i.e., data with values that don't make sense due to a failure in the automatic processing pipeline. Let's look at the target variable first, then the data." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "anthro_waist_cm: 10722 rows\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "text/plain": [ "(anthro_waist_cm 90.0\n", " dtype: float64,\n", " anthro_waist_cm 0.0\n", " dtype: float64)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.plot('target')\n", "data['target'].max(), data['target'].min()," ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Yeah I don't know about that waist cm of 0 ...\n", "The below code can be used to try different values of outliers to drop, since it is not by default applied in place." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dropped 33 Rows\n", "anthro_waist_cm: 10689 rows\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "data.filter_outliers_by_std(scope='target', n_std=5).plot('target')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "5 std seems okay, so let's actually apply it." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dropped 33 Rows\n" ] } ], "source": [ "data.filter_outliers_by_std(scope='target', n_std=5, inplace=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's look at the distribution of skew values for the dti data" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dmri_rsi.nd_fiber.at_fmaj -2.420074\n", "dmri_rsi.nds2_fiber.at_fmaj -2.287048\n", "dmri_rsi.nd_fiber.at_cst.lh -2.213900\n", "dmri_rsi.nt_fiber.at_cst.lh -2.198591\n", "dmri_rsi.nds2_fiber.at_fmin -2.101969\n", " ... \n", "dmri_rsi.nts2_fiber.at_cst.rh 0.992159\n", "dmri_rsi.n0s2_fiber.at_fmin 1.005677\n", "dmri_rsi.nts2_fiber.at_fmin 1.006378\n", "dmri_rsi.n0s2_fiber.at_cst.lh 1.072836\n", "dmri_rsi.nts2_fiber.at_cst.lh 1.090904\n", "Length: 294, dtype: float64" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data['data'].skew().sort_values()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Looks okay, let's choose the variable with the most extreme skew to plot." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dmri_rsi.nd_fiber.at_fmaj: 10689 rows\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "data.plot(scope='dmri_rsi.nd_fiber.at_fmaj')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "How about we apply just a strict criteria of say 10 std." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dropped 26 Rows\n" ] } ], "source": [ "data.filter_outliers_by_std(scope='data', n_std=10, inplace=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define a Test set. \n", "\n", "In this example project we are going to test a bunch of different Machine Learning Pipeline's. In order to avoid meta-issues of overfitting onto our dataset, we will therefore define a train-test split. The train set we will use to try different pipelines, then only with the best final pipeline will we use the test set. \n", "\n", "We will impose one extra constraint when applying the test split, namely that members of the same family, i.e., those with the same family id, stay in the same training or testing fold." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Performing test split on: 10663 subjects.\n", "random_state: 6\n", "Test split size: 0.2\n", "\n", "Performed train/test split\n", "Train size: 8562\n", "Test size: 2101\n" ] }, { "data": { "text/html": [ "
\n", "

Data

\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
dmri_rsi.n0_fiber.at_allfib.lhdmri_rsi.n0_fiber.at_allfib.rhdmri_rsi.n0_fiber.at_allfibersdmri_rsi.n0_fiber.at_allfibnocc.lhdmri_rsi.n0_fiber.at_allfibnocc.rhdmri_rsi.n0_fiber.at_atr.lhdmri_rsi.n0_fiber.at_atr.rhdmri_rsi.n0_fiber.at_ccdmri_rsi.n0_fiber.at_cgc.lhdmri_rsi.n0_fiber.at_cgc.rh...dmri_rsi.vol_fiber.at_scs.lhdmri_rsi.vol_fiber.at_scs.rhdmri_rsi.vol_fiber.at_sifc.lhdmri_rsi.vol_fiber.at_sifc.rhdmri_rsi.vol_fiber.at_slf.lhdmri_rsi.vol_fiber.at_slf.rhdmri_rsi.vol_fiber.at_tslf.lhdmri_rsi.vol_fiber.at_tslf.rhdmri_rsi.vol_fiber.at_unc.lhdmri_rsi.vol_fiber.at_unc.rh
00.3276230.3234200.3259570.3405590.3323640.3478370.3360720.3068030.3113470.304854...23672.013056.09648.09528.010152.011504.08384.08024.04968.07176.0
20.3253740.3114650.3190270.3412130.3263340.3466510.3353620.2881240.3264160.300990...33112.019256.011928.08688.013144.015344.010488.010936.06904.09480.0
30.3050950.3043570.3051700.3154770.3128660.3139720.3167290.2887420.2891660.290347...28480.016016.013024.011960.013600.014880.011416.010592.06952.08736.0
40.3168600.3152380.3163990.3282510.3272590.3339980.3181620.2940080.2978000.299230...29904.017968.012720.011336.013528.015672.011096.011816.05912.07336.0
50.3235210.3267410.3254660.3360030.3352910.3262430.3373670.3053820.3118430.315721...23048.012032.09056.09248.09672.011048.07848.07520.05088.07448.0
..................................................................
118700.3357410.3360480.3363720.3498060.3477320.3499660.3456920.3120560.3240540.336676...28328.015400.09656.010080.011312.013496.08728.09176.04960.07392.0
118710.3205630.3175250.3194290.3273020.3221610.3330860.3154820.3085540.2999380.298093...23792.013632.09928.08912.09152.012288.07128.08912.05744.07376.0
118720.3270510.3253860.3265220.3409180.3348540.3454350.3356100.3057200.3086300.330612...28640.016384.09496.011216.012168.012312.09520.08952.04568.09056.0
118730.3235790.3193770.3218050.3349450.3294330.3322000.3340170.3043990.3038310.307037...26216.014672.09408.08872.010960.012584.08880.09176.03696.06168.0
118740.3835370.3714830.3778220.3944130.3734860.4046840.3744350.3662700.4153420.418280...26544.015624.09904.010360.09904.012216.07712.08000.05208.08816.0
\n", "

10663 rows × 294 columns

\n", "

8562 rows × 294 columns - Train Set

2101 rows × 294 columns - Test Set

\n", "
\n", "

Target

\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
anthro_waist_cm
031.00
226.75
323.50
430.00
528.00
......
1187026.00
1187130.00
1187219.00
1187325.00
1187432.00
\n", "

10663 rows × 1 columns

\n", "

8562 rows × 1 columns - Train Set

2101 rows × 1 columns - Test Set

\n", "
\n", "

Non Input

\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
rel_family_idsex
073210
239711
331391
445431
519331
.........
1187031281
1187121110
1187259070
1187355940
1187462380
\n", "

10663 rows × 2 columns

\n", "

8562 rows × 2 columns - Train Set

2101 rows × 2 columns - Test Set

\n" ], "text/plain": [ " anthro_waist_cm dmri_rsi.n0_fiber.at_fx.rh \\\n", "0 31.00 0.246540 \n", "2 26.75 0.146416 \n", "3 23.50 0.229894 \n", "4 30.00 0.192228 \n", "5 28.00 0.223994 \n", "... ... ... \n", "11870 26.00 0.236385 \n", "11871 30.00 0.247628 \n", "11872 19.00 0.224581 \n", "11873 25.00 0.212500 \n", "11874 32.00 0.237271 \n", "\n", " dmri_rsi.n0_fiber.at_fx.lh dmri_rsi.n0_fiber.at_cgc.rh \\\n", "0 0.240964 0.304854 \n", "2 0.241515 0.300990 \n", "3 0.225981 0.290347 \n", "4 0.201559 0.299230 \n", "5 0.230152 0.315721 \n", "... ... ... \n", "11870 0.233723 0.336676 \n", "11871 0.244926 0.298093 \n", "11872 0.181725 0.330612 \n", "11873 0.225045 0.307037 \n", "11874 0.232131 0.418280 \n", "\n", " dmri_rsi.n0_fiber.at_cgc.lh dmri_rsi.n0_fiber.at_cgh.rh \\\n", "0 0.311347 0.255081 \n", "2 0.326416 0.230479 \n", "3 0.289166 0.204329 \n", "4 0.297800 0.249119 \n", "5 0.311843 0.210526 \n", "... ... ... \n", "11870 0.324054 0.245577 \n", "11871 0.299938 0.201506 \n", "11872 0.308630 0.251494 \n", "11873 0.303831 0.254441 \n", "11874 0.415342 0.241162 \n", "\n", " dmri_rsi.n0_fiber.at_cgh.lh dmri_rsi.n0_fiber.at_cst.rh \\\n", "0 0.244332 0.378148 \n", "2 0.232169 0.386119 \n", "3 0.200217 0.364850 \n", "4 0.248674 0.365131 \n", "5 0.197801 0.400478 \n", "... ... ... \n", "11870 0.259280 0.388121 \n", "11871 0.192919 0.386154 \n", "11872 0.265312 0.382624 \n", "11873 0.257420 0.380283 \n", "11874 0.263211 0.394936 \n", "\n", " dmri_rsi.n0_fiber.at_cst.lh dmri_rsi.n0_fiber.at_atr.rh ... \\\n", "0 0.388728 0.336072 ... \n", "2 0.400546 0.335362 ... \n", "3 0.365397 0.316729 ... \n", "4 0.359856 0.318162 ... \n", "5 0.395570 0.337367 ... \n", "... ... ... ... \n", "11870 0.388715 0.345692 ... \n", "11871 0.390666 0.315482 ... \n", "11872 0.379022 0.335610 ... \n", "11873 0.377752 0.334017 ... \n", "11874 0.414786 0.374435 ... \n", "\n", " dmri_rsi.vol_fiber.at_ifsfc.lh dmri_rsi.vol_fiber.at_fxcut.rh \\\n", "0 13224.0 2720.0 \n", "2 18080.0 2728.0 \n", "3 18768.0 2776.0 \n", "4 18112.0 2528.0 \n", "5 14224.0 2216.0 \n", "... ... ... \n", "11870 15848.0 2760.0 \n", "11871 14368.0 2728.0 \n", "11872 17616.0 2928.0 \n", "11873 16024.0 2728.0 \n", "11874 17096.0 1976.0 \n", "\n", " dmri_rsi.vol_fiber.at_fxcut.lh dmri_rsi.vol_fiber.at_allfibers \\\n", "0 1928.0 264000.0 \n", "2 2264.0 339840.0 \n", "3 1784.0 331024.0 \n", "4 2408.0 327192.0 \n", "5 1536.0 259936.0 \n", "... ... ... \n", "11870 2192.0 292304.0 \n", "11871 2072.0 271624.0 \n", "11872 2072.0 317816.0 \n", "11873 2032.0 286832.0 \n", "11874 1776.0 298280.0 \n", "\n", " dmri_rsi.vol_fiber.at_allfibnocc.rh \\\n", "0 91368.0 \n", "2 119280.0 \n", "3 114912.0 \n", "4 115360.0 \n", "5 91944.0 \n", "... ... \n", "11870 101352.0 \n", "11871 98328.0 \n", "11872 113016.0 \n", "11873 97896.0 \n", "11874 105096.0 \n", "\n", " dmri_rsi.vol_fiber.at_allfibnocc.lh dmri_rsi.vol_fiber.at_allfib.rh \\\n", "0 92144.0 133816.0 \n", "2 123784.0 166808.0 \n", "3 115472.0 167400.0 \n", "4 114592.0 165224.0 \n", "5 92208.0 130792.0 \n", "... ... ... \n", "11870 103848.0 145272.0 \n", "11871 94352.0 138504.0 \n", "11872 106280.0 162352.0 \n", "11873 97496.0 143704.0 \n", "11874 101088.0 152848.0 \n", "\n", " dmri_rsi.vol_fiber.at_allfib.lh sex rel_family_id \n", "0 131856.0 0 7321 \n", "2 174480.0 1 3971 \n", "3 165336.0 1 3139 \n", "4 164176.0 1 4543 \n", "5 130960.0 1 1933 \n", "... ... ... ... \n", "11870 148776.0 1 3128 \n", "11871 134832.0 0 2111 \n", "11872 157344.0 0 5907 \n", "11873 144768.0 0 5594 \n", "11874 147488.0 0 6238 \n", "\n", "[10663 rows x 297 columns]" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# We use this to say we want to preserve families\n", "preserve_family = bp.CVStrategy(groups='rel_family_id')\n", "\n", "# Apply the test split\n", "data = data.set_test_split(size=.2,\n", " cv_strategy=preserve_family,\n", " random_state=6)\n", "\n", "data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluate Different Pipelines\n", "\n", "First let's save some commonly used parameters in an object called the ProblemSpec, we will use all defaults except for the number of jobs, for that let's use n_jobs=8." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "ps = bp.ProblemSpec(n_jobs=8)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The function we will use to evaluate different pipelines is bp.evaluate, let's start with an example with just a linear regression model." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicting target = anthro_waist_cm\n", "Using problem_type = regression\n", "Using scope = all defining an initial total of 294 features.\n", "Evaluating 8562 total data points.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "662b37ffb59e4f7fb9573c17e5829b25", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Folds: 0%| | 0/5 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
dmri_rsi.n0_fiber.at_allfib.lhdmri_rsi.n0_fiber.at_allfib.rhdmri_rsi.n0_fiber.at_allfibersdmri_rsi.n0_fiber.at_allfibnocc.lhdmri_rsi.n0_fiber.at_allfibnocc.rhdmri_rsi.n0_fiber.at_atr.lhdmri_rsi.n0_fiber.at_atr.rhdmri_rsi.n0_fiber.at_ccdmri_rsi.n0_fiber.at_cgc.lhdmri_rsi.n0_fiber.at_cgc.rh...dmri_rsi.vol_fiber.at_scs.lhdmri_rsi.vol_fiber.at_scs.rhdmri_rsi.vol_fiber.at_sifc.lhdmri_rsi.vol_fiber.at_sifc.rhdmri_rsi.vol_fiber.at_slf.lhdmri_rsi.vol_fiber.at_slf.rhdmri_rsi.vol_fiber.at_tslf.lhdmri_rsi.vol_fiber.at_tslf.rhdmri_rsi.vol_fiber.at_unc.lhdmri_rsi.vol_fiber.at_unc.rh
013.65871710.18291711.1137032.7411032.0961370.652331-19.68933318.1335936.17015618.584436...0.000343-0.0005890.0005510.000348-0.0002160.0002781.902580e-04-0.0005440.0006430.000572
113.386769-7.8071205.5908722.0748261.611921-5.200155-17.35084018.9543274.41804418.556810...0.000263-0.0005050.0005230.000272-0.0003710.0004741.721382e-04-0.0005820.0005300.000525
214.6010056.4474036.5665935.1058032.144557-0.249971-17.50551619.81559210.12197215.690188...0.000319-0.0006000.0004550.000446-0.0003030.0006181.192093e-07-0.0006150.0003780.000662
312.66046416.193197-0.1935202.1844923.968235-7.643031-16.10378115.5393729.86713314.523899...0.000341-0.0004410.0005870.000453-0.0002840.0006081.287460e-04-0.0007010.0003590.000616
416.21234911.5189057.9051566.0858971.294112-8.058172-11.78694213.22719710.81222912.374775...0.000080-0.0005150.0006140.000411-0.0003420.0007842.737045e-04-0.0007380.0005550.000540
\n", "

5 rows × 294 columns

\n", "" ], "text/plain": [ " dmri_rsi.n0_fiber.at_allfib.lh dmri_rsi.n0_fiber.at_allfib.rh \\\n", "0 13.658717 10.182917 \n", "1 13.386769 -7.807120 \n", "2 14.601005 6.447403 \n", "3 12.660464 16.193197 \n", "4 16.212349 11.518905 \n", "\n", " dmri_rsi.n0_fiber.at_allfibers dmri_rsi.n0_fiber.at_allfibnocc.lh \\\n", "0 11.113703 2.741103 \n", "1 5.590872 2.074826 \n", "2 6.566593 5.105803 \n", "3 -0.193520 2.184492 \n", "4 7.905156 6.085897 \n", "\n", " dmri_rsi.n0_fiber.at_allfibnocc.rh dmri_rsi.n0_fiber.at_atr.lh \\\n", "0 2.096137 0.652331 \n", "1 1.611921 -5.200155 \n", "2 2.144557 -0.249971 \n", "3 3.968235 -7.643031 \n", "4 1.294112 -8.058172 \n", "\n", " dmri_rsi.n0_fiber.at_atr.rh dmri_rsi.n0_fiber.at_cc \\\n", "0 -19.689333 18.133593 \n", "1 -17.350840 18.954327 \n", "2 -17.505516 19.815592 \n", "3 -16.103781 15.539372 \n", "4 -11.786942 13.227197 \n", "\n", " dmri_rsi.n0_fiber.at_cgc.lh dmri_rsi.n0_fiber.at_cgc.rh ... \\\n", "0 6.170156 18.584436 ... \n", "1 4.418044 18.556810 ... \n", "2 10.121972 15.690188 ... \n", "3 9.867133 14.523899 ... \n", "4 10.812229 12.374775 ... \n", "\n", " dmri_rsi.vol_fiber.at_scs.lh dmri_rsi.vol_fiber.at_scs.rh \\\n", "0 0.000343 -0.000589 \n", "1 0.000263 -0.000505 \n", "2 0.000319 -0.000600 \n", "3 0.000341 -0.000441 \n", "4 0.000080 -0.000515 \n", "\n", " dmri_rsi.vol_fiber.at_sifc.lh dmri_rsi.vol_fiber.at_sifc.rh \\\n", "0 0.000551 0.000348 \n", "1 0.000523 0.000272 \n", "2 0.000455 0.000446 \n", "3 0.000587 0.000453 \n", "4 0.000614 0.000411 \n", "\n", " dmri_rsi.vol_fiber.at_slf.lh dmri_rsi.vol_fiber.at_slf.rh \\\n", "0 -0.000216 0.000278 \n", "1 -0.000371 0.000474 \n", "2 -0.000303 0.000618 \n", "3 -0.000284 0.000608 \n", "4 -0.000342 0.000784 \n", "\n", " dmri_rsi.vol_fiber.at_tslf.lh dmri_rsi.vol_fiber.at_tslf.rh \\\n", "0 1.902580e-04 -0.000544 \n", "1 1.721382e-04 -0.000582 \n", "2 1.192093e-07 -0.000615 \n", "3 1.287460e-04 -0.000701 \n", "4 2.737045e-04 -0.000738 \n", "\n", " dmri_rsi.vol_fiber.at_unc.lh dmri_rsi.vol_fiber.at_unc.rh \n", "0 0.000643 0.000572 \n", "1 0.000530 0.000525 \n", "2 0.000378 0.000662 \n", "3 0.000359 0.000616 \n", "4 0.000555 0.000540 \n", "\n", "[5 rows x 294 columns]" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Beta weights\n", "results.get_fis()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
predicty_true
2824.34583121.25
3324.80690224.50
3626.96858423.00
4027.80594830.80
4725.69167920.00
.........
1184828.00659825.50
1185123.92553524.00
1185530.38943738.80
1186125.68760926.00
1186829.83816035.00
\n", "

1713 rows × 2 columns

\n", "
" ], "text/plain": [ " predict y_true\n", "28 24.345831 21.25\n", "33 24.806902 24.50\n", "36 26.968584 23.00\n", "40 27.805948 30.80\n", "47 25.691679 20.00\n", "... ... ...\n", "11848 28.006598 25.50\n", "11851 23.925535 24.00\n", "11855 30.389437 38.80\n", "11861 25.687609 26.00\n", "11868 29.838160 35.00\n", "\n", "[1713 rows x 2 columns]" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Raw predictions made from each fold\n", "results.get_preds_dfs()[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "All options are listed under: 'Saved Attributes' and 'Available Methods'.\n", "\n", "Anyways, let's continue trying different models. We will use a ridge regression model. Let's also use the fact that the jupyter notebook is defining variables in global scope to clean up the evaluation code a bit so we don't have to keep copy and pasting it." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "def eval_pipe(pipeline, **kwargs):\n", " return bp.evaluate(pipeline=pipeline,\n", " dataset=data,\n", " problem_spec=ps,\n", " **kwargs)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0a633cae7c9145159bb4f267f1b13b7b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Folds: 0%| | 0/5 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
mean_scores_explained_variancemean_scores_neg_mean_squared_errorstd_scores_explained_variancestd_scores_neg_mean_squared_errormean_timing_fitmean_timing_score
pipeline
sgd0.226867-13.6320930.0155470.28724916.1824050.007756
ridge0.244241-13.1055430.0143480.33520511.0246280.461483
elastic0.214935-13.6120860.0174570.3045218.5842990.036844
lgbm0.159819-14.6281750.0086120.31766430.4178820.063121
\n", "" ], "text/plain": [ " mean_scores_explained_variance mean_scores_neg_mean_squared_error \\\n", "pipeline \n", "sgd 0.226867 -13.632093 \n", "ridge 0.244241 -13.105543 \n", "elastic 0.214935 -13.612086 \n", "lgbm 0.159819 -14.628175 \n", "\n", " std_scores_explained_variance std_scores_neg_mean_squared_error \\\n", "pipeline \n", "sgd 0.015547 0.287249 \n", "ridge 0.014348 0.335205 \n", "elastic 0.017457 0.304521 \n", "lgbm 0.008612 0.317664 \n", "\n", " mean_timing_fit mean_timing_score \n", "pipeline \n", "sgd 16.182405 0.007756 \n", "ridge 11.024628 0.461483 \n", "elastic 8.584299 0.036844 \n", "lgbm 30.417882 0.063121 " ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Define a set of bp.Options as wrapped in bp.Compare\n", "compare_pipes = bp.Compare([bp.Option(sgd_pipe, name='sgd'),\n", " bp.Option(ridge_search_pipe, name='ridge'),\n", " bp.Option(elastic_pipe, name='elastic'),\n", " bp.Option(lgbm_pipe, name='lgbm')])\n", "\n", "# Pass as before as if a pipeline\n", "results = eval_pipe(compare_pipes)\n", "results.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Applying the Test Set\n", "\n", "So far we have been only running internal 5-fold CV on the training set. What if we say we are done with exploration, and now we want to essentially confirm that our best model we have found through internal CV on the training set generalizes to a set of unseen data. To do this, we re-train the best model tested on the full training set and test it on the testing set. In BPt this is done by just passing cv='test' to evaluate. " ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicting target = anthro_waist_cm\n", "Using problem_type = regression\n", "Using scope = all defining an initial total of 294 features.\n", "Evaluating 10663 total data points.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "59efb8760a3e4133bdf31f90c8102051", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Folds: 0%| | 0/1 [00:00