···11+from glob import glob
22+import os
33+from collections import OrderedDict
44+from mne import create_info, concatenate_raws, viz
55+from mne.io import RawArray
66+from mne.channels import read_montage
77+import pandas as pd
88+import numpy as np
99+import seaborn as sns
1010+from matplotlib import pyplot as plt
1111+1212+sns.set_context('talk')
1313+sns.set_style('white')
1414+1515+1616+def load_data(fnames, sfreq=128., replace_ch_names=None):
1717+ """Load CSV files from the /data directory into a Raw object.
1818+1919+ Args:
2020+ fnames (array): CSV filepaths from which to load data
2121+2222+ Keyword Args:
2323+ sfreq (float): EEG sampling frequency
2424+ replace_ch_names (dict or None): dictionary containing a mapping to
2525+ rename channels. Useful when an external electrode was used.
2626+2727+ Returns:
2828+ (mne.io.array.array.RawArray): loaded EEG
2929+ """
3030+3131+ raw = []
3232+ print(fnames)
3333+ for fname in fnames:
3434+ # read the file
3535+ data = pd.read_csv(fname, index_col=0)
3636+3737+ data = data.dropna()
3838+3939+ # get estimation of sampling rate and use to determine sfreq
4040+ # yes, this could probably be improved
4141+ srate = 1000 / (data.index.values[1] - data.index.values[0])
4242+ if srate >= 200:
4343+ sfreq = 256
4444+ else:
4545+ sfreq = 128
4646+4747+ # name of each channel
4848+ ch_names = list(data.columns)
4949+5050+ # indices of each channel
5151+ ch_ind = list(range(len(ch_names)))
5252+5353+ if replace_ch_names is not None:
5454+ ch_names = [c if c not in replace_ch_names.keys()
5555+ else replace_ch_names[c] for c in ch_names]
5656+5757+ # type of each channels
5858+ ch_types = ['eeg'] * (len(ch_ind) - 1) + ['stim']
5959+ montage = read_montage('standard_1005')
6060+6161+ # get data and exclude Aux channel
6262+ data = data.values[:, ch_ind].T
6363+6464+ # create MNE object
6565+ info = create_info(ch_names=ch_names, ch_types=ch_types,
6666+ sfreq=sfreq, montage=montage)
6767+ raw.append(RawArray(data=data, info=info))
6868+6969+ # concatenate all raw objects
7070+ raws = concatenate_raws(raw)
7171+7272+ return raws
7373+7474+7575+def plot_topo(epochs, conditions=OrderedDict()):
7676+ palette = sns.color_palette("hls", len(conditions) + 1)
7777+ evokeds = [epochs[name].average() for name in (conditions)]
7878+7979+ evoked_topo = viz.plot_evoked_topo(
8080+ evokeds, vline=None, color=palette[0:len(conditions)], show=False)
8181+ evoked_topo.patch.set_alpha(0)
8282+ evoked_topo.set_size_inches(10, 8)
8383+ for axis in evoked_topo.axes:
8484+ for line in axis.lines:
8585+ line.set_linewidth(2)
8686+8787+ legend_loc = 0
8888+ labels = [e.comment if e.comment else 'Unknown' for e in evokeds]
8989+ legend = plt.legend(labels, loc=legend_loc, prop={'size': 20})
9090+ legend.get_frame().set_facecolor(axis.facecolor)
9191+ txts = legend.get_texts()
9292+ for txt, col in zip(txts, palette):
9393+ txt.set_color(col)
9494+9595+ return evoked_topo
9696+9797+9898+def plot_conditions(epochs, ch_ind=0, conditions=OrderedDict(), ci=97.5, n_boot=1000,
9999+ title='', palette=None,
100100+ diff_waveform=(4, 3)):
101101+ """Plot Averaged Epochs with ERP conditions.
102102+103103+ Args:
104104+ epochs (mne.epochs): EEG epochs
105105+106106+ Keyword Args:
107107+ conditions (OrderedDict): dictionary that contains the names of the
108108+ conditions to plot as keys, and the list of corresponding marker
109109+ numbers as value. E.g.,
110110+111111+ conditions = {'Non-target': [0, 1],
112112+ 'Target': [2, 3, 4]}
113113+114114+ ch_ind (int): index of channel to plot data from
115115+ ci (float): confidence interval in range [0, 100]
116116+ n_boot (int): number of bootstrap samples
117117+ title (str): title of the figure
118118+ palette (list): color palette to use for conditions
119119+ ylim (tuple): (ymin, ymax)
120120+ diff_waveform (tuple or None): tuple of ints indicating which
121121+ conditions to subtract for producing the difference waveform.
122122+ If None, do not plot a difference waveform
123123+124124+ Returns:
125125+ (matplotlib.figure.Figure): figure object
126126+ (list of matplotlib.axes._subplots.AxesSubplot): list of axes
127127+ """
128128+ if isinstance(conditions, dict):
129129+ conditions = OrderedDict(conditions)
130130+131131+ if palette is None:
132132+ palette = sns.color_palette("hls", len(conditions) + 1)
133133+134134+ X = epochs.get_data() * 1e6
135135+ times = epochs.times
136136+ y = pd.Series(epochs.events[:, -1])
137137+ fig, ax = plt.subplots()
138138+139139+ for cond, color in zip(conditions.values(), palette):
140140+ sns.tsplot(X[y.isin(cond), ch_ind], time=times, color=color,
141141+ n_boot=n_boot, ci=ci)
142142+143143+ if diff_waveform:
144144+ diff = (np.nanmean(X[y == diff_waveform[1], ch_ind], axis=0) -
145145+ np.nanmean(X[y == diff_waveform[0], ch_ind], axis=0))
146146+ ax.plot(times, diff, color='k', lw=1)
147147+148148+ ax.set_title(epochs.ch_names[ch_ind])
149149+ ax.axvline(x=0, color='k', lw=1, label='_nolegend_')
150150+151151+ ax.set_xlabel('Time (s)')
152152+ ax.set_ylabel('Amplitude (uV)')
153153+ ax.set_xlabel('Time (s)')
154154+ ax.set_ylabel('Amplitude (uV)')
155155+156156+ if diff_waveform:
157157+ legend = (['{} - {}'.format(diff_waveform[1], diff_waveform[0])] +
158158+ list(conditions.keys()))
159159+ else:
160160+ legend = conditions.keys()
161161+ ax.legend(legend)
162162+ sns.despine()
163163+ plt.tight_layout()
164164+165165+ if title:
166166+ fig.suptitle(title, fontsize=20)
167167+168168+ fig.set_size_inches(10, 8)
169169+170170+ return fig, ax
171171+172172+173173+def plot_highlight_regions(x, y, hue, hue_thresh=0, xlabel='', ylabel='',
174174+ legend_str=()):
175175+ """Plot a line with highlighted regions based on additional value.
176176+177177+ Plot a line and highlight ranges of x for which an additional value
178178+ is lower than a threshold. For example, the additional value might be
179179+ pvalues, and the threshold might be 0.05.
180180+181181+ Args:
182182+ x (array_like): x coordinates
183183+ y (array_like): y values of same shape as `x`
184184+185185+ Keyword Args:
186186+ hue (array_like): values to be plotted as hue based on `hue_thresh`.
187187+ Must be of the same shape as `x` and `y`.
188188+ hue_thresh (float): threshold to be applied to `hue`. Regions for which
189189+ `hue` is lower than `hue_thresh` will be highlighted.
190190+ xlabel (str): x-axis label
191191+ ylabel (str): y-axis label
192192+ legend_str (tuple): legend for the line and the highlighted regions
193193+194194+ Returns:
195195+ (matplotlib.figure.Figure): figure object
196196+ (list of matplotlib.axes._subplots.AxesSubplot): list of axes
197197+ """
198198+ fig, axes = plt.subplots(1, 1, figsize=(10, 5), sharey=True)
199199+200200+ axes.plot(x, y, lw=2, c='k')
201201+ plt.xlabel(xlabel)
202202+ plt.ylabel(ylabel)
203203+204204+ kk = 0
205205+ a = []
206206+ while kk < len(hue):
207207+ if hue[kk] < hue_thresh:
208208+ b = kk
209209+ kk += 1
210210+ while kk < len(hue):
211211+ if hue[kk] > hue_thresh:
212212+ break
213213+ else:
214214+ kk += 1
215215+ a.append([b, kk - 1])
216216+ else:
217217+ kk += 1
218218+219219+ st = (x[1] - x[0]) / 2.0
220220+ for p in a:
221221+ axes.axvspan(x[p[0]]-st, x[p[1]]+st, facecolor='g', alpha=0.5)
222222+ plt.legend(legend_str)
223223+ sns.despine()
224224+225225+ return fig, axes
226226+227227+228228+def get_epochs_info(epochs):
229229+ return [*[{x: len(epochs[x])} for x in epochs.event_id], {"Drop Percentage": round((1 - len(epochs.events)/len(epochs.drop_log)) * 100, 2)}, {"Total Epochs": len(epochs.events)}]