As the demand for long-context large language models (LLMs) increases, models with context windows of up to 128k or 1M tokens are becoming increasingly prevalent. However, long- context LLM inference is challenging since the inference speed decreases significantly as the sequence length grows. This slowdown is primarily caused by loading a large KV cache during self-attention. Previous works have shown that a small portion of critical tokens will dominate the attention outcomes. However, we observe the criticality of a token highly depends on the query. To this end, we propose Quest, a query-aware token criticality estimation algorithm. Quest keeps track of the minimal and maximal Key values in KV cache pages and estimates the criticality of a given page using Query vectors. By only loading the Top-K critical KV cache pages for attention, Quest significantly speeds up self-attention without sacrificing accuracy. We show that Quest can achieve up to 7.03× self-attention speedup, which reduces inference latency by 2.23× while performing well on tasks with long dependencies with negligible accuracy loss.
Key Idea: preserve all KV cache, and significantly accelerate inference by reducing the memory movement from the entire KV cache to selected constant K pages.
@inproceedings{tang2024quest,
title={{QUEST: Query-Aware Sparsity for Efficient Long-Context LLM Inference}},
author={Jiaming Tang and Yilong Zhao and Kan Zhu and Guangxuan Xiao and Baris Kasikci and Song Han},
booktitle={Proceedings of the International Conference on Machine Learning (ICML)},
year={2024},
}
We thank Zihao Ye for his insightful discussion, feedback, and useful advice on algorithm design and FlashInfer integration. This work was supported by generous gifts from Intel, Google, and the PRISM Research Center, a JUMP Center cosponsored by SRC and DARPA.