多图像超分辨率ESPCN网络输入错误

Multi-Image Super-Resolution ESPCN Network Input Error

提问人:west42 提问时间:10/23/2023 最后编辑:west42 更新时间:10/23/2023 访问量:34

问:

我正在努力修改 ESPCN 网络以接受多个低分辨率输入图像,但我遇到了与输入图像数组形状相关的错误。对于下面的代码长度,我深表歉意,但我想足够彻底(希望)找出问题所在。

## Load required packages ##
import os, math, cv2, time, ntpath, pathlib, random, glob
import numpy as np
import pydotplus as pydot
import pandas as pd
import pandas_datareader
from numpy import asarray, zeros, ones, vstack
from os import listdir

import tensorflow as tf
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import load_img, img_to_array, array_to_img, smart_resize, ImageDataGenerator
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Cropping2D, AveragePooling2D, Conv2D, Input, Add, Concatenate, Conv2DTranspose
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler, EarlyStopping
from tensorflow.keras.optimizers import Adam, RMSprop, Adagrad
from tensorflow.keras.losses import MeanSquaredError
from opencv_jupyter_ui import cv2_imshow
from IPython.display import display

import PIL
from PIL import Image
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import image, pyplot
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset

## Initialize global variables ##
# Enhance image resolution by a factor of 4
crop_size = 400
upscale_factor = 4
input_size = crop_size // upscale_factor

# Set training parameters
num_input_images = 4     # number of MISR network input images to create
input_shape=(50, 50, 3)   # size of each low-res input image (width, height, channels)
batch_size = 50                  # set batch size for training
val_split = 0.2                      # use 20% of the Augmented data for network validation
epochs = 10
lr = 0.001 # learning rate
decay_rate = 0.1      # decay rate for variable learning rate
loss_fn = keras.losses.MeanSquaredError()     # define loss function

# optimizer
optimizer = keras.optimizers.Adam(learning_rate=lr)

# define training data paths
lr_train_path = '/path/to/train_data/lr/*.png'
hr_train_path = '/path/to/train_data/hr/*.png'

# define test data paths
lr_test_path = '/path/to/test_data/lr/*.png'
hr_test_path = '/path/to/test_data/hr/*.png'

## Utility functions ##
# Function to load images from file directory
def load_data(lr_path, hr_path):
    # initialize list variables
    lrImgs = [] 
    hrImgs = []
    
    # read low-res images into file and create list of images
    for image in sorted(glob(lr_path)):
        img = plt.imread(image) # read images in folder
        #img = PIL.Image.open(image)
        lrImgs.append(img) # append images to list
    
    # read high-res images into file and create list of images
    for image in sorted(glob(hr_path)):
        img = plt.imread(image) # read images in folder
        hrImgs.append(img) # append images to list
        
    # convert lists to arrays
    lrImgs = np.array(lrImgs)
    hrImgs = np.array(hrImgs)
    
    return(lrImgs, hrImgs)
 
# Build CNN Model 
def make_misr_espcn(input_shape, upscale_factor, num_input_images, channels):
    conv_args = {
        'activation': 'relu',
        'kernel_initializer': 'Orthogonal',
        'padding': 'same',
    }
    
    #Input layer
    inputs = [Input(shape=input_shape) for _ in range(num_input_images)]
    #inputs = keras.Input(shape=(None, None, channels))
    
    # Shared convolutional layers
    conv1 = Conv2D(64, (5, 5), **conv_args)
    conv2 = Conv2D(64, (3, 3), **conv_args)
    conv3 = Conv2D(32, (3, 3), **conv_args)
    
    # Independent ESPCN branches for each input
    sr_branches = []
    for i in range(num_input_images):
        x = Conv2D(3 * upscale_factor ** 2, (3, 3), **conv_args)(inputs[i])
        x = conv1(x)
        x = conv2(x)
        x = conv3(x)
        #x = Dropout(0.3)(x)
        x = Conv2D(upscale_factor ** 2, (3, 3), **conv_args)(x)
        x = tf.nn.depth_to_space(x, upscale_factor)
        sr_branches.append(x)

    # Merge super-resolved branches
    if len(sr_branches) > 1:
        output = Add()(sr_branches)
    else:
        output = sr_branches[0]
    
    # Create the model
    model = Model(inputs=inputs, outputs=output)
    return model

