Context Encoders를 활용한 배경 채우기(2)

Context Encoder에 대한 자세한 내용은 이전 포스트를 참고해 주세요.
Context Encoders를 활용한 배경 채우기(1)

Dataset은 미리 준비 되어 있어야 하고 저장되어 있는 위치에 맞도록 경로를 바꿔 주어야 합니다.

Training 코드

환경

Python언어로 작성하였고 Pytorch와 기타 라이브러리를 이용했습니다. 학습은 Google colab GPU를 이용했습니다.

사전 준비

from google.colab import drive
drive.mount('/content/drive')

google drive를 mount합니다.

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch import optim
from torch.utils.tensorboard import SummaryWriter 
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
import time


if torch.cuda.is_available():
  device = 'cuda:0'
else:
  device = 'cpu'

필요한 라이브러리를 import하고 GPU를 사용할 경우 device를 GPU로 바꿔 줍니다.

if not os.path.isfile("/content/voc_test_2007_tar"):
  !cp "/content/drive/MyDrive/Faster_RCNN/VOCtest_06-Nov-2007.tar" "/content/voc_test_2007_tar" 
  ! tar -xf "/content/voc_test_2007_tar"

마스크를 얻기 위해서 Pascal Voc 2007 dataset을 google drive에서 google colab런타임으로 복사해 줍니다.

if not os.path.isfile("/content/city.zip"):
  !cp "/content/drive/MyDrive/Inpainting/city.zip" "/content/city.zip" 
  ! unzip "/content/city.zip" -d "/content/city"

if not os.path.isfile("/content/nature.zip"):
  !cp "/content/drive/MyDrive/Inpainting/nature.zip" "/content/nature.zip" 
  ! unzip "/content/nature.zip" -d "/content/nature"

도시와 자연 이미지 데이터셋을 google colab런타임으로 복사해 줍니다.

Dataset 만들기

마스크로 가려진 이미지를 출력하기 위해 먼저 Pascal Voc 2007로부터 마스크를 만드는 코드를 작성합니다.

def make_mask(size):
  label = 0
  while True:
    seg_list = os.listdir("/content/VOCdevkit/VOC2007/SegmentationObject/")
    seg_image = Image.open("/content/VOCdevkit/VOC2007/SegmentationObject/"+seg_list[np.random.randint(0,len(seg_list))])
    seg_image = seg_image.resize((size,size))
    np_seg = np.array(seg_image,dtype=np.uint8)
    labels = np.unique(np_seg)

    for lb in labels[1:-1]:
      if len(np.where(np_seg == lb)[0]) < (size**2)/4:
        label = lb
        break

    if label != 0:
      break

  np_seg = np.where(np_seg == label,1.0,0)
  np_seg = np.stack((np_seg,np_seg,np_seg),axis = 2)

  return np_seg

Pascal Voc 2007의 Object Segmentation에 사용된 라벨을 이미지 파일로 엽니다. 서로 다른 object는 서로 다른 라벨로 픽셀에 표시되어 있습니다. 따라서 np.unique()로 어느 라벨이 있는지 알아냅니다.

이후 0번과 255번을 제외한 라벨 중 이미지 전체에 차지하는 비율이 1/4인 것 하나를 고릅니다. (if len(np.where(np_seg == lb)[0]) < (size**2)/4 부분)

마지막으로 정해진 라벨이 있는 곳을 1로 바꾸고 dataset이미지와 연산을 하기 위해 3체널로 만들어줍니다.

