博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
pytorch魔改data_set,帮助DataLoader实现enumerate(test_loader)载入image、target、name、oriimg
阅读量:2050 次
发布时间:2019-04-28

本文共 19456 字,大约阅读时间需要 64 分钟。

以项目pytorch-deeplab-xception为例:

测试代码:

def test(self):        self.model.eval()        self.evaluator.reset()        # tbar = tqdm(self.test_loader, desc='\r')        for i, sample in enumerate(self.test_loader):            image, target = sample['image'], sample['label']            with torch.no_grad():                output = self.model(image)            pred = output.data.cpu().numpy()            target = target.cpu().numpy()

这里通过枚举函数,生成的只有image和target,

但是image是已经数据增强过的,有可能已经改变,而且经过Totensor()函数和均值方差,

已经不能适应后续我们可视化任务的需要,

这里我们还需要name、oriimg(resize后的原始图像,以帮助我们可视化)

def test(self):        self.model.eval()        self.evaluator.reset()        # tbar = tqdm(self.test_loader, desc='\r')        num = len(self.test_loader)        for i, sample in enumerate(self.test_loader):            image, target = sample['image'], sample['label']            print(i,"/",num)            torch.cuda.synchronize()            start = time.time()            with torch.no_grad():                output = self.model(image)            end = time.time()            times = (end - start) * 1000            print(times, "ms")            torch.cuda.synchronize()            pred = output.data.cpu().numpy()            pred = np.argmax(pred, axis=1)            target = target.cpu().numpy()

原始的dateset代码:

train_set = coco.COCOSegmentation(args, split='train')

val_set = coco.COCOSegmentation(args, split='val'

train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)

val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)

from dataloaders.datasets import cityscapes, coco, combine_dbs, pascal, sbdfrom torch.utils.data import DataLoaderdef make_data_loader(args, **kwargs):    elif args.dataset == 'coco':        train_set = coco.COCOSegmentation(args, split='train')        val_set = coco.COCOSegmentation(args, split='val')        num_class = train_set.NUM_CLASSES        train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)        val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)        test_loader = None        return train_loader, val_loader, test_loader, num_class    else:        raise NotImplementedError

通过COCO处理函数得到COCO的dataset:

import numpy as npimport torchfrom torch.utils.data import Datasetfrom mypath import Pathfrom tqdm import trangeimport osfrom pycocotools.coco import COCOfrom pycocotools import maskfrom torchvision import transformsfrom dataloaders import custom_transforms as trfrom PIL import Image, ImageFileImageFile.LOAD_TRUNCATED_IMAGES = Trueclass COCOSegmentation(Dataset):    NUM_CLASSES = 21    CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4,        1, 64, 20, 63, 7, 72]    def __init__(self,                 args,                 base_dir=Path.db_root_dir('coco'),                 split='train',                 year='2017'):        super().__init__()        ann_file = os.path.join(base_dir, 'annotations/instances_{}{}.json'.format(split, year))        ids_file = os.path.join(base_dir, 'annotations/{}_ids_{}.pth'.format(split, year))        self.img_dir = os.path.join(base_dir, 'images/{}{}'.format(split, year))        self.split = split        self.coco = COCO(ann_file)        self.coco_mask = mask        if os.path.exists(ids_file):            self.ids = torch.load(ids_file)        else:            ids = list(self.coco.imgs.keys())            self.ids = self._preprocess(ids, ids_file)        self.args = args    def __getitem__(self, index):        _img, _target = self._make_img_gt_point_pair(index)        sample = {'image': _img, 'label': _target}        if self.split == "train":            return self.transform_tr(sample)        elif self.split == 'val':            return self.transform_val(sample)    def _make_img_gt_point_pair(self, index):        coco = self.coco        img_id = self.ids[index]        img_metadata = coco.loadImgs(img_id)[0]        path = img_metadata['file_name']        _img = Image.open(os.path.join(self.img_dir, path)).convert('RGB')        cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id))        _target = Image.fromarray(self._gen_seg_mask(            cocotarget, img_metadata['height'], img_metadata['width']))        return _img, _target    def _preprocess(self, ids, ids_file):        print("Preprocessing mask, this will take a while. " + \              "But don't worry, it only run once for each split.")        tbar = trange(len(ids))        new_ids = []        for i in tbar:            img_id = ids[i]            cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id))            img_metadata = self.coco.loadImgs(img_id)[0]            mask = self._gen_seg_mask(cocotarget, img_metadata['height'],                                      img_metadata['width'])            # more than 1k pixels            if (mask > 0).sum() > 1000:                new_ids.append(img_id)            tbar.set_description('Doing: {}/{}, got {} qualified images'. \                                 format(i, len(ids), len(new_ids)))        print('Found number of qualified images: ', len(new_ids))        torch.save(new_ids, ids_file)        return new_ids    def _gen_seg_mask(self, target, h, w):        mask = np.zeros((h, w), dtype=np.uint8)        coco_mask = self.coco_mask        for instance in target:            rle = coco_mask.frPyObjects(instance['segmentation'], h, w)            m = coco_mask.decode(rle)            cat = instance['category_id']            if cat in self.CAT_LIST:                c = self.CAT_LIST.index(cat)            else:                continue            if len(m.shape) < 3:                mask[:, :] += (mask == 0) * (m * c)            else:                mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8)        return mask    def transform_tr(self, sample):        composed_transforms = transforms.Compose([            tr.RandomHorizontalFlip(),            tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),            tr.RandomGaussianBlur(),            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),            tr.ToTensor()])        return composed_transforms(sample)    def transform_val(self, sample):        composed_transforms = transforms.Compose([            tr.FixScaleCrop(crop_size=self.args.crop_size),            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),            tr.ToTensor()])        return composed_transforms(sample)    def __len__(self):        return len(self.ids)

