[pytorch] Pytorch에서 학습한 모델 저장 및 불러오기

Pytorch에서 학습한 모델 저장 및 불러오기

Pytorch 모델을 저장하고, 불러와 보기

Pytorch 에서 모델의 가중치를 저장하기 위해선 3가지 함수만 알면 충분 합니다.

  • torch.save객체를 디스크에 저장합니다. pickle 모듈을 이용하여 객체를 직렬화 하며, 이 함수를 사용하여 모든 종류의 모델, Tensor 등을 저장할 수 있습니다.
  • torch.loadpickle 모듈을 이용하여 객체를 역직렬화하여 메모리에 할당합니다.
  • torch.nn.Module.load_state_dict: 역직렬화된 state_dict를 사용, 모델의 매개변수들을 불러옵니다. state_dict는 간단히 말해 각 체층을 매개변수 Tensor로 매핑한 Python 사전(dict) 객체입니다.

간단한 DNN 모델을 통해 연습 해 보겠습니다.

Code

import torch
import torch.nn as nn

x_data = torch.Tensor([
    [0, 0],
    [1, 0],
    [1, 1],
    [0, 0],
    [0, 0],
    [0, 1]
])

y_data = torch.LongTensor([
    0,  # etc
    1,  # mammal
    2,  # birds
    0,
    0,
    2
])

class DNN(nn.Module):
    def __init__(self):
        super(DNN, self).__init__()
        self.w1 = nn.Linear(2, 10)
        self.bias1 = torch.zeros([10])

        self.w2 = nn.Linear(10, 3)
        self.bias2 = torch.zeros([3])
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=0)

    def forward(self, x):
        y = self.w1(x) + self.bias1
        y = self.relu(y)

        y = self.w2(y) + self.bias2
        return y

model = DNN()

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

for epoch in range(1000):
    output = model(x_data)

    loss = criterion(output, y_data)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print("progress:", epoch, "loss=", loss.item())

torch.save(object, path)

전체 모델을 저장하거나, 모델의 state_dict를 저장 할 때 사용합니다.

  • object: 저장할 모델 객체
  • path: 저장할 위치 + 파일명
PATH = './weights/'

torch.save(model, PATH + 'model.pt')  # 전체 모델 저장
torch.save(model.state_dict(), PATH + 'model_state_dict.pt')  # 모델 객체의 state_dict 저장
torch.save({
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict()
}, PATH + 'all.tar')  # 여러 가지 값 저장, 학습 중 진행 상황 저장을 위해 epoch, loss 값 등 일반 scalar값 저장 가능

torch.load(path)

전체 모델을 불러오거나, 모델의 state_dict를 불러 올 때 사용합니다.

  • path: 불러올 위치 + 파일명

torch.nn.Module.load_state_dict(dict):

state_dict를 이용하여, 모델 객체 내의 매개 변수 값을 초기화 합니다.

  • dict: 불러올 매개 변수 값들이 담겨있는 state_dict 객체
model = torch.load(PATH + 'model.pt')  # 전체 모델을 통째로 불러옴, 클래스 선언 필수
model.load_state_dict(torch.load(PATH + 'model_state_dict.pt'))  # state_dict를 불러 온 후, 모델에 저장

checkpoint = torch.load(PATH + 'all.tar')   # dict 불러오기
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])

모델을 불러 온 이후에는 이 모델을 학습 할 껀지, 사용 할 껀지에 따라 각각 model.train()model.eval() 둘 중에 하나를 사용 하면 됩니다.

다른 모델의 매개변수 사용하기

모델의 매개변수의 일부만 불러 사용하는 것은 전이학습을 이용할 때 자주 사용합니다. state_dict의 일부만 불러오거나, 적재하려는 모델보다 더 많은 키를 갖고 있는 state_dict를 불러 올때는, load_state_dict() 함수의 파라미터에 strict=False를 입력 해 주면 됩니다.

경축! 아무것도 안하여 에스천사게임즈가 새로운 모습으로 재오픈 하였습니다.
어린이용이며, 설치가 필요없는 브라우저 게임입니다.
https://s1004games.com

주의: 위 코드에서는 구현 하지 않은 클래스 입니다.

torch.save(modelA.state_dict(), PATH)  # 저장하기

modelB = TheModelBClass(*args, **kwargs)  # 불러오기
modelB.load_state_dict(torch.load(PATH), strict=False)

GPU, CPU간 모델 불러오기

GPU 에서 학습 한 모델과, CPU 에서 학습 한 모델 간 저장하는 방법은 같지만, 케이스에 따라 불러오는 과정이 다릅니다.

저장하는 방법은 다음과 같습니다.

torch.save(model.state_dict(), PATH + 'model.pt')

GPU에서 저장, CPU에서 불러오기

