「从零开始学大模型」PagedAttention
概述
Paged Attention 要解决的,是 KV Cache 显存浪费的问题。
一般来说,为了放下 KV Cache,会预先 alloc 出一大块显存,但是 KV Cache 大学又是随着推理进行不断增长,是一个动态的逐渐增大的存储需求。此外,在多个序列同时进行推理时,各个序列的长度不同,对应的 KV Cache 也不一样,此时也不方便为不同序列配置不一样大小的显存。
因为上述种种问题,传统的 KV Cache 的实现是非常浪费显存的。Paged Attention 引入了 OS 的内存分页的思想,将显存也进行分页管理,极大的提升了显存利用率。
核心思想
核心思想其实对于学过 OS 内存管理的同学来说并不复杂:
- 不再用连续的一整块的显存,而是将 KV Cache 分为 block(页)存储
- 每一页存固定大小的,少量 token (比如16个)
- 维护 block table(页表),处理逻辑 block 到 物理 block 的映射
- 物理 block 则是按需申请、分配 这样,显存浪费只会发生在最后一个 block 的内部,可以更高效的利用好显存空间,整体浪费率 < 4%。
额外红利
共享相同页
由于有了“逻辑地址”到“物理地址”这一层抽象,两个不同的逻辑地址可以指向同一个物理地址。这样就可以实现序列片段的“浅拷贝”。
如果两个不同的序列有共同的 token,就可以只在真实显存中只存一份,但是让两个序列共享。而这在并行执行很多序列推理时,是很常见的。比如说,同一个 chat bot,它的 system prompt 在每一个 session 中都是一样的,只有到了后面的 user prompt 才会有差别。这样如果一张显卡上有很多个 session,他们就全部能共享这一部分的 system prompt 对应的 KV Cache 显存占用,能省很多呢!
为了为了维持这个特性,Paged Attention 会对每个页维护一个 RefCount 引用计数,像 C++ shared_ptr 一样,以防还有人要用呢就 free 了。然后在 vLLM 的具体实现里,不对旧的页面进行清理。也就是说就算引用计数归零了,在这个页被别人申请之前,它上面的内容是不会变的,还可以通过页号来访问。不过这也导致了程序必须假设显存是没有初始化、是“脏”的。
Copy-on-Write
这是继续前面“共享相同页”的一个后续发展。假设有两个序列在前半部分都是一样的,他们 KV Cache 都是一样的,于是前半部分一直都是”共享相同页”,但从某一个词开始,它们的序列就不一样了,就开始分叉成两个序列,KV Cache 不一样,就需要不同的页来存,不再能共享了。
在这个点,会触发 Copy-on-Write 写时复制机制。之所以是“写时复制”,因为其实前面的序列只有第一次算的那一条是在写,其他相同的序列实际上在 vLLM 的实现里面,会计算 prompt 的 hash 值,如果 hash 值匹配上了,只需要读对应块就行了(甚至不需要真的去计算 KV,只要算一下 hash 值)。但在这一刻, hash 值不同的第二个序列开始“写”了,于是从这里开始,它就有自己的块来写了,这就是写时复制。
据论文所说:使用这两个额外的分页红利,这类 Chatbot 场景下,显存减少 55%,吞吐提升最多 2.2x。
参考资料
vLLM 解剖学 这个真是好文章,推荐所有人认真看一下