importnumpyasnpimporttorchfromdataclassesimportdataclassfromvbll.utils.distributionsimportNormal,DenseNormal,get_parameterizationfromcollections.abcimportCallableimporttorch.nnasnndefKL(p,q_scale):feat_dim=p.mean.shape[-1]mse_term=(p.mean**2).sum(-1).sum(-1)/q_scaletrace_term=(p.trace_covariance/q_scale).sum(-1)logdet_term=(feat_dim*np.log(q_scale)-p.logdet_covariance).sum(-1)return0.5*(mse_term+trace_term+logdet_term)# currently exclude constant@dataclassclassVBLLReturn():predictive:Normal|DenseNormal# Could return distribution or mean/covtrain_loss_fn:Callable[[torch.Tensor],torch.Tensor]val_loss_fn:Callable[[torch.Tensor],torch.Tensor]ood_scores:None|Callable[[torch.Tensor],torch.Tensor]=None
[docs]classRegression(nn.Module):""" Variational Bayesian Linear Regression Parameters ---------- in_features : int Number of input features out_features : int Number of output features regularization_weight : float Weight on regularization term in ELBO parameterization : str Parameterization of covariance matrix. Currently supports {'dense', 'diagonal', 'lowrank'} prior_scale : float Scale of prior covariance matrix wishart_scale : float Scale of Wishart prior on noise covariance dof : float Degrees of freedom of Wishart prior on noise covariance """def__init__(self,in_features,out_features,regularization_weight,parameterization='dense',prior_scale=1.,wishart_scale=1e-2,cov_rank=None,dof=1.):super(Regression,self).__init__()self.wishart_scale=wishart_scaleself.dof=(dof+out_features+1.)/2.self.regularization_weight=regularization_weight# define prior, currently fixing zero mean and arbitrarily scaled covself.prior_scale=prior_scale# noise distributionself.noise_mean=nn.Parameter(torch.zeros(out_features),requires_grad=False)self.noise_logdiag=nn.Parameter(torch.randn(out_features)-1)# last layer distributionself.W_dist=get_parameterization(parameterization)self.W_mean=nn.Parameter(torch.randn(out_features,in_features))self.W_logdiag=nn.Parameter(torch.randn(out_features,in_features)-0.5*np.log(in_features))ifparameterization=='dense':self.W_offdiag=nn.Parameter(torch.randn(out_features,in_features,in_features)/in_features)elifparameterization=='lowrank':self.W_offdiag=nn.Parameter(torch.randn(out_features,in_features,cov_rank)/in_features)defW(self):cov_diag=torch.exp(self.W_logdiag)ifself.W_dist==Normal:cov=self.W_dist(self.W_mean,cov_diag)elifself.W_dist==DenseNormal:tril=torch.tril(self.W_offdiag,diagonal=-1)+torch.diag_embed(cov_diag)cov=self.W_dist(self.W_mean,tril)elifself.W_dist==LowRankNormal:cov=self.W_dist(self.W_mean,self.W_offdiag,cov_diag)returncovdefnoise(self):returnNormal(self.noise_mean,torch.exp(self.noise_logdiag))defforward(self,x):out=VBLLReturn(self.predictive(x),self._get_train_loss_fn(x),self._get_val_loss_fn(x))returnoutdefpredictive(self,x):return(self.W()@x[...,None]).squeeze(-1)+self.noise()def_get_train_loss_fn(self,x):defloss_fn(y):# construct predictive density N(W @ phi, Sigma)W=self.W()noise=self.noise()pred_density=Normal((W.mean@x[...,None]).squeeze(-1),noise.scale)pred_likelihood=pred_density.log_prob(y)trace_term=0.5*((W.covariance_weighted_inner_prod(x.unsqueeze(-2)[...,None]))*noise.trace_precision)kl_term=KL(W,self.prior_scale)wishart_term=(self.dof*noise.logdet_precision-0.5*self.wishart_scale*noise.trace_precision)total_elbo=torch.mean(pred_likelihood-trace_term)total_elbo+=self.regularization_weight*(wishart_term-kl_term)return-total_elboreturnloss_fndef_get_val_loss_fn(self,x):defloss_fn(y):# compute log likelihood under variational posterior via marginalizationreturn-torch.mean(self.predictive(x).log_prob(y))returnloss_fn