
AI
CUDA OOM 해결 사례 공유 - PyTorch all_gather_object 의 비밀
두줄요약
데이터셋 로딩 중 발생한 CUDA OOM의 원인을 `all_gather_object` 내부 동작에서 찾았습니다. 데이터를 chunk로 나눠 gather하도록 바꿔 GPU 메모리 사용량을 줄였습니다.
문제 상황
- 학습이 아닌 데이터셋 로딩 단계에서 갑작스러운 CUDA OOM 발생
- CPU에서 처리할 것으로 예상한 데이터 합치기 과정에서 GPU 메모리 사용 확인
원인 분석
dist.all_gather_object내부에서 객체를 pickle 후 텐서로 변환해 NCCL 통신용 GPU 메모리 사용- 각 rank의 데이터를 한 번에 gather하면서 데이터 규모 증가 시 OOM 유발
해결 방법
- 데이터를 여러 chunk로 나눠
all_gather_object를 반복 호출 - chunk마다
torch.cuda.empty_cache()로 순간 메모리 사용량 완화
선택 이유
- 노드 수가 많고 GPU 간 통신이 빠른 환경에 적합
- sorted batch 같은 사전 로딩 기법을 유지하면서 분산 로딩 가능
