Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions specparam/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,15 @@ class Algorithm():
Data object with spectral data and metadata.
results : Results
Results object with model fit results and metrics.
model : SpectralModel, optional
The model object this object is linked to, to provide access to other attributes.
debug : bool
Whether to run in debug state, raising an error if encountered during fitting.
"""

def __init__(self, name, description, public_settings, private_settings=None,
data_format='spectrum', modes=None, data=None, results=None, debug=False):
data_format='spectrum', modes=None, data=None, results=None, model=None,
debug=False):
"""Initialize Algorithm object."""

self.name = name
Expand Down Expand Up @@ -66,6 +69,8 @@ def __init__(self, name, description, public_settings, private_settings=None,

self.set_debug(debug)

self._model = model


def _fit_prechecks(self, verbose):
"""Pre-checks to run before the fit function - if are some, overload this function."""
Expand Down Expand Up @@ -178,13 +183,14 @@ class AlgorithmCF(Algorithm):
"""

def __init__(self, name, description, public_settings, private_settings=None,
data_format='spectrum', modes=None, data=None, results=None, debug=False):
data_format='spectrum', modes=None, data=None, results=None,
model=None, debug=False):
"""Initialize Algorithm object."""

Algorithm.__init__(self, name=name, description=description,
public_settings=public_settings, private_settings=private_settings,
data_format=data_format, modes=modes, data=data, results=results,
debug=debug)
model=model, debug=debug)

self._cf_settings_desc = CURVE_FIT_SETTINGS
self._cf_settings = SettingsValues(self._cf_settings_desc.names)
Expand Down
4 changes: 2 additions & 2 deletions specparam/algorithms/spectral_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class SpectralFitAlgorithm(AlgorithmCF):
def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_height=0.0,
peak_threshold=2.0, ap_percentile_thresh=0.025, ap_guess=None, ap_bounds=None,
cf_bound=1.5, bw_std_edge=1.0, gauss_overlap_thresh=0.75, maxfev=5000,
tol=0.00001, modes=None, data=None, results=None, debug=False):
tol=0.00001, modes=None, data=None, results=None, model=None, debug=False):
"""Initialize base model object"""

# Initialize base algorithm object with algorithm metadata
Expand All @@ -115,7 +115,7 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h
description='Original parameterizing neural power spectra algorithm.',
public_settings=SPECTRAL_FIT_SETTINGS_DEF,
private_settings=SPECTRAL_FIT_PRIVATE_SETTINGS_DEF,
modes=modes, data=data, results=results, debug=debug)
modes=modes, data=data, results=results, model=model, debug=debug)

## Public settings
self.settings.peak_width_limits = peak_width_limits
Expand Down
66 changes: 59 additions & 7 deletions specparam/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from specparam.sim.gen import gen_freqs
from specparam.data import SpectrumMetaData, ModelChecks
from specparam.utils.array import unlog
from specparam.utils.spectral import trim_spectrum
from specparam.utils.checks import check_input_options
from specparam.reports.strings import gen_data_str
Expand Down Expand Up @@ -36,6 +37,8 @@ class Data():
Whether to check the spectral data. If so, raises an error for any NaN / Inf values.
format : {'power'}
The representation format of the data.
model : SpectralModel, optional
The model object this object is linked to, to provide access to other attributes.

Attributes
----------
Expand All @@ -55,7 +58,7 @@ class Data():
All power values are stored internally in log10 scale.
"""

def __init__(self, check_freqs=True, check_data=True, format='power'):
def __init__(self, check_freqs=True, check_data=True, format='power', model=None):
"""Initialize Data object."""

self._reset_data(True, True)
Expand All @@ -70,6 +73,7 @@ def __init__(self, check_freqs=True, check_data=True, format='power'):
check_input_options(format, FORMATS, 'format')
self.format = format

self._model = model

@property
def has_data(self):
Expand Down Expand Up @@ -154,6 +158,54 @@ def get_meta_data(self):
return SpectrumMetaData(**{key : getattr(self, key) for key in self._meta_fields})


def get_data(self, component='full', space='log'):
"""Get a data component.