torch.load() 함수의 map_location 인자에 torch.device('cpu') 를 전달 함으로써, 모델을 동적으로 CPU 장치에 할당합니다.

device = torch.device('cpu')
model = DNN()
model.load_state_dict(torch.load(PATH, map_location=device))

GPU에서 저장, GPU에서 불러오기

torch.load() 로 초기화 한 모델의 model.to(torch.device('cuda')) 를 호출하여, CUDA Tensor로 내부 매개변수를 형변환 해 주어야 합니다.

device = torch.device('cuda')
model = DNN()
model.load_state_dict(torch.load(PATH))
model.to(device)

CPU에서 저장, GPU에서 불러오기

torch.load() 함수의 map_location 인자에 cuda:device_id 를 전달 함으로써, 모델을 동적으로 해당 GPU 장치에 할당합니다. 그 이후에 model.to(torch.device('cuda'))를 호출, 모델 내의 Tensor를 CUDA Tensor로 변환 합니다. 모든 모델 입력에도 .to(torch.device('cuda'))를 입력하여, CUDA Tensor로 변환하여야 합니다.

device = torch.device('cuda')
model = DNN()
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # 사용할 GPU 장치 번호 선택.
model.to(device)  # CUDA Tensor 형 변환
 

[출처] https://justkode.kr/deep-learning/pytorch-save/

 

본 웹사이트는 광고를 포함하고 있습니다.
광고 클릭에서 발생하는 수익금은 모두 웹사이트 서버의 유지 및 관리, 그리고 기술 콘텐츠 향상을 위해 쓰여집니다.
번호 제목 글쓴이 날짜 조회 수
공지 오라클 기본 샘플 데이터베이스 졸리운_곰 2014.01.02 25084
공지 [SQL컨셉] 서적 "SQL컨셉"의 샘플 데이타 베이스 SAMPLE DATABASE of ORACLE 가을의 곰을... 2013.02.10 24563
공지 [G_SQL] Sample Database 가을의 곰을... 2012.05.20 25942
1042 [AutoML][AutoKeras] [OSS] AutoKeras로 자동학습(AutoML) 하기 file 졸리운_곰 2023.07.02 12
1041 [NoSQL][MongoDB] Truncate a collection 졸리운_곰 2023.06.04 12
1040 [Tensorflow 2.0] 모델 저장하고 불러오기 졸리운_곰 2023.05.21 29
» [pytorch] Pytorch에서 학습한 모델 저장 및 불러오기 졸리운_곰 2023.05.21 11
1038 [MySQL] MySQL - 테이블 만들기 file 졸리운_곰 2023.05.13 29
1037 [R library] library(XML) # install.packages("XML") 인스톨 에러 졸리운_곰 2023.05.06 18
1036 [MySQL] MySQL Strict mode 끄기/켜기 졸리운_곰 2023.05.05 21
1035 [R 데이터 분석] Titanic: Machine Learning from Disaster (타이타닉 생존 예측) file 졸리운_곰 2023.04.29 55
1034 [R 데이터 분석] R 유명한 패키지 정리 졸리운_곰 2023.04.24 76
1033 [NoSQL] MongoDB 인증 모드 (password) 설정 졸리운_곰 2023.03.26 35
1032 [MySQL] [MySQL] 테이블 구조와 데이터 복사 (Table Structure and Data Copy) 졸리운_곰 2023.03.20 24
1031 [R 데이터 분석] Shiny : 대시보드 배포하기 file 졸리운_곰 2023.03.19 76
1030 [데이터 수집 및 전처리] (놀라운) 한글 데이터 짱! AwesomeKorean_Data file 졸리운_곰 2023.03.07 37
1029 [pytorch] Using BERT with Pytorch file 졸리운_곰 2023.03.06 28
1028 [pytorch] Full NMT model from pretrained BERT file 졸리운_곰 2023.03.06 16
1027 [기계학습][딥러닝] PyTorch Hello World 졸리운_곰 2023.02.12 23
1026 [PostgreSQL] 열을 행으로 전환 쿼리 졸리운_곰 2023.01.29 44
1025 [postgreSQL] PostgreSQL 계층형 쿼리 구현 방법 졸리운_곰 2023.01.29 35
1024 [postgreSQL] ORACLE쿼리에서 postgreSQL쿼리 변환 졸리운_곰 2023.01.29 26
1023 [postgreSQL] [PostgreSQL] stored function(stored procedures) 사용하기 졸리운_곰 2023.01.23 26
대표 김성준 주소 : 경기 용인 분당수지 U타워 등록번호 : 142-07-27414
통신판매업 신고 : 제2012-용인수지-0185호 출판업 신고 : 수지구청 제 123호 개인정보보호최고책임자 : 김성준 sjkim70@stechstar.com
대표전화 : 010-4589-2193 [fax] 02-6280-1294 COPYRIGHT(C) stechstar.com ALL RIGHTS RESERVED