Frameworks/PyTorch

[PyTorch] 파이토치 프로파일링 (PyTorch profiler API)

Jonghyuk Baek 2021. 5. 24. 00:47

Introduction

지지난 글에서는, 파이썬 코드를 실행할 때 코드의 시간 성능을 프로파일러 (profiler) 를 이용해 측정하는 방법을 알아봤었다. 여기

(지지난 글이었지만 글을 수정 하면서 발행일을 업데이트 했고... 목록의 한참 위로 올라가 버렸다)

 

함수의 호출 수와 개별 / 누적 소모 시간을 한 번에 정리해 주기 때문에 어느 구간에서 프로그램의 연산 병목이 일어나고 있는지 쉽게 찾을 수 있다.

 

그렇다면 비슷한 느낌으로 파이토치 연산의 전체적인 시간 / 메모리 성능을 한 눈에 보기 쉽게 정리해 주는 기능은 없을까? 일단 위의 프로파일러는 CPU 상에서 실행되는 함수만을 트래킹 하기 때문에, 대부분 GPU를 사용하는 파이토치 모델의 성능 측정에 사용하기에는 적합하지 않다. 

 

다행히도 파이토치 라이브러리에는 기본적으로 torch.autograd.profiler 모듈을 통해 간단한 프로파일러 API를 제공하며, 이를 이용해 파이토치 모델 내부 연산들의 시간 / 메모리 소모량을 한 번에 측정할 수 있다. 한번 간단한 예제들과 함께 알아보자.

 


1. Sample module

먼저 테스트를 위한 간단한 모듈을 하나 정의해 보자.

import torch
import numpy as np
import torch.nn as nn
import torch.autograd.profiler as profiler


class ProfileTargetModule(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = True, bn: bool = True):
        super(ProfileTargetModule, self).__init__()
        self.conv = nn.Conv2d(in_features, out_features, kernel_size=3, padding=1, bias=bias)
        self.bn = nn.BatchNorm2d(out_features)
        
    def forward(self, input):
        with profiler.record_function("CONV FORWARD"):
            out = self.conv(input)
            out = self.bn(out)

        with profiler.record_function("SVD"):
            u, s, vh = np.linalg.svd(out.cpu().detach().numpy())
            s = torch.from_numpy(s).cuda()
            
        return out, s

이미지 데이터에 대해서 2차원 conv와 batchnorm을 수행한 후, spatial dimension에 대해서 SVD를 수행해 얻어낸 sigma 매트릭스를 내보내는 모델이다.

 

이 모듈에서는 SVD를 numpy 라이브러리를 사용해 수행하는데, 이를 위해 텐서를 CPU로 옮겨 numpy array로 바꿔주고, 다시 결과를 GPU tensor로 옮겨주는 작업을 하고 있다.

 

profiler.record_function < with profiler.record_function(...): > 부분은 코드를 작은 sub-task로 나누어 특정한 구역 별로 성능을 측정할 때 사용된다. 위 예시에서는 Conv+BN 부분과 SVD 부분에 설정해 두었다. 

 

프로파일링을 원하는 지점을 < with profiler.profile(...) as prof: > 블록으로 감싸고, 프로파일링 결과를 출력할 때는 < print(prof.key_averages().table(...)) > 함수를 이용해 표 형태로 결과를 출력할 수 있다. 이전의 파이썬 프로파일러와 같이, 여러 옵션을 통해 결과의 출력 방식을 마음대로 조절할 수 있다.

2. CPU profiling

한번 프로파일링 결과를 확인해 보자. 먼저 CPU만을 이용해 프로파일링을 수행해 보자. profiler.profile 호출 시 use_cuda 인자를 따로 넣어주지 않으면 자동적으로 CPU를 이용한 프로파일링이 실행된다.

x = torch.rand(1, 3, 128, 128).cuda()
model = ProfileTargetModule(3, 8, True, True).cuda()
out, s = model(x) # Warm-up

