当前位置: 代码迷 >> 综合 >> torch.topk()
  详细解决方案

torch.topk()

热度:85   发布时间:2024-01-10 12:18:05.0

官网说明:

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

Returns the k largest elements of the given input tensor along a given dimension.

If dim is not given, the last dimension of the input is chosen.

If largest is False then the k smallest elements are returned.

A namedtuple of (values, indices) is returned, where the indices are the indices of the elements in the original input tensor.

The boolean option sorted if True, will make sure that the returned k elements are themselves sorted

parameters:

  • input (Tensor) – the input tensor.

  • k (int) – the k in “top-k”

  • dim (int, optional) – the dimension to sort along

  • largest (bool, optional) – controls whether to return largest or smallest elements

  • sorted (bool, optional) – controls whether to return the elements in sorted order

  • out (tuple, optional) – the output tuple of (Tensor, LongTensor) that can be optionally given to be used as output buffers

举个例子:

x = torch.rand(3,5)
print(x)
indexs = x.topk(k=1,dim=1) 
print(indexs[1])
tensor([[0.9068, 0.6301, 0.3500, 0.3612, 0.8632],[0.6435, 0.7596, 0.5890, 0.4887, 0.3763],[0.7244, 0.7431, 0.2717, 0.7388, 0.7798]])
tensor([[0],[1],[4]])
  相关解决方案