티스토리 뷰

반응형

pytorch에서 모델을 로딩할 때, state_dict_path로 .pth 파일로 불러오고 load_state_dict를 하면 에러가 뜨는 현상이 발생했다.

state_dict = torch.load(model_state_dict_path)
self.model.load_state_dict(new_state_dict)

 

다른 모델 구조라면 어쩔 수 없이 strict를 써야 되겠지만, 심지어는 같은 모델인데도 이런 현상이 발생한다.

 

대충 이럴 때는 state_dict에 있는 모델 레이어의 이름과 전이학습된 모델 레이어의 이름이 달라서 그렇다.

따라서 state_dict를 살펴 본 후, 해당 이름을 제거해주면 해결된다.

class LitCustomModel(L.LightningModule):
    def __init__(self,model_state_dict_path=None):
        super().__init__()
        self.model = CustomModel()
        self.validation_step_output = []
        
        if model_state_dict_path is not None:
            state_dict = torch.load(model_state_dict_path)
            new_state_dict = {}  # 새로운 상태 사전을 만듭니다.
            for k, v in state_dict.items():
                if k.startswith('module.'):  # 'module.'로 시작하는 키에 대해서만 수정
                    k = k.replace('module.', '')  # 'module.'을 제거하여 모델 클래스와 일치시킴
                new_state_dict[k] = v
            self.model.load_state_dict(new_state_dict)

 

필자의 경우에는 모든 레이어의 model 앞에 module이라는 키가 붙어 있어서(module.model..이런 식), 그걸 제거해주는 방향으로 해결했다.

 

안되는 경우, 다른 방법으로 시도해보자.

반응형