Parameters
----------
component : {'full', 'aperiodic', 'peak'}
Which data component to return.
'full' - full power spectrum
'aperiodic' - isolated aperiodic data component
'peak' - isolated peak data component
space : {'log', 'linear'}
Which space to return the data component in.
'log' - returns in log10 space.
'linear' - returns in linear space.

Returns
-------
output : 1d array
Specified data component, in specified spacing.

Notes
-----
The 'space' parameter doesn't just define the spacing of the data component
values, but rather defines the space of the additive data definition such that
`power_spectrum = aperiodic_component + peak_component`.
With space set as 'log', this combination holds in log space.
With space set as 'linear', this combination holds in linear space.
"""

if not self.has_data:
raise NoDataError("No data available to fit, can not proceed.")
assert space in ['linear', 'log'], "Input for 'space' invalid."

if component == 'full':
output = self.power_spectrum if space == 'log' \
else unlog(self.power_spectrum)
elif component == 'aperiodic':
output = self._model.results.model._spectrum_peak_rm if space == 'log' else \
unlog(self.power_spectrum) / unlog(self._model.results.model._peak_fit)
elif component == 'peak':
output = self._model.results.model._spectrum_flat if space == 'log' else \
unlog(self.power_spectrum) - unlog(self._model.results.model._ap_fit)
else:
raise ValueError('Input for component invalid.')

return output


def plot(self, plt_log=False, **plt_kwargs):
"""Plot the power spectrum."""

Expand Down Expand Up @@ -339,10 +391,10 @@ class Data2D(Data):
All power values are stored internally in log10 scale.
"""

def __init__(self):
def __init__(self, *args, **kwargs):
"""Initialize Data2D object."""

Data.__init__(self)
Data.__init__(self, *args, **kwargs)

self.power_spectra = None

Expand Down Expand Up @@ -451,10 +503,10 @@ class Data2DT(Data2D):
All power values are stored internally in log10 scale.
"""

def __init__(self):
def __init__(self, *args, **kwargs):
"""Initialize Data2DT object."""

Data2D.__init__(self)
Data2D.__init__(self, *args, **kwargs)


@property
Expand Down Expand Up @@ -521,10 +573,10 @@ class Data3D(Data2DT):
All power values are stored internally in log10 scale.
"""

def __init__(self):
def __init__(self, *args, **kwargs):
"""Initialize Data3D object."""

Data2DT.__init__(self)
Data2DT.__init__(self, *args, **kwargs)

self.spectrograms = None

Expand Down
51 changes: 1 addition & 50 deletions specparam/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from copy import deepcopy

from specparam.utils.array import unlog
from specparam.utils.checks import check_array_dim
from specparam.modes.modes import Modes
from specparam.modutils.errors import NoDataError
Expand Down Expand Up @@ -56,7 +55,7 @@ def add_modes(self, aperiodic_mode, periodic_mode):
Mode for periodic component, or string specifying which mode to use.
"""

self.modes = Modes(aperiodic=aperiodic_mode, periodic=periodic_mode)
self.modes = Modes(aperiodic=aperiodic_mode, periodic=periodic_mode, model=self)

if getattr(self, 'results', None):
self.results.modes = self.modes
Expand All @@ -66,54 +65,6 @@ def add_modes(self, aperiodic_mode, periodic_mode):
self.algorithm._reset_subobjects(modes=self.modes, results=self.results)


def get_data(self, component='full', space='log'):
"""Get a data component.

Parameters
----------
component : {'full', 'aperiodic', 'peak'}
Which data component to return.
'full' - full power spectrum
'aperiodic' - isolated aperiodic data component
'peak' - isolated peak data component
space : {'log', 'linear'}
Which space to return the data component in.
'log' - returns in log10 space.
'linear' - returns in linear space.

Returns
-------
output : 1d array
Specified data component, in specified spacing.

