8.3.13. Different classifiers in decoding the Haxby dataset

Here we compare different classifiers on a visual object recognition decoding task.

8.3.13.1. We start by loading the data and applying simple transformations to it

# Fetch data using nilearn dataset fetcher
from nilearn import datasets
# by default 2nd subject data will be fetched
haxby_dataset = datasets.fetch_haxby()

# print basic information on the dataset
print('First subject anatomical nifti image (3D) located is at: %s' %
      haxby_dataset.anat[0])
print('First subject functional nifti image (4D) is located at: %s' %
      haxby_dataset.func[0])

# load labels
import numpy as np
import pandas as pd
labels = pd.read_csv(haxby_dataset.session_target[0], sep=" ")
stimuli = labels['labels']
# identify resting state labels in order to be able to remove them
task_mask = (stimuli != 'rest')

# find names of remaining active labels
categories = stimuli[task_mask].unique()

# extract tags indicating to which acquisition run a tag belongs
session_labels = labels['chunks'][task_mask]

# Load the fMRI data
from nilearn.input_data import NiftiMasker

# For decoding, standardizing is often very important
mask_filename = haxby_dataset.mask_vt[0]
masker = NiftiMasker(mask_img=mask_filename, standardize=True)
func_filename = haxby_dataset.func[0]
masked_timecourses = masker.fit_transform(
    func_filename)[task_mask]

Out:

First subject anatomical nifti image (3D) located is at: /home/kshitij/nilearn_data/haxby2001/subj2/anat.nii.gz
First subject functional nifti image (4D) is located at: /home/kshitij/nilearn_data/haxby2001/subj2/bold.nii.gz

8.3.13.2. Then we define the various classifiers that we use

A support vector classifier

from sklearn.svm import SVC
svm = SVC(C=1., kernel="linear")

# The logistic regression
from sklearn.linear_model import (LogisticRegression,
                                  RidgeClassifier,
                                  RidgeClassifierCV,
                                  )
logistic = LogisticRegression(C=1., penalty="l1", solver='liblinear')
logistic_50 = LogisticRegression(C=50., penalty="l1", solver='liblinear')
logistic_l2 = LogisticRegression(C=1., penalty="l2", solver='liblinear')

# Cross-validated versions of these classifiers
from sklearn.model_selection import GridSearchCV
# GridSearchCV is slow, but note that it takes an 'n_jobs' parameter that
# can significantly speed up the fitting process on computers with
# multiple cores
svm_cv = GridSearchCV(SVC(C=1., kernel="linear"),
                      param_grid={'C': [.1, 1., 10., 100.]},
                      scoring='f1', n_jobs=1, cv=3, iid=False)

logistic_cv = GridSearchCV(
        LogisticRegression(C=1., penalty="l1", solver='liblinear'),
        param_grid={'C': [.1, 1., 10., 100.]},
        scoring='f1', cv=3, iid=False,
        )
logistic_l2_cv = GridSearchCV(
        LogisticRegression(C=1., penalty="l2", solver='liblinear'),
        param_grid={
            'C': [.1, 1., 10., 100.]
            },
        scoring='f1', cv=3, iid=False,
        )

# The ridge classifier has a specific 'CV' object that can set it's
# parameters faster than using a GridSearchCV
ridge = RidgeClassifier()
ridge_cv = RidgeClassifierCV()

# A dictionary, to hold all our classifiers
classifiers = {'SVC': svm,
               'SVC cv': svm_cv,
               'log l1': logistic,
               'log l1 50': logistic_50,
               'log l1 cv': logistic_cv,
               'log l2': logistic_l2,
               'log l2 cv': logistic_l2_cv,
               'ridge': ridge,
               'ridge cv': ridge_cv
               }

8.3.13.3. Here we compute prediction scores

Run time for all these classifiers

# Make a data splitting object for cross validation
from sklearn.model_selection import LeaveOneGroupOut, cross_val_score
cv = LeaveOneGroupOut()

import time

classifiers_scores = {}

for classifier_name, classifier in sorted(classifiers.items()):
    classifiers_scores[classifier_name] = {}
    print(70 * '_')

    for category in categories:
        classification_target = stimuli[task_mask].isin([category])
        t0 = time.time()
        classifiers_scores[classifier_name][category] = cross_val_score(
            classifier,
            masked_timecourses,
            classification_target,
            cv=cv,
            groups=session_labels,
            scoring="f1",
        )

        print(
            "%10s: %14s -- scores: %1.2f +- %1.2f, time %.2fs" %
            (
                classifier_name,
                category,
                classifiers_scores[classifier_name][category].mean(),
                classifiers_scores[classifier_name][category].std(),
                time.time() - t0,
            ),
        )

