from itertools import chain
import numpy as np
from sklearn.model_selection._split import _BaseKFold
from sklearn.utils.validation import indexable
from sklearn.utils import _safe_indexing, check_random_state
from ..data import Data
[docs]
class KFold(_BaseKFold):
'''
KFold splitter which works on a naplib.Data object or a list-like sequence.
Parameters
----------
n_splits : int
Number of folds. Must be at least 2.
shuffle : bool, default=False
Whether to shuffle the data before splitting into batches.
Note that the samples within each split will not be shuffled.
random_state : int, RandomState instance or None, default=None
When `shuffle` is True, `random_state` affects the ordering of the
indices, which controls the randomness of each fold. Otherwise, this
parameter has no effect.
Pass an int for reproducible output across multiple function calls.
Examples
--------
>>> from naplib.model_selection import KFold
>>> list1 = [1,2,3] # this could be a field of a Data object, like data['resp']
>>> list2 = [5,6,7] # this could be another field of a Data object, like data['aud']
>>> kfold = KFold(3)
>>> for train_data, test_data, train_data2, test_data2 in kfold.split(list1, list2):
>>> print(train_data, test_data, train_data2, test_data2)
[2, 3] [1] [6, 7] [5]
[1, 3] [2] [5, 7] [6]
[1, 2] [3] [5, 6] [7]
'''
def __init__(self, n_splits, shuffle=False, random_state=None):
super().__init__(n_splits=n_splits, shuffle=shuffle, random_state=random_state)
def _iter_test_indices(self, X, y=None, groups=None):
n_samples = len(X)
indices = np.arange(n_samples)
if self.shuffle:
check_random_state(self.random_state).shuffle(indices)
n_splits = self.n_splits
fold_sizes = np.full(n_splits, n_samples // n_splits, dtype=int)
fold_sizes[: n_samples % n_splits] += 1
current = 0
for fold_size in fold_sizes:
start, stop = current, current + fold_size
yield indices[start:stop]
current = stop
[docs]
def split(self, *args):
"""Generate splits of the data.
Parameters
----------
*args : Data or list-like objects
Sets of data which will be split into train and test groups.
Yields
------
train : Data or list-like objects
The training set for that split.
test : Data or list-like objects
The testing set for that split.
"""
data = indexable(*args)
n_samples = len(data[0])
if self.n_splits > n_samples:
raise ValueError(
(
"Cannot have number of splits n_splits={0} greater"
" than the number of samples: n_samples={1}."
).format(self.n_splits, n_samples)
)
for train, test in super().split(data[0]):
tmp = list(
chain.from_iterable(
(_safe_indexing(a, train), _safe_indexing(a, test)) for a in data
)
)
for i, d in enumerate(data):
if isinstance(d, Data):
tmp[2*i] = Data(tmp[2*i], strict=False)
tmp[2*i+1] = Data(tmp[2*i+1], strict=False)
yield tmp