Composable fine tuning and input masking with mlx-tuning-fork

Chimezie Ogbuji
6 min readFeb 22, 2025
https://commons.wikimedia.org/wiki/File:B.B._King_3011710050.jpg

It’s been a long time coming (as BB King said), but mix-tuning-fork has a new release, version 0.4.0. Most of the changes are due to those in MLX, which has gone through significant modifications during this time. In particular, mlx_lm now supports input masking during instruction fine tuning.

Support for this capability in MLX was the main motivation behind creating mlx-tuning-fork, so now that this is a core part of mlx_lm, it is nice to retire this functionality to rely on their reference implementation and focus on other motivations such as ease of use similar to that of axolotl.

To understand input masking during instruction fine tuning (or completion-only training), it is important to understand what logits are and what instruction fine tuning is.

Instruction fine tuning with labeled instruction data is a straightforward, supervised training method for adapting Large Language Models (LLMs) to follow instructions using annotated input-output text pairs. The instructions can stand on their own or have, as a component, input data that they are meant to act upon and are paired with the resulting output or a completion. Ideally, when an LLM is sampled with an instruction, it would provide output equivalent or similar to the label in the pair.

For example, consider training an LLM to follow instructions to provide a formal definition from a domain reference model of canonical human anatomy like the FMA. You can use the cogbuji/MrGrammaticalOntology_anatomist dataset for this instruction fine tuning. There are no inputs for this kind of instruction, just a request to define an anatomy term.

Below is an example of an instruction data record for the left lateral ventricle:

Example instructions for defnining “left lateral ventricle”

In practice with transformer-based decoder models, this works by:

  • taking records such as these from the training set, encoding (or tokenizing) them into the internal representation the Neural Network uses during its autoregressive inference
  • calculating the loss or how accurately the probabilistic model infers, completes, or produces the completion tokens, encodings of the corresponding text.
Tokenization or encoding of instruction task data (training data)

The key component of the autoregressive inference process is a conditional probability distribution over the content given to the model as tokens, which the model produces in response. These are called logits. Conditional probability measures the probability of an event occurring, given that another event is already known to have happened. The logits capture conditional probabilities — at each token position — over all possible tokens (the model’s vocabulary) of having a particular token in the vocabulary at that position, given that it is preceded by all the tokens up to that place thus far from left to right and along the tokens given to the model. In the example below, the logits at position C₂ are conditioned on being preceded by all the input tokens in black and the C₀ and C₁ completion tokens.

Logits, masking, and loss calculation

During the instruction fine tuning, the loss is calculated using these logits. At this point, the loss is usually calculated using a cross-entropy loss function, which can consider the loss over the entire sequence of tokens provided to the model or just those that correspond to the completion or target tokens. Input masking or prompt masking is how the latter is done, which is useful for avoiding penalizing the model for not reproducing the instruction or input exactly as provided in the training data record.

This is a high-level abstraction of the initial process of LLM training that shares much in common with benchmark evaluation and is more formally defined and shown below from the appendix of the paper associated with the highly popular and influential Language Model Evaluation Harness used and referenced in hundreds of papers.

Appendix A2 of Stella Biderman, et al.’s “Lessons from the Trenches on Reproducible Evaluation of Language Models”

For what is being discussed here, loglikelihood can just be considered a normalized form of logits. However, it still captures the conditional probability that is the engine of the autoregressive modeling whose accuracy we measure in evaluation on the one hand and update during training on the other. It is also the same engine from which sampling is done to generate responses from the LLM during language tasks.

During training, the loss described earlier is minimized via backpropagation, or backward propagation, and this step is not related to benchmark evaluation. It occurs after producing the conditional probability from the models as logits, which is common to both. Backward propagation, however, is how the underlying neural network is trained by updating its parameters to minimize the loss or error.

From “How GPT3 Works — Visualizations and Animations”

With MLX (mlx_lm), it is the ‘masking’ described earlier that distinguishes self-supervised pre-training from supervised instruction fine-tuning. These two training tasks are the main components of how foundational models are built and can be used to first train a model on one task (performing continued pretraining) using a large corpus of unlabeled text from a specific, new domain and to transfer and take advantage of the resulting model's capabilities in learning another task, i.e., transfer learning.

Support for input or prompt masking and being able to chain mlx-tuning-fork training tasks using configuration files in a composable way, performing one configuration-based training task after another, and using the resulting adapters from one step as the basis for a subsequent one were the major changes in this release.

In this way, configurations can orchestrate the use of MLX for domain-adaptive pretraining followed by instruction tuning to produce domain-specific models from existing foundational models using domain-specific training data publicly available on Hugging Face or local data.

So, using a configuration file with the following content (in a file called clinical-pretraining.yaml) :

model: "mlx-community/Mistral-Small-24B-Instruct-2501-4bit"
train: true
hf_dataset:
name: "cogbuji/medqa_corpus_en"
train_split: "train[:-10%]"
valid_split: "train[-10%:]"
text_feature: "text"

num_layers: 20
batch_size: 2

epochs: 1
evals_per_epoch: 1

learning_rate: 5e-6

and another with the following (in a file called anatomist-coder-instruction-fine-tuning.yaml):

model: "mlx-community/Mistral-Small-24B-Instruct-2501-4bit"
train: true
mask_prompt: true
hf_dataset:
- name: "cogbuji/MrGrammaticalOntology_anatomist"
train_split: "train"
valid_split: "validation"
prompt_feature: "input"
completion_feature: "output"
- name: "cogbuji/MrGrammaticalOntology_clinical_coding"
train_split: "train"
valid_split: "validation"
prompt_feature: "input"
completion_feature: "output"

batch_size: 2

epochs: 2
evals_per_epoch: 1
learning_rate: 5e-6

Notice the new mask_prompt: true directive, which tells mlx_lm to apply prompt masking to the Huggingface datasets. You can have mlx-tuning-fork perform the following training tasks:

via the following command line that demonstrates using the training task configurations in a composable way:

$ mlx_tuning_fork_training --train-type lora \
/path/to/clinical-pretraining.yaml \
/path/to/anatomist-coder-instruction-fine-tuning.yaml

See the change log for a more detailed list of the changes in this release.

Sign up to discover human stories that deepen your understanding of the world.

Free

Distraction-free reading. No ads.

Organize your knowledge with lists and highlights.

Tell your story. Find your audience.

Membership

Read member-only stories

Support writers you read most

Earn money for your writing

Listen to audio narrations

Read offline with the Medium app

Chimezie Ogbuji
Chimezie Ogbuji

Written by Chimezie Ogbuji

An informatics engineer, data scientist, inventor, and business owner

No responses yet

Write a response