import logging
import os
from typing import Union, Tuple, List, Optional, Dict, Sequence, Callable
from functools import partial
from tqdm.auto import tqdm
import numpy as np
from scipy.signal import resample, welch, correlate
from scipy.interpolate import interp1d
from scipy.spatial.distance import pdist, squareform
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from hdf5storage import loadmat
from naplib import preprocessing, logger, Data
from naplib.io import load_tdt, load_nwb, load_edf, load, load_wav_dir
from naplib.features import auditory_spectrogram
from naplib.preprocessing import make_contact_rereference_arr
from .alignment import align_stimulus_to_recording
ACCEPTED_DATA_TYPES = ['edf', 'tdt', 'nwb', 'pkl']
BUFFER_TIME = 2 # seconds of buffer in addition to befaft so that filtering doesn't produce edge effects
[docs]
def process_ieeg(
data_path: str,
alignment_dir: str,
/,
*,
stim_order: Optional[Union[str, Sequence[str]]]=None,
stim_dirs: Optional[Dict[str, str]]=None,
data_type: str='infer',
time_range: Union[float, Tuple[float, float]]=0,
elec_inds: Optional[Union[np.ndarray, Sequence[int]]]=None,
elec_names: Optional[Union[str, Sequence[str]]]=None,
rereference_grid: Optional[Union[np.ndarray, str]]=None,
rereference_method: str='avg',
store_reference: bool=False,
aud_channel: Union[str, int]='infer',
aud_channel_infer_method: str='crosscorr',
bands: Union[str, List[str], List[np.ndarray], List[float], np.ndarray]=['highgamma'],
phase_amp: str='amp',
befaft: Union[List, np.ndarray]=[1, 1],
intermediate_fs: Optional[int]=600,
final_fs: int=100,
alignment_kwargs: dict={},
line_noise_kwargs: dict={},
store_sounds: bool=False,
store_all_wav: bool=False,
aud_fn: Optional[Union[Callable, Dict[str, Callable]]]=auditory_spectrogram,
aud_kwargs: Optional[dict]=None,
n_jobs: int=1,
):
"""
Process raw iEEG data.
Parameters
----------
data_path : str path-like
String specifying data directory (for TDT) or file path for raw data file
alignment_dir : str path-like
Directory containing a set of stimulus waveforms as .wav files for alignment. This will be called 'aud'.
stim_order : Optional[Union[str, Sequence[str]]] path-like or sequence of strings, defaults to ``alignment_dir``
If a sequence of strings, must contain the order of the stimuli names corresponding to the names of the .wav files in
``alignment_dir``. If a file, must be either a StimOrder.mat file, or StimOrder.txt file containing the order of the
stimuli names as lines in the file. If a directory, the directory must contain such a file. If None, will search for
such a file within ``alignment_dir``.
stim_dirs : Optional[Dict[str, str]], defaults to ``alignment_dir``
If not provided, alignment_dir is assumed to contain the stimulus as well.
If provided, can be used to specify additional paths from which to load sounds which will be converted to the
chosen spectrotemporal features. This dict should have keys which are the name for the stimuli and values
which are the path to the stimulus directory of wav files. The files within this must have the same names
as those within ``stim_dir``. E.g. {'aud': './stimuli', 'aud_spk1': './stimuli_spk1', 'aud_spk2': './stimuli_spk2'}
data_type : str, default='infer'
One of {'edf', 'tdt', 'nwb', 'pkl', 'infer'}. The data type of the raw neural data to load.
time_range : float or (float, float), default=0
If a single float, the amount of time in seconds to skip at the start of the recording. If a 2-tuple `(start, end)`,
the time range of the recording to read between.
elec_inds : Optional[Union[np.ndarray, Sequence[int]]], default=None
If not None, the sorted indices of the data recording channels to keep. Important to note that this filtering is done
prior to manual setting of elec_names and rereferencing, so it might affect their results.
elec_names : Optional[Union[str, Sequence[str]]] path-like or sequence of strings, default=None
Electrode labels for all data channels read from ``data_path``. Should either be the path to a text file where
each line is the label of an electrode contact, or a list of strings where each element is the label of an electrode
contact. In both cases, the number of labels provided should match the number of data channels in ``data_path``. If
None, the labels included in the data file will be used.
rereference_grid : Optional[Union[np.ndarray, str]], default=None
If not None, then data are re-referenced based on this referencing scheme. If a numpy array, then
should specify categorical groupings of which electrodes to be grouped together for re-referencing,
and must be the same length as the number of electrodes in the raw data. If 'array', electrodes on
the same electrode array will be grouped together (e.g., RT1, RT2, RT3). If 'subject', all electrodes
will fall in the same group, which is equivalent to an NxN matrix of ones.
rereference_method : Optional[str], default='avg'
If provided, must specify a method for common rereferencing, either 'avg' (average), 'pca' (PCA),
or 'med' (median). Only used if ``rereference_grid`` is not None.
store_reference : bool, default=False
If True, include the reference which was subtracted from each channel in the output Data.
aud_channel : Union[str, int], default='infer',
If an int, specifies the index of the wav channel loaded from the raw recording which
should be used for alignment. If 'infer', then this is inferred.
aud_channel_infer_method : str, default='crosscorr'
Method for inferring aud channel used for alignment, either 'crosscorr','spectrum', or 'interactive'.
'crosscorr' computes cross correlation between stimulus waveform and each wav channel and selects
maximum. 'spectrum' compares the power spectra of each wav channel to that of the stimulus and chooses
the maximum (which is not very robust when using certain alignment stimuli like triggers).
'interactive' prints the name of each wav channel and asks the user to specify which one
should be used for alignment. This is only an option when labels_wav are present, which is
only for some data types (like edf).
bands : Union[str, list[str], list[np.ndarray], list[float], np.ndarray], default=['highgamma']
Frequency bands, specified as either strings or array-likes of length 2 giving the lower
and upper bounds. For example, [[8, 13], np.array([30, 70]), 'highgamma'] is equivalent
to ['theta', 'gamma', [30, 70]]. Or, can use 'raw' to specify raw neural data. Keep in mind,
this will still be resampled according to the ``final_fs`` parameter.
phase_amp : str, default='amp'
Whether to save the phase, amplitude, or both, of each extracted frequency band. Options
are {'phase', 'amp', 'both'}.
befaft : Union[List, np.ndarray], default=[1,1]
Extra time (in sec.) to store from the neural data before the start of and after the end of each
stimulus.
intermediate_fs : Optional[int], default=600
If provided downsamples the loaded raw neural data to this sampling rate before further
preprocessing. If this is greater than the raw sampling rate, no resampling is done.
final_fs : int, default=100
Final sampling rate for neural data and spectrograms.
alignment_kwargs : dict, default={}
If provided, will be passed to naplib.naplab.align_stimulus_to_recording to override keyword arguments.
line_noise_kwargs : dict, default={}
Dict of kwargs to naplib.preprocessing.filter_line_noise
store_sounds : bool, default=False
If True, store raw sound wave for each stimulus in stim_dirs in the output Data.
store_all_wav : bool, default=False
If True, store all recorded wav channels that were stored by the neural recording hardware. This may include
any other signals that were hooked up at the same time, such as EKG, triggers, etc.
aud_fn : optional callable or dict, default=naplib.features.auditory_spectrogram
Function(s) to be applied to each stimulus sound. If None, no audio transforms will be computed. By default,
`naplib.features.auditory_spectrogram` will be used to compute an auditory spectrogram. If a callable `f`,
the function `f` will be applied to each stimulus audio and should have signature
`(x: NDArray, sr: float, **kwargs) -> NDArray`, where `x` is 1-D audio signal with shape (in_samples,) and
`sr` is the sampling rate of the audio. The returned tensor should have shape (n_samples, n_features). If a
dictionary, the keys should be strings and will be used in field names of the output Data object, and the
values should be callable.
aud_kwargs : optional dict, default=None
Optional dictionary of extra arguments to be passed to `aud_fn`. Only used when `aud_fn` is a single callable.
If `aud_kwargs` is not None and `aud_fn` is not a single callable, an error will be raised.
n_jobs : int, default=1
Number of CPU cores to use for the parallelizable processes. Higher number of jobs also uses higher memory,
so there might even be a negative effect when working with large datasets.
Returns
-------
data : nl.Data
Data object containing all requested fields after preprocessing.
"""
# # Check aud_fn
aud_fn = _prep_aud_fn(aud_fn, aud_kwargs)
# # infer data type
if data_type is None or data_type not in ACCEPTED_DATA_TYPES:
data_type, data_path = _infer_data_type(data_path)
logger.info(f'Inferred data type to be {data_type} from the data directory')
if len(befaft) != 2:
raise ValueError(f'befaft must be a list or array of length 2.')
if isinstance(time_range, (int, float)):
t_start, t_end = time_range, 0
elif isinstance(time_range, tuple) and len(time_range) == 2:
t_start, t_end = time_range
else:
raise ValueError('time_range should be a float or a 2-tuple of floats')
# # load data and aud channels
if data_type == 'tdt':
logger.info(f'Loading tdt data...')
raw_data = load_tdt(data_path, t1=t_start, t2=t_end)
elif data_type == 'nwb':
logger.info(f'Loading nwb data...')
if not data_path.endswith(('.nwb', '.NWB')):
raise ValueError(f'data_type is nwb but data_path is not a nwb file: {data_path}')
raw_data = load_nwb(data_path)
elif data_type == 'edf':
logger.info(f'Loading edf data...')
if not data_path.endswith(('.edf', '.EDF')):
raise ValueError(f'data_type is edf but data_path is not an edf file: {data_path}')
raw_data = load_edf(data_path, t1=t_start, t2=t_end)
elif data_type == 'pkl':
if not data_path.endswith(('.pkl', '.p')):
raise ValueError(f'data_type is pkl but data_path is not a pkl file: {data_path}')
logger.info(f'Loading pkl data...')
raw_data = load(data_path)
if not isinstance(raw_data, dict) or ('data' not in raw_data or 'data_f' not in raw_data or 'wav' not in raw_data or 'wav_f' not in raw_data):
raise ValueError(f'pkl data is not formatted correctly. Must be a pickled dict containing "data", "data_f", "wav", "wav_f" keys at least')
if store_all_wav and 'labels_wav' not in raw_data:
raise ValueError('store_all_wav is True, but to store wav channels in final output there must be the key "labels_wav" in the pickled data.')
if 'labels_data' in raw_data:
raw_data['labels_data'] = np.asarray(raw_data['labels_data'])
else:
raise ValueError(f'Invalid data_type parameter. Must be one of {ACCEPTED_DATA_TYPES}')
# # check if any data skipped
t_skip = raw_data.get('t_skip', 0)
# # filter electrodes
if elec_inds:
elec_inds = np.asarray(elec_inds, dtype=int)
# make sure array is strictly increasing
for i in range(len(elec_inds)-1):
if elec_inds[i] >= elec_inds[i+1]:
raise ValueError('elec_inds must be strictly increasing sequence of ints')
raw_data['data'] = raw_data['data'][:, elec_inds]
raw_data['labels_data'] = raw_data['labels_data'][elec_inds]
# # set electrode labels
if isinstance(elec_names, str):
elec_names = _load_elec_names(elec_names)
if elec_names:
if len(elec_names) != raw_data['data'].shape[1]:
raise ValueError('List of electrode labels should have same size as number of data channels')
if 'labels_data' in raw_data:
logger.warning('Overriding original electrode labels with user-specified values')
raw_data['labels_data'] = np.asarray(elec_names)
# # load StimOrder
logger.info('Loading StimOrder...')
if stim_order is None:
stim_order = _load_stim_order(alignment_dir)
elif isinstance(stim_order, str):
stim_order = _load_stim_order(stim_order)
# # load stimuli files
logger.info('Loading stimuli...')
stim_data = load_wav_dir(alignment_dir, rescale=True, subset=set(stim_order))
if stim_dirs is not None:
extra_stim_data = {k: load_wav_dir(stim_dir2, rescale=True, subset=set(stim_order)) for k, stim_dir2 in stim_dirs.items()}
else:
extra_stim_data = {'aud': stim_data}
# # figure out which channel is used for alignment
if aud_channel == 'infer':
logger.info(f'Inferring alignment channel from wav channels...')
alignment_wav, alignment_ch = _infer_aud_channel(raw_data['wav'],
raw_data['wav_f'],
raw_data.get('labels_wav', None),
list(stim_data.values()),
method=aud_channel_infer_method,
debug=logger.isEnabledFor(logging.DEBUG))
logger.info(f'Inferred alignment channel is {alignment_ch}.')
else:
if raw_data['wav'].ndim > 1:
if not isinstance(aud_channel, int):
raise TypeError(f'Invalid aud_channel argument. Must either be "infer" or an int specifying'
f' an index of the raw audio channels to use, but got {aud_channel}')
alignment_wav = raw_data['wav'][:,aud_channel]
alignment_ch = aud_channel
else:
alignment_wav = raw_data['wav']
alignment_ch = aud_channel
# # perform alignment
alignment_times, alignment_confidence = align_stimulus_to_recording(
alignment_wav, raw_data['wav_f'], stim_data, stim_order, **alignment_kwargs
)
# truncate data around earliest and lastest time that we need
earliest_time = max(alignment_times[0][0] - befaft[0] - BUFFER_TIME, 0)
latest_time = alignment_times[-1][1] + befaft[1] + BUFFER_TIME
if befaft[0] > alignment_times[0][0]:
raise ValueError(f"Not enough data to use befaft[0]={befaft[0]}. First stimulus aligned to {alignment_times[0][0]} sec")
if befaft[1] > raw_data['data'].shape[0] / raw_data['data_f'] - alignment_times[-1][1]:
raise ValueError(
f"Not enough data to use befaft[1]={befaft[1]}. Last stimulus alignment ends at {alignment_times[-1][1]} "
f"sec but only have {raw_data['data'].shape[0] / raw_data['data_f']} sec of data"
)
alignment_times = np.asarray(alignment_times) - earliest_time # shift times back since we are going to truncate the data
earliest_sample, latest_sample = (int(raw_data['data_f'] * t) for t in (earliest_time, latest_time))
raw_data['data'] = raw_data['data'][earliest_sample:latest_sample]
earliest_sample, latest_sample = (int(raw_data['wav_f'] * t) for t in (earliest_time, latest_time))
raw_data['wav'] = raw_data['wav'][earliest_sample:latest_sample]
# # resample to intermediate_fs Hz
if intermediate_fs is not None and final_fs <= intermediate_fs < raw_data['data_f']:
new_len = int(intermediate_fs / raw_data['data_f'] * raw_data['data'].shape[0])
logger.info(f'Resampling data to {intermediate_fs} Hz')
channels = range(raw_data['data'].shape[1])
for ch in tqdm(channels) if logger.isEnabledFor(logging.INFO) else channels:
raw_data['data'][:new_len, ch] = resample(raw_data['data'][:, ch], new_len)
raw_data['data'] = raw_data['data'][:new_len]
raw_data['data_f'] = intermediate_fs
# # Make a copy if arrays are views
if raw_data['data'].base is not None:
raw_data['data'] = raw_data['data'].copy()
if raw_data['wav'].base is not None:
raw_data['wav'] = raw_data['wav'].copy()
# # append befaft zeros to the stims which were not used for alignment as well as the one which was (if it's in the dict too)
for stim_data_dict_ in extra_stim_data.values():
for wavname_, wavdata_ in stim_data_dict_.items():
bef_zeros = int(round(wavdata_[0] * befaft[0]))
aft_zeros = int(round(wavdata_[0] * befaft[1]))
if wavdata_[1].ndim == 1:
stim_data_dict_[wavname_] = wavdata_[0], np.pad(wavdata_[1], (bef_zeros,aft_zeros))
else:
stim_data_dict_[wavname_] = wavdata_[0], np.pad(wavdata_[1], ((bef_zeros,aft_zeros), (0,0)))
# # preprocessing
# # common referencing
if not isinstance(rereference_grid, str):
pass
elif rereference_grid == 'array':
if 'labels_data' not in raw_data:
raise ValueError('Implicit array-based rereferencing not allowed when electrode labels are not specified')
rereference_grid = make_contact_rereference_arr(raw_data['labels_data'])
elif rereference_grid == 'subject':
rereference_grid = np.ones((raw_data['data'].shape[1],) * 2, dtype=int)
else:
raise ValueError(f'Unknown string rereference_grid mode: {rereference_grid}')
if rereference_grid is not None:
logger.info(f'Performing common rereferencing using "{rereference_method}" method...')
if store_reference:
rereferenced_data, reference_to_store = preprocessing.rereference(rereference_grid, field=[raw_data['data']], method=rereference_method, return_reference=True)
else:
rereferenced_data = preprocessing.rereference(rereference_grid, field=[raw_data['data']], method=rereference_method, return_reference=False)
reference_to_store = None
raw_data['data'] = rereferenced_data[0]
else:
reference_to_store = None
# filter line noise (after this, raw_data['data'] is a list of length 1)
logger.info('Filtering line noise...')
raw_data['data'] = preprocessing.filter_line_noise(field=[raw_data['data']],
fs=raw_data['data_f'],
in_place=True,
**line_noise_kwargs)
# # Cut raw data up into blocks based on alignment
logger.info('Chunking responses based on alignment...')
data_by_trials_raw, effective_buffer_times = _split_data_on_alignment(
Data({'raw': raw_data['data']}), raw_data['data_f'], alignment_times, befaft, buffer_time=BUFFER_TIME
)
# # extract frequency bands
if 'raw' in bands:
include_raw = True
bands = [bb for bb in bands if bb != 'raw']
else:
include_raw = False
Wn = _infer_freq_bands(bands) # get frequency bands from string names
bandnames = []
for band, wn_ in zip(bands, Wn):
if isinstance(band, str):
bandnames.append(band)
else:
bandnames.append(str(wn_))
if len(Wn) > 0:
logger.info(f'Extracting frequency bands: {Wn} ...')
data_by_trials = preprocessing.phase_amplitude_extract(field=data_by_trials_raw['raw'],
fs=raw_data['data_f'],
Wn=Wn, bandnames=bandnames,
fs_out=final_fs,
n_jobs=n_jobs)
logger.info(f'Storing response bands of interest...')
# only keep amplitude or phase if that's what the user specified
if phase_amp == 'amp':
fields_to_keep = [xx for xx in data_by_trials.fields if ' amp' in xx]
elif phase_amp == 'phase':
fields_to_keep = [xx for xx in data_by_trials.fields if ' phase' in xx]
else:
fields_to_keep = data_by_trials.fields
data_by_trials = data_by_trials[fields_to_keep]
if include_raw:
data_by_trials['raw'] = data_by_trials_raw['raw']
else:
# if no other frequency bands, then default to output raw
data_by_trials = data_by_trials_raw
desired_lens = [round(final_fs / raw_data['data_f'] * len(xx)) for xx in data_by_trials_raw['raw']]
if 'raw' in data_by_trials.fields:
data_by_trials['raw'] = [resample(xx, d_len, axis=0) for xx, d_len in zip(data_by_trials_raw['raw'], desired_lens)]
if reference_to_store is not None:
reference_to_store, _ = _split_data_on_alignment(Data({'ref': reference_to_store}), raw_data['data_f'], alignment_times, befaft, buffer_time=BUFFER_TIME)
data_by_trials['reference'] = [resample(xx, d_len, axis=0) for xx, d_len in zip(reference_to_store['ref'], desired_lens)]
data_by_trials = _remove_buffer_time(data_by_trials, final_fs, effective_buffer_times)
if store_all_wav:
logger.info('Chunking wav channels based on alignment...')
wav_data_chunks, _ = _split_data_on_alignment(Data({'wav': [raw_data['wav']]}), raw_data['wav_f'], alignment_times, befaft, buffer_time=0)
# final output dict to be made into naplib.Data object
alignment_times = np.asarray(alignment_times)
final_output = {'name': stim_order,
'alignment_start': list(alignment_times[:,0] + earliest_time + t_skip),
'alignment_end': list(alignment_times[:,1] + earliest_time + t_skip),
'alignment_confidence': alignment_confidence,
'dataf': [final_fs for _ in stim_order],
'befaft': [befaft for _ in stim_order]}
# extract spectrograms
if aud_fn:
logger.info(f'Computing auditory spectrogram for each stimulus set in stim_dirs ...')
# mapping from name (like 'aud') to list of spectrograms
for k, stim_data_dict in extra_stim_data.items():
for name, fn in aud_fn.items():
final_output[f'{k} {name}' if name else k] = _transform_stims(
stim_data_dict, stim_order, final_fs, fn,
)
if store_sounds:
for k, stim_data_dict in extra_stim_data.items():
final_output[f'{k} sound'] = [stim_data_dict[stim_name][1] for stim_name in stim_order]
final_output[f'{k} soundf'] = [stim_data_dict[stim_name][0] for stim_name in stim_order]
del extra_stim_data
for fieldname in data_by_trials.fields:
final_output[fieldname] = data_by_trials[fieldname]
if store_all_wav:
final_output['wavf'] = [raw_data['wav_f'] for _ in stim_order]
for ww, wav_ch_name in enumerate(raw_data['labels_wav']):
final_output[wav_ch_name] = [xx[:,ww] for xx in wav_data_chunks['wav']]
# # Put output Data all together
final_output = Data(final_output)
final_output.set_info({
'channel_labels': raw_data.get('labels_data', None),
'rereference_grid': rereference_grid,
'data_type': data_type,
**raw_data.get('info', {})
})
logger.info('All done!')
return final_output
def _load_elec_names(elec_names_path: str) -> List[str]:
"""
Load txt file containg list of electrode labels, one per line, returning it as a list of strings.
Empty lines in file will be skipped, so empty labels are not possible in this file.
Parameter
---------
elec_names_path : str
Path to file
Returns
-------
elec_names : List[str]
Stimulus order as a list of stimulus names
"""
with open(elec_names_path, 'r') as infile:
lines = [x.strip() for x in infile.readlines() if not x.isspace()]
return lines
def _infer_aud_channel(wav_data: np.ndarray, wav_fs: int, wav_labels: Sequence[str],
stim_data: List[Tuple[float, np.ndarray]],
method: str='crosscorr', min_freq=20, debug=False):
"""
Infer which recorded wav channel matches the stimulus waveforms provided.
Parameters
----------
wav_data : np.ndarray, shape (time, channels)
Loaded wav channels from the recording system. Should be of shape (time, channels)
wav_fs : int
Sampling rate of wav data.
wav_labels : Sequence[str]
Name for each wav channel.
stim_data : List[Tuple[float, np.ndarray]]
List of tuples containing sampling rate and stimuli sounds/trigger waveforms,
each of shape (time, ) or (time, 2). If stereo, the left channel will be used.
method : str, default='spectrum'
Method for inferring correct channel from wav_data. Options are 'spectrum', 'envelope',
'crosscorr', or 'interactive'.
'crosscorr' computes cross correlation between stimulus waveform and each wav channel and selects
maximum. 'spectrum' compares the power spectra of each wav channel to that of the stimulus and chooses
the maximum (which is not very robust when using certain alignment stimuli like triggers).
'interactive' plots each wav channel and asks the user to specify which one should be used for alignment.
min_freq : float, default=20
Only used if method='spectrum'. Minimum frequency to include when calculating correlation
between spectrums.
debug : bool, default=False
If True, plots the spectrum of each channel and the spectrum of the stimulus.
Returns
-------
alignment_wav : np.ndarray
The channel from wav_data which matches the stimuli given. Will have shape (time, )
alignment_index : int
Index from wav_data channels which was picked as the alignment channel
"""
if wav_data.ndim == 1:
return wav_data, wav_fs
if wav_data.shape[1] == 1:
return wav_data[:,0], wav_fs
assert isinstance(stim_data, list)
if method == 'interactive':
if wav_labels is None:
raise ValueError('Interactive mode only supported when wav_labels is available.')
print(f'These are the available channels: {", ".join(wav_labels)}.')
ch_idx = None
while ch_idx is None:
pick = input('Which is the audio channel? ').strip()
for i, s in enumerate(wav_labels):
if pick == s:
ch_idx = i
break
return wav_data[:, ch_idx], ch_idx
elif method == 'spectrum':
fs0 = stim_data[0][0]
concat_stims = []
for i, (fs, stim_waveform) in enumerate(stim_data):
if fs != fs0:
raise ValueError(f'Sampling rates are not all the same. First stimulus has sampling rate of'
f' {fs0} Hz, and stimulus {i} has sampling rate of {fs}')
concat_stims.append(stim_waveform)
concat_stims = np.concatenate(concat_stims, axis=0)
# select left channel of stimuli only
if concat_stims.ndim > 1:
concat_stims = concat_stims[:,0][:,np.newaxis]
# compute spectrum of wav data and stimuli
f1, px1 = welch(wav_data, fs=wav_fs, axis=0)
f2, px2 = welch(concat_stims, fs=fs0, axis=0)
if wav_fs > fs0:
# downsample px1 and move to range of f2
new_px1 = []
for ii in range(px1.shape[1]):
interp = interp1d(f1, px1[:,ii])
new_px1.append(interp(f2))
px1 = np.vstack(new_px1).T # back to shape (freqs, channels)
shared_f = f1
elif fs0 > wav_fs:
# downsample px2 and move to range of f1
interp = interp1d(f2, px2.squeeze())
px2 = interp(f1)[:,np.newaxis]
shared_f = f1
else:
shared_f = f1
good_freqs = shared_f >= min_freq
shared_f = shared_f[good_freqs]
px1 = px1[good_freqs]
px2 = px2[good_freqs]
if px2.ndim == 1:
px2 = px2[:,np.newaxis]
cat_px = np.concatenate([px1, px2], axis=1)
dists = 1.0-pdist(cat_px.T, metric='correlation')
dists = squareform(dists)
best_ch_idx = np.nanargmax(dists[:,-1])
if debug:
plt.figure(figsize=(8,6))
plt.title('Spectrums of Stimulus and Wav Channels')
plt.plot(shared_f, px2/px2.max(), color='k', label='Stimulus')
for jj in range(px1.shape[1]):
plt.plot(shared_f, px1[:,jj]/px1[:,jj].max(), label='Ch {}: crr={:.3f}'.format(jj, dists[jj,-1]))
plt.legend()
plt.show()
return wav_data[:, best_ch_idx], best_ch_idx
elif method == 'crosscorr' or method == 'xcorr':
# Find longest stimulus for more robust inference
longest_stim = np.argmax(len(s)/f for f, s in stim_data)
stim_fs, stim_data = stim_data[longest_stim]
desired_len = int(wav_fs / stim_fs * len(stim_data))
if desired_len != len(stim_data):
stim_data = resample(stim_data, desired_len, axis=0)
if stim_data.ndim > 1 and stim_data.shape[1] > 1:
logger.warning('Performing alignment with stereo audio stimuli is not recommended.'
' It is recommended to use mono-channel audio for alignment, and any'
' additional stimuli (including stereo audio) desired in the final '
' Data object can be specified as extra stimulus directories.')
scores = []
for c in range(wav_data.shape[1]):
if stim_data.ndim == 1 or stim_data.shape[1] == 1:
pos = np.nanargmax(correlate(wav_data[:, c], stim_data.squeeze(), 'valid'))
score = _pearsonr(wav_data[pos:pos+len(stim_data), c], stim_data.squeeze())
else:
pos_left = np.nanargmax(correlate(wav_data[:, c], stim_data[:,0], 'valid'))
score_left = _pearsonr(wav_data[pos_left:pos_left+len(stim_data), c], stim_data[:,0])
pos_right = np.nanargmax(correlate(wav_data[:, c], stim_data[:,1], 'valid'))
score_right = _pearsonr(wav_data[pos_right:pos_right+len(stim_data), c], stim_data[:,1])
score = np.nanmax([score_left, score_right])
scores.append(score)
if debug:
logger.debug(f'Alignment xcorr scores: {", ".join(str(s) for s in scores)}')
best_ch_idx = np.nanargmax(scores)
return wav_data[:,best_ch_idx], best_ch_idx
else:
raise ValueError(f'Unsupported method argument: {method}')
def _pearsonr(x, y):
import warnings
with warnings.catch_warnings():
warnings.filterwarnings('ignore', module='scipy.stats')
warnings.filterwarnings('ignore', module='scipy.stats')
return pearsonr(x, y)[0]
def _infer_freq_bands(
bands: Union[str, List[str], List[np.ndarray], List[float], np.ndarray]
) -> List[List[Union[float, int]]]:
"""
Parameters
----------
bands : Union[str, List[str], List[np.ndarray], List[float], np.ndarray]
Bands to translate into lower and upper frequency ranges. Allowed strings
are 'theta', 'alpha', 'gamma', 'highgamma'.
Returns
-------
band_bounds : List[List[Union[float, int]]]
List of bands, each band specified as a list of length 2 containing the lower
and upper bound of the frequency band.
"""
FREQUENCY_BANDS = {'theta': [4,8], 'alpha': [8, 13],
'gamma': [30, 70], 'highgamma': [70, 150]}
new_bands = []
if isinstance(bands, list):
if len(bands) == 2 and all([isinstance(x, float) or isinstance(x, int) for x in bands]):
# just a list of length 2, so these are lower and upper bounds
new_bands.append(bands)
else:
for band in bands:
if isinstance(band, str):
if band not in FREQUENCY_BANDS:
raise ValueError(f'Invalid band name. If a string, must be one of {FREQUENCY_BANDS}, but got {band}')
else:
new_bands.append(FREQUENCY_BANDS[band])
else:
if (not isinstance(band, np.ndarray) and not isinstance(band, list)) or len(band) != 2:
raise ValueError(f'each band must be a list of numpy array of length 2, but found band {band}')
new_bands.append(list(band))
elif isinstance(bands, str):
new_bands.append(FREQUENCY_BANDS[bands])
else:
if not isinstance(bands, np.ndarray):
raise TypeError('bands is neither a string, nor a list, nor a numpy array.')
elif bands.squeeze().shape != (2,):
raise ValueError(f'bands must only have two elements if given as numpy array but got bands of shape {bands.shape}')
else:
new_bands.append(list(bands))
return new_bands
def _infer_data_type(data_path: str):
"""
Infer which data loader to use based on what files are in the directory or the file extension
if a single file is given.
Parameters
----------
data_path : str, path-like
Directory or file containing data
Returns
-------
data_type : str
One of 'tdt', 'edf', 'nwb', 'pkl'.
file_path : str
Path to file or directory which contains the data.
"""
if data_path.endswith(('.edf', '.EDF')):
return 'edf', data_path
if data_path.endswith(('.nwb', '.NWB')):
return 'nwb', data_path
if data_path.endswith(('.pkl', '.p')):
return 'pkl', data_path
files_in_dir = [x for x in os.listdir(data_path) if '.' in x and x[0]!='.']
file_suffixes = [x.split('.')[-1] for x in files_in_dir]
if 'sev' in file_suffixes or 'tev' in file_suffixes:
return 'tdt', data_path
elif 'edf' in file_suffixes or 'EDF' in file_suffixes:
if file_suffixes.count('edf') + file_suffixes.count('EDF') > 1:
raise ValueError(f'Inferred edf format, but more than one edf file found in given directory.')
if 'edf' in file_suffixes:
return 'edf', os.path.join(data_path, files_in_dir[file_suffixes.index('edf')])
else:
return 'edf', os.path.join(data_path, files_in_dir[file_suffixes.index('EDF')])
elif 'nwb' in file_suffixes or 'NWB' in file_suffixes:
if file_suffixes.count('nwb') + file_suffixes.count('NWB') > 1:
raise ValueError(f'Inferred nwb format, but more than one nwb file found in given directory.')
if 'nwb' in file_suffixes:
return 'nwb', os.path.join(data_path, files_in_dir[file_suffixes.index('nwb')])
else:
return 'nwb', os.path.join(data_path, files_in_dir[file_suffixes.index('NWB')])
elif 'pkl' in file_suffixes or 'p' in file_suffixes:
if file_suffixes.count('pkl') + file_suffixes.count('p') > 1:
raise ValueError(f'Inferred pkl format, but more than one pickle file found in given directory.')
if 'pkl' in file_suffixes:
return 'pkl', os.path.join(data_path, files_in_dir[file_suffixes.index('pkl')])
else:
return 'pkl', os.path.join(data_path, files_in_dir[file_suffixes.index('p')])
raise ValueError(f'Could not infer data type from directory.')
def _prep_aud_fn(aud_fn: Optional[Union[Callable, Dict]], aud_kwargs: Optional[Dict]) -> Dict:
if aud_kwargs is not None and not isinstance(aud_fn, Callable):
raise ValueError('aud_kwargs only supported when aud_fn is a single callable')
if aud_fn is None:
return {}
if isinstance(aud_fn, Callable):
return {'': partial(aud_fn, **aud_kwargs) if aud_kwargs else aud_fn}
if isinstance(aud_fn, dict):
for k, f in aud_fn.items():
if not isinstance(k, str):
raise ValueError('aud_fn dictionary keys should be of type string')
if not isinstance(f, Callable):
raise ValueError("aud_fn dictionary values should be callable")
return aud_fn
raise ValueError("aud_fn should be either None, callable, or dict")
def _transform_stims(stim_data_dict, stim_order, fs_out, aud_fn):
"""
Transform each stimulus in `stim_data_dict` using the provided function `aud_fn`,
then return a list of the resulting tensors ordered by stim_order (stimuli can repeat in stim_order).
Parameters
----------
stim_data_dict : dict
Dictionary mapping from string name of stimulus to a tuple of (fs, wav_data)
stim_order : list of strings
List of names in desired order. Each name must be a key that exists in stims_data_dict,
but they can repeat.
fs_out : int
Sampling rate of output spectrograms
aud_fn : Callable
Function for computing spectrogram from waveform. The function should have signature
``(x: NDArray, sr: float, **kwargs) -> NDArray`` where x has shape (in_samples,), sr
is the sampling rate of x, and the return value has shape (out_samples, freq_bins).
Returns
-------
specs : list of np.ndarray
List of same length as stim_order containing the spectrogram for each stimulus
"""
if logger.isEnabledFor(logging.INFO):
stim_data_dict = tqdm(stim_data_dict.items(), total=len(stim_data_dict))
else:
stim_data_dict = stim_data_dict.items()
spec_dict = {}
for k, (fs, sig) in stim_data_dict:
if k not in stim_order:
continue # skip this stimulus if don't need it for stim_order
if sig.ndim == 2:
specs = []
for ch in range(sig.shape[1]):
specs.append(aud_fn(sig[:,ch], fs)[:,:,np.newaxis])
spec = np.concatenate(specs, axis=-1)
elif sig.ndim == 1:
spec = aud_fn(sig, fs)
else:
raise ValueError(f'Waveform to compute spectrogram for is more than 2 dimensional. Got {sig.ndim} dimensions')
# resample to fs_out
desired_len = int(fs_out / fs * len(sig))
if desired_len != spec.shape[0]:
logger.warning(
f"Resampling transform '{aud_fn}' of stimulus '{k}' from {len(spec)} to {desired_len} samples"
)
spec = resample(spec, desired_len, axis=0)
spec_dict[k] = spec
output = [spec_dict[stim_name] for stim_name in stim_order]
return output
def _load_stim_order(stim_order_path: str) -> List[str]:
"""
Load either StimOrder.mat or StimOrder.txt file and return stimulus order as list of names
Parameter
---------
stim_order_path : str
Path to file or directory containing file.
Returns
-------
stim_order : List[str]
Stimulus order as a list of stimulus names
"""
if stim_order_path.endswith('.mat') or stim_order_path.endswith('.txt'):
good_filepath = stim_order_path
else:
file_names = os.listdir(stim_order_path)
found_file = False
for fname in file_names:
if fname == 'StimOrder.mat' or fname == 'StimOrder.txt':
found_file = True
good_filepath = os.path.join(stim_order_path, fname)
break
if not found_file:
raise FileNotFoundError(f'Tried to find stim order file but could not find "StimOrder.mat" or "StimOrder.txt"'
f' within directory "{stim_order_path}". Must specify either a direct path to one of those files, or'
f' a directory containing at least one of them.')
if good_filepath.endswith('.mat'):
stim_order_dict = loadmat(good_filepath)
if 'StimOrder' not in stim_order_dict:
raise ValueError(f'Successfully StimOrder.mat but it did not contain a variable named "StimOrder"')
stim_order = [x.item() for x in stim_order_dict['StimOrder'].squeeze()]
return stim_order
else:
with open(good_filepath, 'r') as infile:
lines = [x.strip() for x in infile.readlines() if not x.isspace()]
return lines
def _split_data_on_alignment(data, fs, alignment_startstops, befaft, buffer_time=1):
"""
data must be length 1, but can have as many fields as needed, each of which is a numpy array (time, ...)
"""
output = {}
effective_buffer_times = []
for field in data.fields:
split_field = []
duration = len(data[0][field]) / fs
for align_region in alignment_startstops:
effective_buffer_times.append([buffer_time, buffer_time])
start_time = align_region[0] - befaft[0]
if start_time < buffer_time:
effective_buffer_times[-1][0] = start_time
start_time = 0
else:
start_time -= buffer_time
end_time = align_region[1] + befaft[1]
if duration - end_time < buffer_time:
effective_buffer_times[-1][1] = duration - end_time
end_time = duration
else:
end_time += buffer_time
start_sample = int(round(start_time * fs))
end_sample = int(round(end_time * fs))
split_field.append(data[0][field][start_sample:end_sample])
output[field] = split_field
return Data(output), effective_buffer_times
def _remove_buffer_time(data, fs, buffer_times):
for trial in range(len(data)):
buffer_samples = [round(fs*t) for t in buffer_times[trial]]
for field in data.fields:
start_sample = buffer_samples[0]
end_sample = len(data[trial][field])-buffer_samples[1]
data[trial][field] = data[trial][field][start_sample:end_sample]
return data