Stimulus Reconstruction Basics

Using TRF for auditory stimulus reconstruction.

This notebook demonstrates how to use the TRF model in naplib-python to do stimulus reconstruction from neural recordings. This is a technique where the auditory stimulus is reconstructed from the electrode responses. Here, we will train a linear TRF model to reconstruct the stimulus spectrogram using 250 ms of responses from a group of electrodes.

For more information on stimulus reconstruction and its uses, please see the following papers:

# Author: Gavin Mischler
#
# License: MIT


import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import resample
from sklearn.linear_model import Ridge

import naplib as nl

Load and Preprocess Data

data = nl.io.load_speech_task_data()
print(f'This data has {len(data)} trials')


# This data contains 10 electrodes of simultaneous recording
print(data['resp'][0].shape)

# get auditory spectrogram for each stimulus sound
data['spec'] = [nl.features.auditory_spectrogram(trial['sound'], 11025) for trial in data]

# make sure the spectrogram is the exact same size as the responses
data['spec'] = [resample(trial['spec'], trial['resp'].shape[0]) for trial in data]

# normalize responses
data['resp'] = nl.preprocessing.normalize(data, 'resp')

# Since the spectrogram is 128-channels, which is very large, we downsample it
print(f"before resampling: {data['spec'][0].shape}")

resample_kwargs = {'num': 32, 'axis': 1}
data['spec_32'] = nl.array_ops.concat_apply(data['spec'], resample, function_kwargs=resample_kwargs)

print(f"after resampling:  {data['spec_32'][0].shape}")
This data has 10 trials
(6197, 10)
before resampling: (6197, 128)
after resampling:  (6197, 32)

Visualize Spectrogram

Let's look at the first 5 seconds more closely to understand the data

fig, axes = plt.subplots(2,1,figsize=(8,5), sharex=True)
axes[0].imshow(data[0]['spec_32'][:500].T, aspect=3, origin='lower')
axes[0].set_title('Spectrogram of stimulus (to reconstruct)')
axes[1].plot(data[0]['resp'][:500])
axes[1].set_title('Multichannel recording to use as input\nfeatures to reconstruction model')
plt.tight_layout()
plt.show()
Spectrogram of stimulus (to reconstruct), Multichannel recording to use as input features to reconstruction model

Train TRF Model

Train a single TRF model to reconstruct all frequency bins as a single output channel with multiple dimensions.

First, we reshape the stimulus spectrograms to be (time * 1 * frequency), because if they are just (time * frequency), then a separate model will be trained for each frequency bin.

data['reshaped_spec'] = [x[:,np.newaxis,:] for x in data['spec_32']]
print(data['reshaped_spec'][0].shape)

# separate into train and test sections
data_train = data[1:]
data_test = data[:1]

# model parameters
tmin = -0.40 # 400 ms window from the neural response is used to reconstruct the next time sample of the stimulus
tmax = 0
sfreq = 100

# train model with Ridge estimator
mdl = nl.encoding.TRF(tmin=tmin, tmax=tmax, sfreq=sfreq, estimator=Ridge())
mdl.fit(data_train, X='resp', y='reshaped_spec')

reconstructed_stims = mdl.predict(data_test, X='resp')

# compute correlation score
corr = mdl.corr(data_test, X='resp', y='reshaped_spec')


region = slice(0, 500)

fig, axes = plt.subplots(2,1,figsize=(12,6))

axes[0].imshow(data_test[0]['reshaped_spec'].squeeze()[region].T, aspect=3, origin='lower')
axes[0].set_title('True stimulus')
axes[1].imshow(reconstructed_stims[0].squeeze()[region].T, aspect=3, origin='lower')
axes[1].set_title('Reconstructed stimulus, corr={:.3f}'.format(corr.item()))

plt.show()
True stimulus, Reconstructed stimulus, corr=0.670
(6197, 1, 32)

  0%|          | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:01<00:00,  1.07s/it]
100%|██████████| 1/1 [00:01<00:00,  1.07s/it]

