博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
用SVM来识别魔方颜色(实践篇)
阅读量:2160 次
发布时间:2019-05-01

本文共 4653 字,大约阅读时间需要 15 分钟。

来源:极客Merry

为了能清楚表达接下来要干什么,先上个最终的效果图看看。没错只是让电脑能识别魔方的颜色,然后。。。。就没有然后了

在第一篇推文中,简单介绍了SVM的基本原理但是没有涉及核函数的内容。

SVM的基本原理

merryTong,公众号:极客Merry

核函数的出现是为了解决线性不可分的问题,核函数的思想就是通过一个非线性变换将输入空间映射到特征空间,这样在输入空间不能使用分隔超平面分类的问题,在特征空间就能使用分隔超平面完美解决了。

理论知识已经有了,所以要做的就是使用SVM算法来识别魔方颜色,而第一步需要训练SVM模型。在上一篇推文中,我们推导了SVM的最大间距公式和正确分类的约束条件,支持向量机算法求解的参数有w和b,只有这两个参数确定了,则建立的模型才随之确定。所以SVM的训练过程也是在寻找参数w和b的过程。

可能有人会觉得奇怪,支持向量机本身是一个二分类算法,魔方颜色识别是一个多分类问题,支持向量机能解决这个问题吗?虽然SVM是二分类算法,但是将SVM扩展后仍然可以解决多分类问题。SVM解决多分类问题的基本思路是:有几个类别就设置几个分类器,对于每一个类来说,都有一个当前类和其他类的二分类器。对于魔方颜色分类器来说,支持向量机并行地设置6个分类器,对于红色类来说,样本只有红色类和其他类;对于绿色类来说,样本只有绿色类和其他类,依此类推,多分类问题就转变成了多个二分类问题。

本文使用的训练数据集是个人采集制作完成,但是本文的重点在于SVM的实践,所以数据集的制作过程就此略过。自监督学习需要给训练数据集打上相应的标签,并按标签分类。所以本文用相应颜色的英文单词首字母来表示颜色图像属于什么类别。魔方的颜色有:红、绿、蓝、白、橙、黄,所以对应的类别标签为:R、G、B、W、O、Y。它大概长这个样子:

这是整个数据集中绿色图像的数据,每种颜色有100张左右的图像。数据标注的格式为"颜色_序号",每张图像大小为25*25像素。

训练数据集的文件结构如下图所示:

SVM模型的训练过程大致如下:第一步将图像像素值归一化,提高运算效率;第二步将彩色图像向量化,即将25*25*3的三维张量展平成向量,得到了1875维的向量,这个向量就是一个样本的数据特征;最后将向量数据送入模型进行训练。

在开始写代码前,还需要提前配置好环境,代码在Python3.x下实现,需要安装的库有:OpenCV,PIL,skit-learn,numpy。等一切准备就绪,就可以动手撸代码了!

首先导入必要的库。这里的支持向量机算法需要调用skit-learn机器学习库中封装好的API

import numpy as npimport osimport timefrom sklearn import svmfrom sklearn.externals import joblibfrom cv2 import cv2

第一步加载数据,按照上文所说的文件结构,编写代码将数据加载到内存。

def read_all_data(file_path):    cName = ['R','G', 'Y', 'W', 'O', 'B']        # 得到一个图像文件名列表flist       i = 0          for c in cName:                train_data_path = os.path.join(file_path, c)           # 获取文件夹下的所有图片路径列表            flist_ = get_file_list(train_data_path)         if i == 0:               dataMat, dataLabel = read_and_convert(flist_)        else:            dataMat_, dataLabel_ = read_and_convert(flist_)            # 按轴axis0连接array组成一个新的array            dataMat = np.concatenate((dataMat, dataMat_), axis=0)            dataLabel = np.concatenate((dataLabel, dataLabel_), axis=0)          #print(dataMat.shape)            #print(len(dataLabel))         i +=1    return dataMat, dataLabel
def read_and_convert(imgFileList):    dataLabel = [] # 存放类标签    #计算图像个数    dataNum = len(imgFileList)     # dataNum * 1875 的矩阵    dataMat = np.zeros((dataNum, 1875))     for i in range(dataNum):        imgNameStr = imgFileList[i]        # 得到 颜色_编号.jpg        imgName = get_img_name_str(imgNameStr)          # 得到 类标签(颜色) 如  B_5.jpg        classTag = imgName.split(".")[0].split("_")[0]         dataLabel.append(classTag)        dataMat[i,:] = img2vector(imgNameStr)    return dataMat, dataLabel
def get_file_list(path):    file_list=[]    # 获取path路径下的所有文件名    for file_name in os.listdir(path):         fin_path = os.path.join(path, file_name)        if (fin_path.endswith('.jpg')):            file_list.append(fin_path)    return file_list
# 按文件路径拆分并取最后一个元素,即图像文件名def get_img_name_str(imgPath):    return imgPath.split(os.path.sep)[-1]

