Easy FunctionGemma finetuning with Tunix on Google TPUs

Finetuning the FunctionGemma model is made fast and easy using the lightweight JAX-based Tunix library on Google TPUs, a process demonstrated here using LoRA for supervised finetuning. This approach delivers significant accuracy improvements with high TPU efficiency, culminating in a model ready for deployment.

Wei Wei
4 min readadvanced
--
View Original

Overview

This tutorial demonstrates how to fine-tune FunctionGemma, a small language model for translating natural language into API calls, using Google's Tunix library on TPUs. The article walks through a complete LoRA-based supervised fine-tuning workflow on free-tier Google Colab TPU v5e-1, covering model loading, LoRA adapter application, custom dataset creation with completion-only loss, training execution, and exporting the merged model for on-device deployment.

What You'll Learn

1

How to fine-tune FunctionGemma using Tunix and LoRA on Google TPUs

2

How to set up JAX sharding and mesh configurations for TPU training

3

How to implement completion-only loss with a custom dataset class for function-calling models

4

How to merge LoRA adapters and export fine-tuned models to safetensors format

5

How to leverage free-tier Colab TPU v5e-1 for efficient LLM fine-tuning

Prerequisites & Requirements

  • Understanding of LLM fine-tuning concepts including LoRA and supervised fine-tuning
  • Familiarity with Python and deep learning frameworks
  • Google Colab account with TPU v5e-1 access (free tier)
  • Hugging Face Hub account for downloading model weights and datasets
  • Basic understanding of JAX and its sharding concepts(optional)

Key Questions Answered

What is FunctionGemma and what is it used for?
FunctionGemma is a small language model that enables developers to build fast and cost-effective agents that translate natural language into actionable API calls, particularly suited for edge devices. It bridges the gap between user intent expressed in natural language and structured API function calls.
How do you fine-tune FunctionGemma with Tunix on Google TPUs?
You download FunctionGemma weights and the Mobile Action dataset via Hugging Face Hub, set up a JAX mesh for TPU sharding, load the model using Tunix's create_model_from_safe_tensors(), apply LoRA adapters with Qwix to attention layers, create a custom dataset with completion-only loss, configure training with PeftTrainer using cosine decay learning rate schedule and AdamW optimizer, then train within the mesh context.
What is Tunix and what training techniques does it support?
Tunix is a lightweight JAX-based library designed to streamline LLM post-training. It is part of the extended JAX AI Stack and supports supervised fine-tuning, Parameter-Efficient Fine-Tuning (PEFT), preference tuning, reinforcement learning, and model distillation. It works with models like Gemma, Qwen, and LLaMA, and is designed for efficient use across large-scale hardware accelerators.
How do you apply LoRA adapters to FunctionGemma using Qwix?
Qwix's LoraProvider is configured with a regex module path targeting attention and projection layers (q_einsum, kv_einsum, gate_proj, down_proj, up_proj), along with LoRA rank and alpha parameters. The apply_lora_to_model function then wraps the base model with LoRA adapters. After application, the model state is sharded using JAX partition specs for efficient TPU utilization.
How do you implement completion-only loss for function-calling fine-tuning?
A custom dataset class tokenizes both the full prompt-and-completion sequence and the prompt-only sequence. An input mask is created that zeros out the prompt tokens and marks only the completion tokens with ones. This ensures the loss is computed only on the model's generated output, not the input prompt, which improves training quality for instruction-following tasks.
How do you export a fine-tuned LoRA model back to safetensors format?
After training, Tunix provides the save_lora_merged_model_as_safetensors() function from gemma_params, which takes the local model path, output directory, the LoRA model, rank, and alpha parameters. This merges the LoRA adapters back into the base model weights and saves the result as safetensors files, ready for downstream processing like on-device deployment with LiteRT.
Can you fine-tune FunctionGemma on free Google Colab TPUs?
Yes, the entire fine-tuning workflow runs on free-tier Google Colab TPU v5e-1, which is a single-core TPU. Tunix handles the mesh configuration automatically, and when only one TPU core is available, a simple mesh without sharding is created. The training takes only a few minutes and achieves high TPU utilization rates.

Technologies & Tools