这里使用的数据增强为自己写的函数,不是pytorch中的transforms

但是使用了transforms的容器函数:transforms.Compose()

数据增强函数都有:

class Normalize(object):

class ToTensor(object):

class RandomHorizontalFlip(object):

class RandomRotate(object):

class RandomGaussianBlur(object):

class RandomScaleCrop(object):

class FixScaleCrop(object):

class FixedResize(object):

这里有一个流程,

pytorch-deeplab-xception/train.py

self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs) 

pytorch-deeplab-xception/dataloaders/__init__.py

train_set = cityscapes.CityscapesSegmentation(args, split='train') 

pytorch-deeplab-xception/dataloaders/datasets/coco.py

def __getitem__(self, index):

def __getitem__(self, index):        _img, _target = self._make_img_gt_point_pair(index)        sample = {'image': _img, 'label': _target}        if self.split == "train":            return self.transform_tr(sample)        elif self.split == 'val':            return self.transform_val(sample)

 def _make_img_gt_point_pair(self, index):

def _make_img_gt_point_pair(self, index):        coco = self.coco        img_id = self.ids[index]        img_metadata = coco.loadImgs(img_id)[0]        path = img_metadata['file_name']        _img = Image.open(os.path.join(self.img_dir, path)).convert('RGB')        cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id))        _target = Image.fromarray(self._gen_seg_mask(            cocotarget, img_metadata['height'], img_metadata['width']))        return _img, _target

def transform_tr(self, sample): 

def transform_tr(self, sample):        composed_transforms = transforms.Compose([            tr.RandomHorizontalFlip(),            tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),            tr.RandomGaussianBlur(),            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),            tr.ToTensor()])        return composed_transforms(sample)

 

pytorch-deeplab-xception/dataloaders/custom_transforms.py

class Normalize(object):

class ToTensor(object):

class FixScaleCrop(object):

class FixScaleCrop(object):    def __init__(self, crop_size):        self.crop_size = crop_size    def __call__(self, sample):        img = sample['image']        mask = sample['label']        w, h = img.size        if w > h:            oh = self.crop_size            ow = int(1.0 * w * oh / h)        else:            ow = self.crop_size            oh = int(1.0 * h * ow / w)        img = img.resize((ow, oh), Image.BILINEAR)        mask = mask.resize((ow, oh), Image.NEAREST)        # center crop        w, h = img.size        x1 = int(round((w - self.crop_size) / 2.))        y1 = int(round((h - self.crop_size) / 2.))        img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))        mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))        return {'image': img,                'label': mask}
class Normalize(object):    """Normalize a tensor image with mean and standard deviation.    Args:        mean (tuple): means for each channel.        std (tuple): standard deviations for each channel.    """    def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):        self.mean = mean        self.std = std    def __call__(self, sample):        img = sample['image']        mask = sample['label']        img = np.array(img).astype(np.float32)        mask = np.array(mask).astype(np.float32)        img /= 255.0        img -= self.mean        img /= self.std        return {'image': img,                'label': mask}
class ToTensor(object):    """Convert ndarrays in sample to Tensors."""    def __call__(self, sample):        # swap color axis because        # numpy image: H x W x C        # torch image: C X H X W        img = sample['image']        mask = sample['label']        img = np.array(img).astype(np.float32).transpose((2, 0, 1))        mask = np.array(mask).astype(np.float32)        img = torch.from_numpy(img).float()        mask = torch.from_numpy(mask).float()        return {'image': img,                'label': mask}

改进一下,可以返回原图和路径:

