Wie speichere ich ein trainiertes Modell in PyTorch?

Lesezeit: 11 Minuten

Benutzer-Avatar
Wasi Ahmad

Wie speichere ich ein trainiertes Modell in PyTorch? Das habe ich gelesen:

  1. torch.save()/torch.load() dient zum Speichern/Laden eines serialisierbaren Objekts.
  2. model.state_dict()/model.load_state_dict() dient zum Speichern/Laden des Modellzustands.

  • Ich denke, das liegt daran, dass torque.save() auch alle Zwischenvariablen speichert, wie z. B. Zwischenausgaben für die Verwendung der Rückwärtsausbreitung. Aber Sie müssen nur die Modellparameter wie Gewicht/Bias usw. speichern. Manchmal können erstere viel größer sein als letztere.

    – Dawei Yang

    18. März 2017 um 17:36 Uhr

  • Ich habe getestet torch.save(model, f) und torch.save(model.state_dict(), f). Die gespeicherten Dateien haben dieselbe Größe. Jetzt bin ich verwirrt. Außerdem fand ich die Verwendung von pickle zum Speichern von model.state_dict() extrem langsam. Ich denke, der beste Weg ist die Verwendung torch.save(model.state_dict(), f) da Sie die Erstellung des Modells übernehmen und Fackel das Laden der Modellgewichte übernimmt, wodurch mögliche Probleme beseitigt werden. Bezug: Discussion.pytorch.org/t/saving-torch-models/838/4

    – Dawei Yang

    29. März 2017 um 2:01 Uhr


  • Scheint, als hätte PyTorch dies in ihrem etwas expliziter angesprochen Abschnitt Tutorials– Es gibt dort viele gute Informationen, die in den Antworten hier nicht aufgeführt sind, einschließlich des Speicherns von mehr als einem Modell gleichzeitig und Warmstartmodellen.

    – whlteXbread

    24. März 2019 um 21:55 Uhr

  • was ist falsch an der Verwendung pickle?

    – Charly Parker

    13. Juli 2020 um 18:23 Uhr

  • @CharlieParker Torch.save basiert auf Pickle. Folgendes ist aus dem oben verlinkten Tutorial: “[torch.save] speichert das gesamte Modul mit dem pickle-Modul von Python. Der Nachteil dieses Ansatzes besteht darin, dass die serialisierten Daten an die spezifischen Klassen und die genaue Verzeichnisstruktur gebunden sind, die beim Speichern des Modells verwendet wird. Der Grund dafür ist, dass Pickle die Modellklasse selbst nicht speichert. Stattdessen speichert es einen Pfad zu der Datei, die die Klasse enthält, die während der Ladezeit verwendet wird. Aus diesem Grund kann Ihr Code auf verschiedene Weise beschädigt werden, wenn er in anderen Projekten oder nach Umgestaltungen verwendet wird.”

    – David Müller

    14. Juli 2020 um 9:56 Uhr


Benutzer-Avatar
dontloo

Gefunden diese Seite auf ihrem github repo:

Empfohlene Vorgehensweise zum Speichern eines Modells

Es gibt zwei Hauptansätze zum Serialisieren und Wiederherstellen eines Modells.

Die erste (empfohlene) speichert und lädt nur die Modellparameter:

torch.save(the_model.state_dict(), PATH)

Dann später:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

Die zweite speichert und lädt das gesamte Modell:

torch.save(the_model, PATH)

Dann später:

the_model = torch.load(PATH)

In diesem Fall sind die serialisierten Daten jedoch an die spezifischen Klassen und die genaue verwendete Verzeichnisstruktur gebunden, sodass sie bei der Verwendung in anderen Projekten oder nach einigen ernsthaften Umgestaltungen auf verschiedene Weise beschädigt werden können.


