An easy-to-use platform for EEG experimentation in the classroom
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

moving files

-175
app/utils/jupyter/cells.ts app/utils/pyodide/cell.js
app/utils/jupyter/functions.ts app/utils/pyodide/functions.js
app/utils/jupyter/pipes.ts app/utils/pyodide/pipes.js
-175
app/utils/jupyter/utils.py
··· 1 - from glob import glob 2 - import os 3 - from collections import OrderedDict 4 - from mne import create_info, concatenate_raws, viz 5 - from mne.io import RawArray 6 - from mne.channels import read_montage 7 - import pandas as pd 8 - import numpy as np 9 - import seaborn as sns 10 - from matplotlib import pyplot as plt 11 - 12 - sns.set_context('talk') 13 - sns.set_style('white') 14 - 15 - 16 - def load_data(fnames, sfreq=128., replace_ch_names=None): 17 - """Load CSV files from the /data directory into a Raw object. 18 - 19 - Args: 20 - fnames (array): CSV filepaths from which to load data 21 - 22 - Keyword Args: 23 - sfreq (float): EEG sampling frequency 24 - replace_ch_names (dict or None): dictionary containing a mapping to 25 - rename channels. Useful when an external electrode was used. 26 - 27 - Returns: 28 - (mne.io.array.array.RawArray): loaded EEG 29 - """ 30 - 31 - raw = [] 32 - print(fnames) 33 - for fname in fnames: 34 - # read the file 35 - data = pd.read_csv(fname, index_col=0) 36 - 37 - data = data.dropna() 38 - 39 - # get estimation of sampling rate and use to determine sfreq 40 - # yes, this could probably be improved 41 - srate = 1000 / (data.index.values[1] - data.index.values[0]) 42 - if srate >= 200: 43 - sfreq = 256 44 - else: 45 - sfreq = 128 46 - 47 - # name of each channel 48 - ch_names = list(data.columns) 49 - 50 - # indices of each channel 51 - ch_ind = list(range(len(ch_names))) 52 - 53 - if replace_ch_names is not None: 54 - ch_names = [c if c not in replace_ch_names.keys() 55 - else replace_ch_names[c] for c in ch_names] 56 - 57 - # type of each channels 58 - ch_types = ['eeg'] * (len(ch_ind) - 1) + ['stim'] 59 - montage = read_montage('standard_1005') 60 - 61 - # get data and exclude Aux channel 62 - data = data.values[:, ch_ind].T 63 - 64 - # create MNE object 65 - info = create_info(ch_names=ch_names, ch_types=ch_types, 66 - sfreq=sfreq, montage=montage) 67 - raw.append(RawArray(data=data, info=info)) 68 - 69 - # concatenate all raw objects 70 - raws = concatenate_raws(raw) 71 - 72 - return raws 73 - 74 - 75 - def plot_topo(epochs, conditions=OrderedDict()): 76 - palette = sns.color_palette("hls", len(conditions) + 1) 77 - evokeds = [epochs[name].average() for name in (conditions)] 78 - 79 - evoked_topo = viz.plot_evoked_topo( 80 - evokeds, vline=None, color=palette[0:len(conditions)], show=False) 81 - evoked_topo.patch.set_alpha(0) 82 - evoked_topo.set_size_inches(10, 8) 83 - for axis in evoked_topo.axes: 84 - for line in axis.lines: 85 - line.set_linewidth(2) 86 - 87 - legend_loc = 0 88 - labels = [e.comment if e.comment else 'Unknown' for e in evokeds] 89 - legend = plt.legend(labels, loc=legend_loc, prop={'size': 20}) 90 - txts = legend.get_texts() 91 - for txt, col in zip(txts, palette): 92 - txt.set_color(col) 93 - 94 - return evoked_topo 95 - 96 - 97 - def plot_conditions(epochs, ch_ind=0, conditions=OrderedDict(), ci=97.5, n_boot=1000, 98 - title='', palette=None, 99 - diff_waveform=(4, 3)): 100 - """Plot Averaged Epochs with ERP conditions. 101 - 102 - Args: 103 - epochs (mne.epochs): EEG epochs 104 - 105 - Keyword Args: 106 - conditions (OrderedDict): dictionary that contains the names of the 107 - conditions to plot as keys, and the list of corresponding marker 108 - numbers as value. E.g., 109 - 110 - conditions = {'Non-target': [0, 1], 111 - 'Target': [2, 3, 4]} 112 - 113 - ch_ind (int): index of channel to plot data from 114 - ci (float): confidence interval in range [0, 100] 115 - n_boot (int): number of bootstrap samples 116 - title (str): title of the figure 117 - palette (list): color palette to use for conditions 118 - ylim (tuple): (ymin, ymax) 119 - diff_waveform (tuple or None): tuple of ints indicating which 120 - conditions to subtract for producing the difference waveform. 121 - If None, do not plot a difference waveform 122 - 123 - Returns: 124 - (matplotlib.figure.Figure): figure object 125 - (list of matplotlib.axes._subplots.AxesSubplot): list of axes 126 - """ 127 - if isinstance(conditions, dict): 128 - conditions = OrderedDict(conditions) 129 - 130 - if palette is None: 131 - palette = sns.color_palette("hls", len(conditions) + 1) 132 - 133 - X = epochs.get_data() 134 - times = epochs.times 135 - y = pd.Series(epochs.events[:, -1]) 136 - fig, ax = plt.subplots() 137 - 138 - for cond, color in zip(conditions.values(), palette): 139 - sns.tsplot(X[y.isin(cond), ch_ind], time=times, color=color, 140 - n_boot=n_boot, ci=ci) 141 - 142 - if diff_waveform: 143 - diff = (np.nanmean(X[y == diff_waveform[1], ch_ind], axis=0) - 144 - np.nanmean(X[y == diff_waveform[0], ch_ind], axis=0)) 145 - ax.plot(times, diff, color='k', lw=1) 146 - 147 - ax.set_title(epochs.ch_names[ch_ind]) 148 - ax.axvline(x=0, color='k', lw=1, label='_nolegend_') 149 - 150 - ax.set_xlabel('Time (s)') 151 - ax.set_ylabel('Amplitude (uV)') 152 - ax.set_xlabel('Time (s)') 153 - ax.set_ylabel('Amplitude (uV)') 154 - 155 - # Round y axis tick labels to 2 decimal places 156 - # ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f')) 157 - 158 - if diff_waveform: 159 - legend = (['{} - {}'.format(diff_waveform[1], diff_waveform[0])] + 160 - list(conditions.keys())) 161 - else: 162 - legend = conditions.keys() 163 - ax.legend(legend) 164 - sns.despine() 165 - plt.tight_layout() 166 - 167 - if title: 168 - fig.suptitle(title, fontsize=20) 169 - 170 - fig.set_size_inches(10, 8) 171 - 172 - return fig, ax 173 - 174 - def get_epochs_info(epochs): 175 - return [*[{x: len(epochs[x])} for x in epochs.event_id], {"Drop Percentage": round((1 - len(epochs.events)/len(epochs.drop_log)) * 100, 2)}, {"Total Epochs": len(epochs.events)}]