# Define Callbacks to monitor training
class Get_PSNR_Callback(keras.callbacks.Callback):
    def __init__(self, test_img, upscale_factor):
        super(Get_PSNR_Callback, self).__init__()
        self.test_img = test_img

    # Store PSNR value in each epoch.
    def on_epoch_begin(self, epoch, logs=None):
        self.psnr = []

    def on_epoch_end(self, epoch, logs=None):
        print('Mean PSNR for epoch: %.2f' % (np.mean(self.psnr)))
        if epoch % 5 == 0:
            prediction = upscale_image(self.model, self.test_img)
            plot_results(prediction, 'epoch-' + str(epoch), 'Epoch-%i Prediction' %epoch+1, 'model_out_imgs/')
            #plot_results(prediction, 'Epoch-%i Prediction' %epoch, 'model_out_imgs/')

    def on_test_batch_end(self, batch, logs=None):
        self.psnr.append(10 * math.log10(1 / logs['loss']))

# Create checkpoint to save current model parameters
def model_checkpoint_callback(filepath):
    return ModelCheckpoint(
        filepath,
        save_weights_only=False,
        monitor='loss',
        mode='min',
        save_best_only=True,
    )
        
# Create an EarlyStopping callback
early_stop_callback = EarlyStopping(
    monitor='loss',  # Choose a suitable validation metric
    patience=20,         # Number of epochs with no improvement after which training will be stopped
    verbose=1,
    restore_best_weights=True
)

# Create a LearningRateScheduler callback for variable learning rate
def lr_schedule(epoch, epochs, decay_rate, lr):
    if epoch < epochs // 2:
        return lr  # Keep the initial learning rate for the first half of the number of training epochs
    else:
        return lr * tf.math.exp(-decay_rate)  # Decay the learning rate by the exponential function (e^x)

# Run Model Prediction and Plot Results 
def model_predict(model, lr_img, hr_img, upscale_factor, index):
        
    # use trained model to predict enhanced image resolution
    prediction = upscale_image(model, lr_img)
    
    # convert images to arrays for performance computations
    lowres_img_arr = img_to_array(lr_img)
    highres_img_arr = img_to_array(hr_img)
    predict_img_arr = img_to_array(prediction[0])
    
    plot_results(lr_img * 255, index, 'Low-Res Input', 'model_out_imgs/')
    plot_results(hr_img * 255, index, 'High-Res Target', 'model_out_imgs/')
    plot_results(prediction[0], index, 'High-Res Prediction', 'model_out_imgs/')
    
    # find Peak Signal to Noise Ratio
    test_psnr = tf.image.psnr(predict_img_arr, highres_img_arr, max_val=255)
    
    # find Structural Similarity Index
    pr_hr_ssim = SSIM(highres_img_arr, predict_img_arr)
    
    return test_psnr, pr_hr_ssim, prediction

# Evaluate Model Performance 
def psnr(original, enhanced): 
    mse = np.mean((enhanced - original) ** 2) 
    if(mse == 0):  # MSE is zero means no noise is present in the signal. 
                  # Therefore PSNR has no importance. 
        return 100
    max_pixel = 255.0
    psnr = 20 * log10(max_pixel / sqrt(mse)) 
    return psnr

def SSIM(original, enhanced):
    # convert images to gray scale
    orig_gray = cv2.cvtColor(original, cv2.COLOR_BGR2GRAY)
    enhance_gray = cv2.cvtColor(enhanced, cv2.COLOR_BGR2GRAY)
    # compute the Structureal Similarity Index (SSIM)
    (score, diff) = compare_ssim(orig_gray, enhance_gray, full=True)
    diff = (diff * 255).astype('uint8')
    return score, diff