上面已经说过,在拿到图像数据后,对图像归一化并且将图像展平成向量。

# 将图像转换为向量def img2vector(imgFile):    img = cv2.imread(imgFile,1)    img_arr = np.array(img)     # 对图像进行归一化    img_normlization = img_arr/255     # 1 * 1875 向量    img_arr2 = np.reshape(img_normlization, (1,-1))     return img_arr2

最后定义SVM分类器,skit-learn已经为我们提供了SVM算法的API,其中SVC类是用来做分类任务的。

首先需要选择SVM的核函数,核函数由参数kernel指定:'linear'表示线性核函数,它只能产生线性的分隔超平面;'poly'表示多项式核函数;'rbf'表示高斯核函数。多项式核函数和高斯核函数都可以产生复杂的分隔超平面。

除了指定kernel外,还需要指定gamma值,这是高斯核函数的参数,默认指定为1/features。kernel这个参数没有最优,只能是在实验过程中测试然后再调整。

def create_svm(dataMat, dataLabel, path):        clf = svm.SVC(C = 1.0, kernel='rbf')     # 开始训练模型        rf = clf.fit(dataMat, dataLabel)        #存储训练好的模型    joblib.dump(rf, path)       return clf

然后就可以把代码跑起来了。

if __name__ == '__main__':      st = time.clock()     path = '数据图片路径'       dataMat, dataLabel = read_all_data(path)           model_path = os.path.join(path,'svm_cube.model')        create_svm(dataMat, dataLabel, model_path)        et = time.clock()        print("Training spent {:.4f}s.".format((et - st)))

因为数据量不大,所以训练很快就能完成。接着使用已经训练好的模型,看看效果怎么样。

import numpy as npimport osfrom sklearn.externals import joblibimport matplotlib.pyplot as pltfrom cv2 import cv2model_path = '模型保存路径'img_path = "测试图片路径"clf = joblib.load(model_path) # 加载模型def img2vector(img):    img_arr = np.array(img)     img_normlization = img_arr/255     img_arr2 = np.reshape(img_normlization, (1,-1))     return img_arr2i = 1for file in os.listdir(img_path):    filepath = os.path.join(img_path, file)    img = cv2.imread(filepath)    img2arr = img2vector(img)    preResult = clf.predict(img2arr)    img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)    plt.subplot(2,3,i)    plt.imshow(img)    plt.title(preResult)    i += 1    print(preResult)plt.show()

正常来说这里要对模型进行评估,但是本文没有划分测试集,所以评估这一步就省略了。现在只是对一张单个颜色的图像进行识别,对一张魔方图像的识别还需要能够从整张图像中找到颜色块,这涉及到一些图像处理和边缘提取的内容。

转载地址:http://gczzb.baihongyu.com/

你可能感兴趣的文章
【LEETCODE】36-Valid Sudoku
查看>>
【LEETCODE】205-Isomorphic Strings
查看>>
【LEETCODE】204-Count Primes
查看>>
【LEETCODE】228-Summary Ranges
查看>>
【LEETCODE】27-Remove Element
查看>>
【LEETCODE】66-Plus One
查看>>
【LEETCODE】26-Remove Duplicates from Sorted Array
查看>>
【LEETCODE】118-Pascal's Triangle
查看>>
【LEETCODE】119-Pascal's Triangle II
查看>>
【LEETCODE】88-Merge Sorted Array
查看>>
【LEETCODE】19-Remove Nth Node From End of List
查看>>
【LEETCODE】125-Valid Palindrome
查看>>
【LEETCODE】28-Implement strStr()
查看>>
【LEETCODE】6-ZigZag Conversion
查看>>
【LEETCODE】8-String to Integer (atoi)
查看>>
【LEETCODE】14-Longest Common Prefix
查看>>
【LEETCODE】38-Count and Say
查看>>
【LEETCODE】278-First Bad Version
查看>>
【LEETCODE】303-Range Sum Query - Immutable
查看>>
【LEETCODE】21-Merge Two Sorted Lists
查看>>