Train a TRF model to reconstruct each frequency bin individually

For this, we can use the spectrogram which is shape (time * frequency), instead of the reshaped spectrogram, since the TRF model will automatically train a single model for each channel in the second dimension of the y variable.

# train model
mdl = nl.encoding.TRF(tmin=tmin, tmax=tmax, sfreq=sfreq, estimator=Ridge())
mdl.fit(data_train, X='resp', y='spec_32')

reconstructed_stims_bychannel = mdl.predict(data_test, X='resp')
reconstructed_stims_bychannel[0].shape

# compute correlation score
corr = mdl.corr(data_test, X='resp', y='spec_32').mean()

region = slice(0, 500)

fig, axes = plt.subplots(2,1,figsize=(12,6))

axes[0].imshow(data_test[0]['spec_32'].squeeze()[region].T, aspect=3, origin='lower')
axes[0].set_title('True stimulus')
axes[1].imshow(reconstructed_stims_bychannel[0].squeeze()[region].T, aspect=3, origin='lower')
axes[1].set_title('Reconstructed stimulus, corr={:.3f}'.format(corr))

plt.show()
True stimulus, Reconstructed stimulus, corr=0.590
  0%|          | 0/32 [00:00<?, ?it/s]
  3%|▎         | 1/32 [00:01<00:32,  1.05s/it]
  6%|▋         | 2/32 [00:01<00:24,  1.21it/s]
  9%|▉         | 3/32 [00:02<00:21,  1.33it/s]
 12%|█▎        | 4/32 [00:03<00:20,  1.38it/s]
 16%|█▌        | 5/32 [00:03<00:19,  1.41it/s]
 19%|█▉        | 6/32 [00:04<00:18,  1.42it/s]
 22%|██▏       | 7/32 [00:05<00:17,  1.43it/s]
 25%|██▌       | 8/32 [00:05<00:16,  1.43it/s]
 28%|██▊       | 9/32 [00:06<00:16,  1.43it/s]
 31%|███▏      | 10/32 [00:07<00:15,  1.41it/s]
 34%|███▍      | 11/32 [00:07<00:14,  1.42it/s]
 38%|███▊      | 12/32 [00:08<00:14,  1.39it/s]
 41%|████      | 13/32 [00:09<00:13,  1.41it/s]
 44%|████▍     | 14/32 [00:10<00:13,  1.38it/s]
 47%|████▋     | 15/32 [00:10<00:12,  1.40it/s]
 50%|█████     | 16/32 [00:11<00:11,  1.40it/s]
 53%|█████▎    | 17/32 [00:12<00:10,  1.40it/s]
 56%|█████▋    | 18/32 [00:13<00:10,  1.39it/s]
 59%|█████▉    | 19/32 [00:13<00:09,  1.40it/s]
 62%|██████▎   | 20/32 [00:14<00:08,  1.42it/s]
 66%|██████▌   | 21/32 [00:15<00:07,  1.42it/s]
 69%|██████▉   | 22/32 [00:15<00:07,  1.40it/s]
 72%|███████▏  | 23/32 [00:16<00:06,  1.42it/s]
 75%|███████▌  | 24/32 [00:17<00:05,  1.43it/s]
 78%|███████▊  | 25/32 [00:17<00:04,  1.44it/s]
 81%|████████▏ | 26/32 [00:18<00:04,  1.40it/s]
 84%|████████▍ | 27/32 [00:19<00:03,  1.41it/s]
 88%|████████▊ | 28/32 [00:19<00:02,  1.44it/s]
 91%|█████████ | 29/32 [00:20<00:02,  1.42it/s]
 94%|█████████▍| 30/32 [00:21<00:01,  1.43it/s]
 97%|█████████▋| 31/32 [00:22<00:00,  1.44it/s]
100%|██████████| 32/32 [00:22<00:00,  1.50it/s]
100%|██████████| 32/32 [00:22<00:00,  1.41it/s]

Total running time of the script: (0 minutes 46.416 seconds)

Gallery generated by Sphinx-Gallery