목록 보기
CUDA OOM 해결 사례 공유 - PyTorch all_gather_object 의 비밀
AI

CUDA OOM 해결 사례 공유 - PyTorch all_gather_object 의 비밀

데보션
데보션
2025년 4월 22일

두줄요약

데이터셋 로딩 중 발생한 CUDA OOM의 원인을 `all_gather_object` 내부 동작에서 찾았습니다. 데이터를 chunk로 나눠 gather하도록 바꿔 GPU 메모리 사용량을 줄였습니다.

문제 상황

  • 학습이 아닌 데이터셋 로딩 단계에서 갑작스러운 CUDA OOM 발생
  • CPU에서 처리할 것으로 예상한 데이터 합치기 과정에서 GPU 메모리 사용 확인

원인 분석

  • dist.all_gather_object 내부에서 객체를 pickle 후 텐서로 변환해 NCCL 통신용 GPU 메모리 사용
  • 각 rank의 데이터를 한 번에 gather하면서 데이터 규모 증가 시 OOM 유발

해결 방법

  • 데이터를 여러 chunk로 나눠 all_gather_object를 반복 호출
  • chunk마다 torch.cuda.empty_cache()로 순간 메모리 사용량 완화

선택 이유

  • 노드 수가 많고 GPU 간 통신이 빠른 환경에 적합
  • sorted batch 같은 사전 로딩 기법을 유지하면서 분산 로딩 가능

댓글 0

댓글을 작성하려면 로그인이 필요합니다.

댓글을 불러오는 중...