6464from datetime import datetime
6565import gc
6666import json
67+ import os
6768import random
6869import time
6970from typing import Any, AsyncGenerator, Optional
70- import os
71-
7271
72+ from benchmarks.eval_accuracy import eval_accuracy
73+ from benchmarks.metrics import CounterMetric, EventMetric
7374import grpc
74- from benchmarks.metrics import EventMetric, CounterMetric
7575from jetstream.core.proto import jetstream_pb2
7676from jetstream.core.proto import jetstream_pb2_grpc
7777from jetstream.engine.token_utils import load_vocab
7878from jetstream.external_tokenizers.llama3 import llama3_tokenizer
7979import numpy as np
80- from tqdm.asyncio import tqdm # pytype: disable=pyi-error
8180import pandas
82-
83- from eval_accuracy import eval_accuracy
81+ from tqdm.asyncio import tqdm # pytype: disable=pyi-error
8482from transformers import AutoTokenizer
8583
8684
@@ -706,136 +704,7 @@ def sample_warmup_requests(requests):
706704 break
707705
708706
709- def main(args: argparse.Namespace):
710- print(args)
711- random.seed(args.seed)
712- np.random.seed(args.seed)
713-
714- model_id = args.model
715- tokenizer_id = args.tokenizer
716- use_hf_tokenizer = args.use_hf_tokenizer
717-
718- prefill_quota = AsyncCounter(init_value=3)
719- active_req_quota = AsyncCounter(init_value=450)
720-
721- api_url = f"{args.server}:{args.port}"
722-
723- tokenizer = get_tokenizer(model_id, tokenizer_id, use_hf_tokenizer)
724- if tokenizer == "test" or args.dataset == "test":
725- input_requests = mock_requests(
726- args.total_mock_requests
727- ) # e.g. [("AB", 2, "AB", 3)]
728- else:
729- dataset = []
730- if args.dataset == "openorca":
731- dataset = load_openorca_dataset_pkl(args.dataset_path)
732- elif args.dataset == "sharegpt":
733- dataset = load_sharegpt_dataset(
734- args.dataset_path,
735- args.conversation_starter,
736- )
737-
738- # A given args.max_output_length value is the max generation step,
739- # when the args.max_output_length is default to None, the sample's golden
740- # output length will be used to decide the generation step.
741- input_requests = sample_requests(
742- dataset=dataset,
743- tokenizer=tokenizer,
744- num_requests=args.num_prompts,
745- max_output_length=args.max_output_length,
746- )
747-
748- warmup_requests = None
749- if args.warmup_mode == "full":
750- warmup_requests = input_requests
751- elif args.warmup_mode == "sampled":
752- warmup_requests = list(sample_warmup_requests(input_requests)) * 2
753-
754- if warmup_requests:
755- print(f"Warmup (mode: {args.warmup_mode}) is starting.")
756- _, _ = asyncio.run(
757- benchmark(
758- api_url=api_url,
759- tokenizer=tokenizer,
760- input_requests=warmup_requests,
761- request_rate=args.request_rate,
762- disable_tqdm=args.disable_tqdm,
763- prefill_quota=prefill_quota,
764- active_req_quota=active_req_quota,
765- is_warmup=True,
766- )
767- )
768- print(f"Warmup (mode: {args.warmup_mode}) has completed.")
769-
770- # TODO: Replace this with warmup complete signal once supported.
771- # Wait for server completely warmup before running the benchmark.
772- time.sleep(5)
773-
774- benchmark_result, request_outputs = asyncio.run(
775- benchmark(
776- api_url=api_url,
777- tokenizer=tokenizer,
778- input_requests=input_requests,
779- request_rate=args.request_rate,
780- disable_tqdm=args.disable_tqdm,
781- prefill_quota=prefill_quota,
782- active_req_quota=active_req_quota,
783- )
784- )
785-
786- # Process output
787- output = [output.to_dict() for output in request_outputs]
788- if args.run_eval:
789- eval_json = eval_accuracy(output)
790-
791- # Save config and results to json
792- if args.save_result:
793- # dimensions values are strings
794- dimensions_json = {}
795- # metrics values are numerical
796- metrics_json = {}
797-
798- # Setup
799- current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
800- dimensions_json["date"] = current_dt
801- dimensions_json["model_id"] = model_id
802- dimensions_json["tokenizer_id"] = tokenizer_id
803- if args.additional_metadata_metrics_to_save is not None:
804- dimensions_json = {
805- **dimensions_json,
806- **json.loads(args.additional_metadata_metrics_to_save),
807- }
808- metrics_json["num_prompts"] = args.num_prompts
809-
810- # Traffic
811- metrics_json["request_rate"] = args.request_rate
812- metrics_json = {**metrics_json, **benchmark_result}
813- if args.run_eval:
814- metrics_json = {**metrics_json, **eval_json}
815-
816- final_json = {}
817- final_json["metrics"] = metrics_json
818- final_json["dimensions"] = dimensions_json
819-
820- # Save to file
821- base_model_id = model_id.split("/")[-1]
822- file_name = (
823- f"JetStream-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
824- )
825- with open(file_name, "w", encoding="utf-8") as outfile:
826- json.dump(final_json, outfile)
827-
828- if args.save_request_outputs:
829- file_path = args.request_outputs_file_path
830- with open(file_path, "w", encoding="utf-8") as output_file:
831- json.dump(
832- output,
833- output_file,
834- indent=4,
835- )
836-
837-
838- if __name__ == "__main__":
707+ def parse_args() -> argparse.Namespace:
839708 parser = argparse.ArgumentParser(
840709 description="Benchmark the online serving throughput."
841710 )
@@ -909,7 +778,6 @@ def main(args: argparse.Namespace):
909778 default=150,
910779 help="The maximum number of mock requests to send for benchmark testing.",
911780 )
912-
913781 parser.add_argument(
914782 "--max-output-length",
915783 type=int,
@@ -926,7 +794,6 @@ def main(args: argparse.Namespace):
926794 "the output length of the golden dataset would be passed."
927795 ),
928796 )
929-
930797 parser.add_argument("--seed", type=int, default=0)
931798 parser.add_argument(
932799 "--disable-tqdm",
@@ -977,7 +844,138 @@ def main(args: argparse.Namespace):
977844 choices=["human", "gpt", "both"],
978845 help="What entity should be the one starting the conversations.",
979846 )
847+ return parser.parse_args()
848+
849+
850+ def main(args: argparse.Namespace):
851+ print(args)
852+ random.seed(args.seed)
853+ np.random.seed(args.seed)
854+
855+ model_id = args.model
856+ tokenizer_id = args.tokenizer
857+ use_hf_tokenizer = args.use_hf_tokenizer
858+
859+ prefill_quota = AsyncCounter(init_value=3)
860+ active_req_quota = AsyncCounter(init_value=450)
861+
862+ api_url = f"{args.server}:{args.port}"
863+
864+ tokenizer = get_tokenizer(model_id, tokenizer_id, use_hf_tokenizer)
865+ if tokenizer == "test" or args.dataset == "test":
866+ input_requests = mock_requests(
867+ args.total_mock_requests
868+ ) # e.g. [("AB", 2, "AB", 3)]
869+ else:
870+ dataset = []
871+ if args.dataset == "openorca":
872+ dataset = load_openorca_dataset_pkl(args.dataset_path)
873+ elif args.dataset == "sharegpt":
874+ dataset = load_sharegpt_dataset(
875+ args.dataset_path,
876+ args.conversation_starter,
877+ )
878+
879+ # A given args.max_output_length value is the max generation step,
880+ # when the args.max_output_length is default to None, the sample's golden
881+ # output length will be used to decide the generation step.
882+ input_requests = sample_requests(
883+ dataset=dataset,
884+ tokenizer=tokenizer,
885+ num_requests=args.num_prompts,
886+ max_output_length=args.max_output_length,
887+ )
888+
889+ warmup_requests = None
890+ if args.warmup_mode == "full":
891+ warmup_requests = input_requests
892+ elif args.warmup_mode == "sampled":
893+ warmup_requests = list(sample_warmup_requests(input_requests)) * 2
894+
895+ if warmup_requests:
896+ print(f"Warmup (mode: {args.warmup_mode}) is starting.")
897+ _, _ = asyncio.run(
898+ benchmark(
899+ api_url=api_url,
900+ tokenizer=tokenizer,
901+ input_requests=warmup_requests,
902+ request_rate=args.request_rate,
903+ disable_tqdm=args.disable_tqdm,
904+ prefill_quota=prefill_quota,
905+ active_req_quota=active_req_quota,
906+ is_warmup=True,
907+ )
908+ )
909+ print(f"Warmup (mode: {args.warmup_mode}) has completed.")
910+
911+ # TODO: Replace this with warmup complete signal once supported.
912+ # Wait for server completely warmup before running the benchmark.
913+ time.sleep(5)
914+
915+ benchmark_result, request_outputs = asyncio.run(
916+ benchmark(
917+ api_url=api_url,
918+ tokenizer=tokenizer,
919+ input_requests=input_requests,
920+ request_rate=args.request_rate,
921+ disable_tqdm=args.disable_tqdm,
922+ prefill_quota=prefill_quota,
923+ active_req_quota=active_req_quota,
924+ )
925+ )
926+
927+ # Process output
928+ output = [output.to_dict() for output in request_outputs]
929+ if args.run_eval:
930+ eval_json = eval_accuracy(output)
931+
932+ # Save config and results to json
933+ if args.save_result:
934+ # dimensions values are strings
935+ dimensions_json = {}
936+ # metrics values are numerical
937+ metrics_json = {}
980938
981- parsed_args = parser.parse_args()
939+ # Setup
940+ current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
941+ dimensions_json["date"] = current_dt
942+ dimensions_json["model_id"] = model_id
943+ dimensions_json["tokenizer_id"] = tokenizer_id
944+ if args.additional_metadata_metrics_to_save is not None:
945+ dimensions_json = {
946+ **dimensions_json,
947+ **json.loads(args.additional_metadata_metrics_to_save),
948+ }
949+ metrics_json["num_prompts"] = args.num_prompts
950+
951+ # Traffic
952+ metrics_json["request_rate"] = args.request_rate
953+ metrics_json = {**metrics_json, **benchmark_result}
954+ if args.run_eval:
955+ metrics_json = {**metrics_json, **eval_json}
956+
957+ final_json = {}
958+ final_json["metrics"] = metrics_json
959+ final_json["dimensions"] = dimensions_json
960+
961+ # Save to file
962+ base_model_id = model_id.split("/")[-1]
963+ file_name = (
964+ f"JetStream-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
965+ )
966+ with open(file_name, "w", encoding="utf-8") as outfile:
967+ json.dump(final_json, outfile)
968+
969+ if args.save_request_outputs:
970+ file_path = args.request_outputs_file_path
971+ with open(file_path, "w", encoding="utf-8") as output_file:
972+ json.dump(
973+ output,
974+ output_file,
975+ indent=4,
976+ )
977+
978+
979+ if __name__ == "__main__":
982980 gc.disable()
983- main(parsed_args )
981+ main(parse_args() )
0 commit comments