【动手学深度学习】pytorch-参数管理

news/2024/7/7 14:14:36

pytorch-参数管理

概述

 我们的目标是找到使损失函数最小化的模型参数值。 经过训练后,我们将需要使用这些参数来做出未来的预测。 此外,有时我们希望提取参数,以便在其他环境中复用它们, 将模型保存下来,以便它可以在其他软件中执行, 或者为了获得科学的理解而进行检查。

# 创建一个单隐藏层的MLP
import torch
from torch import nn

net = nn.Sequential(nn.Linear(4,8),nn.ReLU(),nn.Linear(8,1))
X = torch.rand(size = (2,4))
net(X)



参数访问

# 参数访问  全连接层包含两个参数  分别是该层的权重和偏置  两者都为存储单精度浮点数
print(net[2].state_dict())

在这里插入图片描述

print(type(net[2].bias))
print(net[2].bias)
print(net[2].bias.data)

在这里插入图片描述

# 一次性访问所有参数
print(*[(name,param.shape) for name,param in net[0].named_parameters()])
print(*[(name,param.shape) for name,param in net.named_parameters()])

在这里插入图片描述

嵌套块收集参数


def block1():
    return nn.Sequential(nn.Linear(4,8),nn.ReLU(),
                         nn.Linear(8,4),nn.ReLU())

def block2():
    net = nn.Sequential()
    for i in range(4):
        net.add_module(f'block{i}',block1())

    return net

#  块和层之间进行组合
rgnet = nn.Sequential(block2(),nn.Linear(4,1))
rgnet(X)

在这里插入图片描述

访问第一个主要的块中第二个子块的第一层的偏置
在这里插入图片描述

参数初始化

 pytorch根据一个范围均匀初始化权重和偏置矩阵 这个范围是根据输入和输出维度计算得到,Pytorch.init模块提供了多种预置初始化方法。

内置初始化

下面的代码将所有的权重参数初始化为标准差为0.01的高斯随机变量 并且将偏置参数设置为0

def init_normal(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight,mean = 0,std = 0.01)
        nn.init.zeros_(m.bias)

net.apply(init_normal)
net[0].weight.data[0],net[0].bias.data[0]

可以将所有的参数初始化为1


def init_constant(m):
    if type(m) == nn.Linear:
        nn.init.constant_(m.weight,1)
        nn.init.zeros_(m.bias)

net.apply(init_constant)
net[0].weight.data[0],net[0].bias.data[0]

针对不同的块进行初始化

def init_xavier(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)

def init_42(m):
    if type(m) == nn.Linear:
        nn.init.constant_(m.weight,42)

net[0].apply(init_xavier)
net[2].apply(init_42)
print(net[0].weight.data[0])
print(net[2].weight.data)

自定义初始化

def my_init(m):
    if type(m) == nn.Linear:
        print("Init", *[(name, param.shape)
                        for name, param in m.named_parameters()][0])
        nn.init.uniform_(m.weight, -10, 10)
        m.weight.data *= m.weight.data.abs() >= 5

net.apply(my_init)
net[0].weight[:2]

参数共享

第三层和第四层共享一个参数

shared = nn.Linear(8,8)
net = nn.Sequential(nn.Linear(4,8),nn.ReLU(),
                    
                    shared,nn.ReLU(),
                    shared,nn.ReLU(),
                    nn.Linear(8,1))


net(X)

print(net[2].weight.data[0] == net[4].weight.data[0])




http://www.niftyadmin.cn/n/4594191.html

相关文章

分享嵌入式入门学习指导

最近有好多同学在咨询嵌入式该怎么入门,应该怎么学习,有什么好的学习方法推荐,以及嵌入式入门的学习路线。今天我就带着大家的问题,一一为大家解决。首先嵌入式门槛虽然较高,但也跟其他事物一样,并不是牢不…

小心旺旺新骗局“这款还有货吗,我想拍哦”

这个死骗子,今天一大早给我发了一个如下所示的消息: 注意上面的链接,它实际上是一个转向,先转到淘宝的登陆页面,等你输入用户名与密码后,再转到了www.tcobco.com,以达到某些不可告人的目的&…

mysql B+tree

什么是索引?索引是为了加速对表中数据行的检索而创建的一种分散存储的数据结构。 id和磁盘地址的映射。 关系型数据库存在磁盘当中。 为什要用索引?索引能极大减少存储引擎需要扫描的数据量。 索引可以把随机IO变成顺序IO。 索引可以帮助我们在进行分组、…

写给电子工程师的,非常值得一看

今天带着大家了解下未来嵌入式大致发展方向,以及的对嵌入式入门学习的一个规划!!!! 嵌入式应用领域如下图所示:当我们在学习嵌入式时,我们首先需要了解嵌入式应用领域,且我们以后向往…

ansible基础-playbook剧本的使用

ansible基础-playbook剧本的使用 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任。 一.YAML概述 1>.YAML的诞生 YAML是一个可读性高,用来表达数据序列的格式。  YAML参考了其他多种语言&#xff0c…

c语言的七大查找算法,非常值得学习

今天带着大家学习下,C语言七大查找算法!!!!!!!!!! 这里我们首先看下算法的概念:算法(Algorithm)是指解题方案的准确而完整的描述,是一系列解决问题的清晰指令,算法代表着用系统的方法描述解决问题的策略机制…

hdoj3038(带权并查集)

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid3038 题意:对于给定的a1..an,通过询问下标x..y,给出a[x]...a[y],但给出的值可能是错的,需要判断,因为题目说的是整数,也可能是…

飞思卡尔MC9S12X Flash驱动

今天带着大家学习下飞思卡尔MC9S12 Flash驱动 在现今的经济社会,比拼的“快”不仅仅是速度快,更是效率高。身处社会分工细致的今天,让自己更快效率更高是有方法的。 每一家MCU产商都会提供他们生产的MCU型号的datasheet,Reference Manual等…