100% Accurate Mushroom Classification in Python

Mushrooms come in a wide variety of shapes and sizes and colors and some are edible while others should be kept far away from the dinner table. Mushroom classification for those who live in rural areas is something crucial to survival — at least it might have been centuries ago. Skilled horticulturist might not even need a second glance to classify a mushroom. Let’s see if we can build a model to have it automatically detected using a number of variables.

Our data comes from to us from a 1987 dataset found on Kaggle now with over 8124 records.

import warnings
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score, classification_report, roc_curve, auc
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
warnings.filterwarnings("ignore")dataset = pd.read_csv('mushrooms.csv', na_values='?')
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 8124 entries, 0 to 8123
Data columns (total 23 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 class 8124 non-null object
1 cap-shape 8124 non-null object
2 cap-surface 8124 non-null object
3 cap-color 8124 non-null object
4 bruises 8124 non-null object
5 odor 8124 non-null object
6 gill-attachment 8124 non-null object
7 gill-spacing 8124 non-null object
8 gill-size 8124 non-null object
9 gill-color 8124 non-null object
10 stalk-shape 8124 non-null object
11 stalk-root 5644 non-null object
12 stalk-surface-above-ring 8124 non-null object
13 stalk-surface-below-ring 8124 non-null object
14 stalk-color-above-ring 8124 non-null object
15 stalk-color-below-ring 8124 non-null object
16 veil-type 8124 non-null object
17 veil-color 8124 non-null object
18 ring-number 8124 non-null object
19 ring-type 8124 non-null object
20 spore-print-color 8124 non-null object
21 population 8124 non-null object
22 habitat 8124 non-null object
dtypes: object(23)
memory usage: 1.4+ MB

There are 23 variables including the class variable. Other variables include the colour, shape and surface of the cap, gill, stalk and veil, as well as odor, bruising or population variables. Most of our variable seem to be free of missing values.

# Missing values in each column
print("Missing values for each column:\n", dataset.isnull().sum())
Missing values for each column:
class 0
cap-shape 0
cap-surface 0
cap-color 0
bruises 0
odor 0
gill-attachment 0
gill-spacing 0
gill-size 0
gill-color 0
stalk-shape 0
stalk-root 2480
stalk-surface-above-ring 0
stalk-surface-below-ring 0
stalk-color-above-ring 0
stalk-color-below-ring 0
veil-type 0
veil-color 0
ring-number 0
ring-type 0
spore-print-color 0
population 0
habitat 0
dtype: int64

Over 30% of the information in the ‘stalk-root’ column are missing. I decide to fill these in with the current mode of the column, which is ‘b’.

# stalk-root value_counts
b 3776
e 1120
c 556
r 192
Name: stalk-root, dtype: int64
# Fill 2480 NA in stalk-root with mode
dataset['stalk-root'].fillna(dataset['stalk-root'].mode()[0], inplace=True)

We can now visualise our variables since we have no further data cleaning to undertake. Since all variables are categorical, we can just create a for loop to create countplots for us. Once with just the variables alone and another incorporating a hue for the class/target variable.

# Data Visualisation
# countplot of every variable
for i, col in enumerate(dataset.columns):
sns_plot = sns.countplot(x=col, data=dataset)
sns_plot.figure.savefig("{} countplot.png".format(col))

# countplot of every variable with hue = class/target
for i, col in enumerate(dataset.columns):
sns_plot = sns.countplot(x=col, hue='class', data=dataset)
sns_plot.figure.savefig("{} class countplot.png".format(col))

We’re quite happy to see that our data has a balanced target variable, between the poisonous and edible mushrooms.

Some insights we see are

  • edible mushrooms are more likely to have bruising
  • edible mushrooms are most likely to have no odor whilst poisonous mushrooms will have a foul odor
  • the stalk surface in edible mushrooms are smooth whilst poisonous mushrooms are usually silky
  • the gill size of edible mushrooms are usually broad whilst poisonous mushrooms are a mix of both broad and narrow
  • the spore print color of edible mushrooms are usually black or brown whilst for poisonous mushrooms are chocolate and white
  • the categories which are the worst ones at first glance to use to classify mushrooms are: cap shape, cap surface, cap colour, gill attachment, gill spacing, stalk shape, ring number, veil type, and veil color.

We can move onto now the pre-processing of the data to prepare it for modelling. Let’s first undertake the creation of dummy variables for all variables except ‘class’, using get_dummies. And also we will map the class variable to a binary 1 for poisonous and 0 for edible:

# Preprocessing
# get_dummies for all except target variable
dummies_columns = [list(dataset.columns)[i] for i in np.arange(1,23)]
dataset = pd.get_dummies(dataset, columns=dummies_columns)

# Map target variable class: 1 = poisonous, 0 = edible
dataset['class'] = dataset['class'].map({"p": 1, "e": 0})

We now split the data to have a 80% training and 20% test set:

# Data splitting
# X,y split
x = dataset.iloc[:, 1:].values
y = dataset.iloc[:, 0].values

# Training set and Test set
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=0)
print("Number transactions x_train dataset: ", x_train.shape)
print("Number transactions y_train dataset: ", y_train.shape)
print("Number transactions x_test dataset: ", x_test.shape)
print("Number transactions y_test dataset: ", y_test.shape)
Number transactions x_train dataset: (6499, 116)
Number transactions y_train dataset: (6499,)
Number transactions x_test dataset: (1625, 116)
Number transactions y_test dataset: (1625,)

One last thing to do is to undertake feature scaling using StandardScaler:

# Feature Scaling
sc = StandardScaler()
x_train = sc.fit_transform(x_train)
x_test = sc.transform(x_test)

I’ll be using the Logistic Regression algorithm to create our model and classify our test data:

# Modelling
# Fitting Logistic Regression
lr = LogisticRegression()
lr.fit(x_train, y_train)
y_pred = lr.predict(x_test)
y_prob = lr.predict_proba(x_test)[:, 1]
# Classification report and scoring
print(classification_report(y_test, y_pred))
print(f'ROC AUC score: {roc_auc_score(y_test, y_prob)}')
print('Accuracy Score: ', accuracy_score(y_test, y_pred))
precision recall f1-score support
0 1.00 1.00 1.00 852
1 1.00 1.00 1.00 773
accuracy 1.00 1625
macro avg 1.00 1.00 1.00 1625
weighted avg 1.00 1.00 1.00 1625
ROC AUC score: 1.0
Accuracy Score: 1.0

Our model seems to have a 100% accuracy with a 100% f1-score and 1.0 ROC AUC score. Perfect model!

Let’s visualise this with a confusion matrix

# Visualising Confusion Matrix
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, cmap='Blues', annot=True, fmt='d', cbar=False, yticklabels=['Edible', 'Poisonous'],
xticklabels=['Predicted Edible', 'Predicted Poisonous'])

Our ROC AUC Curve likewise looks very appealing.

# Roc AUC Curve
false_positive_rate, true_positive_rate, thresholds = roc_curve(y_test, y_prob)
roc_auc = auc(false_positive_rate, true_positive_rate)
plt.plot(false_positive_rate, true_positive_rate, label='AUC = %0.3f' % roc_auc)
plt.plot([0, 1], [0, 1], linestyle='--')
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')

We were very lucky to have an almost perfectly balanced dataset with a sufficient amount of data to create an accurate model. If we weren’t so lucky with an unbalanced dataset then we would have to look into sampling techniques as well as look into using more complex algorithms compared to Logistic Regression. Our model however seems to be suffcient and could be used for an amateur ‘shroomer’.