torch 模块中 sort 方法的基本使用

torch.sort()

方法原型

1torch.sort(input, dim=None, descending=False, out=None) -> (Tensor, LongTensor)

返回值

1A tuple of (sorted_tensor, sorted_indices) is returned, 
2where the sorted_indices are the indices of the elements in the original input tensor.

参数

  • input (Tensor) – the input tensor
    形式上与 numpy.narray 类似
  • dim (int, optional) – the dimension to sort along
    维度,对于二维数据:dim=0 按列排序,dim=1 按行排序,默认 dim=1
  • descending (bool, optional) – controls the sorting order (ascending or descending)
    降序,descending=True 从大到小排序,descending=False 从小到大排序,默认 descending=Flase

实例

 1import torch
2x = torch.randn(3,4)
3x  #初始值,始终不变
4tensor([[-0.9950-0.6175-0.1253,  1.3536],
5        [ 0.1208-0.4237-1.1313,  0.9022],
6        [-1.1995-0.0699-0.4396,  0.8043]])
7sorted, indices = torch.sort(x)  #按行从小到大排序
8sorted
9tensor([[-0.9950-0.6175-0.1253,  1.3536],
10        [-1.1313-0.4237,  0.1208,  0.9022],
11        [-1.1995-0.4396-0.0699,  0.8043]])
12indices
13tensor([[0123],
14        [2103],
15        [0213]])
16sorted, indices = torch.sort(x, descending=True)  #按行从大到小排序 (即反序)
17sorted
18tensor([[ 1.3536-0.1253-0.6175-0.9950],
19        [ 0.9022,  0.1208-0.4237-1.1313],
20        [ 0.8043-0.0699-0.4396-1.1995]])
21indices
22tensor([[3210],
23        [3012],
24        [3120]])
25sorted, indices = torch.sort(x, dim=0)  #按列从小到大排序
26sorted
27tensor([[-1.1995-0.6175-1.1313,  0.8043],
28        [-0.9950-0.4237-0.4396,  0.9022],
29        [ 0.1208-0.0699-0.1253,  1.3536]])
30indices
31tensor([[2012],
32        [0121],
33        [1200]])
34sorted, indices = torch.sort(x, dim=0, descending=True)  #按列从大到小排序
35sorted
36tensor([[ 0.1208-0.0699-0.1253,  1.3536],
37        [-0.9950-0.4237-0.4396,  0.9022],
38        [-1.1995-0.6175-1.1313,  0.8043]])
39indices
40tensor([[1200],
41        [0121],
42        [2012]])

官方文档:https://pytorch.org/docs/stable/torch.html