• sort/argsort
  • topk
  • top-5 Acc

1、sort,argsort

  sort:对序列进行一个完全的排序

  argsort:返回排序后的index

(1)tf.random.shuffle(),沿着张量的第一个维度进行打乱

 1 a = tf.range(5)
 2 b = tf.random.shuffle(a)
 3 print(b.numpy()) # [3 4 1 2 0]
 4
 5 a = tf.constant([[1,2],[2,3],[5,6]])
 6 b = tf.random.shuffle(a)
 7 print(b.numpy())
   #只打乱了第一个维度,里面的维度没有打乱
  """   [[5 6]   [1 2]    [2 3]]   """

(2)sort和argsort

a = tf.random.shuffle(tf.range(8))
print(a) #tf.Tensor([1 2 0 4 3], shape=(5,), dtype=int32)

b1 = tf.sort(a) #默认是升序
print(b1)  #tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
b2 = tf.sort(a,direction='DESCENDING') # descending:下降的
print(b2)  #tf.Tensor([4 3 2 1 0], shape=(5,), dtype=int32)

indx = tf.argsort(a,direction='DESCENDING')
print(indx) #tf.Tensor([3 7 0 6 4 1 2 5], shape=(8,), dtype=int32),得到降序排列后的索引值
b3 = tf.gather(a,indx)
print(b3)  #tf.Tensor([7 6 5 4 3 2 1 0], shape=(8,), dtype=int32),通过索引index在a中收集对应的值,得到降序后的排列
01-04 21:34
查看更多