Source code for naplib.stats.encoding

import numpy as np
import scipy.stats as stats


def wilks_lambda_discriminability(D, L):
    '''
    Compute Wilks' Lambda F-ratio and p-value for a single data matrix (not over time).
    
    Parameters
    ----------
    D : array-like of shape (instance, features)
        Data features.
    L : array-like of shape (instance, )
        Labels for each instance.
    
    Returns
    -------
    f_stat : float
        F-statistic
    p_val : float
        p-value

    See Also
    --------
    discriminability
    lda_discriminability

    '''
    N, K = D.shape
    n_classes = len(np.unique(L))

    # Compute the overall mean of X and the class means for each category in Y
    overall_mean = D.mean(axis=0)
    class_means = [D[L == c].mean(axis=0) for c in np.unique(L)]

    # Compute the between-class scatter matrix and the within-class scatter matrix
    S_B = np.zeros((K, K))
    S_W = np.zeros((K, K))

    for i, c in enumerate(np.unique(L)):
        n_c = (L == c).sum()
        mean_diff = (class_means[i] - overall_mean).reshape(K, 1)
        S_B += n_c * mean_diff @ mean_diff.T
        class_diff = D[L == c] - class_means[i]
        S_W += class_diff.T @ class_diff

    # Step 1: Calculate Wilks' Lambda
    wilks_lambda = np.linalg.det(S_W) / np.linalg.det(S_B + S_W)

    # Step 2: Compute the F-statistic based on Wilks' Lambda
    numerator_df = (n_classes - 1) * K
    denominator_df = N - n_classes
    f_stat = (
        (1 - wilks_lambda ** (1 / K)) / (wilks_lambda ** (1 / K))
        * (denominator_df / numerator_df)
    )

    # Step 3: Calculate the p-value using the F-distribution
    p_value = 1 - stats.f.cdf(f_stat, numerator_df, denominator_df)

    return f_stat, p_value


def lda_discriminability(D, L):
    '''
    Compute LDA discriminability for a single data matrix (not over time).
    
    Parameters
    ----------
    D : array-like of shape (instance, features)
        Data features.
    L : array-like of shape (instance, )
        Labels for each instance.
    
    Returns
    -------
    f_all : np.ndarray
        All F-statistics.
    f_stat : float
        F-statistic.
    f_std : float
        Standard dev of f_all.

    See Also
    --------
    discriminability
    wilks_lambda_discriminability

    '''
    labels = list(np.unique(L))
    def where_label_is(label):
        return np.argwhere(L == label).squeeze(-1)
        
    D = np.concatenate([D[where_label_is(label)] for label in labels])
    L = np.concatenate([L[where_label_is(label)] for label in labels])
    
    global_mean = D.mean(0)
    
    sbs = np.zeros(len(labels))
    sws = np.zeros(len(labels))
    total = 0.0
    for i, label in enumerate(labels):
        index = where_label_is(label)

        # Between class variability
        group_mean = D[index].mean(0)
        sbs[i] = np.linalg.norm(group_mean - global_mean, 2)**2 * len(index)
        
        # Within class variability
        sws[i] = sum(np.linalg.norm(D[j] - group_mean, 2)**2 for j in index)
        
        total += len(index)
    
    # Set zeros to epsilon for numerical reasons
    sws[sws == 0] = 1e-6
    
    # Normalize
    sbs /= (len(labels) - 1)
    sws /= (total - len(labels))

    f_all = np.divide(sbs, sws)
    f_stat = sbs.sum() / sws.sum()
    f_std = np.std(len(labels)*sbs / sws.sum(), ddof=1) / np.sqrt(len(labels))

    return f_all, f_stat, f_std


