-
Notifications
You must be signed in to change notification settings - Fork 112
Expand file tree
/
Copy pathfm.py
More file actions
306 lines (240 loc) · 11.3 KB
/
fm.py
File metadata and controls
306 lines (240 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
"""Plots for the FOOOF object.
Notes
-----
This file contains plotting functions that take as input a FOOOF object.
"""
import numpy as np
from fooof.core.io import fname
from fooof.core.utils import nearest_ind
from fooof.core.modutils import safe_import, check_dependency
from fooof.sim.gen import gen_periodic
from fooof.utils.data import trim_spectrum
from fooof.utils.params import compute_fwhm
from fooof.plts.spectra import plot_spectrum
from fooof.plts.settings import PLT_FIGSIZES, PLT_COLORS
from fooof.plts.utils import check_ax, check_plot_kwargs
from fooof.plts.style import check_n_style, style_spectrum_plot
plt = safe_import('.pyplot', 'matplotlib')
###################################################################################################
###################################################################################################
@check_dependency(plt, 'matplotlib')
def plot_fm(fm, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend=True,
file_name=None, ax=None, plot_style=style_spectrum_plot,
data_kwargs=None, model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None):
"""Plot the power spectrum and model fit results from a FOOOF object.
Parameters
----------
fm : FOOOF
Object containing a power spectrum and (optionally) results from fitting.
plot_peaks : None or {'shade', 'dot', 'outline', 'line'}, optional
What kind of approach to take to plot peaks. If None, peaks are not specifically plotted.
Can also be a combination of approaches, separated by '-', for example: 'shade-line'.
plot_aperiodic : boolean, optional, default: True
Whether to plot the aperiodic component of the model fit.
plt_log : boolean, optional, default: False
Whether to plot the frequency values in log10 spacing.
add_legend : boolean, optional, default: False
Whether to add a legend describing the plot components.
file_name : str, optional, default: None
Name with format to save as, including absolute or relative path.
ax : matplotlib.Axes, optional, default: None
Figure axes upon which to plot.
plot_style : callable, optional, default: style_spectrum_plot
A function to call to apply styling & aesthetics to the plot.
data_kwargs, model_kwargs, aperiodic_kwargs, peak_kwargs : None or dict, optional
Keyword arguments to pass into the plot call for each plot element.
Notes
-----
Since FOOOF objects store power values in log spacing,
the y-axis (power) is plotted in log spacing by default.
"""
ax = check_ax(ax, PLT_FIGSIZES['spectral'])
# Log settings - note that power values in FOOOF objects are already logged
log_freqs = plt_log
log_powers = False
# Plot the data, if available
if fm.has_data:
data_kwargs = check_plot_kwargs(data_kwargs, \
{'color' : PLT_COLORS['data'], 'linewidth' : 2.0,
'label' : 'Original Spectrum' if add_legend else None})
plot_spectrum(fm.freqs, fm.power_spectrum, log_freqs, log_powers,
ax=ax, plot_style=None, **data_kwargs)
# Add the full model fit, and components (if requested)
if fm.has_model:
model_kwargs = check_plot_kwargs(model_kwargs, \
{'color' : PLT_COLORS['model'], 'linewidth' : 3.0, 'alpha' : 0.5,
'label' : 'Full Model Fit' if add_legend else None})
plot_spectrum(fm.freqs, fm.fooofed_spectrum_, log_freqs, log_powers,
ax=ax, plot_style=None, **model_kwargs)
# Plot the aperiodic component of the model fit
if plot_aperiodic:
aperiodic_kwargs = check_plot_kwargs(aperiodic_kwargs, \
{'color' : PLT_COLORS['aperiodic'], 'linewidth' : 3.0, 'alpha' : 0.5,
'linestyle' : 'dashed', 'label' : 'Aperiodic Fit' if add_legend else None})
plot_spectrum(fm.freqs, fm._ap_fit, log_freqs, log_powers,
ax=ax, plot_style=None, **aperiodic_kwargs)
# Plot the periodic components of the model fit
if plot_peaks:
_add_peaks(fm, plot_peaks, plt_log, ax=ax, peak_kwargs=peak_kwargs)
# Apply style to plot
check_n_style(plot_style, ax, log_freqs, True)
# Save out figure, if requested
if file_name is not None:
plt.savefig(fname(file_name, 'png'))
def _add_peaks(fm, approach, plt_log, ax, peak_kwargs):
"""Add peaks to a model plot.
Parameters
----------
fm : FOOOF
FOOOF object containing results from fitting.
approach : {'shade', 'dot', 'outline', 'outline', 'line'}
What kind of approach to take to plot peaks.
Can also be a combination of approaches, separated by '-' (for example 'shade-line').
plt_log : boolean, optional, default: False
Whether to plot the frequency values in log10 spacing.
ax : matplotlib.Axes
Figure axes upon which to plot.
peak_kwargs : None or dict
Keyword arguments to pass into the plot call.
This can be a flat dictionary, with plot keyword arguments,
or a dictionary of dictionaries, with keys as labels indicating an `approach`,
and values which contain a dictionary of plot keywords for that approach.
Notes
-----
This is a pass through function, that takes a specification of one
or multiple add peak approaches to use, and calls the relevant function(s).
"""
# Input for kwargs could be None, so check if dict and typecast if not
peak_kwargs = peak_kwargs if isinstance(peak_kwargs, dict) else {}
# Split up approaches, in case multiple are specified, and apply each
for cur_approach in approach.split('-'):
try:
# This unpacks kwargs, if it's embedded dictionaries for each approach
plot_kwargs = peak_kwargs.get(cur_approach, peak_kwargs)
# Pass through to the peak plotting function
ADD_PEAK_FUNCS[cur_approach](fm, plt_log, ax, **plot_kwargs)
except KeyError:
raise ValueError("Plot peak type not understood.")
def _add_peaks_shade(fm, plt_log, ax, **plot_kwargs):
"""Add a shading in of all peaks.
Parameters
----------
fm : FOOOF
FOOOF object containing results from fitting.
plt_log : boolean
Whether to plot the frequency values in log10 spacing.
ax : matplotlib.Axes
Figure axes upon which to plot.
**plot_kwargs
Keyword arguments to pass into the plot call.
"""
kwargs = check_plot_kwargs(plot_kwargs,
{'color' : PLT_COLORS['periodic'], 'alpha' : 0.25})
for peak in fm.get_params('gaussian_params'):
peak_freqs = np.log10(fm.freqs) if plt_log else fm.freqs
peak_line = fm._ap_fit + gen_periodic(fm.freqs, peak)
ax.fill_between(peak_freqs, peak_line, fm._ap_fit, **kwargs)
def _add_peaks_dot(fm, plt_log, ax, **plot_kwargs):
"""Add a short line, from aperiodic to peak, with a dot at the top.
Parameters
----------
fm : FOOOF
FOOOF object containing results from fitting.
plt_log : boolean
Whether to plot the frequency values in log10 spacing.
ax : matplotlib.Axes
Figure axes upon which to plot.
**plot_kwargs
Keyword arguments to pass into the plot call.
"""
kwargs = check_plot_kwargs(plot_kwargs,
{'color' : PLT_COLORS['periodic'],
'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6})
for peak in fm.get_params('peak_params'):
ap_point = np.interp(peak[0], fm.freqs, fm._ap_fit)
freq_point = np.log10(peak[0]) if plt_log else peak[0]
# Add the line from the aperiodic fit up the tip of the peak
ax.plot([freq_point, freq_point], [ap_point, ap_point + peak[1]], **kwargs)
# Add an extra dot at the tip of the peak
ax.plot(freq_point, ap_point + peak[1], marker='o', **kwargs)
def _add_peaks_outline(fm, plt_log, ax, **plot_kwargs):
"""Add an outline of each peak.
Parameters
----------
fm : FOOOF
FOOOF object containing results from fitting.
plt_log : boolean
Whether to plot the frequency values in log10 spacing.
ax : matplotlib.Axes
Figure axes upon which to plot.
**plot_kwargs
Keyword arguments to pass into the plot call.
"""
kwargs = check_plot_kwargs(plot_kwargs,
{'color' : PLT_COLORS['periodic'],
'alpha' : 0.7, 'lw' : 1.5})
for peak in fm.get_params('gaussian_params'):
# Define the frequency range around each peak to plot - peak bandwidth +/- 3
peak_range = [peak[0] - peak[2]*3, peak[0] + peak[2]*3]
# Generate a peak reconstruction for each peak, and trim to desired range
peak_line = fm._ap_fit + gen_periodic(fm.freqs, peak)
peak_freqs, peak_line = trim_spectrum(fm.freqs, peak_line, peak_range)
# Plot the peak outline
peak_freqs = np.log10(peak_freqs) if plt_log else peak_freqs
ax.plot(peak_freqs, peak_line, **kwargs)
def _add_peaks_line(fm, plt_log, ax, **plot_kwargs):
"""Add a long line, from the top of the plot, down through the peak, with an arrow at the top.
Parameters
----------
fm : FOOOF
FOOOF object containing results from fitting.
plt_log : boolean
Whether to plot the frequency values in log10 spacing.
ax : matplotlib.Axes
Figure axes upon which to plot.
**plot_kwargs
Keyword arguments to pass into the plot call.
"""
kwargs = check_plot_kwargs(plot_kwargs,
{'color' : PLT_COLORS['periodic'],
'alpha' : 0.7, 'lw' : 1.4, 'ms' : 10})
ylims = ax.get_ylim()
for peak in fm.get_params('peak_params'):
freq_point = np.log10(peak[0]) if plt_log else peak[0]
ax.plot([freq_point, freq_point], ylims, '-', **kwargs)
ax.plot(freq_point, ylims[1], 'v', **kwargs)
def _add_peaks_width(fm, plt_log, ax, **plot_kwargs):
"""Add a line across the width of peaks.
Parameters
----------
fm : FOOOF
FOOOF object containing results from fitting.
plt_log : boolean
Whether to plot the frequency values in log10 spacing.
ax : matplotlib.Axes
Figure axes upon which to plot.
**plot_kwargs
Keyword arguments to pass into the plot call.
Notes
-----
This line represents the bandwidth (width or gaussian standard deviation) of
the peak, though what is literally plotted is the full-width half-max.
"""
kwargs = check_plot_kwargs(plot_kwargs,
{'color' : PLT_COLORS['periodic'],
'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6})
for peak in fm.gaussian_params_:
peak_top = fm.power_spectrum[nearest_ind(fm.freqs, peak[0])]
bw_freqs = [peak[0] - 0.5 * compute_fwhm(peak[2]),
peak[0] + 0.5 * compute_fwhm(peak[2])]
if plt_log:
bw_freqs = np.log10(bw_freqs)
ax.plot(bw_freqs, [peak_top-(0.5*peak[1]), peak_top-(0.5*peak[1])], **kwargs)
# Collect all the possible `add_peak_*` functions together
ADD_PEAK_FUNCS = {
'shade' : _add_peaks_shade,
'dot' : _add_peaks_dot,
'outline' : _add_peaks_outline,
'line' : _add_peaks_line,
'width' : _add_peaks_width
}