JAX로 NanoChat을 재구현, XLA의 성능과 생태계의 딜레마

by DD
1개월 전
조회수 6

PyTorch 기반 NanoChat을 JAX 및 Flax NNX로 포팅하여 TPU 호환성(TPU Portability)을 확보하고, 모델 크기, 데이터 볼륨, 컴퓨팅 예산에 따른 확장성 연구(Scaling Law Research)를 수행

XLA 컴파일을 통해 Python 오버헤드(Python Overhead)를 제거하고, 단일 GPU에서 10분 이내에 Nano 모델(885K 파라미터) 훈련 및 스트리밍 채팅 UI 구현

JIT 컴파일된 함수 내 디버깅(Debugging)의 어려움과 vLLM, Flash Attention 3 부재 등 생태계 제약(Ecosystem Limitation)으로 인한 단점 존재

Muon 옵티마이저(Newton-Schulz)를 JIT 호환 방식으로 구현하고, JAX의 자동 미분(Automatic Differentiation) 기능을 활용하여 코드 간결성 확보

향후 TPU를 활용한 확장성 실험을 통해 Chinchilla 스타일의 파워 법칙(Power Law)을 분석하고, JAX의 장점을 극대화할 계획

XLA 컴파일(Compilation)의 장점과 단점

JAX의 XLA(XLA) 컴파일은 최적화된 커널(Optimized Kernel) 실행을 통해 훈련 단계의 Python 오버헤드를 제거한다. 특히, 첫 번째 단계에서 XLA가 전체 순방향, 역방향, Muon 옵티마이저 업데이트를 단일 HLO 프로그램으로 융합하여 약 35초의 컴파일 시간을 소요한다.

장점: 컴파일 이후 단계에서는 Python 레벨의 오버헤드 없이 290ms의 빠른 단계 시간(Step Time)을 유지

단점: JIT 컴파일된 함수 내에서 디버깅(Debugging)의 어려움과 XLA 컴파일로 인한 초기 컴파일 시간 소요

결과적으로, 모델 크기가 커지고 각 커널이 더 많은 연산을 수행할수록 XLA의 장점이 부각되지만, 초기 컴파일 시간은 고려해야 할 사항이다.

JAX와 PyTorch의 차이점: 자동 미분(Automatic Differentiation) 및 코드 간결성

JAX는 nnx.value_and_grad 패턴을 통해 (params, data) -> (loss, grads) 형태의 훈련 구조를 강제하여, PyTorch에서 발생할 수 있는 그래디언트 관련 버그(Gradient Bugs)를 줄인다. 특히, zero_grad() 누락, 그래디언트 누수, .detach() 부재 등의 문제를 방지한다.

JAX의 장점: 명시적인 PRNG 스레딩(Explicit PRNG Threading)을 통해 재현성(Reproducibility)을 기본적으로 보장

PyTorch의 단점: torch.use_deterministic_algorithms(True)를 사용해야 하며, 이는 최적화된 여러 커널을 비활성화하여 처리량(Throughput) 감소

결론적으로, JAX는 자동 미분 및 재현성 측면에서 PyTorch보다 유리하며, 코드의 간결성을 높여 개발 생산성(Development Productivity)을 향상시킨다.

TPU 호환성(Portability) 및 확장성 실험

JAX는 XLA 백엔드를 통해 CUDA와 TPU를 모두 지원하므로, 코드 변경 없이 GPU에서 개발하고 TPU에서 확장성 실험을 수행할 수 있다. 이는 AI GDE TPU Research Cloud 프로그램을 통해 TPU v4-8 pod를 사용하는 연구자들에게 큰 이점을 제공한다.

실험 설계: 모델 크기, 데이터 볼륨, 컴퓨팅 예산을 체계적으로 변경하여 Chinchilla 스타일의 파워 법칙(Power Law)을 분석

실험 결과: 600단계 훈련에서 alpha = 0.027로 측정되었으며, 이는 짧은 훈련 시간과 작은 데이터셋의 한계를 반영

향후 계획: TPU를 활용하여 scale_n 및 scale_c 실험을 수행하고, 신뢰할 수 있는 결과를 발표할 예정이다.

결과적으로, JAX는 TPU를 활용한 확장성 연구(Scaling Research)에 적합하며, 연구 개발 효율성을 높일 수 있다.

생태계(Ecosystem) 제약과 PyTorch의 강점

JAX는 vLLM, DeepSpeed, PEFT, bitsandbytes, HuggingFace Transformers 등 PyTorch 생태계의 주요 도구를 지원하지 않아, 해당 도구를 사용하는 경우 JAX로의 전환 비용이 높다. 특히, LoRA 파인 튜닝, vLLM 추론, HuggingFace 벤치마크 평가 등을 위해서는 PyTorch를 유지해야 한다.

PyTorch의 강점: Flash Attention 3, 분산 옵티마이저, PEFT, vLLM 등의 풍부한 생태계

JAX의 단점: 해당 도구들을 JAX 환경에서 처음부터 다시 구현해야 하는 어려움

결론적으로, GPU 환경에서 Flash Attention 3, 분산 옵티마이저 등을 사용해야 하는 경우 PyTorch가 더 적합하며, JAX는 TPU를 활용한 연구에 집중하는 것이 효율적이다.

NanoChat 아키텍처(Architecture)의 핵심 구성 요소

NanoChat은 표준 GPT 또는 LLaMA 아키텍처에 없는 5가지 구성 요소를 포함하며, JAX 포팅 과정에서 각 구성 요소의 구현 방식이 주목할 만하다.

Logit Softcap: 주의 집중(Attention) 점수를 제한하여 깊이에 따른 엔트로피 붕괴 방지

Value Embeddings: 각 토큰에 컨텍스트와 독립적인 잔여 벡터(Residual Vector)를 학습

Smear/Backout Token Mixing: 인접 토큰(Adjacent Token) 혼합을 통해 정보 손실 방지

Depth-Aware Weight Initialization: 레이어 깊이에 따라 가중치 초기화(Weight Initialization) 스케일 조정

Muon Optimizer: Newton-Schulz 반복을 통해 가중치 그래디언트(Weight Gradient) 직교화

이러한 구성 요소들은 JAX의 특징을 활용하여 구현되었으며, 모델 성능 향상에 기여한다.

I Rebuilt Karpathy's NanoChat in JAX. Here's What XLA Gets Right and What It Gets Dead Wrong.

댓글 0

첫 번째 댓글을 남겨보세요!