Notes
-----
The 'space' parameter doesn't just define the spacing of the data component
values, but rather defines the space of the additive data definition such that
`power_spectrum = aperiodic_component + peak_component`.
With space set as 'log', this combination holds in log space.
With space set as 'linear', this combination holds in linear space.
"""

if not self.data.has_data:
raise NoDataError("No data available to fit, can not proceed.")
assert space in ['linear', 'log'], "Input for 'space' invalid."

if component == 'full':
output = self.data.power_spectrum if space == 'log' \
else unlog(self.data.power_spectrum)
elif component == 'aperiodic':
output = self.results.model._spectrum_peak_rm if space == 'log' else \
unlog(self.data.power_spectrum) / unlog(self.results.model._peak_fit)
elif component == 'peak':
output = self.results.model._spectrum_flat if space == 'log' else \
unlog(self.data.power_spectrum) - unlog(self.results.model._ap_fit)
else:
raise ValueError('Input for component invalid.')

return output


def print_settings(self, description=False, concise=False):
"""Print out the current settings.

Expand Down
5 changes: 3 additions & 2 deletions specparam/models/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,12 @@ def __init__(self, *args, **kwargs):
verbose=kwargs.pop('verbose', True),
**kwargs)

self.data = Data3D()
self.data = Data3D(model=self)

self.results = Results3D(modes=self.modes,
metrics=kwargs.pop('metrics', None),
bands=kwargs.pop('bands', None))
bands=kwargs.pop('bands', None),
model=self)

self.algorithm._reset_subobjects(data=self.data, results=self.results)

Expand Down
5 changes: 3 additions & 2 deletions specparam/models/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,12 @@ def __init__(self, *args, **kwargs):
verbose=kwargs.pop('verbose', True),
**kwargs)

self.data = Data2D()
self.data = Data2D(model=self)

self.results = Results2D(modes=self.modes,
metrics=kwargs.pop('metrics', None),
bands=kwargs.pop('bands', None))
bands=kwargs.pop('bands', None),
model=self)

self.algorithm._reset_subobjects(data=self.data, results=self.results)

Expand Down
6 changes: 3 additions & 3 deletions specparam/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ def __init__(self, aperiodic_mode='fixed', periodic_mode='gaussian',
update_converters(DEFAULT_CONVERTERS, converters)
BaseModel.__init__(self, aperiodic_mode, periodic_mode, converters, verbose)

self.data = Data()
self.data = Data(model=self)

self.results = Results(modes=self.modes, metrics=metrics, bands=bands)
self.results = Results(modes=self.modes, metrics=metrics, bands=bands, model=self)

algorithm_settings = {} if algorithm_settings is None else algorithm_settings
self.algorithm = check_algorithm_definition(algorithm, ALGORITHMS)(
**algorithm_settings, modes=self.modes, data=self.data,
results=self.results, debug=debug, **model_kwargs)
results=self.results, debug=debug, model=self, **model_kwargs)


@replace_docstring_sections([docs_get_section(Data.add_data.__doc__, 'Parameters'),
Expand Down
5 changes: 3 additions & 2 deletions specparam/models/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,12 @@ def __init__(self, *args, **kwargs):
verbose=kwargs.pop('verbose', True),
**kwargs)

self.data = Data2DT()
self.data = Data2DT(model=self)

self.results = Results2DT(modes=self.modes,
metrics=kwargs.pop('metrics', None),
bands=kwargs.pop('bands', None))
bands=kwargs.pop('bands', None),
model=self)

self.algorithm._reset_subobjects(data=self.data, results=self.results)

Expand Down
6 changes: 5 additions & 1 deletion specparam/modes/modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ class Modes():
Aperiodic mode.
periodic : str or Mode
Periodic mode.
model : SpectralModel, optional
The model object this object is linked to, to provide access to other attributes.
"""

def __init__(self, aperiodic, periodic):
def __init__(self, aperiodic, periodic, model=None):
"""Initialize modes."""

# Set list of component names
Expand All @@ -29,6 +31,8 @@ def __init__(self, aperiodic, periodic):
self.aperiodic = check_mode_definition(aperiodic, AP_MODES)
self.periodic = check_mode_definition(periodic, PE_MODES)

self.model = model


def check_params(self):
"""Check the description of the parameters for each mode."""
Expand Down
Loading