tensorflow-tf.shape(x)、x.shape和x.get_shape()的区别
tf.shape(x)、x.shape和x.get_shape()的区别
对于Tensor来说
1import tensorflow as tf
2
3input = tf.constant([[0,1,2],[3,4,5]])
4
5print(type(input.shape))
6print(type(input.get_shape()))
7print(type(tf.shape(input)))
8
9Out:
10<class 'tensorflow.python.framework.tensor_shape.TensorShape'>
11<class 'tensorflow.python.framework.tensor_shape.TensorShape'>
12<class 'tensorflow.python.framework.ops.Tensor'>
python
可以看到x.shape
和x.get_shape()
都是返回TensorShape类型对象,而tf.shape(x)
返回的是Tensor类型对象。
具体来说tf.shape()
返回的是tensor,想要获取tensor具体的shape结果需要sess.run
才行。而tf.get_shape
和x.shape
返回的是一个元组,因此要想操作维度信息,则需要调用TensorShape的tf.as_list()
方法,返回的是Python的list。
需要注意的是tf.get_shape()
返回的是元组,不能放到sess.run()
里面,这个里面只能放operation和tensor
对于placeholder来说
对tf.placeholder
占位符来说,如果shape设置的其中某一个是None,那么对于tf.shape,sess.run
会报错,而tf.get_shape
不会,它会在None位置显示“?”表示此位置的shape暂时未知。
1a = tf.Variable(tf.constant(1.5, dtype=tf.float32, shape=[1,2,3,4,5,6,7]), name='a')
2b = tf.placeholder(dtype=tf.int32, shape=[None, 3], name='b')
3s1 = tf.shape(a)
4s2 = a.get_shape()
5print (s1) # Tensor("Shape:0", shape=(7,), dtype=int32)
6print (s2) # 元组 (1, 2, 3, 4, 5, 6, 7)
7
8s11 = tf.shape(b)
9s21 = b.get_shape()
10print (s11) # Tensor("Shape_1:0", shape=(2,), dtype=int32)
11print (s21) # 因为第一位设置的是None,所以这里的第一位显示问号表示暂时不确认 (?, 3)
12with tf.Session() as sess:
13 sess.run(tf.global_variables_initializer())
14 print (sess.run(s1)) # [1 2 3 4 5 6 7]
15 print (sess.run(s11))
16 # InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'b' with dtype int32
17 # [[Node: b = Placeholder[dtype=DT_INT32, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]
python