with profiler.profile(with_stack=True, profile_memory=True) as prof:
    out, s = model(x)
    
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

>>>

--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                            Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             SVD        93.62%      29.385ms        95.06%      29.838ms      29.838ms          -4 b    -512.02 Kb       4.00 Kb           0 b             1  
                    CONV FORWARD         0.64%     201.873us         4.49%       1.409ms       1.409ms          -4 b         -20 b       1.00 Mb        -512 b             1  
                aten::batch_norm         0.08%      25.597us         1.96%     614.343us     614.343us           0 b           0 b     513.00 Kb           0 b             1  
    aten::_batch_norm_impl_index         0.08%      23.755us         1.88%     588.746us     588.746us           0 b           0 b     513.00 Kb           0 b             1  
          aten::cudnn_batch_norm         0.38%     119.148us         1.80%     564.991us     564.991us           0 b           0 b     513.00 Kb           0 b             1  
                    aten::conv2d         0.09%      28.544us         1.66%     521.786us     521.786us           0 b           0 b     512.00 Kb           0 b             1  
                     aten::empty         1.61%     504.614us         1.61%     504.614us      45.874us          40 b          40 b     530.00 Kb     530.00 Kb            11  
               aten::convolution         0.08%      25.357us         1.57%     493.242us     493.242us           0 b           0 b     512.00 Kb           0 b             1  
              aten::_convolution         0.14%      42.538us         1.49%     467.885us     467.885us           0 b           0 b     512.00 Kb           0 b             1  
                        aten::to         0.12%      36.240us         1.20%     376.719us     188.360us     512.00 Kb           0 b       4.00 Kb           0 b             2  
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 31.388ms

결과는 함수의 누적 소요 시간 순으로 정렬했다 (CPU total). 반대로 표 상의 Self CPU는 함수 내에서 다른 함수를 호출해서 소요되는 시간을 포함하지 않는다는 것이 차이점이다. 정렬 기준은 sort_by argument를 통해 조절할 수 있다.

 

결과를 보면 SVD sub-task가 CONV FORWARD task에 비해서 현저히 많은 시간을 소요하고 있다는 것을 알 수 있다 (29.84ms  / 1.41ms). 또한 오퍼레이터나 블록 실행 중의 메모리 (CPU, GPU) 사용량도 모니터링 할 수 있어서, 어느 부분에서 특히 메모리를 많이 소요하는 지 확인하는 것이 가능하다. 

 

위의 < profiler.profile(...) > 부분에서 with_stack은 오퍼레이션이 실행된 파일 명과 라인 넘버를 함께 트래킹 할 지를 조절하는 파라미터이다.

 

아래와 같이 프로파일 결과 출력 시 group_by_stack_n 옵션을 추가로 넣어주면, 오퍼레이터의 마지막 n번의 호출 지점을 함께 출력해 준다.

group by stack option

print(prof.key_averages(group_by_stack_n=5).table(sort_by="cpu_time_total", row_limit=3))

>>>

