【PyTorch】多对象分割项目
创始人
2024-11-15 19:35:14
0

 【PyTorch】单对象分割项目

对象分割任务的目标是找到图像中目标对象的边界。实际应用例如自动驾驶汽车和医学成像分析。这里将使用PyTorch开发一个深度学习模型来完成多对象分割任务。多对象分割的主要目标是自动勾勒出图像中多个目标对象的边界。

对象的边界通常由与图像大小相同的分割掩码定义,在分割掩码中属于目标对象的所有像素基于预定义的标记被标记为相同。

目录

创建数据集

创建数据加载器

创建模型

部署模型

定义损失函数和优化器

训练和验证模型


创建数据集

from torchvision.datasets import VOCSegmentation from PIL import Image    from torchvision.transforms.functional import to_tensor, to_pil_image  class myVOCSegmentation(VOCSegmentation):     def __getitem__(self, index):         img = Image.open(self.images[index]).convert('RGB')         target = Image.open(self.masks[index])          if self.transforms is not None:             augmented= self.transforms(image=np.array(img), mask=np.array(target))             img = augmented['image']             target = augmented['mask']                               target[target>20]=0          img= to_tensor(img)                     target= torch.from_numpy(target).type(torch.long)         return img, target  from albumentations import (     HorizontalFlip,     Compose,     Resize,     Normalize)  mean = [0.485, 0.456, 0.406]  std = [0.229, 0.224, 0.225] h,w=520,520  transform_train = Compose([ Resize(h,w),                 HorizontalFlip(p=0.5),                  Normalize(mean=mean,std=std)])  transform_val = Compose( [ Resize(h,w),                           Normalize(mean=mean,std=std)])              path2data="./data/"     train_ds=myVOCSegmentation(path2data,                  year='2012',                  image_set='train',                  download=False,                  transforms=transform_train)  print(len(train_ds))    val_ds=myVOCSegmentation(path2data,                  year='2012',                  image_set='val',                  download=False,                  transforms=transform_val) print(len(val_ds))   
import torch import numpy as np from skimage.segmentation import mark_boundaries import matplotlib.pylab as plt %matplotlib inline np.random.seed(0) num_classes=21 COLORS = np.random.randint(0, 2, size=(num_classes+1, 3),dtype="uint8")  def show_img_target(img, target):     if torch.is_tensor(img):         img=to_pil_image(img)         target=target.numpy()     for ll in range(num_classes):         mask=(target==ll)         img=mark_boundaries(np.array(img) ,                              mask,                             outline_color=COLORS[ll],                             color=COLORS[ll])     plt.imshow(img)   def re_normalize (x, mean = mean, std= std):     x_r= x.clone()     for c, (mean_c, std_c) in enumerate(zip(mean, std)):         x_r [c] *= std_c         x_r [c] += mean_c     return x_r

 展示训练数据集示例图像

img, mask = train_ds[10] print(img.shape, img.type(),torch.max(img)) print(mask.shape, mask.type(),torch.max(mask))  plt.figure(figsize=(20,20))  img_r= re_normalize(img) plt.subplot(1, 3, 1)  plt.imshow(to_pil_image(img_r))  plt.subplot(1, 3, 2)  plt.imshow(mask)  plt.subplot(1, 3, 3)  show_img_target(img_r, mask)     

展示验证数据集示例图像

img, mask = val_ds[10] print(img.shape, img.type(),torch.max(img)) print(mask.shape, mask.type(),torch.max(mask))  plt.figure(figsize=(20,20))  img_r= re_normalize(img) plt.subplot(1, 3, 1)  plt.imshow(to_pil_image(img_r))  plt.subplot(1, 3, 2)  plt.imshow(mask)  plt.subplot(1, 3, 3)  show_img_target(img_r, mask)

创建数据加载器

 通过torch.utils.data针对训练和验证集分别创建Dataloader,打印示例观察效果

from torch.utils.data import DataLoader train_dl = DataLoader(train_ds, batch_size=4, shuffle=True) val_dl = DataLoader(val_ds, batch_size=8, shuffle=False)   for img_b, mask_b in train_dl:     print(img_b.shape,img_b.dtype)     print(mask_b.shape, mask_b.dtype)     break  for img_b, mask_b in val_dl:     print(img_b.shape,img_b.dtype)     print(mask_b.shape, mask_b.dtype)     break

创建模型

创建并打印deeplab_resnet模型结构,使用预训练权重

from torchvision.models.segmentation import deeplabv3_resnet101 import torch  model=deeplabv3_resnet101(pretrained=True, num_classes=21) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model=model.to(device) print(model)

部署模型

在验证数据集的数据批次上部署模型观察效果 

from torch import nn  model.eval() with torch.no_grad():     for xb, yb in val_dl:         yb_pred = model(xb.to(device))         yb_pred = yb_pred["out"].cpu()         print(yb_pred.shape)             yb_pred = torch.argmax(yb_pred,axis=1)         break print(yb_pred.shape)  plt.figure(figsize=(20,20))  n=2 img, mask= xb[n], yb_pred[n] img_r= re_normalize(img) plt.subplot(1, 3, 1)  plt.imshow(to_pil_image(img_r))  plt.subplot(1, 3, 2)  plt.imshow(mask)  plt.subplot(1, 3, 3)  show_img_target(img_r, mask)