Siehe auch: Speichern und laden Sie das Modell Abschnitt aus den offiziellen PyTorch-Tutorials.

  • Laut @smth Discussion.pytorch.org/t/saving-and-loading-a-model-in-pytorch/… Das Modell wird standardmäßig neu geladen, um das Modell zu trainieren. Sie müssen also the_model.eval() nach dem Laden manuell aufrufen, wenn Sie es für die Inferenz laden und das Training nicht fortsetzen.

    – WillZ

    15. Juli 2018 um 22:30 Uhr


  • Die zweite Methode gibt unter Windows 10 den Fehler stackoverflow.com/questions/53798009/…. konnte ihn nicht lösen

    – Gülzar

    16. Dezember 2018 um 14:29 Uhr

  • Gibt es eine Möglichkeit zum Speichern, ohne dass ein Zugriff auf die Modellklasse erforderlich ist?

    – Michael D

    11. Dezember 2019 um 14:16 Uhr

  • Wie behalten Sie bei diesem Ansatz den Überblick über die *args und **kwargs, die Sie für den Lastfall übergeben müssen?

    – Mariano Kamp

    9. April 2020 um 14:40 Uhr

  • Hallo Leute, könnte mir jemand sagen, was die Erweiterung für die Modell-Diktatdatei (.pth?) und die Erweiterung für die gesamte Modelldatei (.pkl) ist?? Hab ich recht?

    – Franva

    9. August 2021 um 15:35 Uhr

Benutzer-Avatar
Jadiel de Armas

Es hängt davon ab, was Sie tun möchten.

Fall Nr. 1: Speichern Sie das Modell, um es selbst für die Inferenz zu verwenden: Sie speichern das Modell, stellen es wieder her und ändern das Modell dann in den Evaluierungsmodus. Dies geschieht, weil Sie normalerweise haben BatchNorm und Dropout Layer, die sich standardmäßig im Zugmodus bei der Konstruktion befinden:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Fall Nr. 2: Speichern Sie das Modell, um das Training später fortzusetzen: Wenn Sie das zu speichernde Modell weiter trainieren müssen, müssen Sie mehr als nur das Modell speichern. Sie müssen auch den Zustand des Optimierers, Epochen, Punktzahl usw. speichern. Sie würden es so machen:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

Um das Training wieder aufzunehmen, würden Sie Folgendes tun: state = torch.load(filepath)und dann, um den Zustand jedes einzelnen Objekts wiederherzustellen, etwa so:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Da Sie das Training wieder aufnehmen, UNTERLASSEN SIE Anruf model.eval() sobald Sie die Zustände beim Laden wiederherstellen.

Fall Nr. 3: Modell, das von jemand anderem ohne Zugriff auf Ihren Code verwendet werden soll: In Tensorflow können Sie eine erstellen .pb Datei, die sowohl die Architektur als auch die Gewichtungen des Modells definiert. Dies ist sehr praktisch, besonders bei der Verwendung Tensorflow serve. Der äquivalente Weg, dies in Pytorch zu tun, wäre:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

Dieser Weg ist immer noch nicht kugelsicher und da Pytorch immer noch viele Änderungen durchmacht, würde ich es nicht empfehlen.

  • Gibt es eine empfohlene Dateiendung für die 3 Fälle? Oder ist es immer .pth?

    – Verena Haunschmid

    12. Februar 2019 um 8:23 Uhr


  • Im Fall Nr. 3 torch.load gibt nur ein OrderedDict zurück. Wie erhalten Sie das Modell, um Vorhersagen zu treffen?

    – Alber8295

    12. Februar 2019 um 10:44 Uhr

  • Hallo, darf ich wissen, wie der erwähnte “Fall Nr. 2: Modell speichern, um das Training später fortzusetzen” durchgeführt wird? Ich habe es geschafft, den Checkpoint in das Modell zu laden, dann kann ich das Modell nicht ausführen oder fortsetzen, wie “model.to (device) model = train_model_epoch (model, category, optimizer, sched, epochs)”.

    – dnez

    8. März 2019 um 7:16 Uhr

  • Hallo, für Fall eins, der für die Inferenz bestimmt ist, sagen Sie im offiziellen Pytorch-Dokument, dass der Optimierer state_dict für die Inferenz oder den Abschluss des Trainings gespeichert werden muss. „Wenn Sie einen allgemeinen Prüfpunkt speichern, der entweder für die Inferenz oder die Wiederaufnahme des Trainings verwendet werden soll, müssen Sie mehr als nur das state_dict des Modells speichern. Es ist wichtig, auch das state_dict des Optimierers zu speichern, da dieses Puffer und Parameter enthält, die während der Modelleisenbahnen aktualisiert werden .”

    – Mohammed Awny

    21. September 2019 um 13:09 Uhr

  • Im Fall Nr. 3 sollte die Modellklasse irgendwo definiert werden.

    – Michael D

    11. Dezember 2019 um 13:41 Uhr

