加入收藏 | 设为首页 | 会员中心 | 我要投稿 衡阳站长网 (https://www.0734zz.cn/)- 数据集成、设备管理、备份、数据加密、智能搜索!
当前位置: 首页 > 站长资讯 > 动态 > 正文

深度学习Pytorch构造Tensor张量

发布时间:2021-11-07 13:47:52 所属栏目:动态 来源:互联网
导读:1 Tensor的裁剪运算 对Tensor中的元素进行范围过滤 常用于梯度裁剪(gradient clipping),即在发生梯度离散或者梯度爆炸时对梯度的处理 torch.clamp(input, min, max, out=None) Tensor:将输入input张量每个元素的夹紧到区间 [min,max],并返回结果到一个新
1 Tensor的裁剪运算
 对Tensor中的元素进行范围过滤  常用于梯度裁剪(gradient clipping),即在发生梯度离散或者梯度爆炸时对梯度的处理  torch.clamp(input, min, max, out=None) → Tensor:将输入input张量每个元素的夹紧到区间 [min,max],并返回结果到一个新张量。
 
 
2 Tensor的索引与数据筛选
 torch.where(codition,x,y):按照条件从x和y中选出满足条件的元素组成新的tensor,输入参数condition:条件限制,如果满足条件,则选择a,否则选择b作为输出。  torch.gather(input,dim,index,out=None):在指定维度上按照索引赋值输出tensor  torch.inex_select(input,dim,index,out=None):按照指定索引赋值输出tensor  torch.masked_select(input,mask,out=None):按照mask输出tensor,输出为向量  torch.take(input,indices):将输入看成1D-tensor,按照索引得到输出tensor  torch.nonzero(input,out=None):输出非0元素的坐标
import torch  #torch.where  a = torch.rand(4, 4)  b = torch.rand(4, 4)  print(a)  print(b)  out = torch.where(a > 0.5, a, b)  print(out)
 
 
print("torch.index_select")  a = torch.rand(4, 4)  print(a)  out = torch.index_select(a, dim=0,                     index=torch.tensor([0, 3, 2]))  #dim=0按列,index取的是行  print(out, out.shape)
 
 
print("torch.gather")  a = torch.linspace(1, 16, 16).view(4, 4)  print(a)  out = torch.gather(a, dim=0,               index=torch.tensor([[0, 1, 1, 1],                                   [0, 1, 2, 2],                                   [0, 1, 3, 3]]))  print(out)  print(out.shape)  #注:从0开始,第0列的第0个,第一列的第1个,第二列的第1个,第三列的第1个,,,以此类推  #dim=0, out[i, j, k] = input[index[i, j, k], j, k]  #dim=1, out[i, j, k] = input[i, index[i, j, k], k] #dim=2, out[i, j, k] = input[i, j, index[i, j, k]]
 
 
print("torch.masked_index")  a = torch.linspace(1, 16, 16).view(4, 4)  mask = torch.gt(a, 8)  print(a)  print(mask)  out = torch.masked_select(a, mask)  print(out)
 
 
print("torch.take")  a = torch.linspace(1, 16, 16).view(4, 4)  b = torch.take(a, index=torch.tensor([0, 15, 13, 10]))  print(b)
 
 
#torch.nonzero  print("torch.take")  a = torch.tensor([[0, 1, 2, 0], [2, 3, 0, 1]])  out = torch.nonzero(a)  print(out)  #稀疏表示
 
 
3 Tensor的组合/拼接
 torch.cat(seq,dim=0,out=None):按照已经存在的维度进行拼接  torch.stack(seq,dim=0,out=None):沿着一个新维度对输入张量序列进行连接。序列中所有的张量都应该为相同形状。
print("torch.stack")  a = torch.linspace(1, 6, 6).view(2, 3)  b = torch.linspace(7, 12, 6).view(2, 3)  print(a, b)  out = torch.stack((a, b), dim=2)  print(out)  print(out.shape)  print(out[:, :, 0])  print(out[:, :, 1])
 
 
4 Tensor的切片
 torch.chunk(tensor,chunks,dim=0):按照某个维度平均分块(最后一个可能小于平均值)  torch.split(tensor,split_size_or_sections,dim=0):按照某个维度依照第二个参数给出的list或者int进行分割tensor

(编辑:衡阳站长网)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    热点阅读