ValueError('max_value 在训练或验证期间不得为零或 nan')

ValueError('max_value must not be zero or nan during training or validation')

提问人:arpita halder 提问时间:10/28/2023 更新时间:10/28/2023 访问量:14

问:

我可以看到max_value是 nan in transforms.py

enter image description here

屏幕截图 2023-10-27 于 10.14.45 PM.png

因此,由于此异常,我得到了一个 ValueError。

enter image description here

-> maxval = [v.reshape(-1) for v in batch.max_value]

(Pdb) n

--返回--

/home/hpc/rzku/mlvl109h/cine-vn-vortex/cinevn/pl/varnet_module.py(317)()->[tensor([0.002...torch.float64)]

-> maxval = [v.reshape(-1) for v in batch.max_value]

在调试时,我看到maxval的计算方式如上所述,我认为很好。

我假设 batch.max_value 什么也没返回,在转换中只提到max_value:最大绝对图像值。 在转换中max_value不会发生任何事情。 maxval = [v.reshape(-1) for v in batch.max_value] 仅在varnet_module中提及。

批次有问题,但我不确定是什么。如果你能猜到什么,请告诉我。 在调试时,我还发现健全性检查已正确执行(我在这里只使用了 2 个样本,但我也尝试了实际的样本)。enter image description here varnet_module中的代码是:

def train_val_forward(self, batch):
        if torch.any(batch.max_value == 0) or torch.any(batch.max_value.isnan()):
            raise ValueError('max_value must not be zero or nan during training or validation')
        
        if mask is not None:
            mask = batch.mask
            # normalize
            norm_val = None
            kspace = batch.masked_kspace
            noised_kspace = batch.noised_kspace
            

            if self.normalize:
                norm_val = batch.target.abs().flatten(1).max(dim=1).values
                # pytorch broadcasts from right to left, hence we need to expand the dimensions of norm_val manually
                kspace = kspace / norm_val[(...,) + (None,) * (kspace.ndim - 1)]
            # forward through network
            output = self(kspace, batch.mask, batch.num_low_frequencies, batch.sens_maps)
            # crop phase oversampling
            target, output = transforms.center_crop_to_smallest(batch.target, output, ndim=2)
            # normalize target similar to output for loss calculation and unsqueeze to add channel dimension
            if norm_val is None:
                target_for_loss = target
                data_range = batch.max_value
            else:
                target_for_loss = target / norm_val[(...,) + (None,) * (target.ndim - 1)]
                data_range = batch.max_value / norm_val

            #loss = F.mse_loss(pred=output.unsqueeze(1), targ=target_for_loss.unsqueeze(1), reduction='mean')
            loss = self.loss(pred=output.unsqueeze(1), targ=target_for_loss.unsqueeze(1), data_range=data_range)
            # unnormalize output
            if norm_val is not None:
                output = output * norm_val[(...,) + (None,) * (output.ndim - 1)]  # don't use inplace operation here!







            if self.normalize:
                noised_norm_val = batch.target.abs().flatten(1).max(dim=1).values
                # pytorch broadcasts from right to left, hence we need to expand the dimensions of norm_val manually
                noised_kspace = noised_kspace / noised_norm_val[(...,) + (None,) * (noised_kspace.ndim - 1)]
            # forward through network
            noised_output = self(noised_kspace, batch.mask, batch.num_low_frequencies, batch.sens_maps)
            # crop phase oversampling
            target, noised_output = transforms.center_crop_to_smallest(batch.target, noised_output, ndim=2)
            # normalize target similar to output for loss calculation and unsqueeze to add channel dimension
            if noised_norm_val is None:
                noised_target_for_loss = target
                noised_data_range = batch.max_value
            else:
                noised_target_for_loss = target / noised_norm_val[(...,) + (None,) * (target.ndim - 1)]
                noised_data_range = batch.max_value / noised_norm_val
            #noised_loss = F.mse_loss(pred=output.unsqueeze(1), targ=target_for_loss.unsqueeze(1), reduction='mean')
            #noised_loss = self.loss(noised_pred=noised_output.unsqueeze(1), noised_targ=noised_target_for_loss.unsqueeze(1), noised_data_range=noised_data_range)


            # unnormalize output
            if noised_norm_val is not None:
                noised_output = noised_output * noised_norm_val[(...,) + (None,) * (noised_output.ndim - 1)]  # don't use inplace operation here!
            
            # Calculate consistency loss
            consistency_loss = self.consistency_loss_fn(noised_output,output)
            
            # Add the consistency loss to the total loss
            loss += consistency_loss


            


        else:
            #loss_func = 0
            # normalize
            norm_val = None
            kspace = batch.masked_kspace

            if self.normalize:
                norm_val = batch.target.abs().flatten(1).max(dim=1).values
                # pytorch broadcasts from right to left, hence we need to expand the dimensions of norm_val manually
                kspace = kspace / norm_val[(...,) + (None,) * (kspace.ndim - 1)]

            # forward through network
            output = self(kspace, batch.mask, batch.num_low_frequencies, batch.sens_maps)

            # crop phase oversampling
            target, output = transforms.center_crop_to_smallest(batch.target, output, ndim=2)

            # normalize target similar to output for loss calculation and unsqueeze to add channel dimension
            if norm_val is None:
                target_for_loss = target
                data_range = batch.max_value
            else:
                target_for_loss = target / norm_val[(...,) + (None,) * (target.ndim - 1)]
                data_range = batch.max_value / norm_val
            loss = self.loss(pred=output.unsqueeze(1), targ=target_for_loss.unsqueeze(1), data_range=data_range)

            # unnormalize output
            if norm_val is not None:
                output = output * norm_val[(...,) + (None,) * (output.ndim - 1)]  # don't use inplace operation here!

            

        return target, output, loss

    
        

    def training_step(self, batch,batch_idx):
        _, _, loss = self.train_val_forward(batch)
        self.log('train_loss', loss, on_step=True, on_epoch=False, sync_dist=True)
        return loss

    def validation_step(self, batch,batch_idx, dataloader_idx=0):
        target, output, loss = self.train_val_forward(batch)
        return {'output': output, 'target': target, 'val_loss': loss}

    def on_validation_batch_end(self, outputs, batch,batch_idx, dataloader_idx=0):
        if not isinstance(outputs, dict):
            raise RuntimeError('outputs must be a dict')
        # update metrics
        target = outputs['target'].abs()
        output = outputs['output'].abs()
        maxval = [v.reshape(-1) for v in batch.max_value]
        if batch.annotations.isnan().any():
            center = None
        else:
            center = [annotation[0].to(int) for annotation in batch.annotations]
        self.val_loss.update(outputs['val_loss'])
        self.nmse.update(batch.fname, batch.slice_num, target, output)
        self.ssim.update(batch.fname, batch.slice_num, target, output, maxvals=maxval)
        self.psnr.update(batch.fname, batch.slice_num, target, output, maxvals=maxval)
        self.hfen.update(batch.fname, batch.slice_num, target, output)
        if self.ssim_xt is not None and center is not None:
            self.ssim_xt.update(batch.fname, batch.slice_num, target, output, center, maxvals=maxval)

    def on_validation_epoch_end(self):
        # logging
        self.log('validation_loss', self.val_loss, prog_bar=True)
        #self.log('consistency loss', self.consistency_loss_fn, prog_bar=True)
        self.log('val_metrics/nmse', self.nmse)
        self.log('val_metrics/ssim', self.ssim, prog_bar=True)
        self.log('val_metrics/psnr', self.psnr)
        self.log('val_metrics/hfen', self.hfen)
        if self.ssim_xt is not None and self.ssim_xt._update_count > 0:
            self.log('val_metrics/ssim_xt', self.ssim_xt)

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        output = self(batch.masked_kspace, batch.mask, batch.num_low_frequencies, batch.sens_maps)
        output = transforms.batched_crop_to_recon_size(output, batch.header)
        return {'output': output}

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.StepLR(optim, self.lr_step_size, self.lr_gamma)

        return [optim], [scheduler]
深度学习 训练-数据 PyTorch-闪电 MRI

评论


答: 暂无答案