Out:

______________________________________________________________________
       SVC:       scissors -- scores: 0.41 +- 0.31, time 1.22s
       SVC:           face -- scores: 0.58 +- 0.24, time 0.91s
       SVC:            cat -- scores: 0.60 +- 0.25, time 1.11s
       SVC:           shoe -- scores: 0.56 +- 0.23, time 1.14s
       SVC:          house -- scores: 0.81 +- 0.28, time 0.68s
       SVC:   scrambledpix -- scores: 0.79 +- 0.17, time 0.83s
       SVC:         bottle -- scores: 0.53 +- 0.21, time 1.26s
       SVC:          chair -- scores: 0.42 +- 0.25, time 1.14s
______________________________________________________________________
    SVC cv:       scissors -- scores: 0.41 +- 0.31, time 16.40s
    SVC cv:           face -- scores: 0.58 +- 0.24, time 14.03s
    SVC cv:            cat -- scores: 0.60 +- 0.25, time 16.66s
    SVC cv:           shoe -- scores: 0.56 +- 0.23, time 17.09s
    SVC cv:          house -- scores: 0.81 +- 0.28, time 10.73s
    SVC cv:   scrambledpix -- scores: 0.79 +- 0.17, time 13.52s
    SVC cv:         bottle -- scores: 0.53 +- 0.21, time 19.14s
    SVC cv:          chair -- scores: 0.42 +- 0.25, time 17.83s
______________________________________________________________________
    log l1:       scissors -- scores: 0.43 +- 0.28, time 0.92s
    log l1:           face -- scores: 0.74 +- 0.13, time 0.46s
    log l1:            cat -- scores: 0.63 +- 0.23, time 0.75s
    log l1:           shoe -- scores: 0.49 +- 0.19, time 0.84s
    log l1:          house -- scores: 0.89 +- 0.13, time 0.36s
    log l1:   scrambledpix -- scores: 0.69 +- 0.18, time 0.43s
    log l1:         bottle -- scores: 0.47 +- 0.15, time 0.97s
    log l1:          chair -- scores: 0.44 +- 0.14, time 0.85s
______________________________________________________________________
 log l1 50:       scissors -- scores: 0.41 +- 0.26, time 1.23s
 log l1 50:           face -- scores: 0.70 +- 0.12, time 0.72s
 log l1 50:            cat -- scores: 0.63 +- 0.19, time 0.99s
 log l1 50:           shoe -- scores: 0.48 +- 0.19, time 1.28s
 log l1 50:          house -- scores: 0.87 +- 0.12, time 0.47s
 log l1 50:   scrambledpix -- scores: 0.77 +- 0.21, time 0.48s
 log l1 50:         bottle -- scores: 0.49 +- 0.18, time 1.30s
 log l1 50:          chair -- scores: 0.46 +- 0.12, time 1.08s
______________________________________________________________________
 log l1 cv:       scissors -- scores: 0.34 +- 0.29, time 5.93s
 log l1 cv:           face -- scores: 0.70 +- 0.18, time 3.77s
 log l1 cv:            cat -- scores: 0.52 +- 0.28, time 5.13s
 log l1 cv:           shoe -- scores: 0.48 +- 0.19, time 5.67s
 log l1 cv:          house -- scores: 0.84 +- 0.19, time 2.98s
 log l1 cv:   scrambledpix -- scores: 0.70 +- 0.14, time 3.21s
 log l1 cv:         bottle -- scores: 0.41 +- 0.25, time 5.74s
 log l1 cv:          chair -- scores: 0.45 +- 0.13, time 5.71s
______________________________________________________________________
    log l2:       scissors -- scores: 0.51 +- 0.15, time 1.38s
    log l2:           face -- scores: 0.55 +- 0.19, time 1.24s
    log l2:            cat -- scores: 0.54 +- 0.20, time 1.41s
    log l2:           shoe -- scores: 0.51 +- 0.17, time 1.22s
    log l2:          house -- scores: 0.63 +- 0.22, time 1.11s
    log l2:   scrambledpix -- scores: 0.78 +- 0.18, time 1.21s
    log l2:         bottle -- scores: 0.43 +- 0.20, time 1.29s
    log l2:          chair -- scores: 0.54 +- 0.17, time 1.19s
______________________________________________________________________
 log l2 cv:       scissors -- scores: 0.50 +- 0.15, time 9.94s
 log l2 cv:           face -- scores: 0.56 +- 0.20, time 10.28s
 log l2 cv:            cat -- scores: 0.55 +- 0.20, time 11.43s
 log l2 cv:           shoe -- scores: 0.51 +- 0.17, time 11.06s
 log l2 cv:          house -- scores: 0.64 +- 0.22, time 10.05s
 log l2 cv:   scrambledpix -- scores: 0.76 +- 0.18, time 9.32s
 log l2 cv:         bottle -- scores: 0.44 +- 0.21, time 10.87s
 log l2 cv:          chair -- scores: 0.53 +- 0.17, time 11.03s
