import warnings
import numpy as np
import pandas as pd
import scipy.cluster.hierarchy as shc
import sklearn
from sklearn.cluster import AgglomerativeClustering
import matplotlib.pyplot as plt
from scipy import signal as sig
import seaborn as sns
from packaging import version
from .. import logger
[docs]
def kde_plot(data, groupings=None, hist=True, alpha=0.2, bins=None, **kwargs):
"""
Plot kernel density estimate of distribution for data, along with histogram underneath.
Can plot multiple densities, one per grouping. See Examples below for a depiction.
Parameters
----------
data : list or array-like or list of np.ndarrays
Data to plot density of. If of shape (N_points,) and groupings is None, then it is assumed to
be a single distribution to plot. Otherwise, can be either an array of shape (N_points, M_groups),
or a list of length M_groups containing arrays of shape (N_points_i).
See ``groupings`` argument for more on how data should be formatted depending on the groupings desired.
groupings : list or array-like, optional
Grouping method for separating data into different distributions, or labels for those distributions.
If groupings is given when data is 1-dimensional, groupings should provide categorical labels
for each point in data and also be shape (N_points,). Alternatively,
can be a list or array-like of shape/length M_groups, with each element specifying a
label for each group/column in ``data``. You must specify groupings for the axis legend to be shown.
hist : bool, default=True
If True (default), plots a histogram underneath the kernel density estimate.
alpha : float, default=0.2
Alpha value for transparency of histogram. Ignored if ``hist=False``.
bins : int or sequence or str, default = :rc:`hist.bins`
Bins for histogram. Ignored if ``hist=False``.
**kwargs : kwargs
kwargs for seaborn.kdeplot. Cannot include 'data', 'x', 'y', or 'hue'. See below for
some examples of frequently used kwargs.
ax : matplotlib.axes.Axes, optional
Axes to plot on.
bw_method : string, scalar, or callable, optional
Method for determining the smoothing bandwidth to use; passed to scipy.stats.gaussian_kde.
Can be a single float to determine the bandwidth.
color : str or matplotlib color, or list of colors, optional
Color to use if only providing 1 grouping of data (e.g. if ``groupings=None``), or an iterable
of the color to use for each group.
Returns
-------
ax : matplotlib.axes.Axes
matplotlib axes containing the plot
Examples
--------
>>> from naplib.visualization import kde_plot
>>> import numpy as np
>>> rng = np.random.default_rng(1)
>>> rng = np.random.default_rng(1)
>>> data = rng.normal(size=(100,))
>>> data[50:] += 0.5 # shift the second half of the samples
>>> groupings = np.array(['G0'] * 100) # define grouping vector
>>> groupings[50:] = 'G1' # set a different label for the samples we shifted
>>> # plot the density for each group, as well as a histogram underneath each
>>> ax = kde_plot(data, groupings=groupings, bw_method=0.25, bins=15, color=['k','r'])
.. figure:: /figures/kdeplot1.png
:width: 400px
:alt: kde_plot figure
:align: center
>>> # plot the exact same figure from a list of arrays and grouping labels of same length
>>> data_list = [data[:50],data[50:]]
>>> kde_plot(data_list, groupings=['G0','G1'], bw_method=0.25, bins=15, color=['k','r'])
>>> # plot the exact same figure from a 2D numpy array
>>> data_mat = np.concatenate([data[:50,np.newaxis],data[50:,np.newaxis]], axis=1)
>>> kde_plot(data_mat, groupings=['G0','G1'], bw_method=0.25, bins=15, color=['k','r'])
>>> # if we don't pass in groupings but data is still a 2D array or a list,
>>> # then there just won't be a legend, but the plot will be the same
>>> kde_plot(data_mat, bw_method=0.25, bins=15, color=['k','r'])
"""
if 'ax' in kwargs:
ax = kwargs.pop('ax')
else:
ax = plt.gca()
if isinstance(data, list) and all([not isinstance(xx, np.ndarray) for xx in data]):
data = np.asarray(data)
original_groupings_none = groupings is None
if isinstance(data, np.ndarray):
if data.ndim == 1 or data.shape[1]==1:
if groupings is None:
groupings2 = np.zeros_like(data).astype('int') # all one group
else:
groupings2 = [str(x) for x in groupings]
if len(groupings2) != len(data):
raise ValueError(f'data and groupings must be same length, but got data'
f' with length {len(data)} and groupings with length {len(groupings)}')
df = pd.DataFrame.from_dict({'data': data, 'group': groupings2})
else: # multiple things to plot since ndim>1
assert data.ndim > 1
if groupings is None:
groupings = np.zeros_like(data) + np.arange(data.shape[1]) # array of shape (N_points, M_groups)
groupings2 = groupings.flatten('F').astype('int') # e.g. now [0,0,0,1,1,1,2,2,2]
elif len(groupings) == data.shape[1]:
groupings2 = []
for g in groupings:
groupings2 += [g] * len(data)
else:
raise TypeError(f'Invalid format for groupings when data is multidimensional numpy array.'
f' Must be a list of length data.shape[1]')
df = pd.DataFrame.from_dict({'data': data.flatten('F'), 'group': groupings2})
elif isinstance(data, list):
if not all([isinstance(xx, np.ndarray) for xx in data]):
raise TypeError(f'If data is a list, each element must be a numpy array')
if groupings is None:
groupings = [int(i) for i in range(len(data))]
if len(data) != len(groupings):
raise ValueError(f'groupings must be same length as data if data is given as list, '
f'but got data with length {len(data)}, groupings with length {len(groupings)}')
groupings_flat = []
for ii, d in enumerate(data):
for _ in d:
groupings_flat.append(groupings[ii])
df = pd.DataFrame.from_dict({'data': np.concatenate(data, axis=0), 'group': groupings_flat})
else:
raise TypeError(f'data must be either a np.ndarray, a list of scalars, or a list of np.ndarray, but got {type(data)}')
xlbl_before = ax.xaxis.get_label().get_text()
ylbl_before = ax.yaxis.get_label().get_text()
color = None
if 'color' in kwargs:
color = kwargs.pop('color')
if color is None:
color = [None for _ in np.unique(df['group'].values)]
elif color is not None and not isinstance(color, list):
color = [color]
if len(color) != len(np.unique(df['group'].values)):
num_unique = len(np.unique(df['group'].values))
raise ValueError(f'If specified, number of colors provided must match number of groups,'
f' but got {len(color)} colors and {num_unique} groups')
# loop through groups
color_cycle_default = plt.rcParams["axes.prop_cycle"].by_key()["color"]
len_cycle = len(color_cycle_default)
for i, grp in enumerate(sorted(np.unique(df['group'].values))):
if color[i] is None:
col = color_cycle_default[i%len_cycle] #next(ax._get_lines.prop_cycler)['color']
else:
col = color[i]
this_group_df = df.loc[df['group']==grp]
sns.kdeplot(data=this_group_df, ax=ax, x='data', color=col, label=grp, **kwargs)
# add histogram
if hist:
ax.hist(this_group_df['data'], bins=bins, color=col, density=True, alpha=alpha)
if not original_groupings_none:
ax.legend()
ax.set_xlabel(xlbl_before)
if ylbl_before == '':
ax.set_ylabel(ylbl_before)
return ax
[docs]
def shaded_error_plot(*args, ax=None, reduction='mean', err_method='stderr', color=None, alpha=0.4, plt_args={}, shade_args={}, nan_policy='omit'):
'''
Plot the average/median value at each time point and a shaded region indicating error or confidence
level above and below the line. See Examples below for a depiction.
Parameters
----------
x : array-like, shape (n_samples,), optional
*x* values are optional and default to ``range(len(y))``.
y : array-like, shape (n_samples, n_points)
Data to plot, providing the vertical coordinates. *y* values should be
two-dimensional, and statistics used to compute shaded region interval
are computed over the second dimension.
fmt : str, optional
A format string, e.g. 'ro' for red circles. See the matplotlib
`Axes.plot <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.plot.html>`_
Notes section for a full description of the format strings.
Format strings are just an abbreviation for quickly setting
basic line properties. All of these and more can also be
controlled by keyword arguments within color or plt_args.
This argument cannot be passed as keyword.
ax : plt.Axes instance, optional
Axes to use. If not specified, will use current axes.
reduction : str, default='mean'
Reduction method, either 'mean' or 'median'.
err_method : string or float, default='stderr'
The method to use to calculate error bars. If a string, one of ['stderr','std'].
If a float, defines the confidence interval desired. For example 0.95 specifies
a 95% confidence interval around the mean (i.e. the interval from the 2.5th percentile
to the 97.5th percentile). Note, if the data have significant outliers and reduction='mean'
then the confidence interval bounds might not surround the mean value line.
color : str, default=None
Color to plot line and shaded region. Defaults to next color in color cycle.
alpha : float, default=0.4
Shading alpha. Value between 0 and 1.
plt_args : dict, default={}
Dict of args to be passed to plt.plot(). e.g. {'linewidth': 2}, etc.
shade_args : dict, default={}
Dict of args to be passed to plt.fill_between(). e.g. {'alpha': 0.2}, etc.
nan_policy : string, default='omit'
One of ['omit','raise','propogate']. If 'omit', will ignore any nan in the
inputs, if 'raise', will raise a ValueError if nan is found in input, if
'propogate', do not do anything special with nan values.
Returns
-------
ax : matplotlib.axes.Axes
matplotlib axes containing the plot
Examples
--------
>>> from naplib.visualization import shaded_error_plot as sep
>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>> rng = np.random.default_rng(1)
>>> x, y = np.linspace(0, 1, 10), rng.normal(size=(10,5))
>>> fig, ax = plt.subplots(3,1)
>>> sep(y, ax=axes[0]) # plot mean of y, with shaded error regions
>>> sep(y, 'r--', ax=axes[1]) # same plot but color is red and line is dashed
>>> sep(x, y, ax=axes[2], err_method='std') # plot vs specific x values and use std. error
>>> plt.show()
.. figure:: /figures/shadederrorplot1.png
:width: 400px
:alt: shaded_error_plot figure
:align: center
Raises
------
ValueError
if nan found in input and ``nan_policy`` is 'raise'.
'''
if color is not None:
plt_args['color'] = color
shade_args['alpha'] = alpha
fmt = ''
x = None
y = None
if len(args) == 0:
raise ValueError('No data provided to plot.')
elif len(args) == 1:
y = args[0]
elif len(args) == 2:
if isinstance(args[1], str):
y, fmt = args
else:
x, y = args
elif len(args) == 3:
x, y, fmt = args
else:
raise ValueError(f'Too many args passed. Expected at most 3 (x, y, fmt)')
if not isinstance(fmt, str):
raise TypeError(f'fmt must be of type string but got {type(fmt)}')
if x is None:
x = np.arange(len(y))
if y.ndim == 1:
y = y[:,np.newaxis]
if nan_policy not in ['omit','raise','propogate']:
raise Exception(f"nan_policy must be one of ['omit','raise','propogate'], but found {nan_policy}")
if nan_policy == 'raise':
if np.any(np.isnan(x)) or np.any(np.isnan(y)):
raise ValueError('Found nan in input')
if ax is None:
ax = plt.gca()
if reduction == 'mean':
if nan_policy == 'omit':
reduction_func = np.nanmean
else:
reduction_func = np.mean
elif reduction == 'median':
if nan_policy == 'omit':
reduction_func = np.nanmedian
else:
reduction_func = np.median
else:
raise ValueError(f'reduction must be either "mean" or "median", but got {reduction}')
allowed_errors = ['stderr','std']
if isinstance(err_method, str):
if err_method not in allowed_errors:
raise ValueError(f'err_method is a string but is not one of {allowed_errors}, but rather {err_method}')
elif not isinstance(err_method, float):
raise ValueError(f'err_method must be either a string or a float, but got {err_method}')
elif isinstance(err_method, float) and (err_method <= 0 or err_method > 1.0):
raise ValueError(f'If err_method is a float then it must be in the range (0, 1]')
if nan_policy == 'omit':
y_mean = reduction_func(y, axis=1)
if err_method == 'stderr':
y_err = np.nanstd(y, axis=1) / np.sqrt(y.shape[1])
elif err_method == 'std':
y_err = np.nanstd(y, axis=1)
else:
alpha_level = 1.0 - err_method
y_err = [np.nanpercentile(y, 100*alpha_level/2., axis=1), np.nanpercentile(y, 100*(1-(alpha_level/2.)), axis=1)]
else: # propogate, since 'raise' has already been taken care of
y_mean = reduction_func(y, axis=1)
if err_method == 'stderr':
y_err = y.std(1) / np.sqrt(y.shape[1])
elif err_method == 'std':
y_err = y.std(1)
else:
alpha_level = 1.0 - err_method
y_err = [np.percentile(y, 10*alpha_level/2., axis=1), np.percentile(y, 10*(1-(alpha_level/2.)), axis=1)]
if fmt == '':
line_, = ax.plot(x, y_mean, **plt_args)
else:
line_, = ax.plot(x, y_mean, fmt, **plt_args)
color = line_.get_color()
shade_args['color'] = color
if isinstance(err_method, str):
ax.fill_between(x, y_mean-y_err, y_mean+y_err, **shade_args)
else:
ax.fill_between(x, y_err[0], y_err[1], **shade_args)
return ax
[docs]
def hierarchical_cluster_plot(data, axes=None, varnames=None, cmap='bwr', n_clusters=2, metric='euclidean', linkage='ward'):
'''
Perform hierarchical clustering and plot dendrogram and clustered values as an
image underneath. See Examples below for a depiction.
Parameters
----------
data : shape (n_samples, n_features)
Data to cluster and display.
axes : list of plt.Axes, length 2, optional
Array of length 2 containing matplotlib axes to plot on.
axes[0] will be for the dendrogram and axes[1] will be for the data. If not
specified, will create new axes in subplots.
varnames : list of strings, length must = n_features, default=None
Variable names which will be printed as yticklabels on the data plot
cmap : string, default='bwr'
colormap for the data plot
n_clusters : int, default=2
Number of clusters which will be used when computing cluster labels that are returned,
and also for coloring the dendrogram by cluster.
metric : str, default='euclidean'
Distance metric. See scipy.spatial.distance.pdist for valid metrics.
linkage : str, default='ward'
Linkage method. Must be one of 'single','complete','average','weighted','centroid',
'median', or 'ward'. Some linkage methods are only valid for certain distance
metrics.
Returns
-------
cluster_dict : dict
output from scipy.cluster.hierarchy.dendrogram
cluster_labels : np.ndarray
cluster labels from sklearn.cluster.AgglomerativeClustering, shape=(n_samples,)
fig : matplotlib figure
Figure where data was plotted. Only returned if axes were not passed in.
axes : np.ndarray of matplotlib.axes.Axes
Axes where data was plotted. Only returned if axes were not passed in.
Examples
--------
>>> from naplib.visualization import hierarchical_cluster_plot as hcp
>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>> rng = np.random.default_rng(10)
>>> x = rng.normal(size=(100,5))
>>> x[:,1] += rng.normal(loc=1, scale=3, size=(100,))
>>> x[:,2] += rng.normal(loc=-1, scale=3, size=(100,))
>>> varnames = ['var1','var2','var3','var4','var5']
>>> clust, labels, fig, axes = hcp(x, n_clusters=3, varnames=varnames)
.. figure:: /figures/hierarchicalclusterplot1.png
:width: 400px
:alt: hierarchical_cluster_plot figure
:align: center
'''
if axes is None:
return_axes = True
fig, axes = plt.subplots(2,1,figsize=(10, 7), gridspec_kw={'height_ratios': [1,2]})
else:
return_axes = False
Z = shc.linkage(data, method=linkage, metric=metric)
num_colors = -1
color_thresh_bounds = [0, 1]
# starting guess for color_thresh
if n_clusters == 1:
color_thresh = 1.1
elif n_clusters >= data.shape[0]:
color_thresh = 0
else:
color_thresh = 0.5
max_while_loop = 25
while_loops = 0
while (num_colors != n_clusters) and (while_loops < max_while_loop):
if while_loops > 0:
if num_colors < n_clusters: # threshold was too high
color_thresh_bounds[1] = color_thresh # it's the new upper bound
color_thresh = (color_thresh + color_thresh_bounds[0]) / 2
else: # threshold was too low
color_thresh_bounds[0] = color_thresh # it's the new lower bound
color_thresh = (color_thresh + color_thresh_bounds[1]) / 2
dend = shc.dendrogram(Z, no_plot=True, show_leaf_counts=False, get_leaves=True, no_labels=True, color_threshold=color_thresh*max(Z[:,2]))
num_colors = len(set(dend['leaves_color_list']))
while_loops += 1
if (num_colors != n_clusters):
logger.warning('Failed to identify the cut threshold to produce the correct number of colors in the dendrogram plot. The output will still be correct, just not colored correctly.')
# now plot for real
dend = shc.dendrogram(Z, show_leaf_counts=False, get_leaves=True, no_labels=True, ax=axes[0], color_threshold=color_thresh*max(Z[:,2]))
axes[0].set_yticks([])
leaves = dend['leaves']
# The metric parameter is only available after sklearn 1.2.0. Before 1.2.0, it was called affinity
if version.parse(sklearn.__version__) < version.parse("1.2.0"):
cluster = AgglomerativeClustering(n_clusters=n_clusters, affinity=metric, linkage=linkage)
else:
cluster = AgglomerativeClustering(n_clusters=n_clusters, metric=metric, linkage=linkage)
cluster_labels = cluster.fit_predict(data)
if cmap=='bwr':
mm1 = np.abs(data.reshape((-1)).min())
mm2 = np.abs(data.reshape((-1)).max())
mm = max([mm1, mm2])
axes[1].imshow(data[leaves,:].T, cmap=cmap, aspect='auto', vmin=-mm, vmax=mm, interpolation='none')
else:
axes[1].imshow(data[leaves,:].T, cmap=cmap, aspect='auto', interpolation='none')
if varnames:
axes[1].set_yticks([i for i in range(len(varnames))])
axes[1].set_yticklabels(varnames, fontsize=8)
axes[1].set_xticks([])
if return_axes:
return dend, cluster_labels, fig, axes
return dend, cluster_labels
[docs]
def strf_plot(coef, tmin=None, tmax=None, freqs=None, ax=None, smooth=True, vmax=None):
'''
Plot STRF weights as image. Colormap is automatically centered at 0 so
that 0 corresponds to white, positive values are red, and negative values
are blue.
Parameters
----------
coef : np.array, shape (freq, lag)
STRF weights.
tmin : float, optional
Time of first lag (first column in coef)
tmax : float, optional
Time of final lag (last column in coef)
freqs : list or array-like, length=2, optional
Frequency of lowest and highest frequency bin in STRF.
ax : plt.Axes, optional
Axes to plot on. If not specified, will use current axes.
smooth : bool, default=True
Whether or not to smooth the STRF image. Smoothing is
done with 'gouraud' shading in plt.pcolormesh().
vmax : float, optional
If provided, colormap will be between [-vmax, vmax]. If not given,
uses the max absolute value of the coef.
Returns
-------
ax : matplotlib.axes.Axes
Axes where STRF coef is plotted.
Examples
--------
>>> from naplib.visualization import strf_plot
>>> import numpy as np
>>> from scipy.stats import multivariate_normal
>>> # generate example STRF weights following mne's example:
>>> # https://mne.tools/stable/auto_tutorials/machine-learning/30_strf.html
>>> fs = 100
>>> n_freqs = 32
>>> tmin, tmax = 0, 0.4
>>> delays_samp = np.arange(np.round(tmin * fs),
... np.round(tmax * fs) + 1).astype(int)
>>> delays_sec = delays_samp / fs
>>> freqs = np.linspace(50, 5000, n_freqs)
>>> grid = np.array(np.meshgrid(delays_sec, freqs))
>>> # We need data to be shaped as n_epochs, n_features, n_times, so swap axes here
>>> grid = grid.swapaxes(0, -1).swapaxes(0, 1)
>>> # Simulate a temporal receptive field with a Gabor filter
>>> means_high = [.1, 500]
>>> means_low = [.2, 2500]
>>> cov = [[.001, 0], [0, 500000]]
>>> gauss_high = multivariate_normal.pdf(grid, means_high, cov)
>>> gauss_low = -1 * multivariate_normal.pdf(grid, means_low, cov)
>>> weights = gauss_high + gauss_low # Combine to create the "true" STRF
>>> strf_plot(weights, tmin=tmin, tmax=tmax)
.. figure:: /figures/imSTRF1.png
:width: 400px
:alt: strf_plot figure
:align: center
'''
if ax is None:
ax = plt.gca()
if tmin is not None and tmax is not None:
delays_sec = np.linspace(tmin, tmax, coef.shape[1])
lag_string = 'Lag (s)'
else:
delays_sec = np.arange(0, coef.shape[1])
lag_string = 'Lag (samples)'
ax.set_xlabel(lag_string)
freqs_ = np.arange(0, coef.shape[0])
ax.set_ylabel('Frequency')
if smooth:
kwargs = dict(vmax=np.abs(coef).max(), vmin=-np.abs(coef).max(),
cmap='bwr', shading='gouraud')
else:
kwargs = dict(vmax=np.abs(coef).max(), vmin=-np.abs(coef).max(),
cmap='bwr')
if vmax is not None:
kwargs['vmin'] = -vmax
kwargs['vmax'] = vmax
ax.pcolormesh(delays_sec, freqs_, coef, **kwargs)
if freqs is not None:
ax.set_yticks([0, coef.shape[0]-1])
ax.set_yticklabels([freqs[0], freqs[-1]])
return ax
[docs]
def freq_response(ba, fs, ax=None, units='Hz'):
'''
Plot frequency response of a digital filter.
Parameters
----------
ba : tuple of length 2
Tuple containing (b, a), the filter numerator and denominator polynomials.
fs : int
Sampling rate in Hz.
ax : plt.Axes instance, optional
Axes to use. If not specified, will use current axes.
units : string
One of {'Hz', 'rad/s'} specifying whether to plot frequencies in Hz or
radians per second.
Returns
-------
ax : matplotlib.axes.Axes
Axes where STRF coef is plotted.
Examples
--------
>>> import naplib as nl
>>> from naplib.visualization import freq_response
>>> from naplib.preprocessing import filter_butter
>>> # Load sample data to filter
>>> data = nl.io.load_speech_task_data()
>>> alpha_band_data, filters = filter_butter(data, btype='bandpass',
... Wn=[10, 20],
... return_filters=True)
>>> ax = freq_response(filters[0], fs=data[0]['dataf'])
.. figure:: /figures/freq_responses1.png
:width: 400px
:alt: frequency response figure
:align: center
'''
if units not in ['Hz','rad/s']:
raise ValueError(f'units must be one of ["Hz", "rad/s"] but got {units}')
if ax is None:
ax = plt.gca()
if units == 'Hz':
w, h = sig.freqz(ba[0], ba[1], fs=fs)
else:
w, h = sig.freqs(ba[0], ba[1])
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="divide by zero encountered in log10")
ax.semilogx(w, 20 * np.log10(abs(h)))
ax.set_title('Frequency Response')
if units == 'Hz':
ax.set_xlabel('Frequency (Hz)')
else:
ax.set_xlabel('Frequency (radians / second)')
ax.set_ylabel('Amplitude (dB)')
ax.margins(0, 0.1)
ax.grid(which='both', axis='both')
return ax