官网说明:
torch.topk
(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
Returns the k
largest elements of the given input
tensor along a given dimension.
If dim
is not given, the last dimension of the input is chosen.
If largest
is False
then the k smallest elements are returned.
A namedtuple of (values, indices) is returned, where the indices are the indices of the elements in the original input tensor.
The boolean option sorted
if True
, will make sure that the returned k elements are themselves sorted
parameters:
-
input (Tensor) – the input tensor.
-
k (int) – the k in “top-k”
-
dim (int, optional) – the dimension to sort along
-
largest (bool, optional) – controls whether to return largest or smallest elements
-
sorted (bool, optional) – controls whether to return the elements in sorted order
-
out (tuple, optional) – the output tuple of (Tensor, LongTensor) that can be optionally given to be used as output buffers
举个例子:
x = torch.rand(3,5)
print(x)
indexs = x.topk(k=1,dim=1)
print(indexs[1])
tensor([[0.9068, 0.6301, 0.3500, 0.3612, 0.8632],[0.6435, 0.7596, 0.5890, 0.4887, 0.3763],[0.7244, 0.7431, 0.2717, 0.7388, 0.7798]])
tensor([[0],[1],[4]])