class Data(Dataset):
  def __init__(self, size= 128):
    self.city_list = os.listdir("/content/city/")
    self.nature_list = os.listdir("/content/nature")
    self.every_list = []
    self.to_tensor = transforms.ToTensor()
    self.size = size

    for x in self.city_list:
      self.every_list.append("/content/city/"+x)
    for x in self.nature_list:
      self.every_list.append("/content/nature/"+x)
    
    self.every_list.sort()

  def __len__(self):
    return len(self.every_list)
  
  def __getitem__(self,idx):
    mask = make_mask(self.size)
    image = Image.open(self.every_list[idx]).convert("RGB")

    image = image.crop((100,200,image.size[0],image.size[1]))
    image = image.resize((self.size,self.size))
    image = np.array(image,dtype=np.uint8)
    image = image/255

    #The missing region in the masked input image is filled with constant mean value.
    masked_image = (1-mask)*image + mask*(np.zeros_like(image)+np.mean(image))
    
    return self.to_tensor(masked_image).type(torch.float32), self.to_tensor(mask).type(torch.float32), self.to_tensor(image).type(torch.float32)

Dataset class를 만들어 줍니다. 큰 과정은 이미지와 마스크를 불러와 이미지에 마스크를 적용시키는 것입니다.
image = image.crop((100,200,image.size[0],image.size[1])) 으로 google streetview에 방향이 표시된 부분을 없앱니다. 그리고 Dataset instance를 생성할 때 사이즈를 지정하도록 했습니다.
마스크를 적용하는 코드는 아래 코드입니다.
masked_image = (1-mask)*image + mask*(np.zeros_like(image)+np.mean(image))

Model

Encoder, Decoder, Discriminator는 각각 아래와 같이 작성하였습니다.

class Encoder(nn.Module):
  def __init__(self):
    super(Encoder,self).__init__()
    self.encoder = nn.Sequential(nn.Conv2d(in_channels=3,out_channels=64,kernel_size=4,stride=2,padding=1),
                                 nn.BatchNorm2d(64),
                                 nn.LeakyReLU(0.2),
                                 nn.Conv2d(in_channels=64,out_channels=64,kernel_size=4,stride=2,padding=1),
                                 nn.BatchNorm2d(64),
                                 nn.LeakyReLU(0.2),
                                 nn.Conv2d(in_channels=64,out_channels=128,kernel_size=4,stride=2,padding=1),
                                 nn.BatchNorm2d(128),
                                 nn.LeakyReLU(0.2),
                                 nn.Conv2d(in_channels=128,out_channels=256,kernel_size=4,stride=2,padding=1),
                                 nn.BatchNorm2d(256),
                                 nn.LeakyReLU(0.2),
                                 nn.Conv2d(in_channels=256,out_channels=512,kernel_size=4,stride=2,padding=1),
                                 nn.BatchNorm2d(512),
                                 nn.LeakyReLU(0.2),
                                 nn.Conv2d(in_channels=512,out_channels=4000,kernel_size=4,stride=1,padding=0),
                                 nn.BatchNorm2d(4000),
                                 nn.LeakyReLU(0.2)
                                 )
  def forward(self, x):
    x = self.encoder(x)
    return x
class Decoder(nn.Module):
  def __init__(self):
    super(Decoder,self).__init__()
    self.decoder = nn.Sequential(nn.ConvTranspose2d(in_channels=4000,out_channels=512,kernel_size=4,stride=1,padding=0),
                                 nn.BatchNorm2d(512),
                                 nn.ReLU(),
                                 nn.ConvTranspose2d(in_channels=512,out_channels=256,kernel_size=4,stride=2,padding=1),
                                 nn.BatchNorm2d(256),
                                 nn.ReLU(),
                                 nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=4,stride=2,padding=1),
                                 nn.BatchNorm2d(128),
                                 nn.ReLU(),
                                 nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=4,stride=2,padding=1),
                                 nn.BatchNorm2d(64),
                                 nn.ReLU(),
                                 nn.ConvTranspose2d(in_channels=64,out_channels=64,kernel_size=4,stride=2,padding=1),
                                 nn.BatchNorm2d(64),
                                 nn.ReLU(),
                                 nn.ConvTranspose2d(in_channels=64,out_channels=3,kernel_size=4,stride=2,padding=1),
                                 nn.ReLU(),
                                 ) 
  def forward(self, x):
    x = self.decoder(x)
    return x
