5.6. Running scikit-learn functions for more control on the analysis#
This section gives pointers to design your own decoding pipelines with scikit-learn. This builds on the didactic introduction to decoding.
This documentation gives links and additional definitions needed to work correctly with scikit-learn. For a full code example, please check out: Advanced decoding using scikit learn
5.6.1. Performing decoding with scikit-learn#
18.104.22.168. Using scikit-learn estimators#
You can easily import estimators from the scikit-learn machine-learning library,
those available in the
Decoder object and many others.
They all have the
For example you can directly import the versatile Support Vector Classifier (or SVC).
To learn more about the variety of classifiers available in scikit-learn, see the scikit-learn documentation on supervised learning.
22.214.171.124. Cross-validation with scikit-learn#
To perform cross-validation using a scikit-learn estimator, you should first
mask the data using a
nilearn.maskers.NiftiMasker: to extract
only the voxels inside the mask of interest, and transform 4D input fMRI
data to 2D arrays (shape (n_timepoints, n_voxels)) that estimators can work on.
This example shows how to use masking: Simple example of NiftiMasker use
Then use a specific function
that computes for you the score of your model for the different folds
You can change many parameters of the cross_validation here, for example:
use a different cross-validation scheme, for example
speed up the computation by using
n_jobs=-1, which will spread the computation equally across all processors.
use a different scoring function, as a keyword or imported from scikit-learn such as
126.96.36.199. Measuring the chance level#
Dummy estimators: The simplest way to measure prediction performance at chance is to use a “dummy” classifier:
Permutation testing: A more controlled way, but slower, is to do permutation testing on the labels, with
5.6.2. Going further with scikit-learn#
We have seen a very simple analysis with scikit-learn, but your can easily add intermediate processing steps if your analysis requires it. Some common examples are :
adding a feature selection step using scikit-learn pipelines
use any model available in scikit-learn (or compatible with) at any step
add more intermediate steps such as clustering
188.8.131.52. Decoding without a mask: Anova-SVM using scikit-learn#
We can also implement feature selection before decoding as a scikit-learn pipeline (
For this, we need to import the
sklearn.feature_selection module and use
sklearn.feature_selection.f_classif, a simple F-score based feature selection (a.k.a. Anova),
184.108.40.206. Using any other model in the pipeline#
Anova - SVM is a good baseline that will give reasonable results in common settings. However it may be interesting for you to explore the wide variety of supervised learning algorithms in the scikit-learn. These can readily replace the SVM in your pipeline and might be better fitted to some usecases as discussed in the previous section.
The feature selection step can also be tuned. For example we could use a more sophisticated scheme, such as Recursive Feature Elimination (RFE) or add some a clustering step before feature selection. This always amount to creating a pipeline that will link those steps together and apply a sensible cross-validation scheme to it. Scikit-learn usually takes care of the rest for us.
5.6.3. Setting estimator parameters#
Most estimators have parameters that can be set to optimize their performance. Importantly, this must be done via nested cross-validation.
Indeed, there is noise in the cross-validation score, and when we vary the parameter, the curve showing the score as a function of the parameter will have bumps and peaks due to this noise. These will not generalize to new data and chances are that the corresponding choice of parameter will not perform as well on new data.
With scikit-learn nested cross-validation is done via
sklearn.model_selection.GridSearchCV. It is unfortunately time
consuming, but the
n_jobs argument can spread the load on multiple