可见勾勒对象方面效果很好 

定义损失函数和优化器

from torch import nn criterion = nn.CrossEntropyLoss(reduction="sum")
from torch import optim opt = optim.Adam(model.parameters(), lr=1e-6)  def loss_batch(loss_func, output, target, opt=None):        loss = loss_func(output, target)          if opt is not None:         opt.zero_grad()         loss.backward()         opt.step()      return loss.item(), None  from torch.optim.lr_scheduler import ReduceLROnPlateau lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)  def get_lr(opt):     for param_group in opt.param_groups:         return param_group['lr']  current_lr=get_lr(opt) print('current lr={}'.format(current_lr))

训练和验证模型

def loss_epoch(model,loss_func,dataset_dl,sanity_check=False,opt=None):     running_loss=0.0     len_data=len(dataset_dl.dataset)      for xb, yb in dataset_dl:         xb=xb.to(device)         yb=yb.to(device)                  output=model(xb)["out"]         loss_b, _ = loss_batch(loss_func, output, yb, opt)         running_loss += loss_b                  if sanity_check is True:             break          loss=running_loss/float(len_data)     return loss, None  import copy def train_val(model, params):     num_epochs=params["num_epochs"]     loss_func=params["loss_func"]     opt=params["optimizer"]     train_dl=params["train_dl"]     val_dl=params["val_dl"]     sanity_check=params["sanity_check"]     lr_scheduler=params["lr_scheduler"]     path2weights=params["path2weights"]          loss_history={         "train": [],         "val": []}          metric_history={         "train": [],         "val": []}                   best_model_wts = copy.deepcopy(model.state_dict())     best_loss=float('inf')              for epoch in range(num_epochs):         current_lr=get_lr(opt)         print('Epoch {}/{}, current lr={}'.format(epoch, num_epochs - 1, current_lr))             model.train()         train_loss, train_metric=loss_epoch(model,loss_func,train_dl,sanity_check,opt)          loss_history["train"].append(train_loss)         metric_history["train"].append(train_metric)                  model.eval()         with torch.no_grad():             val_loss, val_metric=loss_epoch(model,loss_func,val_dl,sanity_check)                 loss_history["val"].append(val_loss)         metric_history["val"].append(val_metric)                     if val_loss < best_loss:             best_loss = val_loss             best_model_wts = copy.deepcopy(model.state_dict())                          torch.save(model.state_dict(), path2weights)             print("Copied best model weights!")                      lr_scheduler.step(val_loss)         if current_lr != get_lr(opt):             print("Loading best model weights!")             model.load_state_dict(best_model_wts)                       print("train loss: %.6f" %(train_loss))         print("val loss: %.6f" %(val_loss))         print("-"*10)      model.load_state_dict(best_model_wts)     return model, loss_history, metric_history         
import os opt = optim.Adam(model.parameters(), lr=1e-6) lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)  path2models= "./models/" if not os.path.exists(path2models):         os.mkdir(path2models)  params_train={     "num_epochs": 10,     "optimizer": opt,     "loss_func": criterion,     "train_dl": train_dl,     "val_dl": val_dl,     "sanity_check": True,     "lr_scheduler": lr_scheduler,     "path2weights": path2models+"sanity_weights.pt", }  model, loss_hist, _ = train_val(model, params_train)

绘制了训练和验证损失曲线 

num_epochs=params_train["num_epochs"]  plt.title("Train-Val Loss") plt.plot(range(1,num_epochs+1),loss_hist["train"],label="train") plt.plot(range(1,num_epochs+1),loss_hist["val"],label="val") plt.ylabel("Loss") plt.xlabel("Training Epochs") plt.legend() plt.show()

相关内容

热门资讯

类型推断技术及仓颉语言实践 史磊仓颉语言类型推断技术专家一、一种看待类型系统的方式一门编程语言一定得包含类型系统吗?...
Spring为啥用三级缓存而不... 1. 背景介绍Spring框架:一个广泛应用于Java应用开发的轻量级容器框架。三级缓...
LLAma-Factory框架... 一、安装 LLaMA Factorygit clone --depth 1 https://gith...
人工智能未来展望 深入剖析人工智能:现状、挑战与未来展望在当今这个日新月异的科技时代,人工...
Stable Diffusio... 今天分享一个用Stable Diffusion换背景的小教程。在以往为产品或照片更换背景时ÿ...
动手学强化学习 第 17 章 ... 基于 https://github.com/boyu-ai/Hands-on-RL/blob/mai...
关于信创终端的10个基础问题与... 原文链接:关于信创终端的10个基础问题与答复1. 龙芯架构(LoongA...
项目实战——外挂开发(30小时... 项目实战——外挂开发(30小时精通C++和外挂实战)外挂...
从HTTP2的多路复用,到与T... 在讨论题目之前无疑涉及到两个知识点:什么是HTTP什么是TCP在了解以上连个知识点后&...
一文讲透如何使用CSS美化HT...         上一篇文章(初识HTML文件,创建自己的第一个网页!)介...