![TensorFlow 2.0卷积神经网络实战](https://wfqqreader-1252317822.image.myqcloud.com/cover/133/29977133/b_29977133.jpg)
2.2 Hello TensorFlow & Keras
神经网络专家Rachel Thomas曾经说过,“接触了TensorFlow后,我感觉我还是不够聪明,但有了Keras之后,事情会变得简单一些。”
他所提到的Keras是一个高级别的Python神经网络框架,能在TensorFlow上运行的一种高级的API框架。Keras拥有丰富的对数据封装和一些先进模型的实现,避免了“重复造轮子”,如图2.4所示。换言之,Keras对于提升开发者的开发效率来讲意义重大。
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P32_27025.jpg?sign=1738821362-LnG7rDbIX2AXpGRNpk2E7Tk5TT6dFrN2-0-2d47f7d84cb8d39aec97fdb5b51d5e58)
图2.4 TensorFlow+Keras
“不要重复造轮子。”这是TensorFlow引入Keras API的最终目的,本书还是以TensorFlow代码编写为主,Keras作为辅助工具而使用的,目的是为了简化程序编写,这点请读者一定注意。
本章非常重要,强烈建议读者独立完成每个完整代码和代码段的编写。
2.2.1 MODEL!MODEL!MODEL!还是MODEL
神经网络的核心就是模型。
任何一个神经网络的主要设计思想和功能都集中在其模型中。
TensorFlow也是如此。
TensorFlow或者其使用的高级API-Keras核心数据结构是MODEL,一种组织网络层的方式。最简单的模型是Sequential顺序模型,它由多个网络层线性堆叠。对于更复杂的结构,应该使用Keras函数式API(本书的重点就是函数式API编写),其允许构建任意的神经网络图。
为了便于理解和易于上手,作者首先从顺序Sequential开始。一个标准的顺序Sequential模型如下:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P32_27156.jpg?sign=1738821362-Jx47otG3jD8DIwe06iHLFUBK4rIhXJwR-0-0a7ab9d69fb822112f83386fac11b781)
可以看到,这里首先使用创建了一个Sequential模型,之后根据需要逐级向其中添加不同的全连接层,全连接层的作用是进行矩阵计算,而相互之间又通过不同的激活函数进行激活计算(这种没有输入输出值的编程方式对有经验的程序设计人员来说并不友好,仅供举例)。
对于损失函数的计算,根据不同拟合方式和数据集的特点,需要建立不同的损失函数去最大程度地反馈拟合曲线错误。这里的损失函数采用交叉熵函数(softmax_crossentroy),使得数据计算分布能够最大限度地拟合目标值。如果对此陌生的话,读者只需要记住这些名词和下面的代码编写即可继续往下学习。代码如下:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P32_27157.jpg?sign=1738821362-mx5pbIGb2sU6uf0n58Xz1Q9eA8OCJQt6-0-0af48dee4afb18b0ae3a69e1c16d747a)
首先通过模型计算出对应的值。这里内部采用的前向调用函数,读者知道即可。之后tf.reduce_mean计算出损失函数。
模型建立完毕后,就是数据的准备。一份简单而标准的数据,一个简单而具有指导思想的例子往往事半功倍。深度学习中最常用的一个入门起手例子Iris分类,下面就从这个例子开始,最终使用TensorFlow 2.0的Keras模式实现一个Iris鸢尾花分类的例子。
2.2.2 使用Keras API实现鸢尾花分类的例子(顺序模式)
Iris数据集是常用的分类实验数据集,由Fisher于1936年收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度、花萼宽度、花瓣长度、花瓣宽度4个属性预测鸢尾花卉属于Setosa、Versicolour、Virginica这3个种类中的哪一类,如图2.5所示。
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P33_5041.jpg?sign=1738821362-rZv3yFLWNDcXqAAAeujyB1U96KIufsvC-0-67d4e3af5fe632d805f2bdfb109ee2f1)
图2.5 鸢尾花
1. 第一步:数据的准备
不需要读者下载这个数据集,一般常用的机器学习工具自带Iris数据集,引入数据集的代码如下:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P33_27161.jpg?sign=1738821362-mLkyDJ8a24cwTGS4Etl1W0G5I2gtMmM3-0-1caf7e640b4c869c1ec0d18927d3135f)
这里调用的是sklearn的数据库中Iris数据集,直接载入即可。
而其中的数据又是以key-value值对应存放,key值如下:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P33_5058.jpg?sign=1738821362-MdUUCm0pKdJR2SVpJgF3GU2v6tEblYEC-0-c3c95c989df515a47d4c26156025bc74)
由于本例中需要Iris的特征与分类目标,因此这里只需要获取data和target。代码如下:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P34_27162.jpg?sign=1738821362-zNUNV2KUtUtV40IIn4BptUlTMeOhkeEy-0-c3d83f1b933869b210ac9d380aef8e97)
数据打印结果如图2.6所示。
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P34_5096.jpg?sign=1738821362-k56fE8CGp0Q5hYE1ACfGniSf7BIFy0zG-0-21c51c5010803f9a207fb17ac4f417a3)
图2.6 数据打印结果
这里是分别打印了前5条数据。可以看到Iris数据集中的特征,是分成了4个不同特征进行数据记录,而每条特征又对应于一个分类表示。
2. 第二步:数据的处理
下面就是数据处理部分,对特征的表示不需要变动。而对于分类表示的结果,全部打印结果如图2.7所示。
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P34_5100.jpg?sign=1738821362-V297XVB1wfA4ntusal5Se9skBt7WPEVx-0-744192dafc91eb87261f5bf627812e07)
图2.7 数据处理
这里按数字分成了3类,0、1和2分别代表3种类型。如果按直接计算的思路可以将数据结果向固定的数字进行拟合,这是一个回归问题。即通过回归曲线去拟合出最终结果。但是本例实际上是一个分类任务,因此需要对其进行分类处理。
分类处理的一个非常简单的方法就是进行one-hot处理,即将一个序列化数据分成到不同的数据领域空间进行表示,如图2.8所示。
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P34_5104.jpg?sign=1738821362-tEREHXeb6DbaivgKGHr2BB04IKoCRNih-0-745b4e3b54d98aeeb2c5c0ea443b6d7e)
图2.8 one-hot处理
具体在程序处理上,读者可以手动实现one-hot的代码表示,也可以使用Keras自带的分散工具对数据进行处理,代码如下:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P35_27163.jpg?sign=1738821362-CEyWx7hKtD04NPcXGjIU8UuWnuCRCSZY-0-81a543932f8243da2cea424687ffff4e)
这里的num_classes是分成了3类,由一行三列对每个类别进行表示。
交叉熵函数与分散化表示的方法超出了本书的讲解范围,这里就不再做过多介绍,读者只需要知道交叉熵函数需要和softmax配合,从分布上向离散空间靠拢即可。
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P35_27164.jpg?sign=1738821362-e7WuHPAJpY0bs62Rxcad0gdtcGq5QyCg-0-17780f5b9a15cd52d5c9f6e9a60904c5)
当生成的数据读取到内存中并准备以批量的形式打印,使用的是tf.data.Dataset.from_tensor_slices函数,并且可以根据具体情况对batch进行设置。tf.data.Dataset函数更多的细节和用法在后面章节中会专门介绍。
3. 第三步:梯度更新函数的写法
梯度更新函数是根据误差的幅度对数据进行更新的方法,代码如下:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P35_27165.jpg?sign=1738821362-XH0Wy8qYe2LOUIsgo5PVdnm5MDrct3Hg-0-8a885e57dd237cd2cdd241dd5a5c2f20)
与前面线性回归例子的差别是,使用的model直接获取参数的方式对数据进行不断更新而非人为指定,这点请读者注意。至于人为的指定和排除某些参数的方法属于高级程序设计,在后面的章节会介绍。
【程序2-7】
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P35_27168.jpg?sign=1738821362-WxmmJUviw44Et1ZAidDQVTJ2XFZW9FpJ-0-ae760d6d5b9a393f190cac9539ccc97d)
最终打印结果如图2.9所示。可以看到损失值在符合要求的条件下不停降低,达到预期目标。
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P36_5325.jpg?sign=1738821362-jfnQIvESl3NYymXlk2Qo91S3QiXIgrFl-0-7ec295ed46776843910ffc0c77a1ef64)
图2.9 打印结果
2.2.3 使用Keras函数式编程实现鸢尾花分类的例子(重点)
我们在前面也说了,对于有编程经验的程序设计人员来说,顺序编程过于抽象,同时缺乏过多的自由度,因此在较为高级的程序设计中达不到程序设计的目标。
Keras函数式编程是定义复杂模型(如多输出模型、有向无环图,或具有共享层的模型)的方法。
让我们从一个简单的例子开始,程序2-7建立模型的方法是使用顺序编程,即通过逐级添加的方式将数据“add”到模型中。这种方式在较低级水平的编程上可以较好地减轻编程的难度,但是在自由度方面会有非常大的影响,例如当需要对输入的数据进行重新计算时,顺序编程方法就不合适。
函数式编程方法类似于传统的编程,只需要建立模型导入输出和输出“形式参数”即可。有TensorFlow 1.X编程基础的读者可以将其看作是一种新格式的“占位符”。代码如下:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P36_27170.jpg?sign=1738821362-fuFIdqJXHROaD1syQaZkSVpX7iXvpEOt-0-73c926672fa33df622b97cb16ae255cd)
下面开始逐对其进行分析。
1. 输入端
首先是input的形参:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P37_5386.jpg?sign=1738821362-5q8r5A4abHDPu7uFE1mmz6W6fKNjGCJG-0-29b5b925a9d620c98e277c97fa8477be)
这一点需要从源码上来看,代码如下:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P37_27173.jpg?sign=1738821362-xiyVZn7sy58eDima1spDgbw75GgYpTNt-0-52e7b71abd4c39420b750d17d976fa02)
input函数用于实例化Keras张量,Keras张量是来自底层后端输入的张量对象,其中增加了某些属性,使其能够通过了解模型的输入和输出来构建Keras模型。
input函数的参数:
●shape:形状元组(整数),不包括批量大小。例如shape=(32,)表示预期的输入将是32维向量的批次。
●batch_size:可选的静态批量大小(整数)。
●name:图层的可选名称字符串。在模型中应该是唯一的(不要重复使用相同的名称两次)。如果未提供,它将自动生成。
●dtype:数据类型由输入预期的,作为字符串(float32、float64、int32)。
●sparse:一个布尔值,指定是否创建占位符是稀疏的。
●tensor:可选的现有张量包裹到Input图层中。如果设置,图层将不会创建占位符张量。
●**kwargs:其他的一些参数。
上面是官方对其参数所做的解释,可以看到,这里的input函数就是根据设定的维度大小生成一个可供存放对象的张量空间,维度就是shape中设定的维度。需要注意的是,与传统的TensorFlow不同,这里的batch大小是不包含在创建的shape中。
举例来说,在一个后续的学习中会遇到MNIST数据集,即一个手写图片分类的数据集,每张图片的大小用4维来表示[1,28,28,1]。第1个数字是每个批次的大小,第2和3个数字是图片的尺寸大小,第4个1是图片通道的个数。因此输入到input中的数据为:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P37_27178.jpg?sign=1738821362-EKnWyfw4em91Z6Hs90nCWSXPcW2O0tz4-0-db122b43bac2c3d8b8291bcb2615eec4)
2. 中间层
下面每个层的写法与使用顺序模式也是不同:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P38_5473.jpg?sign=1738821362-0u8cPOqIummnuGwYeyZRjePzORjgGGeP-0-ee06cc834dc38fe48c5b68fd5f6b4b11)
在这里每个类被直接定义,之后将值作为类实例化以后的输入值进行输入计算。
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P38_27182.jpg?sign=1738821362-og1GSR6VzulebpKffjlO8wU5uloEfr0h-0-e5558dbc2135037706a1b484a7584c47)
因此可以看到这里与顺序最大的区别就在于实例化类以后有对应的输入端,这一点较为符合一般程序的编写习惯。
3. 输出端
对于输出端不需要额外的表示,直接将计算的最后一个层作为输出端即可:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P38_5499.jpg?sign=1738821362-pn5WuvSXdCAVx4nXV32Hbm8X7lVPeZyO-0-5f6a4072339a46d4af8100e3b8808089)
4. 模型的组合方式
对于模型的组合方式也是很简单的,直接将输入端和输出端在模型类中显式的注明,Keras即可在后台将各个层级通过输入和输出对应的关系连接在一起。
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P38_5506.jpg?sign=1738821362-aejNXgGnuVTy2QNfB32yDaRAABJpkMz0-0-c75f0415804fd0c609bdcc1260e0514d)
完整的代码如下:
【程序2-8】
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P38_5515.jpg?sign=1738821362-X1Zkxbp52YNkqjkuBQUB0AJhSdM4z3Er-0-06e12989726add64c4ebb8dd35245f68)
程序2-8的基本架构对照前面的例子没有多少变化,损失函数和梯度更新方法是固定的写法,这里最大的不同点在于,代码使用了model自带的saver函数对数据进行保存。在TensorFlow 2.0中,数据的保存是由Keras完成,即将图和对应的参数完整地保存在h5格式中。
2.2.4 使用保存的Keras模式对模型进行复用
前面已经说过,对于保存的文件,Keras是将所有的信息都保存在h5文件中,这里包含的所有模型结构信息和训练过的参数信息。
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P39_5707.jpg?sign=1738821362-ZKha9YRIQrAcwxzmEMI4Udp1vOVflG3F-0-05b015cc33c25fe98e28cdca18e023bd)
tf.keras.models.load_model函数是从给定的地址中载入h5模型,载入完成后,会依据存档自动建立一个新的模型。
模型的复用可直接调用模型predict函数:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P39_5714.jpg?sign=1738821362-aN3w9MNOMtQ8GkwVwXSHUJmXUGdN7OAj-0-218d87eb2d48d57be7c87684a4939871)
这里是直接将Iris数据作为预测数据进行输入。全部代码如下:
【程序2-9】
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P39_5722.jpg?sign=1738821362-GoyZEfzjw7AUUmGyQdi9IGehqkXBE3qB-0-cc11bcdbab70e59e30591bb163436b3a)
最终结果如图2.10所示,可以看到计算结果被完整打印出来。
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P40_5800.jpg?sign=1738821362-TNCTT2emOEYP6kErHirjXz49nf9aseCG-0-c3f49e04290605eb2e9e9155ebe61a9a)
图2.10 打印结果
2.2.5 使用TensorFlow 2.0标准化编译对Iris模型进行拟合
在2.1.3小节中,作者使用了符合传统TensorFlow习惯的梯度更新方式对参数进行更新。然实际这种看起来符合编程习惯的梯度计算和更新方法,可能并不符合大多数有机器学习使用经验的读者使用。本节就以修改后的Iris分类为例,讲解标准化TensorFlow 2.0的编译方法。
对于大多数机器学习的程序设计人员来说,往往习惯了使用fit函数和compile函数对数据进行数据载入和参数分析。代码如下(请读者先运行,后面会有详细的运行分析):
【程序2-10】
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P40_27192.jpg?sign=1738821362-kSdTy1b438dX1TpJ1WLEBplAR2poKjkw-0-abd40acd9bc86f1fbe3379219904aa11)
下面我们详细分析一下代码。
1. 数据的获取
本例还是使用了sklearn中的Iris数据集作为数据来源,之后将target转化成one-hot的形式进行存储。顺便提一句,TensorFlow本身也带有one-hot函数,即tf.one_hot,有兴趣的读者可以自行学习。
数据读取之后的处理在后文讲解,这个问题先放一下,请读者继续往下阅读。
2. 模型的建立和参数更新
这里不准备采用新模型的建立方法,对于读者来说,熟悉函数化编程已经能够应付绝对多数的深度学习模型的建立。在后面章节中,我们会教会读者自定义某些层的方法。
对于梯度的更新,到目前为止的程序设计中都是采用了类似回传调用等方式对参数进行更新,这是由程序设计者手动完成的。然而TensorFlow 2.0推荐使用自带的梯度更新方法,代码如下:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P41_27194.jpg?sign=1738821362-QoFCFLocvmIxqltIngs1kGq81RAHzaEi-0-7bfbcfd1049dad6b9f8e828064980e27)
complie函数是模型适配损失函数和选择优化器的专用函数,而fit函数的作用是把训练参数加载进模型中。下面分别对其进行讲解。
(1)compile
compile函数的作用是TensorFlow 2.0中用于配置训练模型专用编译函数。源码如下:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P41_27195.jpg?sign=1738821362-O02f2GFiKUQpxuL2mSYet7tarFq7HpFa-0-957be39c33a4d25ade5f5e0dd7eb7784)
这里我们主要介绍其中最重要的3个参数optimizer、loss和metrics。
●optimizer:字符串(优化器名)或者优化器实例。
●loss:字符串(目标函数名)或目标函数。如果模型具有多个输出,可以通过传递损失函数的字典或列表,在每个输出上使用不同的损失。模型最小化的损失值将是所有单个损失的总和。
●metrics:在训练和测试期间的模型评估标准。通常会使用metrics=['accuracy']。要为多输出模型的不同输出指定不同的评估标准,还可以传递一个字典,如metrics={'output_a':'accuracy'}。
●可以看到,优化器(optimizer)被传入了选定的优化器函数,loss是损失函数,这里也被传入选定的多分类crossentry函数。metrics用来评估模型的标准,一般用准确率表示。
实际上,compile编译函数是一个多重回调函数的集合,对于所有的参数来说,实际上就是根据对应函数的“地址”回调对应的函数,并将参数传入。
举个例子,在上面编译器中我们传递的是一个TensorFlow 2.0自带的损失函数,而实际上往往由于针对不同的计算和误差需要不同的损失函数,这里自定义一个均方差(MSE)损失函数,代码如下:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P42_27201.jpg?sign=1738821362-oji1ume4SXhnR5Ug7HVFlBKLNRndR08I-0-7dd71a541f67811c598538cc507636fe)
这个损失函数接收2个参数,分别是y_true和y_pred,即预测值和真实值的形式参数。之后根据需要计算出真实值和预测值之间的误差。
损失函数名作为地址传递给compile后,即可作为自定义的损失函数在模型中进行编译。代码如下:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P42_27202.jpg?sign=1738821362-IT5rvuhxXtfgF0h9JFNpyipd3x2mRQ2u-0-f6c47965359fa6ea862e66e8ce330ec0)
至于优化器的自定义实际上也是可以的。但是,一般情况下优化器的编写需要比较高的编程技巧以及对模型的理解,这里读者直接使用TensorFlow 2.0自带的优化器即可。
(2)fit
fit函数的作用是以给定数量的轮次(数据集上的迭代)训练模型。其主要参数有如下4个:
●x:训练数据的NumPy数组(如果模型只有一个输入),或者是NumPy数组的列表(如果模型有多个输入)。如果模型中的输入层被命名,你也可以传递一个字典,将输入层名称映射到NumPy数组。如果从本地框架张量馈送(例如TensorFlow数据张量)数据,x可以是None(默认)。
●y:目标(标签)数据的NumPy数组(如果模型只有一个输出),或者是NumPy数组的列表(如果模型有多个输出)。如果模型中的输出层被命名,你也可以传递一个字典,将输出层名称映射到NumPy数组。如果从本地框架张量馈送(例如TensorFlow数据张量)数据,y可以是None(默认)。
●batch_size:整数或None。每次梯度更新的样本数。如果未指定,默认为32。
●epochs:整数。训练模型迭代轮次。一个轮次是在整个x和y上的一轮迭代。请注意,与initial_epoch一起,epochs被理解为“最终轮次”。模型并不是训练了epochs轮,而是到第epochs轮停止训练。
fit函数的主要作用就是对输入的数据进行修改,如果读者已经成功运行了程序2-10,那么现在换一种略微修改后的代码,重新运行Iris数据集。代码如下:
【程序2-11】
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P42_6068.jpg?sign=1738821362-U7PzGtCmvyakIfnbATpvIademKBsvBGp-0-7e37738df096e62438c3894098b55293)
对比程序2-10和程序2-11可以看到,它们最大的不同在于数据读取方式的变化。更为细节的做出比较,在程序2-10中,数据的读取方式和fit函数的载入方式如下:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P43_27210.jpg?sign=1738821362-xRJAxlMm2HIaWocqn7NlLXG5QPPViB2Q-0-107c270eced884522348f873f25ed37d)
Iris的数据读取被分成2个部分,分别是数据特征部分和label分布。而label部分使用Keras自带的工具进行离散化处理。
离散化后处理的部分又被tf.data.Dataset API整合成一个新的数据集,并且依batch被切分成多个部分。
此时fit的处理对象是一个被tf.data.Dataset API处理后的一个Tensor类型数据,并且在切分的时候依照整合的内容被依次读取。在读取的过程中,由于它是一个Tensor类型的数据,fit内部的batch_size划分不起作用,而使用生成数据的tf中数据生成器的batch_size划分。如果读者对其还是不能够理解的话,可以使用如下代码段打印重新整合后的train_data中的数据看看,代码如下:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P44_6272.jpg?sign=1738821362-J8zm1KFoO0rDPWYvMGYB3BLe5igSjRLs-0-91e8fb46c010046366005c56ab437c03)
现在回到程序2-11中,作者取出对应于数据读取和载入的部分如下:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P44_27213.jpg?sign=1738821362-FNXuiK0hYjNEZ5TIAvXMcjVAHEwxaBoM-0-99310377bb9e60f454f57fd75328c84f)
可以看到数据在读取和载入的过程中没有变化,将处理后的数据直接输入到fit函数中供模式使用。此时由于是直接对数据进行操作,因此对数据的划分由fit函数负责,此时fit函数中的batch_size被设定为128。
2.2.6 多输入单一输出TensorFlow 2.0编译方法(选学)
在前面内容的学习中,我们采用的是标准化的深度学习流程,即数据的准备与处理、数据的输入与计算,以及最后结果的打印。虽然在真实情况中可能会遇到各种各样的问题,但是基本步骤是不会变的。
这里存在一个非常重要的问题,在模型的计算过程中,如果遇到多个数据输入端应该怎么处理,如图2.11所示。
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P44_6327.jpg?sign=1738821362-Te4vjmOgB3cU3dsr5POImd3guT7U8hDG-0-ed9959951a0f0facb9c0f17ea7f8adcb)
图2.11 多个数据输入端
以Tensor格式的数据为例,在数据的转化部分就需要将数据进行“打包”处理,即将不同的数据按类型进行打包。如下所示:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P44_6335.jpg?sign=1738821362-rPD9tXYzMQTPwcf0cU0EFdTaNhxVHE3u-0-5c5e4868efda030bfcd352029e91f491)
请注意小括号的位置,这里显示的将数据分成2个部分,输入与标签两类。而多输入的部分被使用小括号打包在一起形成一个整体。
下面还是以Iris数据集为例讲解多数据输入的问题。
1. 第一步:数据的获取与处理
从前面的介绍可以知道,Iris数据集每行是一个由4个特征组合在一起表示的特征集合,此时可以人为地将其切分,即将长度为4的特征转化成一个长度为3和一个长度为1的两个特征集合。代码如下:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P45_27215.jpg?sign=1738821362-FAj8EUYSQc7C1OBfnepcvNx9xsc5Qu7b-0-4837882b3f3fb48c260743a3fe490eb7)
打印其中的一条数据,如下所示:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P45_6416.jpg?sign=1738821362-wIxcbqJdFTum9ia8GQePPpPqVc6eksRW-0-2936044eebe33b8c684d1f438100b5a7)
可以看到,一行4列的数据被拆分成2组特征。
2. 第二步:模型的建立
接下来就是模型的建立,这里数据被人为地拆分成2个部分,因此在模型的输入端,也要能够对应处理2组数据的输入。
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P45_27216.jpg?sign=1738821362-mlo5NBkNkD7T4CJV3qlXo7EWjNo3Jzhp-0-d0e2e888a1fc0df7a493daccb2a40550)
可以看到代码中分别建立了input_xs_1和input_xs_2作为数据的接收端接受传递进来的数据,之后通过一个concat重新将数据组合起来,恢复成一条4特征的集合。
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P45_27218.jpg?sign=1738821362-GLTe8rfAUQiN39dikPErFojUxV2pqy5E-0-d2c22927ea775dbb63323de89cf6978c)
对剩余部分的数据处理没有变化,按前文程序处理即可。
3. 第三步:数据的组合
切分后的数据需要重新对其进行组合,生成能够符合模型需求的Tensor数据。这里最为关键的是在模型中对输入输出格式的定义,把模式的输入输出格式拆出如下:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P46_6478.jpg?sign=1738821362-dIynTw7V0mRyd84f5lXqiTwjzTOFmJy5-0-2c34504bb37bf5d9726a1e6ecf716aca)
因此在Tensor建立的过程中,也要按模型输入的格式创建对应的数据集。格式如下:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P46_6486.jpg?sign=1738821362-9BEqPXqK7Nuruak7f1LfCV7qWzTwebNX-0-cb6d0731a8ec1a53d73e7b29e8734d07)
请注意这里的括号有几重,这里我们采用了2层括号对数据进行包裹,即首先将输入1和输入2包裹成一个输入数据,之后重新打包输出,共同组成一个数据集。转化Tensor数据代码如下:
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P46_27220.jpg?sign=1738821362-xZAGbcmqudMn44nq522xVxeh2hK3JfMG-0-afab87f44a06f6875d573b78c26bc134)
注意
请读者一定要注意小括号的层数。
完整代码如下:
【程序2-12】
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P46_27223.jpg?sign=1738821362-YjmWcOIlbEKRKhvHK3KwzummWwgIReP6-0-8c674e4d9b4532664e4cb364a5d7db07)
最终结算结果如图2.12所示。
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P47_6740.jpg?sign=1738821362-zdn71uSwfiFeGI0MfKYVt4fS1NQNgZFm-0-aa833234d89524babf9db689be304b84)
图2.12 打印结果
对于认真阅读本书的读者来说,这个最终的打印结果应该见过很多次了,在这里TensorFlow 2.0默认输出了每个循环结束后的loss值,并且按compile函数中设定的内容输出准确率(accuarcy)值。最后的evaluate函数是通过对测试集中的数据进行重新计算,从而获取在测试集中的损失值和准确率。本例使用训练数据代替测试数据。
在程序2-12中数据的准备是使用tf.data API完成,即通过打包的方式将数据输出,也可以直接将输入的数据输入到模型中进行训练。代码如下:
【程序2-13】
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P47_27225.jpg?sign=1738821362-mdjs7kICwprTOteHJwvKKRJua3a1ZgNC-0-e4fa87d1d3d259f04c50d464bc7f067f)
最终打印结果请读者自行验证,需要注意的是其中数据的包裹情况。
2.2.7 多输入多输出TensorFlow 2.0编译方法(选学)
读者已经知道了多输入单一输出的TensorFlow 2.0的写法,而在实际编程中有没有可能遇到多输入多输出的情况。
事实上是有的。虽然读者可能遇到的情况会很少,但是在必要的时候还是需要设计多输出的神经网络模型去进行训练,例如“bert”模型。
对于多输出模型的写法,实际上也可以仿照单一输出模型改为多输入模型的写法,将output的数据使用中括号进行包裹。
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P48_27227.jpg?sign=1738821362-Ex4vRpnESbzBJWVBopE1cWfoppTD7TQe-0-e5178e99bed4dd7be502a0a86306e973)
首先是对数据的修正和设计,数据的输入被平均分成2组,每组有2个特征。这实际上没什么变化。而对于特征的分类,在引入one-hot处理的分类数据集外,还保留了数据分类本身的真实值作目标的辅助分类计算结果。而无论是多输入还是多输出,此时都使用打包的形式将数据重新打包成一个整体的数据集合。
在fit函数中,直接是调用了打包后的输入数据即可。
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P49_7003.jpg?sign=1738821362-VsDfuLfRVGeB5pnw1LfLfPXACEz9Zs68-0-3ac60dc91f7393dc37104e3921bcafad)
完整代码如下:
【程序2-14】
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P49_27230.jpg?sign=1738821362-BtHoVqf6P6GKrEp268oFyeaXioSnsqoX-0-bc1fadceecdd8ea76a9c0f86ffc2be28)
输出结果如图2.13所示。
![](https://epubservercos.yuewen.com/5F793F/16499866704817006/epubprivate/OEBPS/Images/Figure-P50_7236.jpg?sign=1738821362-kaCLKjADkzMbTOpUsN5B342l4XQhCtyH-0-859f64cfbf624587c2c013cd73f7d856)
图2.13 打印结果
限于篇幅关系,这里也只给出一部分结果,相信读者能够理解输出的数据内容。