···11-from glob import glob
22-import os
33-from collections import OrderedDict
44-from mne import create_info, concatenate_raws, viz
55-from mne.io import RawArray
66-from mne.channels import read_montage
77-import pandas as pd
88-import numpy as np
99-import seaborn as sns
1010-from matplotlib import pyplot as plt
1111-1212-sns.set_context('talk')
1313-sns.set_style('white')
1414-1515-1616-def load_data(fnames, sfreq=128., replace_ch_names=None):
1717- """Load CSV files from the /data directory into a Raw object.
1818-1919- Args:
2020- fnames (array): CSV filepaths from which to load data
2121-2222- Keyword Args:
2323- sfreq (float): EEG sampling frequency
2424- replace_ch_names (dict or None): dictionary containing a mapping to
2525- rename channels. Useful when an external electrode was used.
2626-2727- Returns:
2828- (mne.io.array.array.RawArray): loaded EEG
2929- """
3030-3131- raw = []
3232- print(fnames)
3333- for fname in fnames:
3434- # read the file
3535- data = pd.read_csv(fname, index_col=0)
3636-3737- data = data.dropna()
3838-3939- # get estimation of sampling rate and use to determine sfreq
4040- # yes, this could probably be improved
4141- srate = 1000 / (data.index.values[1] - data.index.values[0])
4242- if srate >= 200:
4343- sfreq = 256
4444- else:
4545- sfreq = 128
4646-4747- # name of each channel
4848- ch_names = list(data.columns)
4949-5050- # indices of each channel
5151- ch_ind = list(range(len(ch_names)))
5252-5353- if replace_ch_names is not None:
5454- ch_names = [c if c not in replace_ch_names.keys()
5555- else replace_ch_names[c] for c in ch_names]
5656-5757- # type of each channels
5858- ch_types = ['eeg'] * (len(ch_ind) - 1) + ['stim']
5959- montage = read_montage('standard_1005')
6060-6161- # get data and exclude Aux channel
6262- data = data.values[:, ch_ind].T
6363-6464- # create MNE object
6565- info = create_info(ch_names=ch_names, ch_types=ch_types,
6666- sfreq=sfreq, montage=montage)
6767- raw.append(RawArray(data=data, info=info))
6868-6969- # concatenate all raw objects
7070- raws = concatenate_raws(raw)
7171-7272- return raws
7373-7474-7575-def plot_topo(epochs, conditions=OrderedDict()):
7676- palette = sns.color_palette("hls", len(conditions) + 1)
7777- evokeds = [epochs[name].average() for name in (conditions)]
7878-7979- evoked_topo = viz.plot_evoked_topo(
8080- evokeds, vline=None, color=palette[0:len(conditions)], show=False)
8181- evoked_topo.patch.set_alpha(0)
8282- evoked_topo.set_size_inches(10, 8)
8383- for axis in evoked_topo.axes:
8484- for line in axis.lines:
8585- line.set_linewidth(2)
8686-8787- legend_loc = 0
8888- labels = [e.comment if e.comment else 'Unknown' for e in evokeds]
8989- legend = plt.legend(labels, loc=legend_loc, prop={'size': 20})
9090- txts = legend.get_texts()
9191- for txt, col in zip(txts, palette):
9292- txt.set_color(col)
9393-9494- return evoked_topo
9595-9696-9797-def plot_conditions(epochs, ch_ind=0, conditions=OrderedDict(), ci=97.5, n_boot=1000,
9898- title='', palette=None,
9999- diff_waveform=(4, 3)):
100100- """Plot Averaged Epochs with ERP conditions.
101101-102102- Args:
103103- epochs (mne.epochs): EEG epochs
104104-105105- Keyword Args:
106106- conditions (OrderedDict): dictionary that contains the names of the
107107- conditions to plot as keys, and the list of corresponding marker
108108- numbers as value. E.g.,
109109-110110- conditions = {'Non-target': [0, 1],
111111- 'Target': [2, 3, 4]}
112112-113113- ch_ind (int): index of channel to plot data from
114114- ci (float): confidence interval in range [0, 100]
115115- n_boot (int): number of bootstrap samples
116116- title (str): title of the figure
117117- palette (list): color palette to use for conditions
118118- ylim (tuple): (ymin, ymax)
119119- diff_waveform (tuple or None): tuple of ints indicating which
120120- conditions to subtract for producing the difference waveform.
121121- If None, do not plot a difference waveform
122122-123123- Returns:
124124- (matplotlib.figure.Figure): figure object
125125- (list of matplotlib.axes._subplots.AxesSubplot): list of axes
126126- """
127127- if isinstance(conditions, dict):
128128- conditions = OrderedDict(conditions)
129129-130130- if palette is None:
131131- palette = sns.color_palette("hls", len(conditions) + 1)
132132-133133- X = epochs.get_data()
134134- times = epochs.times
135135- y = pd.Series(epochs.events[:, -1])
136136- fig, ax = plt.subplots()
137137-138138- for cond, color in zip(conditions.values(), palette):
139139- sns.tsplot(X[y.isin(cond), ch_ind], time=times, color=color,
140140- n_boot=n_boot, ci=ci)
141141-142142- if diff_waveform:
143143- diff = (np.nanmean(X[y == diff_waveform[1], ch_ind], axis=0) -
144144- np.nanmean(X[y == diff_waveform[0], ch_ind], axis=0))
145145- ax.plot(times, diff, color='k', lw=1)
146146-147147- ax.set_title(epochs.ch_names[ch_ind])
148148- ax.axvline(x=0, color='k', lw=1, label='_nolegend_')
149149-150150- ax.set_xlabel('Time (s)')
151151- ax.set_ylabel('Amplitude (uV)')
152152- ax.set_xlabel('Time (s)')
153153- ax.set_ylabel('Amplitude (uV)')
154154-155155- # Round y axis tick labels to 2 decimal places
156156- # ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
157157-158158- if diff_waveform:
159159- legend = (['{} - {}'.format(diff_waveform[1], diff_waveform[0])] +
160160- list(conditions.keys()))
161161- else:
162162- legend = conditions.keys()
163163- ax.legend(legend)
164164- sns.despine()
165165- plt.tight_layout()
166166-167167- if title:
168168- fig.suptitle(title, fontsize=20)
169169-170170- fig.set_size_inches(10, 8)
171171-172172- return fig, ax
173173-174174-def get_epochs_info(epochs):
175175- return [*[{x: len(epochs[x])} for x in epochs.event_id], {"Drop Percentage": round((1 - len(epochs.events)/len(epochs.drop_log)) * 100, 2)}, {"Total Epochs": len(epochs.events)}]