## Image processing functions ##
# randomly crop smaller image segments
def random_crop(lrImage, hrImage, hrCropSize, scale):
    # calculate the low resolution image crop size and image shape
    lrCropSize = hrCropSize // scale
    lrImageShape = tf.shape(lrImage)[:2]

    # calculate the low resolution image width and height offsets
    lrW = tf.random.uniform(shape=(), maxval=lrImageShape[1] - lrCropSize + 1, dtype=tf.int32)
    lrH = tf.random.uniform(shape=(), maxval=lrImageShape[0] - lrCropSize + 1, dtype=tf.int32)

    # calculate the high resolution image width and height
    hrW = lrW * scale
    hrH = lrH * scale

    # crop the low and high resolution images
    lrImageCropped = tf.slice(lrImage, [lrH, lrW, 0], [(lrCropSize), (lrCropSize), 3])
    hrImageCropped = tf.slice(hrImage, [hrH, hrW, 0], [(hrCropSize), (hrCropSize), 3])

    # return the cropped low and high resolution images
    return (lrImageCropped, hrImageCropped)

# crop center section of image
def get_center_crop(lrImage, hrImage, hrCropSize, scale):
    # calculate the low resolution image crop size and image shape
    lrCropSize = hrCropSize // scale
    lrImageShape = tf.shape(lrImage)[:2]

    # calculate the low resolution image width and height
    lrW = lrImageShape[1]
    lrH = lrImageShape[0]

    # calculate the high resolution image width and height
    hrW = lrW * scale
    hrH = lrH * scale

    # Calculate the center coordinates for cropping
    lr_center_y = lrH // 2
    lr_center_x = lrW // 2
    hr_center_y = hrH // 2
    hr_center_x = hrW // 2

    # Calculate the starting and ending coordinates for cropping
    lr_start_y = lr_center_y - (lrCropSize // 2)
    lr_start_x = lr_center_x - (lrCropSize // 2)
    hr_start_y = hr_center_y - (hrCropSize // 2)
    hr_start_x = hr_center_x - (hrCropSize // 2)

    # Crop the low and high resolution images using a bounding box
    lrImageCropped = tf.image.crop_to_bounding_box(lrImage, lr_start_y, lr_start_x, lrCropSize, lrCropSize)
    hrImageCropped = tf.image.crop_to_bounding_box(hrImage, hr_start_y, hr_start_x, hrCropSize, hrCropSize)

    # return the cropped low and high resolution images
    return lrImageCropped, hrImageCropped

# randomly flip images from left to right
def random_flip(lrImage, hrImage):
    # calculate a random chance for flip
    flipProb = tf.random.uniform(shape=(), maxval=1)
    (lrImage, hrImage) = tf.cond(flipProb < 0.5,
        lambda: (lrImage, hrImage),
        lambda: (flip_left_right(lrImage), flip_left_right(hrImage)))
    
    # return the randomly flipped low and high resolution images
    return (lrImage, hrImage)

# randomly rotate images in 90 degree increments
def random_rotate(lrImage, hrImage):
    # randomly generate the number of 90 degree rotations
    n = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)

    # rotate the low and high resolution images
    lrImage = rot90(lrImage, n)
    hrImage = rot90(hrImage, n)

    # return the randomly rotated images
    return (lrImage, hrImage)

# augment the 500 image dataset to increase the number of training images
def augment_train_data(lrImageTrain, hrImageTrain, hrCropSize, scale):
    lrImageAug = []
    hrImageAug = []
    
    for i in range(10):
        for j in range(len(lrImageTrain)):
            # perform data augmentation
            (lrImage, hrImage) = random_crop(lrImageTrain[j], hrImageTrain[j], hrCropSize, scale)
            (lrImage, hrImage) = random_flip(lrImage, hrImage)
            (lrImage, hrImage) = random_rotate(lrImage, hrImage)

            # reshape the low and high resolution images
            lrImage = tf.reshape(lrImage, (hrCropSize//scale, hrCropSize//scale, 3))
            hrImage = tf.reshape(hrImage, (hrCropSize, hrCropSize, 3))
        
            lrImageAug.append(lrImage)
            hrImageAug.append(hrImage)

    # return the low and high resolution images
    return (lrImageAug, hrImageAug)

# create multiple, slightly different copies of each training image for input to the MISR network
def create_misr_inputs(lr_image, hr_image, num_variations):
    # initialize list variables
    lr_variations = []
    hr_replicas = []
    
    # use Keras ImageDataGenerator for augmentation
    datagen = ImageDataGenerator(
        rotation_range = 5,            # Degrees (0 to 180) for random rotations
        width_shift_range = 0.01,      # Fraction of total width for random horizontal shifts
        height_shift_range = 0.01,     # Fraction of total height for random vertical shifts
        shear_range = 0.01,            # Shear intensity (shear angle in counter-clockwise direction in degrees)
        zoom_range = 0,                # Range for random zoom (1 - zoom_range, 1 + zoom_range); set to 0 to disable zooming
        horizontal_flip = False,       # Randomly flip inputs horizontally
        fill_mode = 'nearest'          # How to fill points outside input boundaries ('constant', 'nearest', 'reflect', 'wrap')
    )
    
    # generate one high-res image for every low-res variation
    for _ in range(num_variations):
        lr_variation = datagen.random_transform(lr_image)
        lr_variations.append(lr_variation)
        hr_replicas.append(hr_image)
        
    # generate one high-res image for all low-res variations
    #for i in range(len(hr_image)):
        #hr_replicas.append(hr_image[i])
      
    return(lr_variations, hr_replicas)

## Main code ##
# load training data
img_trainLR, img_trainHR = load_data(lr_train_path, hr_train_path)

# load test data
img_testLR, img_testHR = load_data(lr_test_path, hr_test_path)

# augment the dataset
(lrTrainAug, hrTrainAug) = augment_train_data(img_trainLR, img_trainHR, hrCropSize=200, scale=4)

# convert augmented dataset lists to arrays
lrTrainAug = np.asarray(lrTrainAug)
hrTrainAug = np.asarray(hrTrainAug)

# Define checkpoint parameters
checkpoint_filepath = 'model_checkpoints/'
test_img = img_testLR[13]

# Instantiate the LearningRateScheduler callback
lr_scheduler_callback = LearningRateScheduler(
    schedule=lambda epoch, lr: lr_schedule(epoch, epochs, decay_rate, lr)
)

# Instantiate model checkpoint callback
#model_checkpoint_callback = model_checkpoint(checkpoint_filepath)

# Initialize callbacks
callbacks = [Get_PSNR_Callback(test_img, upscale_factor), 
             model_checkpoint_callback(checkpoint_filepath), early_stop_callback, lr_scheduler_callback]

# Build the CNN model
model = make_misr_espcn(input_shape, upscale_factor, num_input_images, channels=3)

# compile the model
model.compile(optimizer=optimizer, loss=loss_fn)

# training loop
for epoch in range(epochs):
    print(f'Epoch {epoch+1}/{epochs}')
    
    # Shuffle the training data (X_train and y_train) in the same order
    permutation = np.random.permutation(len(lrTrainAug))
    X_train = lrTrainAug[permutation]
    y_train = hrTrainAug[permutation]
    
    num_batches = len(X_train) // batch_size
    
    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = (batch_idx + 1) * batch_size
        
        # Extract a batch of data
        batch_X = X_train[start_idx:end_idx]
        batch_y = y_train[start_idx:end_idx]
        
        # Generate multiple input image variations
        for i in range(len(batch_X)):
            misr_X, misr_y = create_misr_inputs(batch_X[i], batch_y[i], num_input_images)
            misr_X = np.asarray(misr_X)
            misr_y = np.asarray(misr_y)

            # reshape misr_y to have 4 axes
            #misr_y = misr_y[np.newaxis, ...]
            
            print('LR MISR shape: ', misr_X.shape)
            print('HR MISR shape: ', misr_y.shape)
        
        # train the model
            history = model.fit(
                misr_X, misr_y,            
                batch_size = batch_size,
                epochs = epochs,
                validation_split = val_split,
                callbacks = callbacks
            #callbacks = [psnr_callback, lr_scheduler_callback, early_stop_callback]
            )

我尝试了不同的网络输入,例如

LR MISR input shape:  (4, 50, 50, 3)
HR MISR target shape:  (4, 200, 200, 3)

LR MISR input shape:  (4, 50, 50, 3)
HR MISR target shape:  (1, 200, 200, 3)

每次我收到类似的错误时:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [18], in <cell line: 3>()
     31             print('HR MISR shape: ', misr_y.shape)
     33         # Train the model on the batch
     34             #loss = model.train_on_batch(batch_X, batch_y)
---> 35             history = model.fit(
     36                 #X_train, y_train,
     37                 misr_X, misr_y,            
     38                 batch_size = batch_size,
     39                 epochs = epochs,
     40                 validation_split = val_split,
     41                 callbacks = callbacks
     42             #callbacks = [psnr_callback, lr_scheduler_callback, early_stop_callback]
     43             )
     45 # Track model history to plot MSE curve
     46 #history = model.fit(lrTrainAug, hrTrainAug, epochs=epochs, batch_size=batch_size, callbacks=callbacks, validation_data=img_testLR, verbose=2)
     47 #history = model.fit(lrTrainAug, hrTrainAug, epochs=epochs, batch_size=batch_size, callbacks=callbacks, validation_split=validation_split, verbose=2)
     49 '''
     50 # Iterate over the input images and generate MISR data
     51 for i in range(len(lrTrainAug)):
   (...)
     80 )
     81 '''

File ~/anaconda3/envs/gpu/lib/python3.9/site-packages/keras/utils/traceback_utils.py:70, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     67     filtered_tb = _process_traceback_frames(e.__traceback__)
     68     # To get the full stack trace, call:
     69     # `tf.debugging.disable_traceback_filtering()`
---> 70     raise e.with_traceback(filtered_tb) from None
     71 finally:
     72     del filtered_tb

File /tmp/__autograph_generated_filenq_pv254.py:15, in outer_factory.<locals>.inner_factory.<locals>.tf__train_function(iterator)
     13 try:
     14     do_return = True
---> 15     retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
     16 except:
     17     do_return = False

ValueError: in user code:

    File "/usr/local/home/west42/anaconda3/envs/gpu/lib/python3.9/site-packages/keras/engine/training.py", line 1283, in train_function  *
        return step_function(self, iterator)
    File "/usr/local/home/west42/anaconda3/envs/gpu/lib/python3.9/site-packages/keras/engine/training.py", line 1267, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/usr/local/home/west42/anaconda3/envs/gpu/lib/python3.9/site-packages/keras/engine/training.py", line 1248, in run_step  **
        outputs = model.train_step(data)
    File "/usr/local/home/west42/anaconda3/envs/gpu/lib/python3.9/site-packages/keras/engine/training.py", line 1049, in train_step
        y_pred = self(x, training=True)
    File "/usr/local/home/west42/anaconda3/envs/gpu/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
        raise e.with_traceback(filtered_tb) from None
    File "/usr/local/home/west42/anaconda3/envs/gpu/lib/python3.9/site-packages/keras/engine/input_spec.py", line 219, in assert_input_compatibility
        raise ValueError(

    ValueError: Layer "model" expects 4 input(s), but it received 1 input tensors. Inputs received: [<tf.Tensor 'IteratorGetNext:0' shape=(None, 50, 50, 3) dtype=float32>]

出于某种原因,该模型似乎永远无法将输入图像数组识别为不同的输入。我已经盯着这段代码看了几个星期了,我将非常感谢您能够提供的任何帮助。谢谢!!

python conv-neural-network 值错误

评论


答: 暂无答案