M4 Pro에서 이미지-3D 모델을 구동하다!
Microsoft의 TRELLIS.2 이미지-3D 모델을 Apple Silicon(M1 이상)에서 구동할 수 있도록 PyTorch MPS를 활용하여 포팅
CUDA 기반의 기존 코드를 순수 PyTorch 및 Python 코드로 대체하여 NVIDIA GPU 없이 실행 가능
M4 Pro (24GB)에서 단일 이미지로부터 400K+ 정점 메시(Mesh) 생성에 약 3.5분 소요
커뮤니티에서는 모델의 성능에 대한 다소 부정적인 평가와 함께, 오픈소스 모델에 대한 기대를 표명
CUDA 의존성 제거 및 PyTorch MPS 활용
본 포팅은 TRELLIS.2 모델의 CUDA 기반 연산을 PyTorch MPS(Metal Performance Shaders)를 활용하여 Apple Silicon 환경에서 실행 가능하도록 구현했다. 특히, gather-scatter sparse 3D convolution, SDPA(Scaled Dot-Product Attention) attention, Python 기반의 메시 추출을 통해 CUDA 관련 라이브러리 의존성을 제거했다. 이러한 변경 사항은 9개의 파일에서 수백 줄의 코드로 이루어졌다.
성능 벤치마크 및 병목 현상 분석
M4 Pro (24GB) 환경에서 파이프라인 유형 512 기준으로 모델 로딩에 약 45초, 메시 디코딩에 30초가 소요된다. 전체 생성 시간은 약 3.5분이며, 메모리 사용량은 최대 18GB로 나타났다. 순수 PyTorch 기반의 sparse convolution이 주요 병목 지점으로, CUDA flex_gemm 커널에 비해 약 10배 느리다는 점이 지적되었다.
메시(Mesh) 생성 과정 및 한계점
포팅된 모델은 단일 이미지로부터 400K+ 정점 메시를 생성하며, 텍스처가 포함된 OBJ 및 GLB 파일을 출력한다. 하지만, nvdiffrast의 부재로 인해 텍스처 추출이 불가능하며, 메시 홀(Hole) 채우기 기능도 비활성화되어 있다. 또한, 3D 모델의 품질에 대한 커뮤니티의 의견은 엇갈리며, 일부 사용자는 meshy.ai와 같은 다른 모델을 더 선호하는 경향을 보였다.
기술적 구현 상세
구현 세부 사항을 살펴보면, sparse 3D convolution은 활성 복셀(Voxel)의 공간 해시를 구축하고, 각 커널 위치에 대한 이웃 특징을 수집하여 가중치를 적용하는 방식으로 구현되었다. Mesh Extraction은 Python 딕셔너리를 사용하여 CUDA hashmap 연산을 대체했다. 또한, SDPA(Scaled Dot-Product Attention) 백엔드를 sparse attention 모듈에 추가하여 가변 길이 시퀀스를 배치로 패딩하고, scaled_dot_product_attention 함수를 실행한 후 결과를 언패딩하는 방식으로 구현되었다.