Today’s Paper: Jamba: A Hybrid Transformer-Mamba Language Model. Lieber et al. 28 Mar. 2024. https://arxiv.org/abs/2403.19887
With all the talk about ChatGPT, we sometimes forget about the importance of publicly available models. I am very encouraged by the uptake of the Llama models, for example, in research and applications alike. The introduction of Jamba in today’s paper marks the debut of a new publicly available large language model featuring a hybrid architecture. Technically, this model blends Transformer layers with Mamba layers, a type of state-space model, and includes a mixture-of-experts (MoE) component.
By integrating these two distinct architectural styles, the authors claim that Jamba achieves enhanced performance and greater throughput while maintaining a manageable memory footprint. The version of Jamba being released, which operates on a single 80GB GPU, is a 7B-based model with 12 billion active parameters and 52 billion total available parameters. The architecture is meant to be versatile, supporting design options tailored to different hardware configurations. The model is truly open, having been released under the Apache 2.0 license.
The basic component in the model is a Jamba block, which may be repeated in sequence. Each Jamba block is a combination of Mamba or Attention layers. Each such layer contains either an attention or a Mamba module, followed by a multi-layer perceptron (MLP). The different possible types of layers are shown in (b) above. A Jamba block contains l layers, which are mixed at a ratio of a : m, meaning a attention layers for every m Mamba layers. The degrees of freedom in the model are:
l: The number of layers.
a : m: ratio of attention-to-Mamba layers.
e: how often to use MoE instead of a single MLP.
n: total number of experts per layer.
K: number of top experts used at each token.
Given this design space, Jamba provides flexibility in preferring certain properties over others. For example, increasing m and decreasing a, that is, increasing the ratio of Mamba layers at the expense of attention layers, reduces the required memory for storing the key-value cache. This reduces the overall memory footprint, which is especially important for processing long sequences. Increasing the ratio of Mamba layers also improves throughput, especially at long sequences. However, decreasing a might lower the model’s capabilities.
Additionally, balancing n, K, and e affects the relationship between active parameters and total available parameters. A larger n increases the model capacity at the expense of memory footprint, while a larger K increases the active parameter usage and the compute requirement. In contrast, a larger e decreases the model capacity, while decreasing both compute (when K>1) and memory requirements, and allowing for less communication dependencies (decreasing memory transfers as well as inter-GPU communication during expert-parallel training and inference).
The overall paper is very technical, but the table below compares Jamba with recent publicly available models, showing its advantage in maintaining a small KV cache even with 256K token contexts:
The context length that Jamba can accommodate seems impressive. The figure below shows this nicely, comparing to Llama and Mixtral:
For evaluating the model, the authors relied on the following benchmarks:
Common Sense Reasoning: Benchmarks like HellaSwag, WinoGrande, ARC-E (AI2 Reasoning Challenge Easy), ARC-Challenge, and PIQA (Physical Interaction QA) test the model's ability to reason about everyday situations or specific reasoning challenges.
Reading Comprehension: Datasets such as BoolQ and QuAC (Question Answering in Context) assess the model's ability to understand and process natural language queries within a given context.
Other Benchmarks: These include GSM8K (Grade School Math 8K), HumanEval (coding problems), and Natural Questions in a closed-book setting. They measure the model's general knowledge and problem-solving skills in specialized domains.
Aggregate Benchmarks: The MMLU (Massive Multitask Language Understanding) and BBH (Big Bench Hard) benchmarks aggregate scores across multiple datasets to evaluate broad language understanding capabilities.
For evaluating long-context evaluations, the authors use:
Needle-in-a-Haystack: This evaluation tests the model's ability to retrieve specific information from very long texts.
Naturalistic Long-Context Evaluation: This involves using question-answering benchmarks with extended narrative texts where the model must navigate and interpret lengthy inputs effectively.
On the first set of benchmarks, the table below compares Jamba to several publicly available models. The authors compare with Llama-2 13B, which has about the same number of active parameters as Jamba; Llama-2 70B, which is larger than the model; Gemma, which has 7B parameters; and Mixtral, which has about the same number of active and total parameters as Jamba.
Noticeably, Jamba performs comparably to the leading publicly available models of similar or larger size, including Llama-2 70B and Mixtral. At the same time, the model has a smaller number of total available parameters than Llama-2 (52B compared to 70B). Moreover, as a sparse model, Jamba has only 12B active parameters, similar to Mixtral’s 12.9B active parameters.
The table below summarizes the evaluation results, in terms of F1 on long-form contexts. Jamba outperforms Mixtral on most of the datasets as well as on average. In addition, as these long-context tasks require substantial computation, here Jamba’s efficiency shines, with much better throughput with long contexts.
The authors also investigate the benefits of combining Attention and Mamba. First, they investigate the ratio of Attention to Mamba layers (a : m), with 1.3B parameters models trained for 250B tokens. As the table below shows, the hybrid Jamba model outperforms the pure attention or Mamba models. The ratio of attention-to-Mamba layers may be 1:3 or 1:7 with virtually no performance difference. The accompanying figure below the table shows the training loss of these models, where Jamba exhibits improved loss during training. Given that a 1:7 ratio is more compute-efficient and shows similar performance, the authors opt for it in their larger-scale experiments, which I don’t cover here but which are in the main paper.
In concluding, Jamba seems to be an open and novel architecture which effectively combines Attention and Mamba layers,with MoE modules. It was able to reach state-of-the-art performance and support long contexts. It provides flexibility for balancing performance and memory requirements, while maintaining a high throughput. In future work, the authors plan to release model checkpoints from smaller-scale training runs. The largest model they provided with this release has 12B active and 52B total available parameters, supporting context lengths of up to 256K tokens and fitting in a single 80GB GPU even when processing 140K-token texts.