После удаления нативной поддержки JAX и TensorFlow из библиотеки transformers разработчики столкнулись с проблемой интеграции моделей Hugging Face в JAX-окружение. Решение предложено в блоге Hugging Face: использование библиотеки torchax для преобразования PyTorch-моделей в JAX-совместимые функции.

Установка и настройка
Первоначальная настройка требует установки базовых пакетов:
# Создание venv/conda окружения pip install huggingface-cli huggingface-cli login pip install -U transformers datasets evaluate accelerate timm flax # Установка torchax и JAX pip install torchax pip install jax[tpu] # или jax[cuda12] для GPU
Преобразование модели
Загрузка модели с указанием возврата JAX-тензоров:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torchax
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf",
torch_dtype="bfloat16",
device_map="cpu")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# Ключевой параметр для JAX-тензоров
model_inputs = tokenizer(["Секрет вкусного торта: "], return_tensors="jax")
# Преобразование модели
weights, func = torchax.extract_jax(model)
output = func(weights, (model_inputs.input_ids,))
Проблемы JIT-компиляции
При попытке JIT-компиляции возникают ошибки обработки кастомных типов Hugging Face. Решение — регистрация в системе Pytree:
from jax.tree_util import register_pytree_node
from transformers import modeling_outputs, cache_utils
# Для CausalLMOutputWithPast
def output_flatten(v):
return v.to_tuple(), None
def output_unflatten(_, children):
return modeling_outputs.CausalLMOutputWithPast(*children)
register_pytree_node(modeling_outputs.CausalLMOutputWithPast,
output_flatten,
output_unflatten)
# Для DynamicCache
def _flatten_dynamic_cache(dynamic_cache):
return (dynamic_cache.key_cache, dynamic_cache.value_cache), None
def _unflatten_dynamic_cache(_, children):
cache = cache_utils.DynamicCache()
cache.key_cache, cache.value_cache = children
return cache
register_pytree_node(cache_utils.DynamicCache,
_flatten_dynamic_cache,
_unflatten_dynamic_cache)
Этот workaround — временное спасение для JAX-энтузиастов, но демонстрирует хрупкость экосистемы. Удаление нативной поддержки JAX из transformers создало ненужные барьеры: теперь разработчики тратят время на обходные решения вместо работы с моделями. Особенно чувствительно это для исследователей с ограниченным доступом к TPU-кластерам, где JAX даёт реальные преимущества. Хотя torchax — умный технический хак, стратегическая нестабильность API крупных проектов остаётся проблемой. В долгосрочной перспективе сообществу нужны стандарты совместимости фреймворков, а не костыли.
По материалам: Hugging Face Blog
Оставить комментарий