自由Man

BottleneckLSTM原理及实现

BottleneckLSTM


BottleneckLSTM.jpg

BottleneckLSTM,源自CVPR2017的Google论文 Mobile Video Object Detection with Temporally-Aware Feature Maps

其相较于ConvLSTM,最大的改进是引入Bottleneck Gate,整合ht-1跟input,减少卷积次数及模型参数量,从而达到加速的目的。


Bottleneck

    根据上图,我们看着1×1 --> 3×3这段通路,来做个计算,假设输入feature map的维度为256维,要求输出维度也是256维。有以下两种操作:

1. 256维的输入直接经过一个3×3×256的卷积层,输出一个256维的feature map,那么参数量为:256×3×3×256 = 589,824

2. 256维的输入先经过一个1×1×64的卷积层,再经过一个3×3×64的卷积层,最后经过一个3×3×256的卷积层,输出256维,参数量为:256×1×1×64 + 64×3×3×63 + 64×1×1×256 = 69,632。足足把第一种操作的参数量降低到九分之一!


    1×1卷积核也被认为是影响深远的操作,往后大型的网络为了降低参数量都会应用上1×1卷积核。


实现源码:

class BottleneckLSTMCell(nn.Module):
	""" Creates a LSTM layer cell
	Arguments:
		input_channels : variable used to contain value of number of channels in input
		hidden_channels : variable used to contain value of number of channels in the hidden state of LSTM cell
	"""
	def __init__(self, input_channels, hidden_channels):
		super(BottleneckLSTMCell, self).__init__()
		assert hidden_channels % 2 == 0
		self.input_channels = int(input_channels)
		self.hidden_channels = int(hidden_channels)
		self.num_features = 4
		self.W = nn.Conv2d(in_channels=self.input_channels, out_channels=self.input_channels, kernel_size=3, groups=self.input_channels, stride=1, padding=1)
		self.Wy  = nn.Conv2d(int(self.input_channels+self.hidden_channels), self.hidden_channels, kernel_size=1)
		self.Wi  = nn.Conv2d(self.hidden_channels, self.hidden_channels, 3, 1, 1, groups=self.hidden_channels, bias=False)  
		self.Wbi = nn.Conv2d(self.hidden_channels, self.hidden_channels, 1, 1, 0, bias=False)
		self.Wbf = nn.Conv2d(self.hidden_channels, self.hidden_channels, 1, 1, 0, bias=False)
		self.Wbc = nn.Conv2d(self.hidden_channels, self.hidden_channels, 1, 1, 0, bias=False)
		self.Wbo = nn.Conv2d(self.hidden_channels, self.hidden_channels, 1, 1, 0, bias=False)
		self.Wci = None
		self.Wcf = None
		self.Wco = None
		logging.info("Initializing weights of lstm")
		self._initialize_weights()
	def _initialize_weights(self):
		"""
		Returns:
			initialized weights of the model
		"""
		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
				m.weight.data.normal_(0, math.sqrt(2. / n))
				if m.bias is not None:
					m.bias.data.zero_()
			elif isinstance(m, nn.BatchNorm2d):
				m.weight.data.fill_(1)
				m.bias.data.zero_()
			
	def forward(self, x, h, c): #implemented as mentioned in paper here the only difference is  Wbi, Wbf, Wbc & Wbo are commuted all together in paper
		"""
		Arguments:
			x : input tensor
			h : hidden state tensor
			c : cell state tensor
		Returns:
			output tensor after LSTM cell 
		"""
		x = self.W(x)
		y = torch.cat((x, h),1) #concatenate input and hidden layers 主要改善点
		i = self.Wy(y) #reduce to hidden layer size 
		b = self.Wi(i)	#depth wise 3*3
		ci = torch.sigmoid(self.Wbi(b) + c * self.Wci)
		cf = torch.sigmoid(self.Wbf(b) + c * self.Wcf)
		cc = cf * c + ci * torch.relu(self.Wbc(b))
		co = torch.sigmoid(self.Wbo(b) + cc * self.Wco)
		ch = co * torch.relu(cc)
		return ch, cc
	def init_hidden(self, batch_size, hidden, shape):
		"""
		Arguments:
			batch_size : an int variable having value of batch size while training
			hidden : an int variable having value of number of channels in hidden state
			shape : an array containing shape of the hidden and cell state 
		Returns:
			cell state and hidden state
		"""
		if self.Wci is None:
			self.Wci = Variable(torch.zeros(1, hidden, shape[0], shape[1])).cuda()
			self.Wcf = Variable(torch.zeros(1, hidden, shape[0], shape[1])).cuda()
			self.Wco = Variable(torch.zeros(1, hidden, shape[0], shape[1])).cuda()
		else:
			assert shape[0] == self.Wci.size()[2], 'Input Height Mismatched!'
			assert shape[1] == self.Wci.size()[3], 'Input Width Mismatched!'
		return (Variable(torch.zeros(batch_size, hidden, shape[0], shape[1])).cuda(),
				Variable(torch.zeros(batch_size, hidden, shape[0], shape[1])).cuda()
				)



发表评论:

◎欢迎参与讨论,请在这里发表您的看法、交流您的观点。

Powered By Z-BlogPHP 1.5.2 Zero Theme By 爱墙纸

Copyright ZiYouMan.cn. All Rights Reserved. 蜀ICP备15004526号