--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------------------------------------  
                            Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  Source Location                                                              
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------------------------------------  
                             SVD        94.45%      28.735ms        95.88%      29.171ms      29.171ms          -4 b    -512.02 Kb       4.00 Kb           0 b             1  ...H/lib/python3.9/site-packages/torch/autograd/profiler.py(616): __enter__  
                                                                                                                                                                              /HDD0/JHBaek/profiler.py(17): forward                                        
                                                                                                                                                                              .../lib/python3.9/site-packages/torch/nn/modules/module.py(889): _call_impl  
                                                                                                                                                                              /HDD0/JHBaek/profiler.py(28): <module>                                       
                                                                                                                                                                                                                                                           
                    CONV FORWARD         0.57%     172.193us         3.73%       1.135ms       1.135ms          -4 b         -20 b       1.00 Mb        -512 b             1  ...H/lib/python3.9/site-packages/torch/autograd/profiler.py(616): __enter__  
                                                                                                                                                                              /HDD0/JHBaek/profiler.py(13): forward                                        
                                                                                                                                                                              .../lib/python3.9/site-packages/torch/nn/modules/module.py(889): _call_impl  
                                                                                                                                                                              /HDD0/JHBaek/profiler.py(28): <module>                                       
                                                                                                                                                                                                                                                           
                aten::batch_norm         0.07%      20.741us         1.57%     477.751us     477.751us           0 b           0 b     513.00 Kb           0 b             1  ..._JH/lib/python3.9/site-packages/torch/nn/functional.py(2149): batch_norm  
                                                                                                                                                                              .../lib/python3.9/site-packages/torch/nn/modules/batchnorm.py(135): forward  
                                                                                                                                                                              .../lib/python3.9/site-packages/torch/nn/modules/module.py(889): _call_impl  
                                                                                                                                                                              /HDD0/JHBaek/profiler.py(15): forward                                        
                                                                                                                                                                              .../lib/python3.9/site-packages/torch/nn/modules/module.py(889): _call_impl  
                                                                                                                                                                                                                                                           
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------------------------------------  
Self CPU time total: 30.424ms

위와 같은 방식으로는 출력이 많이 길어지지만, 정확히 어느 지점에서 오퍼레이션이 호출되었는 지 확인할 때 좋다.

3. GPU profiling

이번에는 use_cuda 인자를 profiler.profile에 전달해 GPU까지 이용한 프로파일링 결과를 확인해 보자. 단순히 프로파일러에 use_cuda 인자를 True로 전달해 주면 된다. GPU 연산 시간 위주로 확인하기 위해, 아래의 sort_by 인자를 "cuda_time_total"로 변경해 주자.

x = torch.rand(1, 3, 128, 128).cuda()
model = ProfileTargetModule(3, 8, True, True).cuda()
out, s = model(x)

with profiler.profile(with_stack=True, use_cuda=True, profile_memory=True) as prof:
    out, s = model(x)
    
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

>>>

--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                            Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             SVD        93.40%      29.984ms        95.00%      30.495ms      30.495ms      30.025ms        93.54%      30.495ms      30.495ms          -4 b    -512.02 Kb       4.00 Kb           0 b             1  
                    CONV FORWARD         1.05%     336.994us         4.52%       1.452ms       1.452ms     322.752us         1.01%       1.452ms       1.452ms          -4 b         -20 b       1.00 Mb        -512 b             1  
                    aten::conv2d         0.07%      24.017us         1.78%     570.698us     570.698us      24.576us         0.08%     570.368us     570.368us           0 b           0 b     512.00 Kb           0 b             1  
               aten::convolution         0.06%      18.925us         1.70%     546.681us     546.681us      18.432us         0.06%     545.792us     545.792us           0 b           0 b     512.00 Kb           0 b             1  
              aten::_convolution         0.14%      44.645us         1.64%     527.756us     527.756us      45.056us         0.14%     527.360us     527.360us           0 b           0 b     512.00 Kb           0 b             1  
                aten::batch_norm         0.06%      20.842us         1.52%     489.137us     489.137us      19.264us         0.06%     513.856us     513.856us           0 b           0 b     513.00 Kb           0 b             1  
    aten::_batch_norm_impl_index         0.06%      19.772us         1.46%     468.295us     468.295us      17.408us         0.05%     494.592us     494.592us           0 b           0 b     513.00 Kb           0 b             1  
          aten::cudnn_batch_norm         0.32%     103.347us         1.40%     448.523us     448.523us     477.184us         1.49%     477.184us     477.184us           0 b           0 b     513.00 Kb           0 b             1  
                        aten::to         0.15%      47.316us         1.34%     430.551us     215.275us      87.103us         0.27%     430.848us     215.424us     512.00 Kb           0 b       4.00 Kb           0 b             2  
         aten::cudnn_convolution         1.04%     333.996us         1.16%     371.816us     371.816us     371.712us         1.16%     371.712us     371.712us           0 b           0 b     512.00 Kb     495.00 Kb             1  
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 32.101ms
Self CUDA time total: 32.100ms

