Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions brainiak/reprsimil/brsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2818,11 +2818,13 @@ class GBRSA(BRSA):
In this case, the standard deviation of log(SNR) is set
by the parameter logS_range.
If set to 'unif', a uniform prior in [0,1] is imposed.
In all these cases, SNR is numerically
In all above cases, SNR is numerically
marginalized on a grid of parameters. So the parameter SNR_bins
determines how accurate the numerical integration is. The more
number of bins are used, the more accurate the numerical
integration becomes.
If set to 'equal', all voxels are assumed to have the same fixed
SNR. Pseudo-SNR is 1.0 for all voxels.
In all the cases, the grids used for pseudo-SNR do not really
set an upper bound for SNR, because the real SNR is determined
by both pseudo-SNR and U, the shared covariance structure.
Expand Down Expand Up @@ -3003,11 +3005,14 @@ def __init__(
if type(logS_range) is int:
logS_range = float(logS_range)
self.logS_range = logS_range
assert SNR_prior in ['unif', 'lognorm', 'exp'], \
assert SNR_prior in ['unif', 'lognorm', 'exp', 'equal'], \
'SNR_prior can only be chosen from ''unif'', ''lognorm''' \
' and ''exp'''
' ''exp'' and ''equal'''
self.SNR_prior = SNR_prior
self.SNR_bins = SNR_bins
if self.SNR_prior == 'equal':
self.SNR_bins = 1
else:
self.SNR_bins = SNR_bins
self.rho_bins = rho_bins
self.tol = tol
self.optimizer = optimizer
Expand Down Expand Up @@ -3094,9 +3099,14 @@ def fit(self, X, design, nuisance=None, scan_onsets=None):
# However, in fit(), we keep the scikit-learn API that
# X is the input data to fit and y, a reserved name not used, is
# the label to map to from X.
assert self.SNR_bins >= 10 and self.rho_bins >= 10, \
assert self.SNR_bins >= 10 and self.SNR_prior != 'equal' or \
self.SNR_bins == 1 and self.SNR_prior == 'equal', \
'At least 10 bins are required to perform the numerical'\
' integration over SNR, unless choosing SNR_prior=''equal'','\
' in which case SNR_bins should be 1.'
assert self.rho_bins >= 10, \
'At least 10 bins are required to perform the numerical'\
' integration over SNR and rho'
' integration over rho'
assert self.logS_range * 6 / self.SNR_bins < 0.5 \
or self.SNR_prior != 'lognorm', \
'The minimum grid of log(SNR) should not be larger than 0.5 '\
Expand Down Expand Up @@ -4128,9 +4138,12 @@ def _set_SNR_grids(self):
# Center of mass of each segment between consecutive
# bounds are set as the grids for SNR.
SNR_weights = np.ones(self.SNR_bins) / self.SNR_bins
else: # SNR_prior == 'exp'
elif self.SNR_prior == 'exp':
SNR_grids = self._bin_exp(self.SNR_bins)
SNR_weights = np.ones(self.SNR_bins) / self.SNR_bins
else:
SNR_grids = np.ones(1)
SNR_weights = np.ones(1)
SNR_weights = SNR_weights / np.sum(SNR_weights)
return SNR_grids, SNR_weights

Expand Down
7 changes: 7 additions & 0 deletions tests/reprsimil/test_gbrsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,13 @@ def test_SNR_grids():
and np.all(np.diff(SNR_grids) > 0)
), 'SNR_grids or SNR_weights not correct for exponential prior'

s = brainiak.reprsimil.brsa.GBRSA(SNR_prior='equal')
SNR_grids, SNR_weights = s._set_SNR_grids()
assert (np.all(SNR_grids == 1)
and np.all(SNR_weights == 1)
and np.size(SNR_grids) == 1
), 'SNR_grids or SNR_weights not correct for equal prior'


def test_n_nureg():
import brainiak.reprsimil.brsa
Expand Down