import torchimport randomimport numpy as npfrom PIL import Image, ImageOps, ImageFilterclass Normalize(object):    """Normalize a tensor image with mean and standard deviation.    Args:        mean (tuple): means for each channel.        std (tuple): standard deviations for each channel.    """    def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):        self.mean = mean        self.std = std    def __call__(self, sample):        img = sample['image']        mask = sample['label']        img = np.array(img).astype(np.float32)        mask = np.array(mask).astype(np.float32)        img /= 255.0        img -= self.mean        img /= self.std        # return {'image': img,        #         'label': mask}        #return {'image': img, 'label': mask, 'ori_image': sample['ori_image'], 'path': sample['path']}        return {'image': img, 'label': mask, 'ori_image': sample['ori_image'], 'path': sample['path']}class ToTensor(object):    """Convert ndarrays in sample to Tensors."""    def __call__(self, sample):        # swap color axis because        # numpy image: H x W x C        # torch image: C X H X W        img = sample['image']        mask = sample['label']        # import cv2        # image1 = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)        # target1 = cv2.cvtColor(np.asarray(mask), cv2.COLOR_GRAY2BGR)        # cv2.imwrite("./image5.jpg", image1)        # cv2.imwrite("./target5.jpg", target1)        #        # xxx = np.array(img).astype(np.float32)        # import copy        # xxx1 = copy.deepcopy(xxx)        # xxx2 = copy.deepcopy(xxx)        # img1 = np.array(xxx1).astype(np.float32).transpose((2, 1, 0))        # img2 = np.array(xxx2).astype(np.float32).transpose((2, 0, 1))        img = np.array(img).astype(np.float32).transpose((2, 0, 1))        mask = np.array(mask).astype(np.float32)        img = torch.from_numpy(img).float()        mask = torch.from_numpy(mask).float()        # import cv2        # image1=img.cpu().numpy()        # target1=mask.cpu().numpy()        # image1 = image1.transpose(2, 1, 0)        # image1 = cv2.cvtColor(image1, cv2.COLOR_RGB2BGR)        # target1 = cv2.cvtColor(target1, cv2.COLOR_GRAY2BGR)        # cv2.imwrite("./image4.jpg", image1)        # cv2.imwrite("./target4.jpg", target1)        # return {'image': img,        #         'label': mask}        ori_image = np.array(sample['ori_image']).astype(np.float32).transpose((2, 0, 1))        ori_image = torch.from_numpy(ori_image).float()        #return {'image': img, 'label': mask, 'ori_image': sample['ori_image'], 'path': sample['path']}        return {'image': img, 'label': mask, 'ori_image': ori_image, 'path': sample['path']}class RandomHorizontalFlip(object):    def __call__(self, sample):        img = sample['image']        mask = sample['label']        if random.random() < 0.5:            img = img.transpose(Image.FLIP_LEFT_RIGHT)            mask = mask.transpose(Image.FLIP_LEFT_RIGHT)        # return {'image': img,        #         'label': mask}        return {'image': img, 'label': mask, 'ori_image': sample['ori_image'], 'path': sample['path']}class RandomRotate(object):    def __init__(self, degree):        self.degree = degree    def __call__(self, sample):        img = sample['image']        mask = sample['label']        rotate_degree = random.uniform(-1*self.degree, self.degree)        img = img.rotate(rotate_degree, Image.BILINEAR)        mask = mask.rotate(rotate_degree, Image.NEAREST)        return {'image': img,                'label': mask}class RandomGaussianBlur(object):    def __call__(self, sample):        img = sample['image']        mask = sample['label']        if random.random() < 0.5:            img = img.filter(ImageFilter.GaussianBlur(                radius=random.random()))        # return {'image': img,        #         'label': mask}        return {'image': img, 'label': mask, 'ori_image': sample['ori_image'], 'path': sample['path']}class RandomScaleCrop(object):    def __init__(self, base_size, crop_size, fill=0):        self.base_size = base_size        self.crop_size = crop_size        self.fill = fill    def __call__(self, sample):        img = sample['image']        mask = sample['label']        # random scale (short edge)        short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))        w, h = img.size        if h > w:            ow = short_size            oh = int(1.0 * h * ow / w)        else:            oh = short_size            ow = int(1.0 * w * oh / h)        img = img.resize((ow, oh), Image.BILINEAR)        mask = mask.resize((ow, oh), Image.NEAREST)        # pad crop        if short_size < self.crop_size:            padh = self.crop_size - oh if oh < self.crop_size else 0            padw = self.crop_size - ow if ow < self.crop_size else 0            img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)            mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill)        # random crop crop_size        w, h = img.size        x1 = random.randint(0, w - self.crop_size)        y1 = random.randint(0, h - self.crop_size)        img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))        mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))        #x = mask[mask>1]        return {'image': img, 'label': mask, 'ori_image': img, 'path': sample['path']}        #return {'image': img, 'label': mask, 'ori_image': sample['ori_image'], 'path': sample['path']}        # return {'image': img,        #         'label': mask}class FixScaleCrop(object):    def __init__(self, crop_size):        self.crop_size = crop_size    def __call__(self, sample):        img = sample['image']        mask = sample['label']        w, h = img.size        if w > h:            oh = self.crop_size            ow = int(1.0 * w * oh / h)        else:            ow = self.crop_size            oh = int(1.0 * h * ow / w)        img = img.resize((ow, oh), Image.BILINEAR)        mask = mask.resize((ow, oh), Image.NEAREST)        # center crop        w, h = img.size        x1 = int(round((w - self.crop_size) / 2.))        y1 = int(round((h - self.crop_size) / 2.))        img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))        mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))        # import cv2        # image1 = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)        # target1 = cv2.cvtColor(np.asarray(mask), cv2.COLOR_GRAY2BGR)        # cv2.imwrite("./image3.jpg", image1)        # cv2.imwrite("./target3.jpg", target1)        # return {'image': img,        #         'label': mask,        #         }        #return {'image': img, 'label': mask, 'ori_image': img, 'path': sample['path']}        return {'image': img, 'label': mask, 'ori_image': img, 'path': sample['path']}class FixedResize(object):    def __init__(self, size):        self.size = (size, size)  # size: (h, w)    def __call__(self, sample):        img = sample['image']        mask = sample['label']        assert img.size == mask.size        img = img.resize(self.size, Image.BILINEAR)        mask = mask.resize(self.size, Image.NEAREST)        return {'image': img,                'label': mask}
def __getitem__(self, index):        _img, _target, _path = self._make_img_gt_point_pair(index)        sample = {'image': _img, 'label': _target, 'ori_image': _img, 'path': _path}        if self.split == "train":            return self.transform_tr(sample)        elif self.split == 'val':            return self.transform_val(sample)        elif self.split == 'test':            X = self.transform_val(sample)            # aa = X['image']            # bb = X['label']            #            # aa = aa.cpu().numpy()            # bb = bb.cpu().numpy()            # aa = aa.transpose(2, 1, 0)            # image1 = cv2.cvtColor(aa, cv2.COLOR_RGB2BGR)            # target1 = cv2.cvtColor(bb, cv2.COLOR_GRAY2BGR)            # cv2.imwrite("./image2.jpg", image1)            # cv2.imwrite("./target2.jpg", target1)            return X    def _make_img_gt_point_pair(self, index):        coco = self.coco        img_id = self.ids[index]        img_metadata = coco.loadImgs(img_id)[0]        path = img_metadata['file_name']        _path = path.split('.jpg')[0]        _img = Image.open(os.path.join(self.img_dir, path)).convert('RGB')        cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id))        _target = Image.fromarray(self._gen_seg_mask(            cocotarget, img_metadata['height'], img_metadata['width']))        #_targetx = np.asarray(_target)        #x = _targetx[_targetx > 1]        # image1 = cv2.cvtColor(np.asarray(_img), cv2.COLOR_RGB2BGR)        # target1 = cv2.cvtColor(np.asarray(_target), cv2.COLOR_GRAY2BGR)        # cv2.imwrite("./image1.jpg", image1)        # cv2.imwrite("./target1.jpg", target1)        return _img, _target, _path

 