[docs] def discriminability(D, L, elec_mode='all', method='lda'): ''' Compute discriminability of classes over time. Can use multiple electrodes jointly, or compute the discriminability by each electrode individually. Multiple methods available. Parameters ---------- D : array-like of shape (electrodes, time, instances) Data features over time. L : array-like containing labels for each instance L is either of shape (instances,) if each instance has the same label across the full time axis, or is of shape (time, instances). Must be categorical to indicate class membership for each instance. elec_mode : string, one of ['all', 'individual'] if 'all', computes f-ratio over all electrodes together, and returns f-ratio of shape (time,) if 'individual', computes f-ratio for each electrode individually and returns f-ratio of shape (electrodes, time) method : string, default='lda' Method for computing discriminability. Options are 'lda', 'wilks-lambda'. Returns ------- fratio : np.ndarray if elec_mode=='all', shape=(time,) if elec_mode=='individual', shape=(electrodes, time) pvalues : np.ndarray Only returned if method is 'wilks-lambda'. Shape is the same as ``fratio``. Examples -------- >>> from naplib.stats import discriminability >>> rng = np.random.default_rng(1) >>> elecs, t, instances = 2, 5, 50 >>> D = np.concatenate([rng.normal(size=(elecs, t, instances)), ... rng.normal(loc=1, scale=0.5, size=(elecs, t, instances))], ... axis=-1) >>> # labels for the data, where labels do not change over time >>> L = np.concatenate([np.zeros((instances,)), np.ones((instances,))]) >>> f_stat, p_val = discriminability(D, L, method='wilks-lambda') >>> f_stat array([16.71955996, 19.94997217, 19.0641678 , 15.95256107, 17.90728111]) >>> p_val array([5.65679222e-07, 5.38811997e-08, 1.01531199e-07, 1.00551882e-06, 2.35186401e-07])) ''' def _compute_discrim(x_data, labels_data): if method == 'lda': _, f_stat, _ = lda_discriminability(x_data.T, labels_data) return f_stat, np.nan elif method == 'wilks-lambda': return wilks_lambda_discriminability(x_data.T, labels_data) else: raise ValueError('Bad method. Must be one of "lda", "wilks-lambda"') f_stat, p_vals = None, None if elec_mode == 'all': f_stat = np.zeros(D.shape[1]) p_vals = np.zeros(D.shape[1]) for t in range(D.shape[1]): if L.ndim > 1: f_stat[t], p_vals[t] = _compute_discrim(D[:,t], L[t]) else: f_stat[t], p_vals[t] = _compute_discrim(D[:,t], L) elif elec_mode == 'individual': f_stat = np.zeros(D.shape[:2]) p_vals = np.zeros(D.shape[:2]) for t in range(D.shape[1]): for e in range(D.shape[0]): if L.ndim > 1: f_stat[e,t], p_vals[e,t] = _compute_discrim(D[e,t,None], L[t]) else: f_stat[e,t], p_vals[e,t] = _compute_discrim(D[e,t,None], L) else: raise Exception( f'Error: elec_mode should be one of ["all", "individual"], but got {elec_mode}' ) if method in ['wilks-lambda']: return f_stat, p_vals return f_stat
import numpy as np import numpy as np
[docs] def pairwise_correlation(A, B, axis=0): r""" Compute Pearson correlation between A and B along a specified axis. The correlation is computed pairwise for each corresponding element along the remaining dimensions. The output will have the same shape as the inputs, but with the specified ``axis`` removed. The correlation is calculated as: $$r = \frac{\sum (A_i - \bar{A})(B_i - \bar{B})}{\sqrt{\sum (A_i - \bar{A})^2 \sum (B_i - \bar{B})^2}}$$ Parameters ---------- A : np.ndarray First array. B : np.ndarray Second array. Must be the same shape as A. axis : int, default=0 The axis along which to compute the correlation (e.g., the time dimension). Returns ------- corr : np.ndarray or float Pairwise correlations. If inputs are 1D, returns a float. Otherwise, returns an array of shape equal to the input shape with the ``axis`` dimension removed. """ A = np.asarray(A) B = np.asarray(B) if A.shape != B.shape: raise ValueError(f"A and B must have the same shape, but got {A.shape} and {B.shape}") # 1. Center the data along the specified axis # keepdims=True is essential for broadcasting subtraction am = A - np.mean(A, axis=axis, keepdims=True) bm = B - np.mean(B, axis=axis, keepdims=True) # 2. Compute sum of squares (variance proxies) a_ss = np.sum(am**2, axis=axis) b_ss = np.sum(bm**2, axis=axis) # 3. Compute covariance proxy coscale = np.sum(am * bm, axis=axis) # 4. Return normalized correlation # 1e-15 prevents division by zero for constant signals return coscale / (np.sqrt(a_ss * b_ss) + 1e-15)