Benutzer-Avatar
Prosti

Das Essiggurke Die Python-Bibliothek implementiert binäre Protokolle zum Serialisieren und Deserialisieren eines Python-Objekts.

Wenn du import torch (oder wenn Sie PyTorch verwenden) wird es import pickle für dich und du brauchst nicht anzurufen pickle.dump() und pickle.load() direkt, das sind die Methoden zum Speichern und Laden des Objekts.

In der Tat, torch.save() und torch.load() wird wickeln pickle.dump() und pickle.load() für dich.

EIN state_dict Die andere erwähnte Antwort verdient nur ein paar weitere Anmerkungen.

Was state_dict haben wir in PyTorch? Es sind eigentlich zwei state_dicts.

Das PyTorch-Modell ist torch.nn.Module was hat model.parameters() aufrufen, um lernbare Parameter (w und b) zu erhalten. Diese lernbaren Parameter, die einmal zufällig festgelegt wurden, werden im Laufe der Zeit aktualisiert, wenn wir lernen. Lernbare Parameter sind die ersten state_dict.

Der Zweite state_dict ist der Zustand des Optimierers dict. Sie erinnern sich, dass der Optimierer verwendet wird, um unsere lernbaren Parameter zu verbessern. Aber der Optimierer state_dict Ist repariert. Da gibt es nichts zu lernen.

Da state_dict Objekte sind Python-Wörterbücher, sie können einfach gespeichert, aktualisiert, geändert und wiederhergestellt werden, was den PyTorch-Modellen und -Optimierern eine große Modularität verleiht.

Lassen Sie uns ein supereinfaches Modell erstellen, um dies zu erklären:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Dieser Code gibt Folgendes aus:

