diff --git a/fooof/core/utils.py b/fooof/core/utils.py index f0fcdb588..3b50e44c7 100644 --- a/fooof/core/utils.py +++ b/fooof/core/utils.py @@ -13,13 +13,13 @@ def group_three(vec): Parameters ---------- - vec : 1d array - Array of items to group by 3. Length of array must be divisible by three. + vec : list or 1d array + List or array of items to group by 3. Length of array must be divisible by three. Returns ------- - list of list - List of lists, each with three items. + array or list of list + Array or list of lists, each with three items. Output type will match input type. Raises ------ @@ -30,7 +30,11 @@ def group_three(vec): if len(vec) % 3 != 0: raise ValueError("Wrong size array to group by three.") - return [list(vec[ii:ii+3]) for ii in range(0, len(vec), 3)] + # Reshape, if an array, as it's faster, otherwise asssume lise + if isinstance(vec, np.ndarray): + return np.reshape(vec, (-1, 3)) + else: + return [list(vec[ii:ii+3]) for ii in range(0, len(vec), 3)] def nearest_ind(array, value): diff --git a/fooof/objs/fit.py b/fooof/objs/fit.py index fba745d27..ccdb9b16f 100644 --- a/fooof/objs/fit.py +++ b/fooof/objs/fit.py @@ -1007,18 +1007,16 @@ def _create_peak_params(self, gaus_params): with `freqs`, `fooofed_spectrum_` and `_ap_fit` all required to be available. """ - peak_params = np.empty([0, 3]) + peak_params = np.empty((len(gaus_params), 3)) for ii, peak in enumerate(gaus_params): # Gets the index of the power_spectrum at the frequency closest to the CF of the peak - ind = min(range(len(self.freqs)), key=lambda ii: abs(self.freqs[ii] - peak[0])) + ind = np.argmin(np.abs(self.freqs - peak[0])) # Collect peak parameter data - peak_params = np.vstack((peak_params, - [peak[0], - self.fooofed_spectrum_[ind] - self._ap_fit[ind], - peak[2] * 2])) + peak_params[ii] = [peak[0], self.fooofed_spectrum_[ind] - self._ap_fit[ind], + peak[2] * 2] return peak_params @@ -1037,8 +1035,8 @@ def _drop_peak_cf(self, guess): Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. """ - cf_params = [item[0] for item in guess] - bw_params = [item[2] * self._bw_std_edge for item in guess] + cf_params = guess[:, 0] + bw_params = guess[:, 2] * self._bw_std_edge # Check if peaks within drop threshold from the edge of the frequency range keep_peak = \ diff --git a/fooof/objs/utils.py b/fooof/objs/utils.py index 4ef05d97d..b7a6f052f 100644 --- a/fooof/objs/utils.py +++ b/fooof/objs/utils.py @@ -219,9 +219,14 @@ def fit_fooof_3d(fg, freqs, power_spectra, freq_range=None, n_jobs=1): >>> fgs = fit_fooof_3d(fg, freqs, power_spectra, freq_range=[3, 30]) # doctest:+SKIP """ - fgs = [] - for cond_spectra in power_spectra: - fg.fit(freqs, cond_spectra, freq_range, n_jobs) - fgs.append(fg.copy()) + # Reshape 3d data to 2d and fit, in order to fit with a single group model object + shape = np.shape(power_spectra) + powers_2d = np.reshape(power_spectra, (shape[0] * shape[1], shape[2])) + + fg.fit(freqs, powers_2d, freq_range, n_jobs) + + # Reorganize 2d results into a list of model group objects, to reflect original shape + fgs = [fg.get_group(range(dim_a * shape[1], (dim_a + 1) * shape[1])) \ + for dim_a in range(shape[0])] return fgs diff --git a/fooof/tests/objs/test_utils.py b/fooof/tests/objs/test_utils.py index 6af08372a..6f9479141 100644 --- a/fooof/tests/objs/test_utils.py +++ b/fooof/tests/objs/test_utils.py @@ -120,13 +120,17 @@ def test_combine_errors(tfm, tfg): def test_fit_fooof_3d(tfg): - n_spectra = 2 + n_groups = 2 + n_spectra = 3 xs, ys = gen_group_power_spectra(n_spectra, *default_group_params()) - ys = np.stack([ys, ys], axis=0) + ys = np.stack([ys] * n_groups, axis=0) + spectra_shape = np.shape(ys) tfg = FOOOFGroup() fgs = fit_fooof_3d(tfg, xs, ys) - assert len(fgs) == 2 + assert len(fgs) == n_groups == spectra_shape[0] for fg in fgs: assert fg + assert len(fg) == n_spectra + assert fg.power_spectra.shape == spectra_shape[1:]