본문 바로가기
딥러닝&머신러닝/파이토치 기본 문법

[PyTorch] `model.eval()` 의미

by David.Ho 2023. 1. 10.
728x90
반응형

딥러닝 모델의 코드를 살펴보다 보면 Evaluation 부분에서 꼭 이런 코드가 등장한다.

def evaluation(model, criterion, ...):
    model.eval()
    criterion.eval()
    ...

무슨 의미인지 궁금해서 찾아보니, nn.Module에서 train time과 eval time에서 수행하는 다른 작업을 수행할 수 있도록 switching 하는 함수라고 한다. stackoverflow

train time과 eval time에서 다르게 동작해야 하는 대표적인 예들은

  • Dropout Layer
  • BatchNorm Layer

등등이 있다고 한다.

.eval() 함수는 evaluation 과정에서 사용하지 않아야 하는 layer들을 알아서 off 시키도록 하는 함수인 셈이다.

evaluation/validation 과정에선 보통 model.eval()과 torch.no_grad()를 함께 사용한다고 한다.

# evaluate model:
model.eval()

with torch.no_grad():
    ...
    out_data = model(data)
    ...

eval/val 작업이 끝난 후에는 잊지말고 train mode로 모델을 변경해줘야 한다. 이것은 .train() 함수를 실행시키면 된다.

# after eval/val, and in training step
model.train()

 

PyTorch 공식 문서에서 .eval()에 대한 자세한 내용을 확인할 수 있다. nn.Module.eval()

728x90
반응형

댓글