class ContextEncoder(nn.Module):
  def __init__(self):
    super(ContextEncoder,self).__init__()
    self.encoder = Encoder()
    self.decoder = Decoder()
  
  def forward(self, x):
    bottleneck = self.encoder(x)
    images = self.decoder(bottleneck)
    return images
class Discriminator(nn.Module):
    def __init__(self):
      super(Discriminator,self).__init__()
      self.encoder = nn.Sequential(nn.Conv2d(in_channels=3,out_channels=64,kernel_size=4,stride=2,padding=1),
                                  nn.BatchNorm2d(64),
                                  nn.LeakyReLU(0.2),
                                   nn.Conv2d(in_channels=64,out_channels=64,kernel_size=4,stride=2,padding=1),
                                  nn.BatchNorm2d(64),
                                  nn.LeakyReLU(0.2),
                                  nn.Conv2d(in_channels=64,out_channels=128,kernel_size=4,stride=2,padding=1),
                                  nn.BatchNorm2d(128),
                                  nn.LeakyReLU(0.2),
                                  nn.Conv2d(in_channels=128,out_channels=256,kernel_size=4,stride=2,padding=1),
                                  nn.BatchNorm2d(256),
                                  nn.LeakyReLU(0.2),
                                  nn.Conv2d(in_channels=256,out_channels=512,kernel_size=4,stride=2,padding=1),
                                  nn.BatchNorm2d(512),
                                  nn.LeakyReLU(0.2),
                                  nn.Conv2d(in_channels=512,out_channels=1,kernel_size=4,stride=1,padding=0),
                                  nn.BatchNorm2d(1),
                                  nn.LeakyReLU(0.2),
                                  nn.Flatten(),
                                  nn.Sigmoid()
                                  )
    def forward(self, x):
      x = self.encoder(x)
      return x

Training 과정 표시

Training이 되는 동안 중간 과정을 보기 위한 코드입니다. 우선 이미지를 저장하는 코드입니다.

def saveimg(masked,mask,result,location):
  #mask된 이미지, 복구된 이미지 4세트 저장
  buffer_masked = masked.detach().numpy()
  buffer_masked = np.concatenate(buffer_masked.transpose(0,2,3,1))

  buffer_mask = mask.detach().numpy()
  buffer_mask = np.concatenate(buffer_mask.transpose(0,2,3,1))

  buffer_result = result.detach().numpy()
  buffer_result = np.concatenate(buffer_result.transpose(0,2,3,1))

  buffer_result = buffer_mask*buffer_result + (1-buffer_mask)*buffer_masked

  size = buffer_masked.shape[1]
  sample = np.zeros([size*2,size*6,3])

  sample[:,:size,:] = buffer_masked[:size*2,:,:] #masked 1열
  sample[:,size:2*size,:] = buffer_mask[:size*2,:,:] #mask 1열
  sample[:,2*size:3*size] = buffer_result[:size*2,:,:] #result 1열

  sample[:,3*size:4*size,:] = buffer_masked[size*2:,:,:] #masked 2열
  sample[:,4*size:5*size,:] = buffer_mask[size*2:,:,:] #mask 2열
  sample[:,5*size:6*size,:] = buffer_result[size*2:,:,:] #mask 2열

  sample = (sample - sample.min()) / (sample.max() - sample.min()) #[0,1]

  plt.imsave('{}'.format(location),sample)

matplotlib의 라이브러리를 사용할 수도 있지만 저는 이미지를 크기별로 나눠 합쳤습니다.

writer = SummaryWriter("/content/drive/MyDrive/Inpainting/runs/ContextEncoder")
%load_ext tensorboard
%tensorboard --logdir="/content/drive/MyDrive/Inpainting/runs"

위의 코드는 tensorboard를 사용하기 위한 코드입니다. 셀 안에 출력하고자 magic command(%)를 사용했습니다.

