.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "beginner/basics/saveloadrun_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_beginner_basics_saveloadrun_tutorial.py: `基础知识 `_ || `快速入门 `_ || `张量 `_ || `数据集与数据加载器 `_ || `Transforms `_ || `构建神经网络 `_ || `自动微分 `_ || `优化模型参数 `_ || **保存和加载模型** 保存和加载模型 ============================ 在本节中,我们将学习如何通过保存、加载以及运行模型预测,来持久化模型。 .. GENERATED FROM PYTHON SOURCE LINES 17-22 .. code-block:: default import torch import torchvision.models as models .. GENERATED FROM PYTHON SOURCE LINES 23-27 保存和加载模型权重 -------------------------------- PyTorch模型将学习到的参数存储在一个内部状态字典中,称为``state_dict``。这些参数可以通过``torch.save``进行持久化。 方法: .. GENERATED FROM PYTHON SOURCE LINES 27-31 .. code-block:: default model = models.vgg16(weights='IMAGENET1K_V1') torch.save(model.state_dict(), 'model_weights.pth') .. GENERATED FROM PYTHON SOURCE LINES 32-33 要加载模型权重,您需要先创建一个相同模型的实例,然后使用``load_state_dict()``方法加载参数。 .. GENERATED FROM PYTHON SOURCE LINES 33-38 .. code-block:: default model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model model.load_state_dict(torch.load('model_weights.pth')) model.eval() .. GENERATED FROM PYTHON SOURCE LINES 39-40 。。 注意:: 在进行推理之前,请确保调用``model.eval()``方法以将 dropout 和 batch normalization layers设置为评估模式。如果不这样做,将导致不一致的推理结果。 .. GENERATED FROM PYTHON SOURCE LINES 42-46 保存和加载带有结构的模型 ------------------------------------- 在加载模型权重时,我们需要先实例化模型类,因为类定义了网络的结构。我们可能希望将这个类的结构与模型一起保存, 在这种情况下,我们可以将``model``(而不是``model.state_dict()``)传递给 save 函数: .. GENERATED FROM PYTHON SOURCE LINES 46-49 .. code-block:: default torch.save(model, 'model.pth') .. GENERATED FROM PYTHON SOURCE LINES 50-51 我们可以使用如下方式加载模型: .. GENERATED FROM PYTHON SOURCE LINES 51-54 .. code-block:: default model = torch.load('model.pth') .. GENERATED FROM PYTHON SOURCE LINES 55-56 .. 注意:: 这种方法在序列化模型时使用 Python 的 `pickle `_模块,因此在加载模型时需要依赖实际的类定义。 .. GENERATED FROM PYTHON SOURCE LINES 58-62 相关教程 ----------------- - `PyTorch 中保存和加载通用Checkpoint `_ - `从 checkpoint 加载 nn.Module 的实用技巧 `_ .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_beginner_basics_saveloadrun_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: saveloadrun_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: saveloadrun_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_