torch 模块中 sort 方法的基本使用

torch.sort()

方法原型

torch.sort(input, dim=None, descending=False, out=None) -> (Tensor, LongTensor)

返回值

A tuple of (sorted_tensor, sorted_indices) is returned, 
where the sorted_indices are the indices of the elements in the original input tensor.

参数

  • input (Tensor) – the input tensor
    形式上与 numpy.narray 类似
  • dim (int, optional) – the dimension to sort along
    维度,对于二维数据:dim=0 按列排序,dim=1 按行排序,默认 dim=1
  • descending (bool, optional) – controls the sorting order (ascending or descending)
    降序,descending=True 从大到小排序,descending=False 从小到大排序,默认 descending=Flase

实例

>>> import torch
>>> x = torch.randn(3,4)
>>> x  #初始值,始终不变
tensor([[-0.9950, -0.6175, -0.1253,  1.3536],
        [ 0.1208, -0.4237, -1.1313,  0.9022],
        [-1.1995, -0.0699, -0.4396,  0.8043]])
>>> sorted, indices = torch.sort(x)  #按行从小到大排序
>>> sorted
tensor([[-0.9950, -0.6175, -0.1253,  1.3536],
        [-1.1313, -0.4237,  0.1208,  0.9022],
        [-1.1995, -0.4396, -0.0699,  0.8043]])
>>> indices
tensor([[0, 1, 2, 3],
        [2, 1, 0, 3],
        [0, 2, 1, 3]])
>>> sorted, indices = torch.sort(x, descending=True)  #按行从大到小排序 (即反序)
>>> sorted
tensor([[ 1.3536, -0.1253, -0.6175, -0.9950],
        [ 0.9022,  0.1208, -0.4237, -1.1313],
        [ 0.8043, -0.0699, -0.4396, -1.1995]])
>>> indices
tensor([[3, 2, 1, 0],
        [3, 0, 1, 2],
        [3, 1, 2, 0]])
>>> sorted, indices = torch.sort(x, dim=0)  #按列从小到大排序
>>> sorted
tensor([[-1.1995, -0.6175, -1.1313,  0.8043],
        [-0.9950, -0.4237, -0.4396,  0.9022],
        [ 0.1208, -0.0699, -0.1253,  1.3536]])
>>> indices
tensor([[2, 0, 1, 2],
        [0, 1, 2, 1],
        [1, 2, 0, 0]])
>>> sorted, indices = torch.sort(x, dim=0, descending=True)  #按列从大到小排序
>>> sorted
tensor([[ 0.1208, -0.0699, -0.1253,  1.3536],
        [-0.9950, -0.4237, -0.4396,  0.9022],
        [-1.1995, -0.6175, -1.1313,  0.8043]])
>>> indices
tensor([[1, 2, 0, 0],
        [0, 1, 2, 1],
        [2, 0, 1, 2]])

官方文档:https://pytorch.org/docs/stable/torch.html

发表评论

电子邮件地址不会被公开。 必填项已用*标注