diff --git a/brainiak/reprsimil/brsa.py b/brainiak/reprsimil/brsa.py index af1b8c9b6..5069c321a 100755 --- a/brainiak/reprsimil/brsa.py +++ b/brainiak/reprsimil/brsa.py @@ -2818,11 +2818,13 @@ class GBRSA(BRSA): In this case, the standard deviation of log(SNR) is set by the parameter logS_range. If set to 'unif', a uniform prior in [0,1] is imposed. - In all these cases, SNR is numerically + In all above cases, SNR is numerically marginalized on a grid of parameters. So the parameter SNR_bins determines how accurate the numerical integration is. The more number of bins are used, the more accurate the numerical integration becomes. + If set to 'equal', all voxels are assumed to have the same fixed + SNR. Pseudo-SNR is 1.0 for all voxels. In all the cases, the grids used for pseudo-SNR do not really set an upper bound for SNR, because the real SNR is determined by both pseudo-SNR and U, the shared covariance structure. @@ -3003,11 +3005,14 @@ def __init__( if type(logS_range) is int: logS_range = float(logS_range) self.logS_range = logS_range - assert SNR_prior in ['unif', 'lognorm', 'exp'], \ + assert SNR_prior in ['unif', 'lognorm', 'exp', 'equal'], \ 'SNR_prior can only be chosen from ''unif'', ''lognorm''' \ - ' and ''exp''' + ' ''exp'' and ''equal''' self.SNR_prior = SNR_prior - self.SNR_bins = SNR_bins + if self.SNR_prior == 'equal': + self.SNR_bins = 1 + else: + self.SNR_bins = SNR_bins self.rho_bins = rho_bins self.tol = tol self.optimizer = optimizer @@ -3094,9 +3099,14 @@ def fit(self, X, design, nuisance=None, scan_onsets=None): # However, in fit(), we keep the scikit-learn API that # X is the input data to fit and y, a reserved name not used, is # the label to map to from X. - assert self.SNR_bins >= 10 and self.rho_bins >= 10, \ + assert self.SNR_bins >= 10 and self.SNR_prior != 'equal' or \ + self.SNR_bins == 1 and self.SNR_prior == 'equal', \ + 'At least 10 bins are required to perform the numerical'\ + ' integration over SNR, unless choosing SNR_prior=''equal'','\ + ' in which case SNR_bins should be 1.' + assert self.rho_bins >= 10, \ 'At least 10 bins are required to perform the numerical'\ - ' integration over SNR and rho' + ' integration over rho' assert self.logS_range * 6 / self.SNR_bins < 0.5 \ or self.SNR_prior != 'lognorm', \ 'The minimum grid of log(SNR) should not be larger than 0.5 '\ @@ -4128,9 +4138,12 @@ def _set_SNR_grids(self): # Center of mass of each segment between consecutive # bounds are set as the grids for SNR. SNR_weights = np.ones(self.SNR_bins) / self.SNR_bins - else: # SNR_prior == 'exp' + elif self.SNR_prior == 'exp': SNR_grids = self._bin_exp(self.SNR_bins) SNR_weights = np.ones(self.SNR_bins) / self.SNR_bins + else: + SNR_grids = np.ones(1) + SNR_weights = np.ones(1) SNR_weights = SNR_weights / np.sum(SNR_weights) return SNR_grids, SNR_weights diff --git a/tests/reprsimil/test_gbrsa.py b/tests/reprsimil/test_gbrsa.py index a74a4dcb2..7038d35b4 100644 --- a/tests/reprsimil/test_gbrsa.py +++ b/tests/reprsimil/test_gbrsa.py @@ -361,6 +361,13 @@ def test_SNR_grids(): and np.all(np.diff(SNR_grids) > 0) ), 'SNR_grids or SNR_weights not correct for exponential prior' + s = brainiak.reprsimil.brsa.GBRSA(SNR_prior='equal') + SNR_grids, SNR_weights = s._set_SNR_grids() + assert (np.all(SNR_grids == 1) + and np.all(SNR_weights == 1) + and np.size(SNR_grids) == 1 + ), 'SNR_grids or SNR_weights not correct for equal prior' + def test_n_nureg(): import brainiak.reprsimil.brsa