티스토리 뷰
반응형
pytorch에서 모델을 로딩할 때, state_dict_path로 .pth 파일로 불러오고 load_state_dict를 하면 에러가 뜨는 현상이 발생했다.
state_dict = torch.load(model_state_dict_path)
self.model.load_state_dict(new_state_dict)
다른 모델 구조라면 어쩔 수 없이 strict를 써야 되겠지만, 심지어는 같은 모델인데도 이런 현상이 발생한다.
대충 이럴 때는 state_dict에 있는 모델 레이어의 이름과 전이학습된 모델 레이어의 이름이 달라서 그렇다.
따라서 state_dict를 살펴 본 후, 해당 이름을 제거해주면 해결된다.
class LitCustomModel(L.LightningModule):
def __init__(self,model_state_dict_path=None):
super().__init__()
self.model = CustomModel()
self.validation_step_output = []
if model_state_dict_path is not None:
state_dict = torch.load(model_state_dict_path)
new_state_dict = {} # 새로운 상태 사전을 만듭니다.
for k, v in state_dict.items():
if k.startswith('module.'): # 'module.'로 시작하는 키에 대해서만 수정
k = k.replace('module.', '') # 'module.'을 제거하여 모델 클래스와 일치시킴
new_state_dict[k] = v
self.model.load_state_dict(new_state_dict)
필자의 경우에는 모든 레이어의 model 앞에 module이라는 키가 붙어 있어서(module.model..이런 식), 그걸 제거해주는 방향으로 해결했다.
안되는 경우, 다른 방법으로 시도해보자.
반응형