Tensorboard에 나온 최종 Loss
CE_tensorboard

Training

try:
  check_point = torch.load("/content/drive/My Drive/Inpainting/Check_point.pth",map_location=device)
  start_epoch = check_point['epoch']
  start_iter = check_point['iter'] + 1
  load = 1
  print("Check point loaded")
  print("Epoch: {}, Iter: {}".format(start_epoch,start_iter))
except:
  load = 0
  start_epoch = 0
  start_iter = 0
  print("Check point load error")


context_encoder = ContextEncoder()
context_encoder.to(device)
discriminator = Discriminator()
discriminator.to(device)

if load:
  context_encoder.load_state_dict(check_point['CE_recent'])
  discriminator.load_state_dict(check_point['D_recent'])


CE_optimizer = optim.Adam(params= context_encoder.parameters(),lr=0.001) #discriminator learning rate 10배
D_optimizer = optim.Adam(params= discriminator.parameters(),lr= 0.0001)

if load:
  CE_optimizer.load_state_dict(check_point['CE_optim'])
  D_optimizer.load_state_dict(check_point['D_optim'])

bce_loss = nn.BCELoss()

dataset = Data()
dataloader = DataLoader(dataset,batch_size=128,shuffle=True)

Training중간 과정을 저장하는 check point를 불러오고 ContextEncoder와 Discriminator를 만들어줍니다. Optimizer는 Adam을 사용하였고 ContextEncoder의 learning rate는 Discriminator의 10배로 하였습니다. 학습 데이터는 128 batchsize로 입력됩니다.

Epoch = int(100000 / len(dataloader))
term = int(len(dataloader)/1)

D_loss_sum = 0
CE_loss_sum = 0

start = time.time()
for epoch in range(start_epoch,Epoch):

  for iter, (masked, mask, original) in enumerate(dataloader,start_iter):
    masked = masked.to(device)
    mask = mask.to(device)
    original = original.to(device)

    CE_result = context_encoder(masked)
    CE_result = CE_result.to(device)

    D_loss = bce_loss(discriminator(original),torch.ones((original.shape[0],1),device=device)) + bce_loss(discriminator(CE_result.detach()),torch.zeros((original.shape[0],1),device = device))
    D_loss_sum += D_loss.item()
    if (epoch*len(dataloader)+iter+1) % term == 0:
      print("Iter {} | D: {:0.10f}".format(epoch*len(dataloader)+iter+1,D_loss_sum/term),end=" | ")
      writer.add_scalar('D Loss',D_loss_sum / term, epoch * len(dataloader) + iter+1)
      D_loss_sum = 0
      CE_loss_sum = 0
    
    D_optimizer.zero_grad()
    D_loss.backward()
    D_optimizer.step()

    l_rec = rec_loss = torch.sum(torch.square(mask*(original - CE_result)))
    l_adv = bce_loss(discriminator(CE_result),torch.ones((original.shape[0],1),device=device))
    joint_loss = 0.999*l_rec + 0.001*l_adv
    CE_loss_sum += joint_loss.item()

    if (epoch*len(dataloader)+iter+1) % term == 0:
      end = time.time()
      print("CE: {:0.10f} | Duration: {}min".format(CE_loss_sum/term,int((end-start)/60)))
      writer.add_scalar('CE Loss',CE_loss_sum / term, epoch * len(dataloader) + iter+1)
      state = {
          'epoch' : epoch,
          'iter' : iter,
          'CE_recent' : context_encoder.state_dict(),
          'D_recent' : discriminator.state_dict(),
          'D_optim' : D_optimizer.state_dict(),
          'CE_optim' : CE_optimizer.state_dict()
      }
      torch.save(state,"/content/drive/My Drive/Inpainting/Check_point.pth")
      start = time.time()
      D_loss_sum = 0
      CE_loss_sum = 0
    
    CE_optimizer.zero_grad()
    joint_loss.backward()
    CE_optimizer.step()

    if (epoch * len(dataloader) + iter+1) % (5*term) == 0:
      
      saveimg(masked[:4].to('cpu'),mask[:4].to('cpu'),CE_result[:4].to('cpu'),"/content/drive/MyDrive/Inpainting/result/"+"Iter"+"%d.png"%(epoch * len(dataloader) + iter+1))
      torch.save(context_encoder.state_dict(),"/content/drive/My Drive/Inpainting/result/CE_recent.pth")
      torch.save(discriminator.state_dict(),"/content/drive/My Drive/Inpainting/result/D_recent.pth")

