目次前言torch.no_grad() 是 PyTorch 中的一个上下文管理器,用于在上下文中临时禁用主动梯度计算。它在模子评估或推理阶段非常有效,因为在这些阶段,我们通常不须要计算梯度。禁用梯度计算可以减少内存消耗,并加速计算速率。 基本概念在 PyTorch 中,每次对 requires_grad=True 的张量进行操作时,PyTorch 会构建一个计算图(computation graph),用于计算反向传播的梯度。这对训练模子是须要的,但在评估或推理时不须要。因此,我们可以利用 torch.no_grad() 来临时禁用这些计算图的构建和梯度计算。 用法torch.no_grad() 的利用非常简朴。只须要将不须要梯度计算的代码块放在 with torch.no_grad(): 下即可。 示例代码以下是一个利用 torch.no_grad() 的示例: [code]import torch # 创建一个张量,并设置 requires_grad=True 以便记录梯度 x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) # 在 torch.no_grad() 上下文中禁用梯度计算 with torch.no_grad(): y = x + 2 print(y) # 此时,x 的 requires_grad 属性仍旧为 True,但 y 的 requires_grad 属性为 False print("x 的 requires_grad:", x.requires_grad) print("y 的 requires_grad:", y.requires_grad)[/code]具体解释创建张量并设置 requires_grad=True: [code]x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)[/code]创建一个包罗三个元素的张量 x。 设置 requires_grad=True,告诉 PyTorch 须要为该张量记录梯度。 禁用梯度计算: [code]with torch.no_grad(): y = x + 2 print(y)[/code]进入 torch.no_grad() 上下文,临时禁用梯度计算。 在上下文中,对 x 进行加法操作,得到新的张量 y。 打印 y,此时 y 的 requires_grad 属性为 False。 检察 requires_grad 属性: [code]print("x 的 requires_grad:", x.requires_grad) print("y 的 requires_grad:", y.requires_grad)[/code]打印 x 的 requires_grad 属性,仍旧为 True。 打印 y 的 requires_grad 属性,已被禁用为 False。 利用场景模子评估在评估模子性能时,不须要计算梯度。利用 torch.no_grad() 可以提高评估速率和减少内存消耗。 [code]model.eval() # 切换到评估模式 with torch.no_grad(): for data in validation_loader: outputs = model(data) # 计算评估指标[/code]模子推理在摆设和推理阶段,只须要前向传播,不须要反向传播,因此可以利用 torch.no_grad()。 [code]with torch.no_grad(): outputs = model(inputs) predicted = torch.argmax(outputs, dim=1)[/code]初始化权重或其他不须要梯度的操作 在某些初始化或操作中,不须要梯度计算。 [code]with torch.no_grad(): model.weight.fill_(1.0) # 直接修改权重[/code]小结torch.no_grad() 是一个用于禁用梯度计算的上下文管理器,适用于模子评估、推理等不须要梯度计算的场景。利用 torch.no_grad() 可以明显减少内存利用和加速计算。通过明确和合理利用 torch.no_grad(),可以使得模子评估和推理更加高效和稳固。 额外留意事项训练模式与评估模式: 在利用 torch.no_grad() 时,通常还会将模子设置为评估模式(model.eval()),以确保某些层(如 dropout 和 batch normalization)在推理时的行为与训练时不同。 嵌套利用: torch.no_grad() 可以嵌套利用,内层的 torch.no_grad() 仍旧会禁用梯度计算。 [code]with torch.no_grad(): with torch.no_grad(): y = x + 2 print(y)[/code]规复梯度计算: 在 torch.no_grad() 上下文管理器退出后,梯度计算会主动规复,不须要额外操作。 [code]with torch.no_grad(): y = x + 2 print(y) # 这里梯度计算规复 z = x * 2 print(z.requires_grad) # True[/code]通过合理利用 torch.no_grad(),可以在不须要梯度计算的场景中提升性能并节省资源。 总结到此这篇关于PyTorch中torch.no_grad()用法举例详解的文章就先容到这了,更多相关PyTorch torch.no_grad()详解内容请搜索脚本之家以前的文章或继续欣赏下面的相关文章渴望大家以后多多支持脚本之家! 来源:https://www.jb51.net/python/328240bn7.htm 免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作! |
|手机版|小黑屋|梦想之都-俊月星空
( 粤ICP备18056059号 )|网站地图
GMT+8, 2025-7-1 18:04 , Processed in 0.029850 second(s), 20 queries .
Powered by Mxzdjyxk! X3.5
© 2001-2025 Discuz! Team.