转载地址:http://tqgof.baihongyu.com/

你可能感兴趣的文章
图解HTTP(三)—— HTTP报文内的HTTP信息
查看>>
图解HTTP(四)—— 返回结果的HTTP状态码
查看>>
JavaWeb高级编程(五)—— 使用会话来维持HTTP状态
查看>>
Intellij IDEA使用(十五)—— 如何在IDEA中一个Tomcat启动多个项目和多个Tomcat启动多个项目
查看>>
图解HTTP(五)—— 与HTTP协作的Web服务器
查看>>
程序员的数学(五)—— 排列组合,解决计数问题的方法
查看>>
前后端分离实践(四)—— 使用vue-cli搭建前端展示层并用mock模拟测试数据
查看>>
前后端分离实践(六)—— 前端与后端在生产环境中的分离部署
查看>>
启航 —— 记 —— 第二次自考的反思:自考与自我改造的困境
查看>>
数据结构与算法(三)——线性表
查看>>
Java8学习笔记(一)—— 函数式编程的四个基本接口
查看>>
Java8学习笔记(二)—— Lambda表达式
查看>>
Java8学习笔记(三)—— Optional类的使用
查看>>
Java8学习笔记(四) —— Stream流式编程
查看>>
Java8学习笔记(五)—— 方法引用(::双冒号操作符)
查看>>
数据结构与算法(四)—— 栈与队列
查看>>
数据结构与算法(五)—— 广义表
查看>>
微服务简介
查看>>
CAP定理
查看>>
Docker初探
查看>>