728x90
반응형
Hook
패키지 중간에 자기가 원하는 코드 끼워넣을 수 있는 부분 정도로 이해하면 될 듯하다! (register hook)
- hook: 일반적으로 hook은 프로그램, 혹은 특정 함수 실행 후에 걸어놓는 경우를 일컬음.
- pre-hook: 프로그램 실행 전에 걸어놓는 hook
- forward hook
- register_forward_hook: forward 호출 후에 forward output 계산 후 걸어두는 hook
input은 positional arguments만 담을 수 있으며 (index 같은?) keyword arguments 등은 담을 수 없고, forward 서만 적용이 된다.# register_forward_hook should have the following signature hook(module, input, output) -> None or modified output
hook은 forward output 수정 가능, input 또한 수정 가능하지만 forward에는 영향 없음. - register_forward_pre_hook: forward 호출 전에 걸어두는 hook
# The hook will be called every time before forward() is invoked. It should have the following signature: hook(module, input) -> None or modified input
input은 positional arguments만 담을 수 있으며 (index 같은?) keyword arguments 등은 담을 수 없고, forward 에서만 적용이 된다.
여기서 hook은 input을 수정할 수 있고, 출력은 튜플 혹은 single modified value를 리턴함. single value여도 tuple로 wrapping되어서 나간다는 점 참고.
- register_forward_hook: forward 호출 후에 forward output 계산 후 걸어두는 hook
- backward hook
- register_full_backward_hook (module에 적용)
module input에 대한 gradient가 계산될 때마다 hook이 호출됨.
grad_input과 grad_output은 각각 input과 output에 대한 gradient를 포함하고 있는 튜플. hook은 hook의 인자, 즉 grad_input과 grad_output을 수정할 수는 없지만 새로운 그래디언트를 리턴해서 grad_input 대신 사용할 수 있음 (이후 computation에서 사용 가능). 역시 positional arguments만 허용 가능하며 keyword arguments는 허용되지 않음.# The hook should have the following signature: hook(module, grad_input, grad_output) -> tuple(Tensor) or None
또한 여기서 input 또는 output을 직접적으로 수정하는 건 error 발생. - register_hook (in Tensor)
→ Tensor의 경우에는 only backward hook (Tensor의 gradient가 계산될 때마다 hook 호출됨. hook은 gradient를 바꿀 수는 없지만 새로운 gradient를 생성 가능하며, 기존 grad 대신 사용 가능함.)
- register_full_backward_hook (module에 적용)
hook이 어디에 사용이 될까?
- 디버깅 (레이어 shape, output 등을 출력하는 hook을 넣어주는 방식)
- feature extraction
""" author: Frank Odom https://medium.com/the-dl/how-to-use-pytorch-hooks-5041d777f904 """ from typing import Dict, Iterable, Callable class FeatureExtractor(nn.Module): def __init__(self, model: nn.Module, layers: Iterable[str]): super().__init__() self.model = model self.layers = layers self._features = {layer: torch.empty(0) for layer in layers} for layer_id in layers: layer = dict([*self.model.named_modules()])[layer_id] layer.register_forward_hook(self.save_outputs_hook(layer_id)) def save_outputs_hook(self, layer_id: str) -> Callable: def fn(_, __, output): self._features[layer_id] = output return fn def forward(self, x: Tensor) -> Dict[str, Tensor]: _ = self.model(x) return self._features
- gradient clipping (이 경우에는 torch.Tensor.register_hook)
- visualising activation (forward hook)
""" author: Ayoosh Kathuria https://blog.paperspace.com/pytorch-hooks-gradient-clipping-debugging/ """ import torch import torch.nn as nn class myNet(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3,10,2, stride = 2) self.relu = nn.ReLU() self.flatten = lambda x: x.view(-1) self.fc1 = nn.Linear(160,5) self.seq = nn.Sequential(nn.Linear(5,3), nn.Linear(3,2)) def forward(self, x): x = self.relu(self.conv(x)) x = self.fc1(self.flatten(x)) x = self.seq(x) net = myNet() visualisation = {} def hook_fn(m, i, o): visualisation[m] = o def get_all_layers(net): for name, layer in net._modules.items(): #If it is a sequential, don't register a hook on it # but recursively register hook on all it's module children if isinstance(layer, nn.Sequential): get_all_layers(layer) else: # it's a non sequential. Register a hook layer.register_forward_hook(hook_fn) get_all_layers(net) out = net(torch.randn(1,3,8,8)) # Just to check whether we got all layers visualisation.keys() #output includes sequential layers
728x90
반응형
'딥러닝&머신러닝 > 파이토치 기본 문법' 카테고리의 다른 글
[PYTORCH] POINTWISE OPS (0) | 2023.01.10 |
---|---|
torch.nn.functional.interpolate (0) | 2023.01.09 |
[PyTorch] Tensor 합치기: cat(), stack() (0) | 2023.01.09 |
[Pytorch] squeeze와 unsqueeze 함수 사용법 정리 (0) | 2023.01.09 |
파이토치 nn 모듈 (0) | 2023.01.09 |
댓글