[PyTorch] RuntimeError: Trying to backward through the graph a second time
Introduction
파이토치를 이용해 다양한 모델을 구현하다 보면, 아래와 같은 에러 메시지를 한번 쯤은 마주할 것이다.
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time
오늘은 이 에러 메시지가 어떤 상황에서, 어떤 이유로 출력되는 지 알아보고, 그 해결 방법에 대해 간략히 설명할 예정이다.
Autograd Mechanism
기본적으로 파이토치의 autograd 메커니즘은 requires_grad=True 로 설정된 텐서들에 대해, 텐서 데이터가 생성되기까지의 모든 연산 history를 담고 있는 directed acyclic graph를 생성하고 저장한다. 이렇게 생성해 놓은 그래프는 backward() 함수에서 호출 대상 텐서에 대한, 다른 텐서의 gradient 값을 구하는 데 사용된다.
Problem Cases
위 에러는 보통 .backward() 함수를 한번 만들어진 그래프(전체 혹은 일부) 에 대해 여러 번 호출할 때 발생한다.
backward()가 한 그래프에 대해 여러 번 호출되는 상황은 다양하게 있을 수 있는데, multi-task learning을 하기 위해 여러 loss function에 대해 각각 gradient를 구하고 있는 상황이거나, 멀티스레딩 등으로 인해 앞서 만들어 놓은 그래프의 일부가 공유되어야 하거나, 혹은 뭔가 코드가 잘못 짜여서 (.....) 한번 생성된 그래프에 의도치 않게 여러 번 backward가 호출되었거나 하는 경우가 있다.
마지막 케이스의 흔한 예로는, 이미 computational graph에 참여하고 있는 Tensor를 학습 이터레이션 내에서 재사용 할 때나 (recurrent layer), 의도치 않게 requires_grad=True 로 설정된 텐서를 학습 이터레이션 바깥에서 정의해 여러 이터레이션에 걸친 그래프를 생성할 때가 있다.
그렇다면, gradient를 구하기 위한 computational graph의 일부가 여러 번의 backward 호출에 참여하는 것은 왜 문제가 되는 것일까?
보통 autograd를 이용한 파이토치의 그래프는 모델의 forward 과정 중에서 생성되며, gradient를 구한 이후에는 일반적으로 다시 접근할 필요가 없기 때문에 메모리 최적화를 위해 backward() 함수가 끝나면 intermediate 텐서 버퍼 값들과 함께 버려지는 것이 보통의 학습 사이클에서의 그래프의 역할이다.
...
for index, x in enumerate(train_loader):
...
yhat = model.forward(x)
loss = loss_ftn(yhat, y) # 그래프 생성
loss.backward() # 그래프 제거
optimizer.step()
... # 반복
...
모델 학습 시에는 학습 이터레이션 마다 그래프를 생성하고, 그래프 버퍼를 free 해 주는 과정이 반복된다
만약에 한 학습 사이클에서 두 번째로 backward를 호출하게 되면, 이미 처음 backward 과정에서 gradient 계산을 위한 그래프 내 intermediate 텐서 버퍼가 버려진 상태이기 때문에 (다시 접근할 수 없음), 일부 공유하고 있는 그래프의 intermediate 값을 필요로 하는 이후의 backward 호출이 에러를 뱉어내는 것이다.
...
loss_1.backward() # 그래프 버퍼가 제거됨
loss_2.backward() # backward에 필요한 버퍼가 존재하지 않는다
optimizer.step()
...
두번째 backward 중 에러가 출력될 것이다
만약 잘못 그래프를 생성한 것이 아니라 정말로 multi-task learning이나 멀티스레드 학습처럼 같은 연산 그래프 상에서 backward가 여러 번 호출되어야 하는 상황이라면, 아래와 같이 backward 함수에 retain_graph=True 옵션을 줘서 해결할 수 있다.
...
loss_1.backward(retain_graph=True)
loss_2.backward()
optimizer.step()
...
마지막 backward 호출에서는 retain_graph 옵션을 지워 남은 그래프 버퍼를 제거하자
Possible Solutions
이런 경우가 아닌데 억울하게 동일한 에러가 발생하는 상황이라면, 어디선가 학습 이터레이션 바깥에 정의된 requires_grad=True 텐서나 이터레이션 과정에서 의도치 않게 재사용 되는 텐서가 존재하는 지 확인해 보자.
마지막으로, 특히 이터레이션 후반 부분에서 loss의 합을 계산할 때 total_loss = total_loss + loss 와 같은 식으로 그래프에서 detach를 하지 않은 채로 텐서 합산을 수행해 에러와 마주하는 경우도 흔한데, 만약 이렇게 짜여 있다면 loss.item() 이나 loss.data[0] 와 같은 식으로 수정해 본다면 문제 해결에 도움이 될 것이다.
아래는 참고할 만한 사이트이다.
References