类神经网络的训练技巧
类神经网络的训练技巧
一般指南
训练完的第一件事就是检测training data的loss。如果training data的loss都很大,显然就是在训练集上没有训练好,有两个可能:一是model bias,或者是optimization做得不好。
Model bias
model bias的意思是说,假设你的model太过简单。
大海裡面捞针,这个针指的是一个loss低的function,结果针根本就不在海裡。
重新设计一个model,给你的model更大的弹性
- 输入特征不够多,增加特征输入。如前一天的数据不够使,加上前两个月的数据。
- 特征够多了,但是模型太简单。换用deep learning 或者增加层数。
Optimization Issue
大海捞针,针确实在海裡,但是我们却没有办法把针捞起来。
training data的loss不够低的时候,到底是model bias,还是optimization的问题呢?
- 找不到一个loss低的function,到底是因為我们的model的弹性不够,我们的海裡面没有针
- 还是说,我们的model的弹性已经够了,只是optimization gradient descent不给力,它没办法把针捞出来
1.从比较中获得选择
透过比较不同的模型,来得知说,你的model现在到底够不够大。
假如测2个networks:
- 一个network有20层
- 一个network有56层
横轴指的是training的过程,就是你参数update的过程,随著参数的update,当然你的loss会越来越低,但是结果20层的loss比较低,56层的loss还比较高。
这个不是overfitting,这代表56层的network,它的optimization没有做好,它的optimization不给力。
因为一个56层的network要做到20层的network可以做到的事情,对它来说是轻而易举的。
2.先从一些简单的模型开始训练
看到一个你从来没有做过的问题,也许你可以先跑一些比较小的,比较浅的network,或甚至用一些,不是deep learning的方法
比如SVM之类的,这些模型会更容易做Optimize,不容易出现optimization失败的问题。
先train一些简单的model,对整体有个把握,简单的model可以获得怎样的loss。
3.如果更深层次的网络不能获得更小的训练数据损失,那么就存在优化问题。
如果你发现你深的model,跟浅的model比起来,深的model明明弹性比较大,但loss却没有办法比浅的model压得更低,那就代表说你的optimization有问题,你的gradient descent不给力
Overfitting
如果training data上面的loss小,testing data上的loss大,那你可能就是真的遇到overfitting的问题。
单是在training data上表现不好不应该怀疑是overfitting问题,一个反思模型问题。
解决办法:
增加你的训练集
这个方向往往是最有效的方向,用更多的数据来限制住模型,但是费时费力。
推荐使用data augmentation的方法。
将图片放大、截取、翻转、添加噪声等操作得到新的数据。
单是要根据你对资料的特性,对你现在要处理的问题的理解,来选择合适的,data augmentation的方式。
不要让你的模型,有那麼大的弹性,给它一些限制
- 给它比较少的参数,如果是deep learning的话,就给它比较少的神经元的数目,本来每层一千个神经元,改成一百个神经元之类的,或者是你可以让model共用参数。
- 用比较少的features,本来给三天的数据,改成用给两天的数据。
- Early stopping,提前结束。
- Dropout,随机丢弃一些权重参数。
但是不能给太多的限制,不然模型又回到了model bias的问题。
如何解决这个矛盾的问题呢?
所谓比较复杂的模型就是,它可以包含的function比较多,它的参数比较多,这个就是一个比较复杂的model,随著model越来越复杂,Training的loss可以越来越低,Test的loss会跟著下降,但是当复杂的程度,超过某一个程度以后,Testing的loss就会突然暴增了。
Cross Validation
把Training的资料分成两半,一部分叫作Training Set,一部分是Validation Set
在Training Set上做测试,在Validation Set上面,去衡量它们的分数。根据Validation Set上面的分数,去挑选结果。
N-fold Cross Validation
N-fold Cross Validation就是你先把你的训练集切成N等份,在这个例子裡面我们切成三等份,切完以后,你拿其中一份当作Validation Set,另外两份当Training Set,然后这件事情你要重复三次
在这三个setting下,在这三个Training跟Validation的,data set上面,通通跑过一次,然后把这三个模型,在这三种状况的结果都平均起来,把每一个模型在这三种状况的结果,都平均起来,再看看谁的结果最好
mismatch
mismatch它的原因跟overfitting,其实不一样,一般的overfitting,你可以用搜集更多的数据来解决。但是mismatch意思是说,你的训练集跟测试集,它们的分布本来就是不一样的
收工结束
假设你现在经过一番的努力,你已经可以让你的,training data的loss变小了,那接下来你就可以来看,testing data loss,如果testing data loss也小,有比这个strong baseline还要小就结束了,没什麼好做的就结束了。
参考: