"""Wrapper for using the Scikit-Learn API with Keras models. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import copy import types import numpy as np from .. import losses from ..utils.np_utils import to_categorical from ..utils.generic_utils import has_arg from ..utils.generic_utils import to_list from ..models import Sequential class BaseWrapper(object): """Base class for the Keras scikit-learn wrapper. Warning: This class should not be used directly. Use descendant classes instead. # Arguments build_fn: callable function or class instance **sk_params: model parameters & fitting parameters The `build_fn` should construct, compile and return a Keras model, which will then be used to fit/predict. One of the following three values could be passed to `build_fn`: 1. A function 2. An instance of a class that implements the `__call__` method 3. None. This means you implement a class that inherits from either `KerasClassifier` or `KerasRegressor`. The `__call__` method of the present class will then be treated as the default `build_fn`. `sk_params` takes both model parameters and fitting parameters. Legal model parameters are the arguments of `build_fn`. Note that like all other estimators in scikit-learn, `build_fn` should provide default values for its arguments, so that you could create the estimator without passing any values to `sk_params`. `sk_params` could also accept parameters for calling `fit`, `predict`, `predict_proba`, and `score` methods (e.g., `epochs`, `batch_size`). fitting (predicting) parameters are selected in the following order: 1. Values passed to the dictionary arguments of `fit`, `predict`, `predict_proba`, and `score` methods 2. Values passed to `sk_params` 3. The default values of the `keras.models.Sequential` `fit`, `predict`, `predict_proba` and `score` methods When using scikit-learn's `grid_search` API, legal tunable parameters are those you could pass to `sk_params`, including fitting parameters. In other words, you could use `grid_search` to search for the best `batch_size` or `epochs` as well as the model parameters. """ def __init__(self, build_fn=None, **sk_params): self.build_fn = build_fn self.sk_params = sk_params self.check_params(sk_params) def check_params(self, params): """Checks for user typos in `params`. # Arguments params: dictionary; the parameters to be checked # Raises ValueError: if any member of `params` is not a valid argument. """ legal_params_fns = [Sequential.fit, Sequential.predict, Sequential.predict_classes, Sequential.evaluate] if self.build_fn is None: legal_params_fns.append(self.__call__) elif (not isinstance(self.build_fn, types.FunctionType) and not isinstance(self.build_fn, types.MethodType)): legal_params_fns.append(self.build_fn.__call__) else: legal_params_fns.append(self.build_fn) for params_name in params: for fn in legal_params_fns: if has_arg(fn, params_name): break else: if params_name != 'nb_epoch': raise ValueError( '{} is not a legal parameter'.format(params_name)) def get_params(self, **params): """Gets parameters for this estimator. # Arguments **params: ignored (exists for API compatibility). # Returns Dictionary of parameter names mapped to their values. """ res = copy.deepcopy(self.sk_params) res.update({'build_fn': self.build_fn}) return res def set_params(self, **params): """Sets the parameters of this estimator. # Arguments **params: Dictionary of parameter names mapped to their values. # Returns self """ self.check_params(params) self.sk_params.update(params) return self def fit(self, x, y, **kwargs): """Constructs a new model with `build_fn` & fit the model to `(x, y)`. # Arguments x : array-like, shape `(n_samples, n_features)` Training samples where `n_samples` is the number of samples and `n_features` is the number of features. y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)` True labels for `x`. **kwargs: dictionary arguments Legal arguments are the arguments of `Sequential.fit` # Returns history : object details about the training history at each epoch. """ if self.build_fn is None: self.model = self.__call__(**self.filter_sk_params(self.__call__)) elif (not isinstance(self.build_fn, types.FunctionType) and not isinstance(self.build_fn, types.MethodType)): self.model = self.build_fn( **self.filter_sk_params(self.build_fn.__call__)) else: self.model = self.build_fn(**self.filter_sk_params(self.build_fn)) if (losses.is_categorical_crossentropy(self.model.loss) and len(y.shape) != 2): y = to_categorical(y) fit_args = copy.deepcopy(self.filter_sk_params(Sequential.fit)) fit_args.update(kwargs) history = self.model.fit(x, y, **fit_args) return history def filter_sk_params(self, fn, override=None): """Filters `sk_params` and returns those in `fn`'s arguments. # Arguments fn : arbitrary function override: dictionary, values to override `sk_params` # Returns res : dictionary containing variables in both `sk_params` and `fn`'s arguments. """ override = override or {} res = {} for name, value in self.sk_params.items(): if has_arg(fn, name): res.update({name: value}) res.update(override) return res class KerasClassifier(BaseWrapper): """Implementation of the scikit-learn classifier API for Keras. """ def fit(self, x, y, sample_weight=None, **kwargs): """Constructs a new model with `build_fn` & fit the model to `(x, y)`. # Arguments x : array-like, shape `(n_samples, n_features)` Training samples where `n_samples` is the number of samples and `n_features` is the number of features. y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)` True labels for `x`. **kwargs: dictionary arguments Legal arguments are the arguments of `Sequential.fit` # Returns history : object details about the training history at each epoch. # Raises ValueError: In case of invalid shape for `y` argument. """ y = np.array(y) if len(y.shape) == 2 and y.shape[1] > 1: self.classes_ = np.arange(y.shape[1]) elif (len(y.shape) == 2 and y.shape[1] == 1) or len(y.shape) == 1: self.classes_ = np.unique(y) y = np.searchsorted(self.classes_, y) else: raise ValueError('Invalid shape for y: ' + str(y.shape)) self.n_classes_ = len(self.classes_) if sample_weight is not None: kwargs['sample_weight'] = sample_weight return super(KerasClassifier, self).fit(x, y, **kwargs) def predict(self, x, **kwargs): """Returns the class predictions for the given test data. # Arguments x: array-like, shape `(n_samples, n_features)` Test samples where `n_samples` is the number of samples and `n_features` is the number of features. **kwargs: dictionary arguments Legal arguments are the arguments of `Sequential.predict_classes`. # Returns preds: array-like, shape `(n_samples,)` Class predictions. """ kwargs = self.filter_sk_params(Sequential.predict_classes, kwargs) proba = self.model.predict(x, **kwargs) if proba.shape[-1] > 1: classes = proba.argmax(axis=-1) else: classes = (proba > 0.5).astype('int32') return self.classes_[classes] def predict_proba(self, x, **kwargs): """Returns class probability estimates for the given test data. # Arguments x: array-like, shape `(n_samples, n_features)` Test samples where `n_samples` is the number of samples and `n_features` is the number of features. **kwargs: dictionary arguments Legal arguments are the arguments of `Sequential.predict_classes`. # Returns proba: array-like, shape `(n_samples, n_outputs)` Class probability estimates. In the case of binary classification, to match the scikit-learn API, will return an array of shape `(n_samples, 2)` (instead of `(n_sample, 1)` as in Keras). """ kwargs = self.filter_sk_params(Sequential.predict_proba, kwargs) probs = self.model.predict(x, **kwargs) # check if binary classification if probs.shape[1] == 1: # first column is probability of class 0 and second is of class 1 probs = np.hstack([1 - probs, probs]) return probs def score(self, x, y, **kwargs): """Returns the mean accuracy on the given test data and labels. # Arguments x: array-like, shape `(n_samples, n_features)` Test samples where `n_samples` is the number of samples and `n_features` is the number of features. y: array-like, shape `(n_samples,)` or `(n_samples, n_outputs)` True labels for `x`. **kwargs: dictionary arguments Legal arguments are the arguments of `Sequential.evaluate`. # Returns score: float Mean accuracy of predictions on `x` wrt. `y`. # Raises ValueError: If the underlying model isn't configured to compute accuracy. You should pass `metrics=["accuracy"]` to the `.compile()` method of the model. """ y = np.searchsorted(self.classes_, y) kwargs = self.filter_sk_params(Sequential.evaluate, kwargs) loss_name = self.model.loss if hasattr(loss_name, '__name__'): loss_name = loss_name.__name__ if loss_name == 'categorical_crossentropy' and len(y.shape) != 2: y = to_categorical(y) outputs = self.model.evaluate(x, y, **kwargs) outputs = to_list(outputs) for name, output in zip(self.model.metrics_names, outputs): if name in ['accuracy', 'acc']: return output raise ValueError('The model is not configured to compute accuracy. ' 'You should pass `metrics=["accuracy"]` to ' 'the `model.compile()` method.') class KerasRegressor(BaseWrapper): """Implementation of the scikit-learn regressor API for Keras. """ def predict(self, x, **kwargs): """Returns predictions for the given test data. # Arguments x: array-like, shape `(n_samples, n_features)` Test samples where `n_samples` is the number of samples and `n_features` is the number of features. **kwargs: dictionary arguments Legal arguments are the arguments of `Sequential.predict`. # Returns preds: array-like, shape `(n_samples,)` Predictions. """ kwargs = self.filter_sk_params(Sequential.predict, kwargs) preds = np.array(self.model.predict(x, **kwargs)) if preds.shape[-1] == 1: return np.squeeze(preds, axis=-1) return preds def score(self, x, y, **kwargs): """Returns the mean loss on the given test data and labels. # Arguments x: array-like, shape `(n_samples, n_features)` Test samples where `n_samples` is the number of samples and `n_features` is the number of features. y: array-like, shape `(n_samples,)` True labels for `x`. **kwargs: dictionary arguments Legal arguments are the arguments of `Sequential.evaluate`. # Returns score: float Mean accuracy of predictions on `x` wrt. `y`. """ kwargs = self.filter_sk_params(Sequential.evaluate, kwargs) loss = self.model.evaluate(x, y, **kwargs) if isinstance(loss, list): return -loss[0] return -loss