forked from foundation-model-stack/foundation-model-stack
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbenchmark_inference.py
More file actions
373 lines (327 loc) · 11.5 KB
/
benchmark_inference.py
File metadata and controls
373 lines (327 loc) · 11.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
import argparse
import logging
import os
import random
import statistics
import timeit
import numpy as np
import torch
from torch import distributed as dist
from torch._dynamo import OptimizedModule
from fms import models
from fms.utils import fusion, generation, print0, tokenizers
# Example running llama 7B on one A100:
#
# (bare metal) $ CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 ./scripts/benchmark_inference.py --architecture=llama --variant=7b --tokenizer=~/models/tokenizer.model --batch_size=2 --seq_len=500
# (slurm) $ srun -N 1 --gres=gpu:1 torchrun --nproc_per_node=1 ./scripts/benchmark_inference.py --architecture=llama --variant=7b --tokenizer=~/models/tokenizer.model --batch_size=2 --seq_len=500
# loading model
# loading complete on rank 0
# Uncompiled results:
# - with use_cache=True
# 34.86 ms per token
# - with use_cache=False
# 86.39 ms per token
# End-to-end sequence generation
# - with use_cache=True
# 37.04 ms per token
# - with use_cache=False
# 90.68 ms per token
# Compiling model...
# Compiled results:
# - with use_cache=True
# 18.66 ms per token
# - with use_cache=False
# 67.66 ms per token
# (Compiled) End-to-end sequence generation
# - with use_cache=True
# 20.61 ms per token
# - with use_cache=False
# 71.45 ms per token
parser = argparse.ArgumentParser(
description="Script to benchmark inference time per token on a LLaMA model"
)
parser.add_argument("--device_type", type=str, default="cuda")
parser.add_argument(
"--architecture",
type=str,
default="llama",
help="The model architecture to benchmark",
)
parser.add_argument(
"--variant",
type=str,
default="7b",
help="The model variant (configuration) to benchmark. E.g. 7b, 13b, 70b.",
)
parser.add_argument(
"--tokenizer",
type=str,
required=True,
help="Path to the tokenizer (e.g. ~/tokenizer.model)",
)
parser.add_argument(
"--seq_len",
type=int,
default=512,
help="Sequence length of mock input",
)
parser.add_argument(
"--batch_size",
type=int,
default=2,
help="Batch size of mock input",
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=256,
help="Max number of tokens to generate",
)
parser.add_argument(
"--compile_mode",
type=str,
help="Mode for compilation",
default="default",
choices=["default", "reduce-overhead"],
)
parser.add_argument(
"--deterministic",
action="store_true",
help="Set seeds and torch.use_deterministic_algorithms? Requires env variable `CUBLAS_WORKSPACE_CONFIG=:4096:8`",
)
parser.add_argument(
"--distributed",
action="store_true",
help="This is a distributed job (multiple instances run with RANK+WORLD_SIZE)",
)
parser.add_argument(
"--skip_correctness_check",
action="store_true",
help="Do not test correctness of outputs vs just timing",
)
parser.add_argument(
"--skip_eager_runs", action="store_true", help="Do not run the eager benchmarks"
)
parser.add_argument(
"--skip_compile_runs",
action="store_true",
help="Do not run the compiled benchmarks",
)
parser.add_argument(
"--skip_kvcache_runs",
action="store_true",
help="Do not run the kv-cache benchmarks",
)
parser.add_argument(
"--skip_nokvcache_runs",
action="store_true",
help="Do not run the no kv-cache benchmarks",
)
parser.add_argument(
"--skip_single_token_runs",
action="store_true",
help="Do not run the single token benchmarks",
)
parser.add_argument(
"--skip_e2e_runs", action="store_true", help="Do not run the e2e benchmarks"
)
parser.add_argument(
"--unfuse_weights",
action="store_true",
help="If set to True, this will unfuse any fused weight modules that support the unfuse_weights method",
)
parser.add_argument(
"--quant_dtype",
type=str,
help="enables quantization to the specified dtype",
default="",
choices=["", "int8", "int4-fake"],
)
parser.add_argument(
"--rotate",
action="store_true",
)
args = parser.parse_args()
local_rank = int(os.getenv("LOCAL_RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
if args.device_type == "cuda":
device = torch.device(args.device_type, local_rank)
torch.cuda.set_device(device)
else:
device = torch.device(args.device_type)
torch.set_default_dtype(torch.half)
# requires setting environment variable: `CUBLAS_WORKSPACE_CONFIG=:4096:8`
if args.deterministic:
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED) # pytorch random seed
np.random.seed(SEED) # numpy random seed
torch.use_deterministic_algorithms(True)
if world_size > 1:
dist.init_process_group()
# Fix until PT 2.3
torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
print("loading model")
model = models.get_model(args.architecture, args.variant, device_type=args.device_type, quant_dtype=args.quant_dtype, rotate=args.rotate)
if args.unfuse_weights:
print("unfusing weights")
model = fusion.apply_unfuse_weights(model)
tokenizer = tokenizers.get_tokenizer(args.tokenizer)
model.eval()
torch.set_grad_enabled(False)
print(f"loading complete on rank {local_rank}")
SEQ_LEN = args.seq_len
BATCH_SIZE = args.batch_size
MAX_NEW_TOKENS = args.max_new_tokens
ids = torch.randint(
tokenizer.vocab_size(), (BATCH_SIZE, SEQ_LEN), device=device, dtype=torch.long
)
# This first forward call generates the cache for use in cases where
# `use_cache=True`.
#
# For performance purposes, this call can be considered equivalent to
# `use_cache=False`.
#
# The actual performance of generation with `use_cache=True` would be the cost
# of the first token without cache, plus the cost of all subsequent tokens with
# cache. I.e. the amortized per-token cost would depend on the number of tokens
# generated.
logits, cache = model.forward(ids, use_cache=True)
logits = logits[:, -1, :]
next_val = torch.argmax(logits, dim=-1).unsqueeze(0).t()
next_input = torch.cat((ids, next_val), dim=-1)
# not still needed
del logits
expected, _ = model.forward(
next_val, past_key_value_states=cache, use_cache=True, only_last_token=True
)
expected = torch.argmax(expected, dim=-1)
expected2 = model.forward(next_input, only_last_token=True)
expected2 = torch.argmax(expected2, dim=-1)
torch.testing.assert_close(expected, expected2)
repeat = 3
# The function we're measuring, with or without caching.
#
# In a realistic generate function, the sequence length would grow with each
# subsequent token, and so the average cost would be from a variety of sequence
# lengths.
# We capture the time to generate a single token from a given sequence length
# and batch size. This means we're measuring the cost of the forward pass
# in isolation in a way that's easier to compare, and avoids including the cost
# of the concatenation operation.
def one_token(model, use_cache):
if use_cache:
actual, _ = model.forward(
next_val, past_key_value_states=cache, use_cache=True, only_last_token=True
)
else:
actual = model.forward(next_input, only_last_token=True)
actual = torch.argmax(actual, dim=-1)
if local_rank == 0 and not args.skip_correctness_check:
torch.testing.assert_close(actual, expected)
else:
if args.device_type == "cuda":
torch.cuda.synchronize()
def end_to_end(model, use_cache, expected=None):
result = generation.generate(
model,
ids,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
use_cache=use_cache,
contiguous_cache=args.compile_mode == "reduce-overhead"
and isinstance(
model, OptimizedModule
), # this is needed for reduce-overhead to work correctly for now
)
if local_rank == 0:
assert (
result.size()[-1] == SEQ_LEN + MAX_NEW_TOKENS
), f"{result.size()}, {SEQ_LEN}, {MAX_NEW_TOKENS}"
if expected is not None and not args.skip_correctness_check:
torch.testing.assert_close(result, expected)
else:
if args.device_type == "cuda":
torch.cuda.synchronize()
return result
e2e_expected_cache = end_to_end(model, True)
e2e_expected_nocache = end_to_end(model, True)
def log_result(result):
if local_rank == 0:
median = statistics.median(result)
per_token = median / MAX_NEW_TOKENS
ms = per_token * 1000
print(f"\t{ms:0.2f} ms per token")
def bench_one(use_cache):
print0(f"- with use_cache={use_cache}")
log_result(
timeit.repeat(
lambda: one_token(model, use_cache), number=MAX_NEW_TOKENS, repeat=repeat
)
)
def bench_end_to_end(use_cache, expected):
print0(f"- with use_cache={use_cache}")
result = timeit.repeat(
lambda: end_to_end(model, use_cache, expected), number=1, repeat=repeat
)
log_result(result)
print0(
f"Results for batch size {BATCH_SIZE}, sequence length {SEQ_LEN}, new tokens generated {MAX_NEW_TOKENS}"
)
if not args.skip_eager_runs:
print0("Uncompiled results:")
print0("==========")
if not args.skip_single_token_runs:
print0("Single token generation")
if not args.skip_kvcache_runs:
bench_one(True)
if not args.skip_nokvcache_runs:
bench_one(False)
if not args.skip_e2e_runs:
print0("End-to-end sequence generation")
if not args.skip_kvcache_runs:
bench_end_to_end(True, e2e_expected_cache)
if not args.skip_nokvcache_runs:
bench_end_to_end(False, e2e_expected_nocache)
if not args.skip_compile_runs:
print0("Compiling model...")
# This is to prevent a bug in PT 2.1 that has been fixed in PT 2.2 nightlies
torch._inductor.config.joint_graph_constant_folding = False
# with mode='reduce-overhead' we see better performance but on multi-GPU models
# hit an error on the end-to-end test below when run after other tests (if it's
# run first it works, confirming a memory leak):
# `RuntimeError: Expected curr_block->ptr == block_state.ptr to be true, but got false.`
model = torch.compile(model, dynamic=True, mode=args.compile_mode)
print0()
print0("Compiled results:")
print0("==========")
if not args.skip_single_token_runs:
# Warmup. Especially with torch.compile, first inference pass can be slow.
print(f"Warming up the compiled model for single token in rank {local_rank}")
# Activate dynamo logs to ensure some output during compilation
torch._logging.set_logs(dynamo=logging.INFO)
if not args.skip_kvcache_runs:
one_token(model, True)
if not args.skip_nokvcache_runs:
one_token(model, False)
print(f"Model has warmed up in rank {local_rank}")
# These get much better results with mode='reduce-overhead' but can lead to
# some memory issues
print0("(Compiled) Single token generation")
if not args.skip_kvcache_runs:
bench_one(True)
if not args.skip_nokvcache_runs:
bench_one(False)
if not args.skip_e2e_runs:
print0()
print(f"Warming up the compiled model e2e in rank {local_rank}")
if not args.skip_kvcache_runs:
end_to_end(model, True, e2e_expected_cache)
if not args.skip_nokvcache_runs:
end_to_end(model, False, e2e_expected_nocache)
print(f"Model has warmed up e2e in rank {local_rank}")
print0("(Compiled) End-to-end sequence generation")
if not args.skip_kvcache_runs:
bench_end_to_end(True, e2e_expected_cache)
if not args.skip_nokvcache_runs:
bench_end_to_end(False, e2e_expected_nocache)