GitHub地址链接:https://github.com/NathanUA/U-2-Net
这个显著性检测很好用,强烈推荐,建议二分类的任务都来试试,尤其对边缘细节要求比较高的任务。
下面的效果要不是第一张图预测有瑕疵,我都以为预测代码是把标签复制了一下(+_+)
这里的精度我就不评价了,肉眼看就已经能说明问题了
原图
标签
预测结果:
1.数据准备
和传统的语义分割数据集存放目录是一致的,Image和Mask里面是名字一一对应(名字相同,后缀可以不同,代码里可以改后缀识别,jpg png这些都行)的图像—标签对
2.训练
data_loader.py,在222行左右,需要加copy()
# data loader
from __future__ import print_function, division
import glob
import torch
from skimage import io, transform, color
import numpy as np
import random
import math
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image#==========================dataset load==========================
class RescaleT(object):def __init__(self,output_size):assert isinstance(output_size,(int,tuple))self.output_size = output_sizedef __call__(self,sample):imidx, image, label = sample['imidx'], sample['image'],sample['label']h, w = image.shape[:2]if isinstance(self.output_size,int):if h > w:new_h, new_w = self.output_size*h/w,self.output_sizeelse:new_h, new_w = self.output_size,self.output_size*w/helse:new_h, new_w = self.output_sizenew_h, new_w = int(new_h), int(new_w)# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]# img = transform.resize(image,(new_h,new_w),mode='constant')# lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)img = transform.resize(image,(self.output_size,self.output_size),mode='constant')lbl = transform.resize(label,(self.output_size,self.output_size),mode='constant', order=0, preserve_range=True)return {'imidx':imidx, 'image':img,'label':lbl}class Rescale(object):def __init__(self,output_size):assert isinstance(output_size,(int,tuple))self.output_size = output_sizedef __call__(self,sample):imidx, image, label = sample['imidx'], sample['image'],sample['label']if random.random() >= 0.5:image = image[::-1]label = label[::-1]h, w = image.shape[:2]if isinstance(self.output_size,int):if h > w:new_h, new_w = self.output_size*h/w,self.output_sizeelse:new_h, new_w = self.output_size,self.output_size*w/helse:new_h, new_w = self.output_sizenew_h, new_w = int(new_h), int(new_w)# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]img = transform.resize(image,(new_h,new_w),mode='constant')lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)return {'imidx':imidx, 'image':img,'label':lbl}class RandomCrop(object):def __init__(self,output_size):assert isinstance(output_size, (int, tuple))if isinstance(output_size, int):self.output_size = (output_size, output_size)else:assert len(output_size) == 2self.output_size = output_sizedef __call__(self,sample):imidx, image, label = sample['imidx'], sample['image'], sample['label']if random.random() >= 0.5:image = image[::-1]label = label[::-1]h, w = image.shape[:2]new_h, new_w = self.output_sizetop = np.random.randint(0, h - new_h)left = np.random.randint(0, w - new_w)image = image[top: top + new_h, left: left + new_w]label = label[top: top + new_h, left: left + new_w]return {'imidx':imidx,'image':image, 'label':label}class ToTensor(object):"""Convert ndarrays in sample to Tensors."""def __call__(self, sample):imidx, image, label = sample['imidx'], sample['image'], sample['label']tmpImg = np.zeros((image.shape[0],image.shape[1],3))tmpLbl = np.zeros(label.shape)image = image/np.max(image)if(np.max(label)<1e-6):label = labelelse:label = label/np.max(label)if image.shape[2]==1:tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229else:tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225tmpLbl[:,:,0] = label[:,:,0]# change the r,g,b to b,r,g from [0,255] to [0,1]#transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))tmpImg = tmpImg.transpose((2, 0, 1))tmpLbl = label.transpose((2, 0, 1))return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}class ToTensorLab(object):"""Convert ndarrays in sample to Tensors."""def __init__(self,flag=0):self.flag = flagdef __call__(self, sample):imidx, image, label =sample['imidx'], sample['image'], sample['label']tmpLbl = np.zeros(label.shape)if(np.max(label)<1e-6):label = labelelse:label = label/np.max(label)# change the color spaceif self.flag == 2: # with rgb and Lab colorstmpImg = np.zeros((image.shape[0],image.shape[1],6))tmpImgt = np.zeros((image.shape[0],image.shape[1],3))if image.shape[2]==1:tmpImgt[:,:,0] = image[:,:,0]tmpImgt[:,:,1] = image[:,:,0]tmpImgt[:,:,2] = image[:,:,0]else:tmpImgt = imagetmpImgtl = color.rgb2lab(tmpImgt)# nomalize image to range [0,1]tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0]))tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1]))tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2]))tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0]))tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1]))tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2]))# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3])tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4])tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5])elif self.flag == 1: #with Lab colortmpImg = np.zeros((image.shape[0],image.shape[1],3))if image.shape[2]==1:tmpImg[:,:,0] = image[:,:,0]tmpImg[:,:,1] = image[:,:,0]tmpImg[:,:,2] = image[:,:,0]else:tmpImg = imagetmpImg = color.rgb2lab(tmpImg)# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0]))tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1]))tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2]))tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])else: # with rgb colortmpImg = np.zeros((image.shape[0],image.shape[1],3))image = image/np.max(image)if image.shape[2]==1:tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229else:tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225tmpLbl[:,:,0] = label[:,:,0]# change the r,g,b to b,r,g from [0,255] to [0,1]#transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))tmpImg = tmpImg.transpose((2, 0, 1))tmpLbl = label.transpose((2, 0, 1))return {'imidx':torch.from_numpy(imidx.copy()), 'image': torch.from_numpy(tmpImg.copy()), 'label': torch.from_numpy(tmpLbl.copy())} #需要加.copy()class SalObjDataset(Dataset):def __init__(self,img_name_list,lbl_name_list,transform=None):# self.root_dir = root_dir# self.image_name_list = glob.glob(image_dir+'*.png')# self.label_name_list = glob.glob(label_dir+'*.png')self.image_name_list = img_name_listself.label_name_list = lbl_name_listself.transform = transformdef __len__(self):return len(self.image_name_list)def __getitem__(self,idx):# image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])# label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])image = io.imread(self.image_name_list[idx])imname = self.image_name_list[idx]imidx = np.array([idx])if(0==len(self.label_name_list)):label_3 = np.zeros(image.shape)else:label_3 = io.imread(self.label_name_list[idx])label = np.zeros(label_3.shape[0:2])if(3==len(label_3.shape)):label = label_3[:,:,0]elif(2==len(label_3.shape)):label = label_3if(3==len(image.shape) and 2==len(label.shape)):label = label[:,:,np.newaxis]elif(2==len(image.shape) and 2==len(label.shape)):image = image[:,:,np.newaxis]label = label[:,:,np.newaxis]sample = {'imidx':imidx, 'image':image, 'label':label}if self.transform:sample = self.transform(sample)return sample
训练代码基本不用改,主要是路径拼接能找到图像路径就好了
u2net_train.py
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as Ffrom torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.optim as optim
import torchvision.transforms as standard_transformsimport os
import numpy as np
import globfrom data_loader import Rescale
from data_loader import RescaleT
from data_loader import RandomCrop
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDatasetfrom model import U2NET
from model import U2NETP# ------- 1. define loss function --------bce_loss = nn.BCELoss(size_average=True)def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):loss0 = bce_loss(d0,labels_v)loss1 = bce_loss(d1,labels_v)loss2 = bce_loss(d2,labels_v)loss3 = bce_loss(d3,labels_v)loss4 = bce_loss(d4,labels_v)loss5 = bce_loss(d5,labels_v)loss6 = bce_loss(d6,labels_v)loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.item(),loss1.item(),loss2.item(),loss3.item(),loss4.item(),loss5.item(),loss6.item())) #源码里这里有可能报错,如果报了改成和我一样应该就好了return loss0, loss# ------- 2. set the directory of training dataset --------model_name = 'u2net' #'u2netp or u2net' #有两个模型可以选,选择的标记在下面训练开始前data_dir = os.path.join(os.getcwd(), 'data' + os.sep)
# tra_image_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'im_aug' + os.sep)
# tra_label_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'gt_aug' + os.sep)tra_image_dir = os.path.join('RIVER', 'Train', 'Image/') #训练图片的路径
tra_label_dir = os.path.join('RIVER', 'Train', 'Mask/')image_ext = '.png' #图像后缀,改成自己的图片格式
label_ext = '.png'model_dir = os.path.join(os.getcwd(), 'saved_models', model_name + os.sep) #模型存储位置
#训练相关参数
epoch_num = 2000
batch_size_train = 6
batch_size_val = 2
train_num = 0
val_num = 0tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*' + image_ext)tra_lbl_name_list = []
for img_path in tra_img_name_list:img_name = img_path.split(os.sep)[-1]aaa = img_name.split(".")bbb = aaa[0:-1]imidx = bbb[0]for i in range(1,len(bbb)):imidx = imidx + "." + bbb[i]tra_lbl_name_list.append(data_dir + tra_label_dir + imidx + label_ext)print("---")
print("train images: ", len(tra_img_name_list))
print("train labels: ", len(tra_lbl_name_list))
print("---")train_num = len(tra_img_name_list)salobj_dataset = SalObjDataset(img_name_list=tra_img_name_list,lbl_name_list=tra_lbl_name_list,transform=transforms.Compose([RescaleT(320),RandomCrop(288),ToTensorLab(flag=0)]))
salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=0)# ------- 3. define model --------
# define the net
#选择模型
if(model_name=='u2net'):net = U2NET(3, 1)
elif(model_name=='u2netp'):net = U2NETP(3,1)if torch.cuda.is_available():net.cuda()# ------- 4. define optimizer --------
print("---define optimizer...")
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)# ------- 5. training process --------
print("---start training...")
ite_num = 0
running_loss = 0.0
running_tar_loss = 0.0
ite_num4val = 0
save_frq = 2000 # save the model every 2000 iterations 每迭代两千次存一次模型,这个可以在下面改成每个epoch存,很好改,这里我就不改了for epoch in range(0, epoch_num):net.train()for i, data in enumerate(salobj_dataloader):ite_num = ite_num + 1ite_num4val = ite_num4val + 1inputs, labels = data['image'], data['label']inputs = inputs.type(torch.FloatTensor)labels = labels.type(torch.FloatTensor)# wrap them in Variableif torch.cuda.is_available():inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(),requires_grad=False)else:inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)# y zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimized0, d1, d2, d3, d4, d5, d6 = net(inputs_v)loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)loss.backward()optimizer.step()# # print statisticsrunning_loss += loss.item()running_tar_loss += loss2.item()# del temporary outputs and lossdel d0, d1, d2, d3, d4, d5, d6, loss2, lossprint("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % (epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))if ite_num % save_frq == 0:torch.save(net.state_dict(), model_dir + model_name+"_bce_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))running_loss = 0.0running_tar_loss = 0.0net.train() # resume trainite_num4val = 0
3.预测
u2net_test.py
import os
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optimimport numpy as np
from PIL import Image
import globfrom data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDatasetfrom model import U2NET # full size version 173.6 MB
from model import U2NETP # small version u2net 4.7 MB# normalize the predicted SOD probability map
def normPRED(d):ma = torch.max(d)mi = torch.min(d)dn = (d-mi)/(ma-mi)return dndef save_output(image_name,pred,d_dir):predict = predpredict = predict.squeeze()predict_np = predict.cpu().data.numpy()im = Image.fromarray(predict_np*255).convert('RGB')img_name = image_name.split(os.sep)[-1]image = io.imread(image_name)imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)pb_np = np.array(imo)aaa = img_name.split(".")bbb = aaa[0:-1]imidx = bbb[0]for i in range(1,len(bbb)):imidx = imidx + "." + bbb[i]imo.save(d_dir+imidx+'.png')def main():# --------- 1. get image path and name ---------model_name='u2net'#u2netp #模型名字和训练一致# image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')image_dir = 'D:/wcs/U-2-Net/data/RIVER/Test/Image/' #测试图像路径# prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + '_results' + os.sep)prediction_dir = 'D:/wcs/U-2-Net/data/RIVER/Test/pre/' #保存结果路径model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth')img_name_list = glob.glob(image_dir + os.sep + '*')# print(img_name_list)# --------- 2. dataloader ---------#1. dataloadertest_salobj_dataset = SalObjDataset(img_name_list = img_name_list,lbl_name_list = [],transform=transforms.Compose([RescaleT(320),ToTensorLab(flag=0)]))test_salobj_dataloader = DataLoader(test_salobj_dataset,batch_size=1,shuffle=False,num_workers=1)# --------- 3. model define ---------if(model_name=='u2net'):print("...load U2NET---173.6 MB")net = U2NET(3,1)elif(model_name=='u2netp'):print("...load U2NEP---4.7 MB")net = U2NETP(3,1)# net.load_state_dict(torch.load(model_dir))net.load_state_dict(torch.load('./saved_models/u2net/u2net_bce_itr_36000_train_0.091362_tar_0.003286.pth')) #加载自己的模型if torch.cuda.is_available():net.cuda()net.eval()# --------- 4. inference for each image ---------for i_test, data_test in enumerate(test_salobj_dataloader):print("inferencing:",img_name_list[i_test].split(os.sep)[-1])inputs_test = data_test['image']inputs_test = inputs_test.type(torch.FloatTensor)if torch.cuda.is_available():inputs_test = Variable(inputs_test.cuda())else:inputs_test = Variable(inputs_test)d1,d2,d3,d4,d5,d6,d7= net(inputs_test)# normalization# print(d7.shape)pred = d1[:,0,:,:] #注意这里,这个d1是融合了d2,d3,d4,d5,d6,d7的,如果想了解具体就翻到网络模型去自习看看pred = normPRED(pred)# save results to test_results folderif not os.path.exists(prediction_dir):os.makedirs(prediction_dir, exist_ok=True)save_output(img_name_list[i_test],pred,prediction_dir)del d1,d2,d3,d4,d5,d6,d7if __name__ == "__main__":main()
另外,这个项目在GitHub上展示的人像轮廓提取效果非常好,说明模型关注细节的能力很强,建议需要提取线的相关任务也做做尝试