Оглавление

После удаления нативной поддержки JAX и TensorFlow из библиотеки transformers разработчики столкнулись с проблемой интеграции моделей Hugging Face в JAX-окружение. Решение предложено в блоге Hugging Face: использование библиотеки torchax для преобразования PyTorch-моделей в JAX-совместимые функции.

Иллюстрация совместимости JAX и PyTorch
Источник: huggingface.co

Установка и настройка

Первоначальная настройка требует установки базовых пакетов:

# Создание venv/conda окружения
pip install huggingface-cli
huggingface-cli login
pip install -U transformers datasets evaluate accelerate timm flax

# Установка torchax и JAX
pip install torchax
pip install jax[tpu] # или jax[cuda12] для GPU

Преобразование модели

Загрузка модели с указанием возврата JAX-тензоров:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torchax

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", 
                                             torch_dtype="bfloat16", 
                                             device_map="cpu")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

# Ключевой параметр для JAX-тензоров
model_inputs = tokenizer(["Секрет вкусного торта: "], return_tensors="jax")

# Преобразование модели
weights, func = torchax.extract_jax(model)
output = func(weights, (model_inputs.input_ids,))

Проблемы JIT-компиляции

При попытке JIT-компиляции возникают ошибки обработки кастомных типов Hugging Face. Решение — регистрация в системе Pytree:

from jax.tree_util import register_pytree_node
from transformers import modeling_outputs, cache_utils

# Для CausalLMOutputWithPast
def output_flatten(v):
    return v.to_tuple(), None

def output_unflatten(_, children):
    return modeling_outputs.CausalLMOutputWithPast(*children)

register_pytree_node(modeling_outputs.CausalLMOutputWithPast, 
                     output_flatten, 
                     output_unflatten)

# Для DynamicCache
def _flatten_dynamic_cache(dynamic_cache):
    return (dynamic_cache.key_cache, dynamic_cache.value_cache), None

def _unflatten_dynamic_cache(_, children):
    cache = cache_utils.DynamicCache()
    cache.key_cache, cache.value_cache = children
    return cache

register_pytree_node(cache_utils.DynamicCache,
                     _flatten_dynamic_cache,
                     _unflatten_dynamic_cache)

Этот workaround — временное спасение для JAX-энтузиастов, но демонстрирует хрупкость экосистемы. Удаление нативной поддержки JAX из transformers создало ненужные барьеры: теперь разработчики тратят время на обходные решения вместо работы с моделями. Особенно чувствительно это для исследователей с ограниченным доступом к TPU-кластерам, где JAX даёт реальные преимущества. Хотя torchax — умный технический хак, стратегическая нестабильность API крупных проектов остаётся проблемой. В долгосрочной перспективе сообществу нужны стандарты совместимости фреймворков, а не костыли.

По материалам: Hugging Face Blog