После удаления нативной поддержки JAX и TensorFlow из библиотеки transformers разработчики столкнулись с проблемой интеграции моделей Hugging Face в JAX-окружение. Решение предложено в блоге Hugging Face: использование библиотеки torchax для преобразования PyTorch-моделей в JAX-совместимые функции.

Установка и настройка
Первоначальная настройка требует установки базовых пакетов:
# Создание venv/conda окружения pip install huggingface-cli huggingface-cli login pip install -U transformers datasets evaluate accelerate timm flax # Установка torchax и JAX pip install torchax pip install jax[tpu] # или jax[cuda12] для GPU
Преобразование модели
Загрузка модели с указанием возврата JAX-тензоров:
from transformers import AutoModelForCausalLM, AutoTokenizer import torchax model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype="bfloat16", device_map="cpu") tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") # Ключевой параметр для JAX-тензоров model_inputs = tokenizer(["Секрет вкусного торта: "], return_tensors="jax") # Преобразование модели weights, func = torchax.extract_jax(model) output = func(weights, (model_inputs.input_ids,))
Проблемы JIT-компиляции
При попытке JIT-компиляции возникают ошибки обработки кастомных типов Hugging Face. Решение — регистрация в системе Pytree:
from jax.tree_util import register_pytree_node from transformers import modeling_outputs, cache_utils # Для CausalLMOutputWithPast def output_flatten(v): return v.to_tuple(), None def output_unflatten(_, children): return modeling_outputs.CausalLMOutputWithPast(*children) register_pytree_node(modeling_outputs.CausalLMOutputWithPast, output_flatten, output_unflatten) # Для DynamicCache def _flatten_dynamic_cache(dynamic_cache): return (dynamic_cache.key_cache, dynamic_cache.value_cache), None def _unflatten_dynamic_cache(_, children): cache = cache_utils.DynamicCache() cache.key_cache, cache.value_cache = children return cache register_pytree_node(cache_utils.DynamicCache, _flatten_dynamic_cache, _unflatten_dynamic_cache)
Этот workaround — временное спасение для JAX-энтузиастов, но демонстрирует хрупкость экосистемы. Удаление нативной поддержки JAX из transformers создало ненужные барьеры: теперь разработчики тратят время на обходные решения вместо работы с моделями. Особенно чувствительно это для исследователей с ограниченным доступом к TPU-кластерам, где JAX даёт реальные преимущества. Хотя torchax — умный технический хак, стратегическая нестабильность API крупных проектов остаётся проблемой. В долгосрочной перспективе сообществу нужны стандарты совместимости фреймворков, а не костыли.
По материалам: Hugging Face Blog
Оставить комментарий