Some links below are affiliate links. We may earn a commission if you make a purchase.

AI Model
Functiongemma
Small language model for translating natural language into API calls, the model being fine-tuned
ML Library
Tunix
Lightweight JAX-based library for LLM post-training, used to orchestrate the fine-tuning workflow
ML Framework
Jax
Underlying framework providing sharding, mesh configuration, and hardware acceleration
ML Library
Qwix
Library used to apply LoRA adapters to the model's attention and projection layers
Hardware Accelerator
Google Tpu
TPU v5e-1 used as the training hardware, available on free-tier Google Colab
Fine-tuning Technique
Lora
Parameter-efficient fine-tuning method applied to attention and projection layers
Model Repository
Hugging Face Hub
Used to download FunctionGemma model weights and the Mobile Action dataset
Development Environment
Google Colab
Free-tier notebook environment providing TPU v5e-1 access for training
ML Library
Optax
Used for AdamW optimizer and cosine decay learning rate schedule
Inference Runtime
Litert
Referenced as the target for on-device deployment after fine-tuning
Model Format
Safetensors
File format used for loading and exporting model weights

Key Actionable Insights

1
Use Tunix with LoRA for efficient fine-tuning of small language models on TPUs. This approach enables you to fine-tune FunctionGemma on free-tier Colab TPU v5e-1 hardware, achieving significant accuracy improvements with minimal training overhead and high TPU utilization.
This is particularly valuable for developers building edge-device agents who need cost-effective fine-tuning without expensive GPU infrastructure.
2
Implement completion-only loss by masking prompt tokens in your training data. Create an input mask that zeros out all prompt tokens and only marks completion tokens, ensuring the model learns to generate correct function calls without being penalized on the input context.
This technique is critical for function-calling models where the prompt structure is fixed and only the API call generation quality matters for training.
3
Apply LoRA adapters to both attention layers (q_einsum, kv_einsum) and projection layers (gate_proj, down_proj, up_proj) for comprehensive fine-tuning coverage. Using Qwix's regex-based module path targeting makes this configuration straightforward and maintainable.
Targeting these specific layers provides a good balance between fine-tuning effectiveness and parameter efficiency, keeping the trainable parameter count low while affecting the most impactful model components.
4
After fine-tuning, merge LoRA adapters back into the base model and export to safetensors format for production deployment. Tunix provides a dedicated function for this merge-and-export workflow, enabling seamless transition to on-device inference with tools like LiteRT.
This export step is essential for deploying fine-tuned models to edge devices where LoRA inference overhead is undesirable and a single merged model is required.
5
Use JAX mesh configuration to handle both single-TPU and multi-TPU setups dynamically. The mesh setup code checks the number of available TPU devices and creates an appropriate FSDP/TP sharding scheme, or falls back to a non-sharded mesh for single-core environments.
This pattern ensures your training code works across different hardware configurations without modification, from free-tier Colab to production TPU pods.

Common Pitfalls

1
Not configuring the JAX mesh correctly for the available hardware. When using free-tier Colab with a single TPU core, you need to detect the number of available devices and create an appropriate mesh without sharding, rather than assuming multi-device availability.
The article shows how to dynamically check jax.devices() and create either a sharded or non-sharded mesh configuration accordingly.
2
Computing loss on the entire sequence including the prompt tokens, rather than only on the completion tokens. This trains the model to reproduce the prompt rather than focusing on generating correct function calls, leading to suboptimal fine-tuning results.
The custom dataset class solves this by creating an input mask that zeros out prompt tokens and only marks completion tokens for loss computation.
3
Forgetting to merge LoRA adapters back into the base model before deployment. LoRA adds inference overhead when kept separate, and edge devices typically require a single merged model file for efficient inference with runtimes like LiteRT.
Tunix provides the save_lora_merged_model_as_safetensors() function specifically for this merge-and-export step.

Related Concepts

Parameter-efficient Fine-tuning (peft)
Lora (low-rank Adaptation)
Supervised Fine-tuning
Function Calling In Llms
Jax Sharding And Parallelism
Tpu Training Optimization
On-device Model Deployment
Model Distillation
Preference Tuning
Reinforcement Learning For Llms
Edge AI Inference
Chat Template Tokenization