Tensorflow中的tf.argmax()函数

官方API定义

tf.argmax(input, axis=None, name=None, dimension=None)

Returns the index with the largest value across axes of a tensor.
Args:

  • input: A Tensor. Must be one of the following types: float32, float64, int64, int32, uint8, uint16, int16, int8, complex64, complex128, qint8, quint8, qint32, half.
  • axis: A Tensor. Must be one of the following types: int32, int64. int32, 0 <= axis < rank(input). Describes which axis of the input Tensor to reduce across. For vectors, use axis = 0.
  • name: A name for the operation (optional).

Returns:

  • A Tensor of type int64.

定义中的axis与numpy中的axis是一致的,下面通过代码进行解释。


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
<span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
<span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf

sess = tf.session()
m = sess.run(tf.truncated_normal((<span class="hljs-number">5</span>,<span class="hljs-number">10</span>), stddev = <span class="hljs-number">0.1</span>) )
<span class="hljs-keyword">print</span> type(m)
<span class="hljs-keyword">print</span> m

-------------------------------------------------------------------------------
&lt;type <span class="hljs-string">'numpy.ndarray'</span>&gt;
[[ <span class="hljs-number">0.09957541</span> <span class="hljs-number">-0.0965599</span>   <span class="hljs-number">0.06064715</span> <span class="hljs-number">-0.03011306</span>  <span class="hljs-number">0.05533558</span>  <span class="hljs-number">0.17263047</span>
  <span class="hljs-number">-0.02660419</span>  <span class="hljs-number">0.08313394</span> <span class="hljs-number">-0.07225946</span>  <span class="hljs-number">0.04916157</span>]
 [ <span class="hljs-number">0.11304571</span>  <span class="hljs-number">0.02099175</span>  <span class="hljs-number">0.03591062</span>  <span class="hljs-number">0.01287777</span> <span class="hljs-number">-0.11302195</span>  <span class="hljs-number">0.04822164</span>
  <span class="hljs-number">-0.06853487</span>  <span class="hljs-number">0.0800944</span>  <span class="hljs-number">-0.1155676</span>  <span class="hljs-number">-0.01168544</span>]
 [ <span class="hljs-number">0.15760773</span>  <span class="hljs-number">0.05613248</span>  <span class="hljs-number">0.04839646</span> <span class="hljs-number">-0.0218203</span>   <span class="hljs-number">0.02233066</span>  <span class="hljs-number">0.00929849</span>
  <span class="hljs-number">-0.0942843</span>  <span class="hljs-number">-0.05943</span>     <span class="hljs-number">0.08726917</span> <span class="hljs-number">-0.059653</span>  ]
 [ <span class="hljs-number">0.02553608</span>  <span class="hljs-number">0.07298559</span> <span class="hljs-number">-0.06958302</span>  <span class="hljs-number">0.02948747</span>  <span class="hljs-number">0.00232073</span>  <span class="hljs-number">0.11875584</span>
  <span class="hljs-number">-0.08325859</span> <span class="hljs-number">-0.06616175</span>  <span class="hljs-number">0.15124641</span>  <span class="hljs-number">0.09522969</span>]
 [<span class="hljs-number">-0.04616683</span>  <span class="hljs-number">0.01816062</span> <span class="hljs-number">-0.10866459</span> <span class="hljs-number">-0.12478453</span>  <span class="hljs-number">0.01195056</span>  <span class="hljs-number">0.0580056</span>
  <span class="hljs-number">-0.08500613</span>  <span class="hljs-number">0.00635608</span> <span class="hljs-number">-0.00108647</span>  <span class="hljs-number">0.12054099</span>]]

m是一个5行10列的矩阵,类型为numpy.ndarray


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
<span class="hljs-comment">#使用tensorflow中的tf.argmax()</span>
col_max = sess.run(tf.argmax(m, <span class="hljs-number">0</span>) )  <span class="hljs-comment">#当axis=0时返回每一列的最大值的位置索引</span>
<span class="hljs-keyword">print</span> col_max

row_max = sess.run(tf.argmax(m, <span class="hljs-number">1</span>) )  <span class="hljs-comment">#当axis=1时返回每一行中的最大值的位置索引</span>
<span class="hljs-keyword">print</span> row_max

array([<span class="hljs-number">2</span>, <span class="hljs-number">3</span>, <span class="hljs-number">0</span>, <span class="hljs-number">3</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">3</span>, <span class="hljs-number">4</span>])
array([<span class="hljs-number">5</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">8</span>, <span class="hljs-number">9</span>])

-------------------------------------------------------------------------------
<span class="hljs-comment">#使用numpy中的numpy.argmax</span>
row_max = m.argmax(<span class="hljs-number">0</span>)
<span class="hljs-keyword">print</span> row_max

col_max = m.argmax(<span class="hljs-number">1</span>)
<span class="hljs-keyword">print</span> col_max

array([<span class="hljs-number">2</span>, <span class="hljs-number">3</span>, <span class="hljs-number">0</span>, <span class="hljs-number">3</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">3</span>, <span class="hljs-number">4</span>])
array([<span class="hljs-number">5</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">8</span>, <span class="hljs-number">9</span>])

可以看到tf.argmax()与numpy.argmax()方法的用法是一致的

axis = 0的时候返回每一列最大值的位置索引
axis = 1的时候返回每一行最大值的位置索引
axis = 2、3、4…,即为多维张量时,同理推断

参考


  1. Tensorflow官方API tf.argmax说明
  2. Numpy官方AIP numpy.argmax说明

转自:http://www.jianshu.com/p/469789141af7

Be the first to comment

Leave a Reply

Your email address will not be published.