Paper
Document
Submit new version
Download
Flag content
15

LLM in a flash: Efficient Large Language Model Inference with Limited Memory

Authors
Keivan Alizadeh,Iman Mirzadeh
Dmitry Belenko,Karen Khatamifard,Minsik Cho,Carlo Mundo,Mohammad Rastegari
+5 authors
,Mehrdad Farajtabar
Published
Dec 12, 2023
Posted by
Save
TipTip
Document
Submit new version
Download
Flag content
15
TipTip
Save
Document
Submit new version
Download
Flag content
LLM in a flash:
Efficient Large Language Model Inference with Limited Memory
Keivan Alizadeh, Iman Mirzadeh , Dmitry Belenko , Karen Khatamifard,
Minsik Cho, Carlo C Del Mundo, Mohammad Rastegari, Mehrdad Farajtabar§
Apple
Abstract
Large language models (LLMs) are central to
modern natural language processing, delivering
exceptional performance in various tasks. How-
ever, their intensive computational and memory
requirements present challenges, especially
for devices with limited DRAM capacity. This
paper tackles the challenge of efficiently run-
ning LLMs that exceed the available DRAM
capacity by storing the model parameters on
flash memory but bringing them on demand to
DRAM. Our method involves constructing an
inference cost model that harmonizes with the
flash memory behavior, guiding us to optimize
in two critical areas: reducing the volume of
data transferred from flash and reading data
in larger, more contiguous chunks. Within
this flash memory-informed framework, we
introduce two principal techniques. First,
“windowing” strategically reduces data transfer
by reusing previously activated neurons, and
second, “row-column bundling”, tailored to
the sequential data access strengths of flash
memory, increases the size of data chunks
read from flash memory. These methods
collectively enable running models up to
twice the size of the available DRAM, with a
4-5x and 20-25x increase in inference speed
compared to naive loading approaches in CPU
and GPU, respectively. Our integration of
sparsity awareness, context-adaptive loading,
and a hardware-oriented design paves the way
for effective inference of LLMs on devices
with limited memory.
1 Introduction
In recent years, large language models (LLMs),
such as GPT-3 (Brown et al., 2020), OPT (Zhang
et al., 2022b), and PaLM (Chowdhery et al., 2022),
have demonstrated strong performance across a
Primary Author: kalizadehvahid@apple.com
Major Contribution: imirzadeh@apple.com
Major Contribution: d_belenko@apple.com
§Senior Author: farajtabar@apple.com
Naive
Falcon 7B
(CPU)
Ours Naive
OPT 6.7B
(CPU)
Ours Naive
OPT6.7B
(GPU)
Ours
100
450
700
2250
3100
Inference Latency (ms)
Compute Load From Flash Memory Management
Figure 1: Inference latency of 1 token when half the
memory of the model is available.
wide range of natural language tasks. However, the
unprecedented capabilities of these models come
with substantial computational and memory re-
quirements for inference. LLMs can contain hun-
dreds of billions or even trillions of parameters,
making it challenging to load and run them effi-
ciently, especially on resource-constrained devices.
Currently, the standard approach is to load the en-
tire model into DRAM for inference (Rajbhandari
et al., 2021; Aminabadi et al., 2022). However, this
severely limits the maximum model size that can
be run. For example, a 7 billion parameter model
requires over 14GB of memory just to load the
parameters in half-precision floating point format,
exceeding the capabilities of most edge devices.
To address this limitation, we propose to store
the model parameters on flash memory, which is
at least an order of magnitude larger than DRAM.
Then, during inference, we directly and cleverly
load the required parameters from the flash mem-
ory, avoiding the need to fit the entire model in
DRAM. Our methodology is built on the top of
recent works that have shown LLMs exhibit a high
degree of sparsity in the FeedForward Network
(FFN) layers, with models like OPT (Zhang et al.,
1
arXiv:2312.11514v1 [cs.CL] 12 Dec 2023
DRAM
Flash Memory
100 GB
10 GB
CPUGPU
~ 1 GB/s
~10 GB/s
(a) Bandwidth in a unified memory architecture
4 8 16 32 64
Chunk Size (KB)
0
1000
2000
3000
4000
5000
6000
Random Read Throughput (MB/s)
Upper Bound (Sequential Read)
Threads
32
16
8
4
2
(b) Random read throughput of flash memory
Figure 2: (a) Flash memory offers significantly higher capacity but suffers from much lower bandwidth compared
to DRAM and CPU/GPU caches and registers. (b) The throughput for random reads in flash memory increases with
the size of sequential chunks and the number of threads.
2022b), Falcon (Almazrouei et al., 2023), exhibit-
ing more than 90% sparsity (Mirzadeh et al., 2023;
Liu et al., 2023b). We exploit this sparsity to se-
lectively load only parameters from flash memory
that either have non-zero input or are predicted to
have non-zero output. Specifically, we discuss a
hardware-inspired cost model that includes flash
memory, DRAM, and computing cores (CPU or
GPU). Then, we introduce two complementary
techniques to minimize data transfer and maximize
flash memory throughput:
Windowing: We load parameters for only the
past few tokens, reusing activations from re-
cently computed tokens. This sliding window
approach reduces the number of IO requests to
load weights.
Row-column bundling: We store a concate-
nated row and column of the up-projection and
down-projection layers to read bigger contigu-
ous chunks from flash memory. This increases
throughput by reading larger chunks.
To further minimize the number of weights to be
transferred from flash memory to DRAM, we also
employ methods to predict FFN sparsity and avoid
loading zeroed-out parameters, akin to approaches
documented in Deja Vu (Li and Lu, 2023). To-
gether, windowing and sparsity prediction allow
us to load only 2% of the FFN layer from flash
for each inference query. We also propose a static
memory preallocation to minimize transfers within
DRAM and reduce inference latency. Our load
from flash cost model captures the tradeoff between
loading less data and reading bigger chunks. Op-
timizing this cost model and selectively loading
parameters on demand yields flash loading strate-
gies that can run models 2x larger than the device’s
DRAM capacity and speed up inference by 4-5x
and 20-25x compared to naive implementation in
CPU and GPU, respectively.
2 Flash Memory & LLM Inference
In this section, we explore the characteristics of
memory storage systems (e.g., flash, DRAM), and
their implications for large language model (LLM)
inference. Our aim is to elucidate the challenges
and hardware-specific considerations essential for
algorithm design, particularly in optimizing infer-
ence when working with flash memory.
2.1 Bandwidth and Energy Constraints
While a modern NAND flash memories offers
high bandwidth and low latency, it falls short
of the performance levels of DRAM (Dynamic
Random-Access Memory), especially in memory-
constrained systems. Figure 2a illustrates these
differences. A naive inference implementation that
relies on NAND flash memory might necessitate
reloading the entire model for each forward pass.
This process is not only time-consuming, often tak-
ing seconds for even compressed models, but it also
consumes more energy than transferring data from
DRAM to the CPU or GPU’s internal memory.
In scenarios where DRAM is abundant, the cost
of loading data is somewhat mitigated, as the model
can reside in DRAM. However, the initial loading
of the model still incurs a penalty, particularly in sit-
uations requiring rapid response times for the first
token. Our approach, leveraging activation sparsity
in LLMs, addresses these challenges by enabling
selective reading of model weights, thereby reduc-
ing both time and power costs.
2
2.2 Read Throughput
Flash memory systems perform optimally with
large sequential reads. For instance, benchmarks
on an Apple MacBook Pro M2 with 2TB flash
demonstrate speeds exceeding 6GiB/s for a 1GiB
linear read of an uncached file. However, this high
bandwidth is not replicated for smaller, random
reads due to the inherent multi-phase nature of
these reads, encompassing the operating system,
drivers, interrupt handling, and the flash controller,
among others. Each phase introduces latency, dis-
proportionately affecting smaller reads.
To circumvent these limitations, we advocate
two primary strategies, which can be employed
jointly. The first involves reading larger chunks
of data. Although throughput growth is not linear
(larger chunks take longer to transfer), the latency
for the initial byte becomes a smaller fraction of the
total request time, resulting in more efficient data
reading. This principle is depicted in Figure 2b.
Perhaps a counterintuitive yet interesting obser-
vation is that in some scenarios, it will be faster
to read more than needed (but in larger chunks)
and then discard, than only reading necessary parts
but in smaller chunks. The second strategy lever-
ages parallelized reads, utilizing the inherent paral-
lelism within storage stacks and flash controllers.
Our results indicate that throughputs appropriate
for sparse LLM inference are achievable on stan-
dard hardware using 32KiB or larger random reads
across multiple threads.
Crucial to maximizing throughput is the way
weights are stored, as a layout that enhances the
average chunk length can significantly boost band-
width. In some cases, it might be beneficial to read
and subsequently discard excess data, rather than
splitting the data into smaller, less efficient chunks.
Motivated by the challenges described in this sec-
tion, in section 3, we propose methods to optimize
data transfer volume and enhance read throughput
to significantly enhance inference speeds.
3 Load From Flash
This section addresses the challenge of conducting
inference on devices where the available compu-
tational memory is substantially smaller than the
size of the model. This necessitates storing the full
model weights in flash memory. Our primary met-
ric for evaluating various flash loading strategies is
latency, dissected into three distinct components:
the I/O cost of loading from flash, the overhead of
managing memory with newly loaded data, and the
compute cost for inference operations.
Our proposed solutions for reducing latency un-
der memory constraints are categorized into three
strategic areas, each targeting a specific aspect of
the latency:
Reducing Data Load: Aiming to decrease la-
tency associated with flash I/O operations by
loading less data1.
Optimizing Data Chunk Size: Enhancing flash
throughput by increasing the size of data chunks
loaded, thereby mitigating latency.
Efficient Management of Loaded Data:
Streamlining the management of data once it is
loaded into memory to minimize overhead.
It is important to note that our focus is not on the
compute aspect of the process, as it is orthogonal to
the core concerns of our work. This delineation al-
lows us to concentrate on optimizing flash memory
interactions and memory management to achieve
efficient inference on memory-constrained devices.
Finally, we will elaborate on the implementation
of these strategies in subsequent sections.
3.1 Reducing Data Transfer
Our methodology leverages the inherent sparsity
found in Feed-Forward Network (FFN) models, as
documented in preceding research. The OPT 6.7B
model, for instance, exhibits a notable 97% spar-
sity within its FFN layer. Similarly, the Falcon
7B model has been adapted through fine-tuning,
which involves swapping their activation functions
to ReLU, resulting in 95% sparsity while being al-
most similar in accuracy (Mirzadeh et al., 2023). In
light of this information, our approach involves the
iterative transfer of only the essential, non-sparse
data from flash memory to DRAM for processing
during inference.
It’s notable that, we employ the 7B models as
a practical example to elucidate our approach, but
our findings are adaptable and can be extrapolated
to both larger and smaller scale models with ease.
Selective Persistence Strategy. We opt to re-
tain the embeddings and matrices within the at-
tention mechanism of the transformer constantly
1It is notable that, by data we mean weights of the neural
network. However, our developed techniques can be eas-
ily generalized to other data types transferred and used for
LLM inference, such as activations or KV cache, as suggested
by (Sheng et al., 2023).
3
100%