自由Man

PyTorch 中的 RNN/LSTM/GRU模型参数解释

RNN、LSTM、GRU的结构与PyTorch调用方式基本一致,此处以RNN为例进行说明。

pytorch中的RNN,torch.nn.RNN():

参数:

    


参数含义注意
input_size输入 x 的特征维度数量是特征维度,不是序列长度
hidden_size隐状态 h 中的特征数量hidden层节点数目
num_layersRNN层数hidden层数目,决定网络深度,不是宽度
nonlinearity指定非线性函数使用 [‘tanh’|’relu’]. 默认: ‘tanh’
bias如果是 False , 那么 RNN 层就不会使用偏置权重 b_ih 和 b_hh, 默认: True
batch_first如果 True, 输入 Tensor 的 shape 为 (batch, seq, feature),输出一样默认是False, 输入Tensor的shape 为 (seq, batch, feature),输出一样
dropout如果值非零, 那么除了最后一层外, 其它层的输出都会套上一个 dropout 层
bidirectional如果 True , 将会变成一个双向 RNN, 默认为 False


易混淆点说明

  1. 序列seq的长度,决定了RNN网络的宽度;

   1.png

  1. 注意,此类图中h0~ht为输出,横向箭头表示状态量的传递。

  2. RNN层数,num_layers,决定了RNN网络的深度;示例num_layers=2:

    layers.png

  3. hidden_size, hidden层的节点数目。也就是说num_layers>1时,所有层的节点数目都是一样的,如果想要各个hidden层节点数不一样,可以使用nn.RNNCell进行拼装。

  4. 序列中每个顺序数据的维度,决定了输入数据的节点数input_size。

    注意,此图为其它图中,单个“竖条”结构的展开图。

  5. 输出h_n为最后序列的状态量,output为预测结果。

   20190321163246609.png

绘图来源于他人博客,存在错误,已修改。如还存在不理解的地方,欢迎留言。

参考:https://blog.csdn.net/lwgkzl/article/details/88717678

发表评论:

◎欢迎参与讨论,请在这里发表您的看法、交流您的观点。

Powered By Z-BlogPHP 1.5.2 Zero Theme By 爱墙纸

Copyright ZiYouMan.cn. All Rights Reserved. 蜀ICP备15004526号