Source code for naplib.data

from collections.abc import Iterable
from itertools import groupby
from copy import deepcopy
import numpy as np
from mne import Info


STRICT_FIELDS_REQUIRED = set(['name','sound','soundf','resp','dataf'])


[docs] class Data(Iterable): ''' Class for storing electrode response data along with task- and electrode-related variables. Under the hood, it consists of a list of dictionaries where each dictionary contains all the data for one trial. Parameters ---------- data : dict or list of dictionaries If a list of dicts, then the Nth dictionary defines the Nth trial data, typically corresponding to the Nth stimulus. Each dictionary must contain the same keys if passed in as a list of multiple trials. If a single dict, then the keys specify the field names and the values specify the data across trials, and each value must be a list of length num_trials. strict : bool, default=False If True, requires strict adherance to the following standards: 1) Each trial must contain at least the following fields: ['name','sound','soundf','resp','dataf'] 2) Each trial must contain the exact same set of fields Attributes ---------- fields : list of strings Field names in the data. mne_info : mne.Info instance Measurement info object containing things like electrode locations (only if Data is created from reading a file format like BIDS). info : dict Extra info (not trial-specific) that a user wants to store using Data.set_info or Data.update_info Notes ----- .. figure:: /figures/naplib-python-data-figure.png :width: 500px :alt: Data object layout :align: center The above is a depiction of the type of data that might be stored in an instance of the Data class. Any number of trials can be stored with any number and type of fields. Responses and information do not need to be aligned or the same length/shape across trials. Information can be retrieved from the Data instance by trial, by field, or by a combination of the two, using bracket indexing and slicing, as described below. Examples -------- >>> import naplib as nl >>> import numpy as np >>> # Constructing Data from a dict, where keys give fields and values are lists of trial data >>> names = ['trial1', 'trial2'] # trial names >>> responses = [np.arange(6).reshape(3,2), np.arange(6,12).reshape(3,2)] # neural responses >>> dataf = [100, 100] # sampling rate >>> data = nl.Data({'name': names, 'resp': responses, 'dataf': dataf}) >>> data Data object of 2 trials containing 3 fields [{"name": <class 'str'>, "resp": <class 'numpy.ndarray'>, "dataf": <class 'int'>} {"name": <class 'str'>, "resp": <class 'numpy.ndarray'>, "dataf": <class 'int'>}] >>> # Accessing a single trial returns a view of one trial as a dict >>> data[1] {'name': 'trial2', 'resp': array([[ 6, 7], [ 8, 9], [10, 11]]), 'dataf': 100} >>> # Accessing a single field returns a shallow copy of that field as a list over trials >>> data['name'] ['trial1', 'trial2'] >>> # Accessing multiple fields returns a shallow copy of those fields within a Data instance >>> data[['resp', 'dataf']] Data object of 2 trials containing 2 fields [{"resp": <class 'numpy.ndarray'>, "dataf": <class 'int'>} {"resp": <class 'numpy.ndarray'>, "dataf": <class 'int'>}] >>> # Accessing multiple trials with slice indexing returns a shallow copy of those >>> # trials in a Data instance >>> data[:2] Data object of 2 trials containing 3 fields [{"name": <class 'str'>, "resp": <class 'numpy.ndarray'>, "dataf": <class 'int'>} {"name": <class 'str'>, "resp": <class 'numpy.ndarray'>, "dataf": <class 'int'>}] ''' def __init__(self, data, strict=False): if isinstance(data, dict): lengths = [] for k, v in data.items(): if not isinstance(v, list): raise TypeError( f'When creating a Data from a dict, each value in the ' f'dict must be a list, but for key "{k}" got type {type(v)}' ) lengths.append(len(v)) if not _all_equal_list(lengths): raise ValueError( f'When creating a Data from a dict, each value in the ' f'dict must be a list of the same length, but got different lengths: {lengths}' ) data = [dict(zip(data, vals)) for vals in zip(*data.values())] self._data = data elif isinstance(data, list): self._data = data else: raise TypeError(f'Can only create Data from a dict or a list ' f'of dicts, but found type {type(data)}') self._strict = strict self._validate_new_out_data(data, strict=strict) self._info = {} self._mne_info = None def set_field(self, fielddata, fieldname): ''' Set the information in a single field with a new list of data. Parameters ---------- fielddata : list List containing data to add to each trial for this field. Must be same length as this object fieldname : string Name of field to add. If this field already exists in the Data then the current field will be overwritten. ''' if not isinstance(fielddata, list): raise TypeError(f'Input data must be a list, but found {type(fielddata)}') if len(fielddata) != len(self): raise Exception(f'Length of field ({len(fielddata)}) is not equal to length of this Data ({len(self)})') for i, trial in enumerate(self.data): trial[fieldname] = fielddata[i] def delete_field(self, fieldname): ''' Remove an entire field from the Data object. Parameters ---------- fieldname : string Name of field to delete. ''' if not isinstance(fieldname, str): raise TypeError(f'Field must be a str, but found {type(fieldname)}') for trial in self.data: del trial[fieldname] def get_field(self, fieldname): ''' Return all trials for a single field. Parameters ---------- fieldname : string Which field to get. Returns ------- field : list List containing each trial's value for this field. ''' try: return [tmp[fieldname] for tmp in self.data] except KeyError: raise KeyError(f'Invalid fieldname: {fieldname} not found in data.')
[docs] def __getitem__(self, index): ''' Get either a trial or a field using bracket indexing. See notes and examples below for details. Parameters ---------- index : int or string Which trial to get, or which field to get. Returns ------- data : dict, list, or Data If index is an integer, returns the corresponding trial as a dict. If index is a string, returns the corresponding field, and if it is a list of strings, returns those fields together in a new Data object. Note ---- Depending on how indexing and slicing is performed, the data returned may be a view of the underlying data, or it may be a shallow copy of the underlying data. The only way to get a view of the underlying data, meaning editing that view will also edit the underlying data, is to use integer indexing to get a single trial from the Data instance, which returns a dict for that trial. Indexing by field name first and indexing with slicing both return shallow copies of the data. For example, if we want to set the 'name' field in the first trial of our Data, we can only do it in the following way: >>> data[0]['name'] = 'trial0' Whereas following code will NOT actually change the underlying trial name: >>> data['name'][0] = 'trial0' Examples -------- >>> # Get a specific trial based on its index, which returns a dict >>> from naplib import Data >>> trial_data = [{'name': 'Zero', 'trial': 0, 'resp': [[0,1],[2,3]]}, ... {'name': 'One', 'trial': 1, 'resp': [[4,5],[6,7]]}] >>> data = Data(trial_data, strict=False) >>> data[0] {'name': 'Zero', 'trial': 0, 'resp': [[0, 1], [2, 3]]} >>> # Get a slice of trials, which returns a shallow copy of those trials in a Data instance >>> out[:2] Data object of 2 trials containing 3 fields [{"name": <class 'str'>, "trial": <class 'int'>, "resp": <class 'list'>} {"name": <class 'str'>, "trial": <class 'int'>, "resp": <class 'list'>}] >>> # Get a list of trial data from a single field, which returns a shallow copy of >>> # each trial in that field >>> data['name'] ['TrialZero', 'TrialOne'] >>> # Get a single trial with integer indexing, returning a view of that trial as a dict >>> data[0] {'name': 'TrialZero', 'trial': 0, 'resp': [[0, 1], [2, 3]]} >>> # Get multiple fields using a list of fieldnames, which returns a shallow copy of that >>> # subset of fields >>> data[['resp','trial']] Data object of 2 trials containing 2 fields [{"resp": <class 'list'>, "trial": <class 'int'>} {"resp": <class 'list'>, "trial": <class 'int'>}] ''' if isinstance(index, slice): return Data(self.data[index], strict=self._strict) if isinstance(index, str): return self.get_field(index) if isinstance(index, (list, np.ndarray)): if isinstance(index[0], str): return Data([{field:x[field] for field in index} for x in self], strict=False) else: return Data([self.data[i] for i in index], strict=False) try: return self.data[index] except IndexError: raise IndexError(f'Index invalid for this data. Tried to index {index} but length is {len(self)}.')
[docs] def __setitem__(self, index, data): ''' Set a specific trial or set of trials, or set a specific field, using bracket indexing. See examples below for details. Parameters ---------- index : int or string Which trial to set, or which field to set. If an integer, must be <= the length of the Data, since you can only set a currently existing trial or append to the end, but you cannot set a trial that is beyond that. data : dict or list of data Either trial data to add or field data to add. If index is an integer, dictionary should contain all the same fields as current Data object. Examples -------- >>> # Set a field of a Data >>> from naplib import Data >>> trial_data = [{'name': 'Zero', 'trial': 0, 'resp': [[0,1],[2,3]]}, ... {'name': 'One', 'trial': 1, 'resp': [[4,5],[6,7]]}] >>> data = Data(trial_data) >>> data[0] = {'name': 'New', 'trial': 10, 'resp': [[0,-1],[-2,-3]]} >>> data[0] {'name': 'New', 'trial': 10, 'resp': [[0, -1], [-2, -3]]} >>> # We can also set all values of a field across trials >>> data['name'] = ['TrialZero','TrialOne'] >>> data['name'] ['TrialZero', 'TrialOne'] ''' if isinstance(index, str): self.set_field(data, index) else: if index > len(self): raise IndexError((f'Index is too large. Current data is length {len(self)} ' 'but tried to set index {index}. If you want to add to the end of the list ' 'of trials, use the Data.append() method.')) if index == len(self): self.append(data) else: self.data[index] = data
[docs] def __delitem__(self, index): ''' Delete a specific trial or set of trials, or delete a specific field, using bracket indexing. See examples below for details. Parameters ---------- index : int or string Which trial to delete, or which field to delete. If an integer, must be < the length of the Data, since you can only delete an existing trial Examples -------- >>> # Delete a field of a Data >>> from naplib import Data >>> trial_data = [{'name': 'Zero', 'trial': 0, 'resp': [[0,1],[2,3]]}, ... {'name': 'One', 'trial': 1, 'resp': [[4,5],[6,7]]}] >>> data = Data(trial_data) >>> del data[0] >>> data[0] {'name': 'One', 'trial': 1, 'resp': [[4, 5], [6, 7]]} >>> # We can also delete all values of a field across trials >>> trial_data = [{'name': 'Zero', 'trial': 0, 'resp': [[0,1],[2,3]]}, ... {'name': 'One', 'trial': 1, 'resp': [[4,5],[6,7]]}] >>> data = Data(trial_data) >>> del data['name'] >>> data[0] {'trial': 0, 'resp': [[0, 1], [2, 3]]} ''' if isinstance(index, str): self.delete_field(index) elif isinstance(index, int): if index >= len(self): raise IndexError((f'Index is too large. Current data is length {len(self)} ' 'but tried to delete index {index}. If you want to add to the end of the list ' 'of trials, use the Data.append() method.')) else: del self.data[index] else: raise TypeError(f'Found {type(index)} for index')
[docs] def append(self, trial_data, strict=None): ''' Append a single trial of data to the end of a Data. Parameters ---------- trial_data : dict Dictionary containing all the same fields as current Data object. strict : bool, default=self._strict If true, enforces that new data contains the exact same set of fields as the current Data. Default value is self._strict, which is set based on the input when creating a new Data from scratch with __init__() Raises ------ TypeError If input data is not a dict. ValueError If strict is `True` and the fields contained in the trial_data do not match the fields currently stored in the Data. Examples -------- >>> # Set a field of a Data >>> from naplib import Data >>> trial_data = [{'name': 'Zero', 'trial': 0, 'resp': [[0,1],[2,3]]}, ... {'name': 'One', 'trial': 1, 'resp': [[4,5],[6,7]]}] >>> data = Data(trial_data) >>> new_trial_data = {'name': 'Two', 'trial': 2, 'resp': [[8,9],[10,11]]} >>> data.append(new_trial_data) >>> len(data) 3 ''' if strict is None: strict = self._strict self._validate_new_out_data([trial_data], strict=strict) self.data.append(trial_data)
[docs] def set_info(self, info): ''' Set the info dict for this Data. If there is already data in the `info` attribute, it is replaced with this. Parameters ---------- info : dict Dictionary containing info to store in the Data's `info` attribute. ''' if not isinstance(info, dict): raise TypeError(f'info must be a dict but got {type(info)}') self._info = info
[docs] def update_info(self, info): ''' Add data from a dict to this object's `info` attribute. If there is already data in the `info` attribute, this new info is simply added. Keys which exist in the current `info` dict and also in this new dict will be replaced, while others will be kept. Parameters ---------- info : dict Dictionary containing info to add to the Data's `info` attribute. ''' self._info.update(info)
[docs] def set_mne_info(self, info): ''' Set the mne_info attribute, which contains measurement information. Parameters ---------- info : mne.Info instance Info to set. ''' if not isinstance(info, Info): raise TypeError(f'input info must be an instance of mne.Info, but got {type(info)}') self._mne_info = info
def __iter__(self): return (self[i] for i in range(len(self)))
[docs] def __len__(self): ''' Get the number of trials in the Data object with ``len(Data)``. Examples -------- >>> from naplib import Data >>> trial_data = [{'trial': 0, 'resp': [[0,1],[2,3]]}, {'trial': 1, 'resp': [[4,5],[6,7]]}] >>> data = Data(trial_data, strict=False) >>> len(data) 2 ''' return len(self.data)
def __repr__(self): return self.__str__() # until we can think of a better __repr__ def __str__(self): to_return = f'Data object of {len(self)} trials containing {len(self.fields)} fields\n[' to_print = 2 if len(self) > 3 else 3 for trial_idx, trial in enumerate(self[:to_print]): fieldnames = list(trial.keys()) to_return += '{' for f, fieldname in enumerate(fieldnames): to_return += f'"{fieldname}": {type(trial[fieldname])}' if f < len(fieldnames)-1: to_return += ', ' if trial_idx < len(self)-1: to_return += '}\n' else: to_return += '}' if to_print == 3: to_return += ']\n' elif to_print == 2: to_return += '\n...\n{' fieldnames = list(self[-1].keys()) for f, fieldname in enumerate(fieldnames): to_return += f'"{fieldname}": {type(self[-1][fieldname])}' if f < len(fieldnames)-1: to_return += ', ' to_return += '}]\n' return to_return def _validate_new_out_data(self, input_data, strict=True): first_trial_fields = set(self.fields) for trial in input_data: if not isinstance(trial, dict): raise TypeError(f'input data is not a list of dicts, found {type(trial)}') trial_fields = set(trial.keys()) if not trial_fields: raise ValueError('A trial should have at least one field.') if strict and trial_fields != first_trial_fields: raise ValueError('New data does not contain the same fields as the first trial.') if strict: for required_field in STRICT_FIELDS_REQUIRED: if required_field not in trial_fields: raise ValueError(f'For a "strict" Data object, the data does not contain the required field {required_field}.') @property def fields(self): '''List of strings containing names of all fields in this Data.''' return [k for k, _ in self._data[0].items()] if self._data else [] @property def data(self): '''List of dictionaries containing data for each stimulus response and all associated variables.''' return self._data @property def info(self): '''Dictionary which can be used to store metadata info which does not change over trials, such as subject, recording, or task information.''' return self._info @property def mne_info(self): ''' mne.Info instance which stores measurement information and can be used with mne's visualization functions. This is empty by default unless it is manually added or read in by a function like `naplib.io.load_bids`. ''' if self._mne_info is None: raise ValueError('No mne_info is available for this Data. This must ' 'be read in from external data or added manually to the Data.') return self._mne_info
[docs] def concat(data_list, axis=0, copy=True): ''' Concatenate Data objects across either trials or fields. This performs an inner join on the other dimension, meaning non-shared fields will be lost if concatenating over trials, and non-shared trials will be lost if concatenating over fields. If concatenating over fields and there are shared fields, then the field will only be taken from the first Data object in the input sequence and the rest will be ignored. Note: anything stored in the .info or .mne_info attributes of the objects will not be stored in the output. Parameters ---------- data : list or tuple of Data instances Sequence containing the different Data objects to concatenate. axis : int, defualt=0 To concantate over trials (default), axis should be 0. To concatenate over fields, axis should be 1. copy : bool, default=True Whether to deep copy each Data object before concatenating. Returns ------- data_merged : Data instance A Data instance of the two merged objects. Examples -------- >>> import naplib as nl >>> # First, try concatenating over trials from two different Data objects >>> d1 = nl.Data({'name': ['t1','t2'], 'resp': [[1,2],[3,4,5]], 'extra': ['ex1','ex2']}) >>> d2 = nl.Data({'name': ['t3','t4'], 'resp': [[6,7],[9,10]], 'extra': ['ex3','ex4']}) >>> d_concat = nl.concat((d1, d2)) >>> len(d_concat) 4 >>> d_concat.fields ['name', 'resp', 'extra'] >>> d_concat['name'] ['t1', 't2', 't3', 't4'] >>> d_concat['resp'] [[1, 2], [3, 4, 5], [6, 7], [9, 10]] >>> d_concat['extra'] ['ex1', 'ex2', 'ex3', 'ex4'] >>> # We can also concatenate over fields if we have two Data objects for the same trials >>> # Duplicate fields will only be kept from the first Data object that they appear in >>> d3 = nl.Data({'name': ['t1-1','t2-1'], 'resp': [[1,2],[3,4,5]]}) >>> d4 = nl.Data({'name': ['t1-2','t2-2'], 'meta_data': ['meta1', 'meta2']}) >>> d_concat = nl.concat((d3, d4), axis=1) >>> len(d_concat) 2 >>> d_concat.fields ['name', 'resp', 'meta_data'] >>> d_concat['name'] ['t1-1', 't2-1'] >>> d_concat['resp'] [[1, 2], [3, 4, 5]] >>> d_concat['meta_data'] ['meta1', 'meta2'] ''' if not isinstance(data_list, (list, tuple)): raise TypeError(f'data_list must be a list or tuple but got {type(data_list)}') if len(data_list) == 0: raise ValueError('need at least one Data object to concatenate') for out in data_list: if not isinstance(out, Data): raise TypeError(f'All inputs to data_list must be a Data instance but found {type(out)}') if len(data_list) == 1: return data_list[0] if axis == 0: field_set = set(data_list[0].fields) for data in data_list[1:]: field_set = field_set.intersection(set(data.fields)) field_set = [ff for ff in data_list[0].fields if ff in field_set] if copy: data_merged = deepcopy(data_list[0][field_set]) else: data_merged = data_list[0][field_set] for data in data_list[1:]: if copy: copied_data = deepcopy(data[field_set]) else: copied_data = data[field_set] for trial in copied_data: data_merged.append(trial, strict=False) elif axis == 1: if not all(len(data_list[0])==len(d) for d in data_list): raise ValueError('All Data objects must be same length if concatenating over fields (axis=1).') if copy: data_merged = deepcopy(data_list[0]) else: data_merged = data_list[0] for data in data_list[1:]: current_fields = data_merged.fields for field in data.fields: if field not in current_fields: data_merged[field] = data[field] else: raise ValueError(f'axis must be 0 or 1 but got {axis}') return data_merged
[docs] def join_fields(data_list, fieldname='resp', axis=-1, return_as_data=False): ''' Join trials from a field in multiple Data objects by zipping them together and concatenating each trial together. The field must be of type np.ndarray and concatenation is done with np.concatenate(). Parameters ---------- data : sequence of Data instances Sequence containing the different Data objects to join. fieldname : string, default='resp' Name of the field to concatenate from each Data object. For each trial in each Data instance, this field must be of type np.ndarray or something which can be input to np.concatenate(). axis : int, default = -1 Axis along which to concatenate each trial's data. The default corresponds to the channel dimension of the conventional 'resp' field of a Data object. return_as_data : bool, default=False If True, returns data as a Data object with a single field named fieldname. Returns ------- joined_data : list of np.ndarrays, or Data instance Joined data of same length as each of the Data objects containing concatenated data for each trial. Examples -------- >>> import naplib as nl >>> data1 = nl.Data({'resp': [np.array([0,1,2]).reshape(-1,1), np.array([3,4]).reshape(-1,1)]}) >>> data2 = nl.Data({'resp': [np.array([5,6,7]).reshape(-1,1), np.array([8,9]).reshape(-1,1)]}) >>> data1 [array([[0], [1], [2]]), array([[3], [4]])] >>> data2 [array([[5], [6], [7]]), array([[8], [9]])] >>> resp_joined = nl.join_fields((data1, data2)) >>> resp_joined [array([[0, 5], [1, 6], [2, 7]]), array([[3, 8], [4, 9]])] >>> resp_joined2 = nl.join_fields((data1, data2), axis=0) >>> resp_joined2 [array([[0], [1], [2], [5], [6], [7]]), array([[3], [4], [8], [9]])] ''' for out in data_list: if not isinstance(out, Data): raise TypeError(f'All inputs to data_list must be Data instance but found {type(out)}') field = out.get_field(fieldname) if not isinstance(field[0], np.ndarray): raise TypeError(f'Can only concatenate np.ndarrays, but found {type(field[0])} in this field') starting_fields = [out.get_field(fieldname) for out in data_list] # each one should be a list of np.arrays to_return = [] zipped_fields = list(zip(*starting_fields)) for field_set in zipped_fields: to_return.append(np.concatenate(field_set, axis=axis)) if return_as_data: return Data([dict([(fieldname, x)]) for x in to_return], strict=False) return to_return
def _all_equal_list(iterable): g = groupby(iterable) return next(g, True) and not next(g, False)