- tensor.type([张量类型])
- torch.double()
代码
import torch
import numpy as np
# 使用 type() 函数进行转换
def test01():
data = torch.full([2,3], 10)
print(data.dtype)
# 注意:返回一个新的类型转换过的张量
data = data.type(torch.DoubleTensor)
#data = data.type(torch.IntTensor)
print(data.dtype)
# 使用具体类型函数进行转换
def test02():
data = torch.full([2,3], 10)
print(data.dtype)
# 转换程 float64 类型
data = data.double()
print(data.dtype)
"""
# 转换成其他类型
data.short() # 将张量元素转换成 int16 类型
data. int() # 将张量转换成 int32 类型
data.long() # 将张量转换成 int64 类型
data.float() # 将张量转换成 float32 类型
"""
if __name__ == '__main__':
test01()