当前位置: 代码迷 >> 综合 >> PyTorch:Encoder-RNN|LSTM|GRU
  详细解决方案

PyTorch:Encoder-RNN|LSTM|GRU

热度:71   发布时间:2024-02-21 14:47:56.0

-柚子皮-

#RNN
rnn=nn.RNN(10,20,2) #(each_input_size, hidden_state, num_layers)
input=torch.randn(5,3,10) # (seq_len, batch, input_size)
h0=torch.randn(2,3,20) #(num_layers * num_directions, batch, hidden_size)
output,hn=rnn(input,h0)
print(output.size(),hn.size())
 
 
#LSTM
rnn=nn.LSTM(10,20,2) #(each_input_size, hidden_state, num_layers)
input=torch.randn(5,3,10) # (seq_len, batch, input_size)
h0=torch.randn(2,3,20) #(num_layers * num_directions, batch, hidden_size)
c0=torch.randn(2,3,20) #(num_layers * num_directions, batch, hidden_size)
output,(hn,cn)=rnn(input,(h0,c0))   #seq_len x batch x hidden*bi_directional
print(output.size(),hn.size(),cn.size())
 
 
#GRU
rnn=nn.GRU(10,20,2)
input=torch.randn(5,3,10)
h0=torch.randn(2,3,20)
output,hn=rnn(input,h0)
print(output.size(),hn.size())

 

gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=dropout, bidirectional=True)

from: -柚子皮-

ref:[LSTM和GRU原理及pytorch代码,输入输出大小说明]

 

  相关解决方案