Ich denke, das liegt daran, dass torque.save() auch alle Zwischenvariablen speichert, wie z. B. Zwischenausgaben für die Verwendung der Rückwärtsausbreitung. Aber Sie müssen nur die Modellparameter wie Gewicht/Bias usw. speichern. Manchmal können erstere viel größer sein als letztere.
– Dawei Yang
18. März 2017 um 17:36 Uhr
Ich habe getestet torch.save(model, f) und torch.save(model.state_dict(), f). Die gespeicherten Dateien haben dieselbe Größe. Jetzt bin ich verwirrt. Außerdem fand ich die Verwendung von pickle zum Speichern von model.state_dict() extrem langsam. Ich denke, der beste Weg ist die Verwendung torch.save(model.state_dict(), f) da Sie die Erstellung des Modells übernehmen und Fackel das Laden der Modellgewichte übernimmt, wodurch mögliche Probleme beseitigt werden. Bezug: Discussion.pytorch.org/t/saving-torch-models/838/4
– Dawei Yang
29. März 2017 um 2:01 Uhr
Scheint, als hätte PyTorch dies in ihrem etwas expliziter angesprochen Abschnitt Tutorials– Es gibt dort viele gute Informationen, die in den Antworten hier nicht aufgeführt sind, einschließlich des Speicherns von mehr als einem Modell gleichzeitig und Warmstartmodellen.
– whlteXbread
24. März 2019 um 21:55 Uhr
was ist falsch an der Verwendung pickle?
– Charly Parker
13. Juli 2020 um 18:23 Uhr
@CharlieParker Torch.save basiert auf Pickle. Folgendes ist aus dem oben verlinkten Tutorial: “[torch.save] speichert das gesamte Modul mit dem pickle-Modul von Python. Der Nachteil dieses Ansatzes besteht darin, dass die serialisierten Daten an die spezifischen Klassen und die genaue Verzeichnisstruktur gebunden sind, die beim Speichern des Modells verwendet wird. Der Grund dafür ist, dass Pickle die Modellklasse selbst nicht speichert. Stattdessen speichert es einen Pfad zu der Datei, die die Klasse enthält, die während der Ladezeit verwendet wird. Aus diesem Grund kann Ihr Code auf verschiedene Weise beschädigt werden, wenn er in anderen Projekten oder nach Umgestaltungen verwendet wird.”
In diesem Fall sind die serialisierten Daten jedoch an die spezifischen Klassen und die genaue verwendete Verzeichnisstruktur gebunden, sodass sie bei der Verwendung in anderen Projekten oder nach einigen ernsthaften Umgestaltungen auf verschiedene Weise beschädigt werden können.
Laut @smth Discussion.pytorch.org/t/saving-and-loading-a-model-in-pytorch/… Das Modell wird standardmäßig neu geladen, um das Modell zu trainieren. Sie müssen also the_model.eval() nach dem Laden manuell aufrufen, wenn Sie es für die Inferenz laden und das Training nicht fortsetzen.
– WillZ
15. Juli 2018 um 22:30 Uhr
Die zweite Methode gibt unter Windows 10 den Fehler stackoverflow.com/questions/53798009/…. konnte ihn nicht lösen
– Gülzar
16. Dezember 2018 um 14:29 Uhr
Gibt es eine Möglichkeit zum Speichern, ohne dass ein Zugriff auf die Modellklasse erforderlich ist?
– Michael D
11. Dezember 2019 um 14:16 Uhr
Wie behalten Sie bei diesem Ansatz den Überblick über die *args und **kwargs, die Sie für den Lastfall übergeben müssen?
– Mariano Kamp
9. April 2020 um 14:40 Uhr
Hallo Leute, könnte mir jemand sagen, was die Erweiterung für die Modell-Diktatdatei (.pth?) und die Erweiterung für die gesamte Modelldatei (.pkl) ist?? Hab ich recht?
– Franva
9. August 2021 um 15:35 Uhr
Jadiel de Armas
Es hängt davon ab, was Sie tun möchten.
Fall Nr. 1: Speichern Sie das Modell, um es selbst für die Inferenz zu verwenden: Sie speichern das Modell, stellen es wieder her und ändern das Modell dann in den Evaluierungsmodus. Dies geschieht, weil Sie normalerweise haben BatchNorm und Dropout Layer, die sich standardmäßig im Zugmodus bei der Konstruktion befinden:
torch.save(model.state_dict(), filepath)
#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()
Fall Nr. 2: Speichern Sie das Modell, um das Training später fortzusetzen: Wenn Sie das zu speichernde Modell weiter trainieren müssen, müssen Sie mehr als nur das Modell speichern. Sie müssen auch den Zustand des Optimierers, Epochen, Punktzahl usw. speichern. Sie würden es so machen:
Um das Training wieder aufzunehmen, würden Sie Folgendes tun: state = torch.load(filepath)und dann, um den Zustand jedes einzelnen Objekts wiederherzustellen, etwa so:
Da Sie das Training wieder aufnehmen, UNTERLASSEN SIE Anruf model.eval() sobald Sie die Zustände beim Laden wiederherstellen.
Fall Nr. 3: Modell, das von jemand anderem ohne Zugriff auf Ihren Code verwendet werden soll: In Tensorflow können Sie eine erstellen .pb Datei, die sowohl die Architektur als auch die Gewichtungen des Modells definiert. Dies ist sehr praktisch, besonders bei der Verwendung Tensorflow serve. Der äquivalente Weg, dies in Pytorch zu tun, wäre:
torch.save(model, filepath)
# Then later:
model = torch.load(filepath)
Dieser Weg ist immer noch nicht kugelsicher und da Pytorch immer noch viele Änderungen durchmacht, würde ich es nicht empfehlen.
Gibt es eine empfohlene Dateiendung für die 3 Fälle? Oder ist es immer .pth?
– Verena Haunschmid
12. Februar 2019 um 8:23 Uhr
Im Fall Nr. 3 torch.load gibt nur ein OrderedDict zurück. Wie erhalten Sie das Modell, um Vorhersagen zu treffen?
– Alber8295
12. Februar 2019 um 10:44 Uhr
Hallo, darf ich wissen, wie der erwähnte “Fall Nr. 2: Modell speichern, um das Training später fortzusetzen” durchgeführt wird? Ich habe es geschafft, den Checkpoint in das Modell zu laden, dann kann ich das Modell nicht ausführen oder fortsetzen, wie “model.to (device) model = train_model_epoch (model, category, optimizer, sched, epochs)”.
– dnez
8. März 2019 um 7:16 Uhr
Hallo, für Fall eins, der für die Inferenz bestimmt ist, sagen Sie im offiziellen Pytorch-Dokument, dass der Optimierer state_dict für die Inferenz oder den Abschluss des Trainings gespeichert werden muss. „Wenn Sie einen allgemeinen Prüfpunkt speichern, der entweder für die Inferenz oder die Wiederaufnahme des Trainings verwendet werden soll, müssen Sie mehr als nur das state_dict des Modells speichern. Es ist wichtig, auch das state_dict des Optimierers zu speichern, da dieses Puffer und Parameter enthält, die während der Modelleisenbahnen aktualisiert werden .”
– Mohammed Awny
21. September 2019 um 13:09 Uhr
Im Fall Nr. 3 sollte die Modellklasse irgendwo definiert werden.
– Michael D
11. Dezember 2019 um 13:41 Uhr
Prosti
Das Essiggurke Die Python-Bibliothek implementiert binäre Protokolle zum Serialisieren und Deserialisieren eines Python-Objekts.
Wenn du import torch (oder wenn Sie PyTorch verwenden) wird es import pickle für dich und du brauchst nicht anzurufen pickle.dump() und pickle.load() direkt, das sind die Methoden zum Speichern und Laden des Objekts.
In der Tat, torch.save() und torch.load() wird wickeln pickle.dump() und pickle.load() für dich.
EIN state_dict Die andere erwähnte Antwort verdient nur ein paar weitere Anmerkungen.
Was state_dict haben wir in PyTorch? Es sind eigentlich zwei state_dicts.
Das PyTorch-Modell ist torch.nn.Module was hat model.parameters() aufrufen, um lernbare Parameter (w und b) zu erhalten. Diese lernbaren Parameter, die einmal zufällig festgelegt wurden, werden im Laufe der Zeit aktualisiert, wenn wir lernen. Lernbare Parameter sind die ersten state_dict.
Der Zweite state_dict ist der Zustand des Optimierers dict. Sie erinnern sich, dass der Optimierer verwendet wird, um unsere lernbaren Parameter zu verbessern. Aber der Optimierer state_dict Ist repariert. Da gibt es nichts zu lernen.
Da state_dict Objekte sind Python-Wörterbücher, sie können einfach gespeichert, aktualisiert, geändert und wiederhergestellt werden, was den PyTorch-Modellen und -Optimierern eine große Modularität verleiht.
Lassen Sie uns ein supereinfaches Modell erstellen, um dies zu erklären:
import torch
import torch.optim as optim
model = torch.nn.Linear(5, 2)
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("Model weight:")
print(model.weight)
print("Model bias:")
print(model.bias)
print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
Beachten Sie, dass dies ein Minimalmodell ist. Sie können versuchen, einen Stapel von sequentiellen hinzuzufügen
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.Conv2d(A, B, C)
torch.nn.Linear(H, D_out),
)
Beachten Sie, dass nur Schichten mit lernbaren Parametern (Faltungsschichten, lineare Schichten usw.) und registrierte Puffer (Batchnorm-Schichten) Einträge in den Modellen haben state_dict.
Nicht lernbare Dinge gehören zum Optimizer-Objekt state_dictdie Informationen über den Zustand des Optimierers sowie die verwendeten Hyperparameter enthält.
Der Rest der Geschichte ist derselbe; in der Inferenzphase (dies ist eine Phase, in der wir das Modell nach dem Training verwenden) zum Vorhersagen; wir prognostizieren basierend auf den Parametern, die wir gelernt haben. Für die Inferenz müssen wir also nur die Parameter speichern model.state_dict().
torch.save(model.state_dict(), filepath)
Und später zu verwenden model.load_state_dict(torch.load(filepath)) model.eval()
Hinweis: Vergessen Sie nicht die letzte Zeile model.eval() Dies ist nach dem Laden des Modells entscheidend.
Versuchen Sie auch nicht zu speichern torch.save(model.parameters(), filepath). Das model.parameters() ist nur das Generatorobjekt.
Auf der anderen Seite, torch.save(model, filepath) speichert das Modellobjekt selbst, aber denken Sie daran, dass das Modell nicht über den Optimierer verfügt state_dict. Überprüfen Sie die andere ausgezeichnete Antwort von @ Jadiel de Armas, um das Zustandsdikt des Optimierers zu speichern.
Obwohl es keine einfache Lösung ist, wird die Essenz des Problems gründlich analysiert! Stimme zu.
– Jason Jung
2. Juni 2020 um 14:58 Uhr
harsch
Eine gängige PyTorch-Konvention besteht darin, Modelle entweder mit der Dateierweiterung .pt oder .pth zu speichern.
checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']
#Don't call DataParallel before loading the model otherwise you will get an error
model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU
Wie Sie Ihr Modell speichern, hängt davon ab, wie Sie in Zukunft darauf zugreifen möchten. Wenn Sie eine neue Instanz von aufrufen können model Klasse, dann müssen Sie nur noch die Gewichte des Modells speichern/laden model.state_dict():
Wenn Sie dies aus irgendeinem Grund nicht können (oder die einfachere Syntax bevorzugen), können Sie das gesamte Modell (eigentlich ein Verweis auf die Datei(en), die das Modell definieren, zusammen mit seinem state_dict) mit speichern torch.save():
Da dies jedoch ein Verweis auf den Speicherort der Dateien ist, die die Modellklasse definieren, ist dieser Code nicht portierbar, es sei denn, diese Dateien werden ebenfalls in dieselbe Verzeichnisstruktur portiert.
Wenn Sie möchten, dass Ihr Modell portabel ist, können Sie es einfach mit importieren torch.hub. Wenn Sie eine entsprechend definierte hinzufügen hubconf.py Datei in ein Github-Repo, kann dies einfach aus PyTorch heraus aufgerufen werden, damit Benutzer Ihr Modell mit/ohne Gewichte laden können:
Ich denke, das liegt daran, dass torque.save() auch alle Zwischenvariablen speichert, wie z. B. Zwischenausgaben für die Verwendung der Rückwärtsausbreitung. Aber Sie müssen nur die Modellparameter wie Gewicht/Bias usw. speichern. Manchmal können erstere viel größer sein als letztere.
– Dawei Yang
18. März 2017 um 17:36 Uhr
Ich habe getestet
torch.save(model, f)
undtorch.save(model.state_dict(), f)
. Die gespeicherten Dateien haben dieselbe Größe. Jetzt bin ich verwirrt. Außerdem fand ich die Verwendung von pickle zum Speichern von model.state_dict() extrem langsam. Ich denke, der beste Weg ist die Verwendungtorch.save(model.state_dict(), f)
da Sie die Erstellung des Modells übernehmen und Fackel das Laden der Modellgewichte übernimmt, wodurch mögliche Probleme beseitigt werden. Bezug: Discussion.pytorch.org/t/saving-torch-models/838/4– Dawei Yang
29. März 2017 um 2:01 Uhr
Scheint, als hätte PyTorch dies in ihrem etwas expliziter angesprochen Abschnitt Tutorials– Es gibt dort viele gute Informationen, die in den Antworten hier nicht aufgeführt sind, einschließlich des Speicherns von mehr als einem Modell gleichzeitig und Warmstartmodellen.
– whlteXbread
24. März 2019 um 21:55 Uhr
was ist falsch an der Verwendung
pickle
?– Charly Parker
13. Juli 2020 um 18:23 Uhr
@CharlieParker Torch.save basiert auf Pickle. Folgendes ist aus dem oben verlinkten Tutorial: “[torch.save] speichert das gesamte Modul mit dem pickle-Modul von Python. Der Nachteil dieses Ansatzes besteht darin, dass die serialisierten Daten an die spezifischen Klassen und die genaue Verzeichnisstruktur gebunden sind, die beim Speichern des Modells verwendet wird. Der Grund dafür ist, dass Pickle die Modellklasse selbst nicht speichert. Stattdessen speichert es einen Pfad zu der Datei, die die Klasse enthält, die während der Ladezeit verwendet wird. Aus diesem Grund kann Ihr Code auf verschiedene Weise beschädigt werden, wenn er in anderen Projekten oder nach Umgestaltungen verwendet wird.”
– David Müller
14. Juli 2020 um 9:56 Uhr