실제로 학습이 이뤄지는 부분입니다. D_loss는 discriminator가 원본 이미지를 원본이라고 예측하는 정도 + 복원 이미지를 가짜로 예측하는 정도로 정의합니다. l_adv는 discriminator가 복원 이미지를 원본이라고 예측하는 정도로 context encoder입장에서는 결과가 1에 가까울수록 좋은것 입니다. joint_loss는 픽셀 차이를 나타내는 l_rec와 l_adv를 더한 것입니다.

Discriminator와 context encoder를 번갈아가며 업데이트 합니다. 주의할 점은 D_loss에서 CE_result.detach()를 해주어야 한다는 점입니다. Context encoder에서 나온 복원된 이미지는 학습되어야하는 값이 아니라 Discriminator에 주어지는 입력이기 때문입니다.

Epoch는 최종적으로 100K iteration을 돌 수 있도록 코드를 작성 해 줍니다. term은 저장, 결과출력 등을 하는 최소 단위를 정하기 위한 변수입니다. 나누는 값을 변경하여 term을 조절할 수 있습니다.

Evaluation 코드

Pascal Voc 2012에서 특정 사물을 골라 마스크를 씌우고 그 자리를 뒷 배경으로 복원하도록 하였습니다. 이러한 작업을 위해 세가지 함수를 만들었습니다.

def make_sample(size,batch_size):
  original_list = os.listdir("/content/VOCdevkit/VOC2012/JPEGImages/")

  while True:
    try:
      item = np.random.randint(0,len(original_list))
      original_image = Image.open("/content/VOCdevkit/VOC2012/JPEGImages/"+original_list[item])
      seg_image = Image.open("/content/VOCdevkit/VOC2012/SegmentationObject/"+original_list[item][:-3]+"png")
    except:
      continue

    original_image = original_image.resize((size,size))
    np_original = np.array(original_image,dtype=np.uint8)/255

    seg_image = seg_image.resize((size,size))
    np_seg = np.array(seg_image,dtype=np.uint8)

    labels = np.unique(np_seg)
    label = 0
    for lb in labels[1:-1]:
      if len(np.where(np_seg == lb)[0]) < (size**2)/4:
        label = lb
        break

    if label != 0:
      break
  
  np_seg = np.where(np_seg == labels[1],1.0,0)
  np_seg = np.stack((np_seg,np_seg,np_seg),axis = 2)

  masked_image = (1-np_seg)*np_original + np_seg*(np.zeros_like(np_original)+np.mean(np_original))

  return np_original, np_seg, masked_image

첫 번째는 make_sample입니다. Training 코드의 방법과 같이 이미지에 마스크를 씌웁니다.

def testset(batch_size,size=128):

  l_original = []
  l_seg = []
  l_masked = []
  for i in range(batch_size):
    np_original, np_seg, masked_image = make_sample(size,batch_size)
    l_original.append(np_original.transpose(2,0,1))
    l_seg.append(np_seg.transpose(2,0,1))
    l_masked.append(masked_image.transpose(2,0,1))
  
  return torch.tensor(l_original,dtype=torch.float32), torch.tensor(l_seg,dtype=torch.float32), torch.tensor(l_masked,dtype=torch.float32)

두 번째는 testset입니다. 내부에 make_sample을 활용하여 출력으로 원본 이미지, 마스크, 마스크가 적용된 이미지를 출력합니다.