기존의 CPU 결과 이외에 GPU 상에서 측정된 시간 / 메모리 결과가 추가되었다.

 

SVDCONV FORWARD 부분이 각각 30ms, 1.5ms 정도를 소요하며, 2d convolutionbatchnorm이 각각 570us, 513us 정도를 소요하는 것을 볼 수 있다. 

 

그럼 한번 numpy 라이브러리를 사용해 SVD를 수행하는 부분을 torch.svd 함수로 바꾸면 시간 / 메모리 성능이 어떻게 차이날 지를 알아보자.

class ProfileTargetModule(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = True, bn: bool = True):
        super(ProfileTargetModule, self).__init__()
        self.conv = nn.Conv2d(in_features, out_features, kernel_size=3, padding=1, bias=bias)
        self.bn = nn.BatchNorm2d(out_features)
        
    def forward(self, input):
        with profiler.record_function("CONV FORWARD"):
            out = self.conv(input)
            out = self.bn(out)

        with profiler.record_function("SVD"):
            u, s, vh = torch.svd(out)
            
        return out, s

위와 같이 모듈 코드를 수정했다. 새 프로파일링 결과를 기존 결과와 비교해 보자.

 

(기존)

--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                            Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             SVD        93.40%      29.984ms        95.00%      30.495ms      30.495ms      30.025ms        93.54%      30.495ms      30.495ms          -4 b    -512.02 Kb       4.00 Kb           0 b             1  
                    CONV FORWARD         1.05%     336.994us         4.52%       1.452ms       1.452ms     322.752us         1.01%       1.452ms       1.452ms          -4 b         -20 b       1.00 Mb        -512 b             1  
                    aten::conv2d         0.07%      24.017us         1.78%     570.698us     570.698us      24.576us         0.08%     570.368us     570.368us           0 b           0 b     512.00 Kb           0 b             1  
               aten::convolution         0.06%      18.925us         1.70%     546.681us     546.681us      18.432us         0.06%     545.792us     545.792us           0 b           0 b     512.00 Kb           0 b             1  
              aten::_convolution         0.14%      44.645us         1.64%     527.756us     527.756us      45.056us         0.14%     527.360us     527.360us           0 b           0 b     512.00 Kb           0 b             1  
                aten::batch_norm         0.06%      20.842us         1.52%     489.137us     489.137us      19.264us         0.06%     513.856us     513.856us           0 b           0 b     513.00 Kb           0 b             1  
    aten::_batch_norm_impl_index         0.06%      19.772us         1.46%     468.295us     468.295us      17.408us         0.05%     494.592us     494.592us           0 b           0 b     513.00 Kb           0 b             1  
          aten::cudnn_batch_norm         0.32%     103.347us         1.40%     448.523us     448.523us     477.184us         1.49%     477.184us     477.184us           0 b           0 b     513.00 Kb           0 b             1  
                        aten::to         0.15%      47.316us         1.34%     430.551us     215.275us      87.103us         0.27%     430.848us     215.424us     512.00 Kb           0 b       4.00 Kb           0 b             2  
         aten::cudnn_convolution         1.04%     333.996us         1.16%     371.816us     371.816us     371.712us         1.16%     371.712us     371.712us           0 b           0 b     512.00 Kb     495.00 Kb             1  
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 32.101ms
Self CUDA time total: 32.100ms

(변경 후)