______________________________________________________________________
     ridge:       scissors -- scores: 0.52 +- 0.22, time 0.09s
     ridge:           face -- scores: 0.62 +- 0.18, time 0.07s
     ridge:            cat -- scores: 0.43 +- 0.18, time 0.07s
     ridge:           shoe -- scores: 0.48 +- 0.21, time 0.07s
     ridge:          house -- scores: 0.90 +- 0.09, time 0.07s
     ridge:   scrambledpix -- scores: 0.79 +- 0.17, time 0.07s
     ridge:         bottle -- scores: 0.41 +- 0.19, time 0.07s
     ridge:          chair -- scores: 0.50 +- 0.21, time 0.07s
______________________________________________________________________
  ridge cv:       scissors -- scores: 0.56 +- 0.26, time 0.72s
  ridge cv:           face -- scores: 0.70 +- 0.13, time 0.71s
  ridge cv:            cat -- scores: 0.54 +- 0.22, time 0.79s
  ridge cv:           shoe -- scores: 0.49 +- 0.18, time 0.72s
  ridge cv:          house -- scores: 0.90 +- 0.13, time 0.80s
  ridge cv:   scrambledpix -- scores: 0.85 +- 0.11, time 0.86s
  ridge cv:         bottle -- scores: 0.46 +- 0.24, time 0.71s
  ridge cv:          chair -- scores: 0.51 +- 0.21, time 0.68s

Then we make a rudimentary diagram

import matplotlib.pyplot as plt
plt.figure()

tick_position = np.arange(len(categories))
plt.xticks(tick_position, categories, rotation=45)

for color, classifier_name in zip(
        ['b', 'c', 'm', 'g', 'y', 'k', '.5', 'r', '#ffaaaa'],
        sorted(classifiers)):
    score_means = [classifiers_scores[classifier_name][category].mean()
                   for category in categories]
    plt.bar(tick_position, score_means, label=classifier_name,
            width=.11, color=color)
    tick_position = tick_position + .09

plt.ylabel('Classification accurancy (f1 score)')
plt.xlabel('Visual stimuli category')
plt.ylim(ymin=0)
plt.legend(loc='lower center', ncol=3)
plt.title(
    'Category-specific classification accuracy for different classifiers')
plt.tight_layout()
../../_images/sphx_glr_plot_haxby_different_estimators_001.png

Finally, w plot the face vs house map for the different classifiers

# Use the average EPI as a background
from nilearn import image
mean_epi_img = image.mean_img(func_filename)

# Restrict the decoding to face vs house
condition_mask = stimuli.isin(['face', 'house'])
masked_timecourses = masked_timecourses[
    condition_mask[task_mask]]
stimuli = (stimuli[condition_mask] == 'face')
# Transform the stimuli to binary values
stimuli.astype(np.int)

from nilearn.plotting import plot_stat_map, show

for classifier_name, classifier in sorted(classifiers.items()):
    classifier.fit(masked_timecourses, stimuli)

    if hasattr(classifier, 'coef_'):
        weights = classifier.coef_[0]
    elif hasattr(classifier, 'best_estimator_'):
        weights = classifier.best_estimator_.coef_[0]
    else:
        continue
    weight_img = masker.inverse_transform(weights)
    weight_map = weight_img.get_data()
    threshold = np.max(np.abs(weight_map)) * 1e-3
    plot_stat_map(weight_img, bg_img=mean_epi_img,
                  display_mode='z', cut_coords=[-15],
                  threshold=threshold,
                  title='%s: face vs house' % classifier_name)

show()
  • ../../_images/sphx_glr_plot_haxby_different_estimators_002.png
  • ../../_images/sphx_glr_plot_haxby_different_estimators_003.png
  • ../../_images/sphx_glr_plot_haxby_different_estimators_004.png
  • ../../_images/sphx_glr_plot_haxby_different_estimators_005.png
  • ../../_images/sphx_glr_plot_haxby_different_estimators_006.png
  • ../../_images/sphx_glr_plot_haxby_different_estimators_007.png
  • ../../_images/sphx_glr_plot_haxby_different_estimators_008.png
  • ../../_images/sphx_glr_plot_haxby_different_estimators_009.png
  • ../../_images/sphx_glr_plot_haxby_different_estimators_010.png

Total running time of the script: ( 5 minutes 2.013 seconds)

Generated by Sphinx-Gallery