Batch Normalization的细致理解
最近读论文遇见很多对BN的优化,例如MoCo当中的shuffling BN、Domain Generalization。连原来是什么东西都不知道,怎么看优化呢?
1.不就是归一化吗?其实并不是
可能大家觉得这个东西不就是一个归一化的过程吗?其实并不是这样的一个过程。
我们假定我们直接使用一个归一化,也就是我们本来天马行空的数据直接我们将其拉到均值为0方差为1,那么这样就出现一个问题:在BN层之后的模型必须是接收均值为0方差为1的模型,这可能不太符合普遍规律。
所以BN的作者在这上面加入了一个新的内容就是一个线性变换层。这样就能将本身均值在0方差在1的输出数据做一个适当地变换,让数据的限制不那么死板。所以取得了较好的效果。
大致的过程如下:
m = K.mean(X, axis=-1, keepdims=True)#计算均值
std = K.std(X, axis=-1, keepdims=True)#计算标准差
X_normed = (X - m) / (std + self.epsilon)#归一化
out = self.gamma * X_normed + self.beta#重构变换
说了这么多主要是让大家理解这个BN层最后的线性变换的作用。
2.测试的时候如何处理?
这时出现另外的一个问题,就是我们在训练的时候,数据是一个batch一个batch的通过网络,并且回传。这也就导致了BN(Batch Normalization)中的batch的来源,为什么是针对一个batch做标准化,其实是来自于这里。
这时候一个新的问题产生了,在我们训练的时候存在batch的概念,但是,当我们test(或者描述为evaluate)的时候并没有这个概念。只是一个数据一个数据的输入,并没有batch的概念。那么这个时候我们在向前传播的时候我们用什么做normalization呢?
所以作者就提出了一种解决方案,也就是使用所有batch的均值的平均值和所有batch方差的无偏估计。
分开理解一下:
这里的均值的平均值:其实就相当于全部数据的均值,也相当于每个batch均值的无偏估计。
这里的bacth的方差的无偏估计:其实只是全部batch的方差加和再除以(全部的batch数再减去1),这个是个概率问题。另外,需要注意的是这个batch的方差的无偏估计和全部数据的方差并不是一个东西。
3.还有什么影响?
3.1在BN层之前还有一个线性变换会怎样?
不难发现这个线性变化(x=wx+b)的+b被完全吞没了,因为你均值变回0,加不加b其实都完全一样。但是那个w还是有作用的。
那么这个b失效可怎么办?
其实BN层结束的线性变换,完全可以取代这里进行的变换。
Batch Normalization的细致理解
最近读论文遇见很多对BN的优化,例如MoCo当中的shuffling BN、Domain Generalization。连原来是什么东西都不知道,怎么看优化呢?
1.不就是归一化吗?其实并不是
可能大家觉得这个东西不就是一个归一化的过程吗?其实并不是这样的一个过程。
我们假定我们直接使用一个归一化,也就是我们本来天马行空的数据直接我们将其拉到均值为0方差为1,那么这样就出现一个问题:在BN层之后的模型必须是接收均值为0方差为1的模型,这可能不太符合普遍规律。
所以BN的作者在这上面加入了一个新的内容就是一个线性变换层。这样就能将本身均值在0方差在1的输出数据做一个适当地变换,让数据的限制不那么死板。所以取得了较好的效果。
大致的过程如下:
m = K.mean(X, axis=-1, keepdims=True)#计算均值
std = K.std(X, axis=-1, keepdims=True)#计算标准差
X_normed = (X - m) / (std + self.epsilon)#归一化
out = self.gamma * X_normed + self.beta#重构变换
说了这么多主要是让大家理解这个BN层最后的线性变换的作用。
2.测试的时候如何处理?
这时出现另外的一个问题,就是我们在训练的时候,数据是一个batch一个batch的通过网络,并且回传。这也就导致了BN(Batch Normalization)中的batch的来源,为什么是针对一个batch做标准化,其实是来自于这里。
这时候一个新的问题产生了,在我们训练的时候存在batch的概念,但是,当我们test(或者描述为evaluate)的时候并没有这个概念。只是一个数据一个数据的输入,并没有batch的概念。那么这个时候我们在向前传播的时候我们用什么做normalization呢?
所以作者就提出了一种解决方案,也就是使用所有batch的均值的平均值和所有batch方差的无偏估计。
分开理解一下:
这里的均值的平均值:其实就相当于全部数据的均值,也相当于每个batch均值的无偏估计。
这里的bacth的方差的无偏估计:其实只是全部batch的方差加和再除以(全部的batch数再减去1),这个是个概率问题。另外,需要注意的是这个batch的方差的无偏估计和全部数据的方差并不是一个东西。
3.还有什么影响?
3.1在BN层之前还有一个线性变换会怎样?
不难发现这个线性变化(x=wx+b)的+b被完全吞没了,因为你均值变回0,加不加b其实都完全一样。但是那个w还是有作用的。
那么这个b失效可怎么办?
其实BN层结束的线性变换,完全可以取代这里进行的变换。