keras Lambda 层

时间:2018-10-01 16:09:38   收藏:0   阅读:12585

Lambda层

keras.layers.core.Lambda(function, output_shape=None, mask=None, arguments=None)

本函数用以对上一层的输出施以任何Theano/TensorFlow表达式

如果你只是想对流经该层的数据做个变换,而这个变换本身没有什么需要学习的参数,那么直接用Lambda Layer是最合适的了。

导入的方法是 

from keras.layers.core import Lambda

 Lambda函数接受两个参数,第一个是输入张量对输出张量的映射函数,第二个是输入的shape对输出的shape的映射函数。

参数

例子

# add a x -> x^2 layer
model.add(Lambda(lambda x: x ** 2))
# add a layer that returns the concatenation# of the positive part of the input and
# the opposite of the negative part

def antirectifier(x):
    x -= K.mean(x, axis=1, keepdims=True)
    x = K.l2_normalize(x, axis=1)
    pos = K.relu(x)
    neg = K.relu(-x)
    return K.concatenate([pos, neg], axis=1)

def antirectifier_output_shape(input_shape):
    shape = list(input_shape)
    assert len(shape) == 2  # only valid for 2D tensors
    shape[-1] *= 2
    return tuple(shape)

model.add(Lambda(antirectifier,
         output_shape=antirectifier_output_shape))

输入shape

任意,当使用该层作为第一层时,要指定input_shape

输出shape

output_shape参数指定的输出shape,当使用tensorflow时可自动推断

================================================

keras Lambda自定义层实现数据的切片,Lambda传参数

1、代码如下:

import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Activation,Reshape
from keras.layers import merge
from keras.utils.visualize_util import plot
from keras.layers import Input, Lambda
from keras.models import Model

def slice(x,index):
  return x[:,:,index]

a = Input(shape=(4,2))
x1 = Lambda(slice,output_shape=(4,1),arguments={index:0})(a)
x2 = Lambda(slice,output_shape=(4,1),arguments={index:1})(a)
x1 = Reshape((4,1,1))(x1)
x2 = Reshape((4,1,1))(x2)
output = merge([x1,x2],mode=concat)
model
= Model(a, output) x_test = np.array([[[1,2],[2,3],[3,4],[4,5]]]) print model.predict(x_test) plot(model, to_file=lambda.png,show_shapes=True)

2、注意Lambda 是可以进行参数传递的,传递的方式如下代码所述:

def slice(x,index):
    return x[:,:,index]

 如上,index是参数,通过字典将参数传递进去.

x1 = Lambda(slice,output_shape=(4,1),arguments={index:0})(a)
x2 = Lambda(slice,output_shape=(4,1),arguments={index:1})(a)

3、上述代码实现的是,将矩阵的每一列提取出来,然后单独进行操作,最后在拼在一起。可视化的图如下所示。

 

技术分享图片
 
 

 

参考:

https://blog.csdn.net/hewb14/article/details/53414068

https://blog.csdn.net/lujiandong1/article/details/54936185

https://keras-cn.readthedocs.io/en/latest/layers/core_layer/

来自为知笔记(Wiz)



评论(0
© 2014 mamicode.com 版权所有 京ICP备13008772号-2  联系我们:gaon5@hotmail.com
迷上了代码!