def show_result(original, mask, masked, result):

  original = original.detach().numpy().transpose(0,2,3,1)
  mask = mask.detach().numpy().transpose(0,2,3,1)
  masked = masked.detach().numpy().transpose(0,2,3,1)
  result = result.detach().numpy().transpose(0,2,3,1)

  result = mask*result + (1-mask)*masked

  fig = plt.figure(figsize=(12,10))
  
  batch = original.shape[0]
  for i in range(batch):
    fig.add_subplot(batch,4,4*i+1)
    plt.imshow(original[i],aspect='auto')
    plt.axis('off')
    fig.add_subplot(batch,4,4*i+2)
    plt.imshow(mask[i],aspect='auto')
    plt.axis('off')
    fig.add_subplot(batch,4,4*i+3)
    plt.imshow(masked[i],aspect='auto')
    plt.axis('off')
    fig.add_subplot(batch,4,4*i+4)
    plt.imshow(result[i],aspect='auto')
    plt.axis('off')
    

  plt.subplots_adjust(wspace=0.1,hspace=0)
  plt.show()

세 번째는 show_result입니다. 원본 이미지, 마스크, 마스크가 적용된 이미지, 복원 결과를 matplotlib를 활용해 한눈에 볼 수 있게 합니다. Training코드의 saveimg에서 한 방식과 matplotlib를 활용하는 방식 모두 가능합니다.

마지막으로 결과를 출력하는 코드를 작성합니다.

model = ContextEncoder()
model.load_state_dict(torch.load("/content/drive/My Drive/Inpainting/result/CE_recent.pth",map_location=device))

original, mask, masked = testset(4)
model.eval() #필수
result = model(masked)

show_result(original,mask,masked,result)

결과

논문에서는 100K iteration을 한다고 하였지만 앞서 나온 tensorboard이미지에서 알 수 있듯 50K iteration에서 loss가 줄지 않고 중간 결과로 나오는 복원 이미지도 성능이 충분하여 50K iteration에서 학습을 중단 하였습니다. 아래 사진은 그 결과입니다.

CE_result1

CE_result1

가장 왼쪽 첫 번째 열이 원본 이미지, 두 번째 열이 Pascal Voc 2012에서 제공하는 마스크, 세 번째 열이 마스크가 적용된 이미지, 네 번째 열이 최종 결과 입니다.

결론

Context Encoder를 사용하여 상당히 자연스러운 inpainting을 할 수 있습니다. 처음에 목표한 대로 지워지는 물체의 뒷 부분이 풍경이나 넓은 공간이면 우수한 결과가 나옵니다.

하지만 좁은 공간이나 풍경이 없는 이미지에서는 어색한 결과가 나왔습니다. 마스크가 지우고자 하는 물체를 완전히 덮지 못하면 물체의 윤곽이 남아있는 경우도 있었습니다. 또 결과 이미지의 공항사진처럼 주변에 많은 물체가 있는 이미지에서 하나의 대상을 지우면 지운 부분이 흐릿한 물체들로 채워져 어색함이 있습니다.

Context Encoder이후에 나온 U-Net, Pix2Pix등 더 좋은 성능의 모델을 사용하면 성능을 개선할 수 있을 것입니다.

후기

Image inpainting의 개념을 이해하고 실제로 실습 해 볼 수 있었습니다.
다음에는 Object detection모델로 직접 mask를 만들어 사물 삭제 기능의 전과정을 구현해 보고 싶습니다. 그 때는 최신의 inpainting 모델을 사용하여 성능을 향상 시킬 수 있을 것입니다.

또 Context Encoder 논문에서 inpainting에 사용된 네트워크를 backbone으로 활용하여 classification, detection, segmentation에 적용하는 방식으로 성능을 평가 하였는데 평가와 관련된 내용도 공부 해야겠습니다.

전체 코드: https://github.com/HyungjoByun/Projects/tree/main/Context%20Encoder

Leave a comment