前言
这两个函数优点是通过GPU 运算速度快
目录:
1 where
2 Gather
一 where
原理:
torch.where(condition,x,y)
输入参数:
condition: 判断条件
x,y: Tensor
返回值:
符合条件时: 取x, 不满足取y
优点: 可以使用GPU,加快运算速度
# -*- coding: utf-8 -*-
"""
Created on Thu Dec 22 21:48:02 2022
@author: cxf
"""
import torch
def statistics():
    ans = torch.rand(4,2)
    
    x = torch.tensor([[1,2],
               [1,2],
               [1,2],
               [1,2]])
    
    y = torch.tensor([[3,4],
               [3,4],
               [3,4],
               [3,4]])
    
    
    out =torch.where(ans>0.5,x,y)
    print("\n ans: \n",ans)
    
    print("\n out:  \n",out)
statistics()    
 
二 Gather
输入:
Input
函数说明:
data. gather(dim=d, index=idx)
输入参数:
index: 映射的索引值
data 的shape 和 index的shape 必须一致
但是各维度的size 可以不一致
dim:
映射的维度
输出参数
输出张量的shape 的大小和index 一样
例一 dim =0
# -*- coding: utf-8 -*-
"""
Created on Wed Dec 28 15:34:09 2022
@author: chengxf2
"""
import torch
def gather():
    data = torch.arange(1, 16, 1).view(3,5)
    
    
    print("\n\n",data.numpy())
    
    idx = torch.LongTensor([[0,0,1]])
    
    idx1 = torch.LongTensor([[0],
                             [0],
                             [2]])
    
    a = data.gather(dim=0, index= idx)
    b = data.gather(dim=0, index= idx1) 
    print("\n\n\n\n",a.numpy(),idx.shape)
    print("\n\n\n\n\n",b.numpy(),idx1.shape)
    
gather()data 的shape [3,5]

idx=[[0,0,2]] shape [1,3]

0,0,1 分别代表取data[0,:] data[0,:] .data[1,:],
对应列为索引所在的位置 [0,0,1] 所在位置分别为 【0,1,2】
输出为:
          
同理 idx1=[[0],[0],[2]],shape: torch.Size([3, 1])


例2 dim=1
def gather():
    data = torch.arange(1, 16, 1).view(3,5)
    
    
    print("\n\n",data.numpy())
    
    idx = torch.LongTensor([[0,1,2]])
    
    idx1 = torch.LongTensor([[0],
                             [1],
                             [2]])
    
    a = data.gather(dim=1, index= idx)
    b = data.gather(dim=1, index= idx1) 
    print("\n\n\n\n",a.numpy(),idx.shape)
   
    print("\n\n\n\n\n",b.numpy(),idx1.shape)index 内元素值指定所在列,
行是由index 元素所在行指定
输出的shape 保持一致




















