Tensorflow中的tf.argmax()函数

官方API定义

tf.argmax(input, axis=None, name=None, dimension=None)

Returns the index with the largest value across axes of a tensor.
Args:

  • input: A Tensor. Must be one of the following types: float32, float64, int64, int32, uint8, uint16, int16, int8, complex64, complex128, qint8, quint8, qint32, half.
  • axis: A Tensor. Must be one of the following types: int32, int64. int32, 0 <= axis < rank(input). Describes which axis of the input Tensor to reduce across. For vectors, use axis = 0.
  • name: A name for the operation (optional).

Returns:

  • A Tensor of type int64.

定义中的axis与numpy中的axis是一致的,下面通过代码进行解释。

import numpy as np
import tensorflow as tf

sess = tf.session()
m = sess.run(tf.truncated_normal((5,10), stddev = 0.1) )
print type(m)
print m

-------------------------------------------------------------------------------
<type 'numpy.ndarray'>
[[ 0.09957541 -0.0965599   0.06064715 -0.03011306  0.05533558  0.17263047
  -0.02660419  0.08313394 -0.07225946  0.04916157]
 [ 0.11304571  0.02099175  0.03591062  0.01287777 -0.11302195  0.04822164
  -0.06853487  0.0800944  -0.1155676  -0.01168544]
 [ 0.15760773  0.05613248  0.04839646 -0.0218203   0.02233066  0.00929849
  -0.0942843  -0.05943     0.08726917 -0.059653  ]
 [ 0.02553608  0.07298559 -0.06958302  0.02948747  0.00232073  0.11875584
  -0.08325859 -0.06616175  0.15124641  0.09522969]
 [-0.04616683  0.01816062 -0.10866459 -0.12478453  0.01195056  0.0580056
  -0.08500613  0.00635608 -0.00108647  0.12054099]]

m是一个5行10列的矩阵,类型为numpy.ndarray

#使用tensorflow中的tf.argmax()
col_max = sess.run(tf.argmax(m, 0) )  #当axis=0时返回每一列的最大值的位置索引
print col_max

row_max = sess.run(tf.argmax(m, 1) )  #当axis=1时返回每一行中的最大值的位置索引
print row_max

array([2, 3, 0, 3, 0, 0, 0, 0, 3, 4])
array([5, 0, 0, 8, 9])

-------------------------------------------------------------------------------
#使用numpy中的numpy.argmax
row_max = m.argmax(0)
print row_max

col_max = m.argmax(1)
print col_max

array([2, 3, 0, 3, 0, 0, 0, 0, 3, 4])
array([5, 0, 0, 8, 9])

可以看到tf.argmax()与numpy.argmax()方法的用法是一致的

axis = 0的时候返回每一列最大值的位置索引
axis = 1的时候返回每一行最大值的位置索引
axis = 2、3、4…,即为多维张量时,同理推断

参考


  1. Tensorflow官方API tf.argmax说明
  2. Numpy官方AIP numpy.argmax说明

转自:http://www.jianshu.com/p/469789141af7

Related posts

Leave a Comment