Model's state_dict:
weight      torch.Size([2, 5])
bias      torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state      {}
param_groups      [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

Beachten Sie, dass dies ein Minimalmodell ist. Sie können versuchen, einen Stapel von sequentiellen hinzuzufügen

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

Beachten Sie, dass nur Schichten mit lernbaren Parametern (Faltungsschichten, lineare Schichten usw.) und registrierte Puffer (Batchnorm-Schichten) Einträge in den Modellen haben state_dict.

Nicht lernbare Dinge gehören zum Optimizer-Objekt state_dictdie Informationen über den Zustand des Optimierers sowie die verwendeten Hyperparameter enthält.

Der Rest der Geschichte ist derselbe; in der Inferenzphase (dies ist eine Phase, in der wir das Modell nach dem Training verwenden) zum Vorhersagen; wir prognostizieren basierend auf den Parametern, die wir gelernt haben. Für die Inferenz müssen wir also nur die Parameter speichern model.state_dict().

torch.save(model.state_dict(), filepath)

Und später zu verwenden model.load_state_dict(torch.load(filepath)) model.eval()

Hinweis: Vergessen Sie nicht die letzte Zeile model.eval() Dies ist nach dem Laden des Modells entscheidend.

Versuchen Sie auch nicht zu speichern torch.save(model.parameters(), filepath). Das model.parameters() ist nur das Generatorobjekt.

Auf der anderen Seite, torch.save(model, filepath) speichert das Modellobjekt selbst, aber denken Sie daran, dass das Modell nicht über den Optimierer verfügt state_dict. Überprüfen Sie die andere ausgezeichnete Antwort von @ Jadiel de Armas, um das Zustandsdikt des Optimierers zu speichern.

  • Obwohl es keine einfache Lösung ist, wird die Essenz des Problems gründlich analysiert! Stimme zu.

    – Jason Jung

    2. Juni 2020 um 14:58 Uhr


Benutzer-Avatar
harsch

Eine gängige PyTorch-Konvention besteht darin, Modelle entweder mit der Dateierweiterung .pt oder .pth zu speichern.

Gesamtes Modell speichern/laden

Speichern:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

Belastung:

(Modellklasse muss irgendwo definiert werden)

model.load_state_dict(torch.load(PATH))
model.eval()

Benutzer-Avatar
Joy Mazumder

Wenn Sie das Modell speichern und das Training später fortsetzen möchten:

Einzel-GPU:
Speichern:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath="checkpoint.t7"
torch.save(state,savepath)

Belastung:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

Mehrere GPUs:
Speichern

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath="checkpoint.t7"
torch.save(state,savepath)

Belastung:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU

Benutzer-Avatar
Jakob

Lokal speichern

Wie Sie Ihr Modell speichern, hängt davon ab, wie Sie in Zukunft darauf zugreifen möchten. Wenn Sie eine neue Instanz von aufrufen können model Klasse, dann müssen Sie nur noch die Gewichte des Modells speichern/laden model.state_dict():

# Save:
torch.save(old_model.state_dict(), PATH)

# Load:
new_model = TheModelClass(*args, **kwargs)
new_model.load_state_dict(torch.load(PATH))

Wenn Sie dies aus irgendeinem Grund nicht können (oder die einfachere Syntax bevorzugen), können Sie das gesamte Modell (eigentlich ein Verweis auf die Datei(en), die das Modell definieren, zusammen mit seinem state_dict) mit speichern torch.save():

# Save:
torch.save(old_model, PATH)

# Load:
new_model = torch.load(PATH)

Da dies jedoch ein Verweis auf den Speicherort der Dateien ist, die die Modellklasse definieren, ist dieser Code nicht portierbar, es sei denn, diese Dateien werden ebenfalls in dieselbe Verzeichnisstruktur portiert.

Speichern in der Cloud – TorchHub

Wenn Sie möchten, dass Ihr Modell portabel ist, können Sie es einfach mit importieren torch.hub. Wenn Sie eine entsprechend definierte hinzufügen hubconf.py Datei in ein Github-Repo, kann dies einfach aus PyTorch heraus aufgerufen werden, damit Benutzer Ihr Modell mit/ohne Gewichte laden können:

hubconf.py (github.com/repo_owner/repo_name)

dependencies = ['torch']
from my_module import mymodel as _mymodel

def mymodel(pretrained=False, **kwargs):
    return _mymodel(pretrained=pretrained, **kwargs)

Modell laden:

new_model = torch.hub.load('repo_owner/repo_name', 'mymodel')
new_model_pretrained = torch.hub.load('repo_owner/repo_name', 'mymodel', pretrained=True)

Benutzer-Avatar
Christian__

pip installiere pytorch-lightning

Stellen Sie sicher, dass Ihr übergeordnetes Modell pl.LightningModule anstelle von nn.Module verwendet

Speichern und Laden von Kontrollpunkten mit Pytorch-Blitz

import pytorch_lightning as pl

model = MyLightningModule(hparams)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt")

1085240cookie-checkWie speichere ich ein trainiertes Modell in PyTorch?

This website is using cookies to improve the user-friendliness. You agree by using the website further.

Privacy policy