--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                            Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             SVD         0.25%      76.107us        95.53%      28.572ms      28.572ms      85.153us         0.28%      28.572ms      28.572ms          -4 b         -20 b       1.00 Mb           0 b             1  
                       aten::svd         0.07%      20.633us        95.25%      28.487ms      28.487ms      20.702us         0.07%      28.487ms      28.487ms           0 b           0 b       1.00 Mb           0 b             1  
               aten::_svd_helper        92.60%      27.696ms        95.18%      28.466ms      28.466ms      27.745ms        92.77%      28.466ms      28.466ms           0 b         -32 b       1.00 Mb    -512.50 Kb             1  
                    CONV FORWARD         1.11%     330.670us         3.95%       1.183ms       1.183ms     311.456us         1.04%       1.183ms       1.183ms          -4 b         -20 b       1.00 Mb        -512 b             1  
                    aten::conv2d         0.08%      22.820us         1.98%     593.092us     593.092us      22.368us         0.07%     593.408us     593.408us           0 b           0 b     512.00 Kb           0 b             1  
               aten::convolution         0.06%      19.291us         1.91%     570.272us     570.272us      20.064us         0.07%     571.040us     571.040us           0 b           0 b     512.00 Kb           0 b             1  
              aten::_convolution         0.16%      48.661us         1.84%     550.981us     550.981us      48.576us         0.16%     550.976us     550.976us           0 b           0 b     512.00 Kb           0 b             1  
                     aten::zeros         0.42%     125.955us         1.80%     537.890us     134.473us     443.488us         1.48%     536.736us     134.184us           8 b           0 b     512.50 Kb           0 b             4  
         aten::cudnn_convolution         1.15%     342.481us         1.27%     381.053us     381.053us     381.088us         1.27%     381.088us     381.088us           0 b           0 b     512.00 Kb     495.00 Kb             1  
                aten::batch_norm         0.07%      20.896us         0.68%     202.769us     202.769us      17.664us         0.06%     230.400us     230.400us           0 b           0 b     513.00 Kb           0 b             1  
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 29.908ms
Self CUDA time total: 29.907ms

메모리 측면에서는 큰 차이가 없지만, SVD sub-task의 시간 소요량이 CUDA total 기준으로 2ms 가량 빨라졌으며 전체적인 소요 시간 또한 줄어든 것을 확인할 수 있다.

 

다만 특이한 점으로는 코드를 수정하지 않은 CONV FORWARD 부분의 시간 소요량도 줄어들었는데, 후반부 그래프 구조가 바뀌면서 최적화가 된 건지 혹은 다른 이유로 바뀐 것인지는 확실하지 않다. 이 부분에 대해서는 추가적으로 공부가 필요할 듯 하다...

 


이번에는 파이토치에서 제공하고 있는 프로파일러를 사용해 모델의 시간 / 메모리 성능을 CPU, GPU 상에서 모니터링하는 방법에 대해 간략하게 알아봤다. 파이썬의 내장 프로파일러와 비교했을 때 전체적인 API 형태가 동일하지는 않지만, 개인적인 생각으로는 굉장히 유사하다고 느꼈다. 따라서 둘 중에 하나에 익숙해졌다면 나머지 하나도 적용해 보는데 큰 무리가 없을 것이라 생각된다.

 

기존의 프로파일러의 사용 방식과 똑같이, 파이토치 프로파일러는 모델에서 시간 지연이 일어나는 지점을 찾거나 메모리를 예상치 못하게 많이 소모하는 지점을 찾는 데 유용하게 쓰일 수 있다. 또한 내부의 코드 블록들을 record_function 함수를 이용해 묶어서 구간 별로 프로파일링 결과를 따로 볼 수 있다는 점이 특히 편리하다.

 

글에서는 최소한의 기능만을 소개해 두고 있기 때문에, API 도큐멘테이션 등을 한번 더 읽어보는 것을 추천한다.

 

아래는 참고할 만한 사이트들이다.

References

torch.profiler 도큐멘테이션

https://pytorch.org/docs/master/profiler.html

Profiling your PyTorch Module Tutorial

https://pytorch.org/tutorials/beginner/profiler.html#improve-time-performance

PyTorch Recipes: Pytorch Profiler

https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html

torch.svd

https://pytorch.org/docs/stable/generated/torch.svd.html

반응형