Jupyter Notebook#
Lecture 15: K-Fold CV for Classification#
# Everyone's favorite standard imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
from sklearn.linear_model import LinearRegression,LogisticRegression
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import KFold
1. CV for a classification data set#
Artwork by @allison_horst
For this lab, we are going to use the Palmer Penguins data set by Allison Horst, Alison Hill, and Kristen Gorman. This data set was originally posted in R, but has helpfully been loaded as an easily readable python data set by installing the palmerpenguins
package using pip
.
# You should only have to do this once:
%pip install palmerpenguins
# If it worked, this should load our dataset
from palmerpenguins import load_penguins
penguins = load_penguins()
penguins.head()
As always, when playing with a new data set, your first job is to just get a feel for what’s in the data. We’re going to use this data to predict species of the penguin given the other information.
✅ Questions:
How many penguins are in the data set?
What are the input variables?
What are the possible values of the output variable?
Which are categorical varaibales? Which are quantitative?
Are there any lines with missing data? How is missing data represented in this data set?
Your answers here
✅ Do this: Spoiler alert, there are penguins with missing data. Replace the penguins
dataframe with one where you have removed all those lines. (Hint: this should be a one line operation)
# Your code here
Our next favorite thing to do with any data set is to start trying to visualize relationships between the variables.
sns.pairplot(penguins)
#Here is another nice visualization taken from the palmerpenguins github
g = sns.lmplot(x="flipper_length_mm",
y="body_mass_g",
hue="species",
height=7,
data=penguins,
palette=['#FF8C00','#159090','#A034F0'])
g.set_xlabels('Flipper Length')
g.set_ylabels('Body Mass')
Step 1: Set up your dataframes#
Ok, you have your penguins data frame.
Build a dataframe \(X\) with
island
andsex
replaced with dummy variable(s)Save an pandas series of the entries in
penguins.species
as \(y\).
# Your code here. Feel free to make more cells, I spread this out over at least
# 5 while I was trying to get everything up and running.
Step 2: Run logistic regression#
Ok, you have your penguins data with input variables as X and we are going to predict penguins.species
. While scikitlearn
cannot handle input variables that are categorical (hence why we had to put in our dummy variables ourselves), it’s find with a predictor variable that is. The following code will fit a logistic regression on the whole data set. Of course, you know better than to actually do this to return your results, so in a moment we will be modifying this to get \(k\)-fold CV test errors.
logisticmodel = LogisticRegression(max_iter = 1400) # Note, I needed to up the interations
# to get rid of a convergence warning
logisticmodel.fit(X, y)
Also here’s some helpful code to remember how to get accuracy/error rates out of classification modules in scikitlearn
.
# and now we can also get the error rate on the training set.
from sklearn.metrics import accuracy_score
yhat = logisticmodel.predict(X)
accuracy = accuracy_score(yhat, y)
# Note that accuracy is the percentage correct
print('Accuracy:', accuracy)
# so the percentage incorrect is
print('Error:', 1-accuracy)
# We can get the same info directly from the original model
print('\nAccuracy version 2:', logisticmodel.score(X,y))
✅ Do this: Ok, your job, should you choose to accept it, is to
Train a model predicing
species
from all the input variables using logistic regression.Use \(k\)-fold cross validation to determine the test error. I would recommend using something like \(k=5\) to start building your code, but you can up it to \(k=10\) when you want to see better results.
Hint: while I was building my version, I had to set the
max_iter
for Logistic regression pretty high to get the model to converge. However, my error results were still pretty reasonable with lowermax_iter
, ignoring the massive amount of pink warning boxes. Feel free to mess around with this parameter to see how it affects your output.
# Your code here
Congratulations, we’re done!#
Written by Dr. Liz Munch, Michigan State University
This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.