diff --git a/example_trainer/scripts/benchmark_lora_vs_shared.py b/example_trainer/scripts/benchmark_lora_vs_shared.py index bedd6ed3..972b72dc 100755 --- a/example_trainer/scripts/benchmark_lora_vs_shared.py +++ b/example_trainer/scripts/benchmark_lora_vs_shared.py @@ -83,6 +83,7 @@ def start_vllm_server( gpu_id: int, mode: str = "base", # "base", "lora_eager", "lora_no_eager" max_lora_rank: int = 32, + max_model_len: int = 8192, log_file: str = "vllm.log", ) -> subprocess.Popen: """ @@ -105,8 +106,8 @@ def start_vllm_server( sys.executable, str(vllm_server_path), "--model", model, "--port", str(port), - "--gpu-memory-utilization", "0.45", - "--max-model-len", "8192", + "--gpu-memory-utilization", "0.70", # Higher for 32k context + "--max-model-len", str(max_model_len), "--dtype", "bfloat16", ] @@ -225,6 +226,8 @@ def main(): help="Port for vLLM server") parser.add_argument("--prompt", type=str, choices=["math", "long"], default="long", help="Which prompt to use") + parser.add_argument("--max-model-len", type=int, default=8192, + help="Maximum model context length (e.g., 8192, 32768)") parser.add_argument("--modes", type=str, default="all", help="Comma-separated modes to test: base,lora_eager,lora_no_eager or 'all'") args = parser.parse_args() @@ -262,6 +265,7 @@ def main(): log(f"Model: {args.model}") log(f"LoRA adapter: {args.lora_path or 'None'}") log(f"Max tokens: {args.max_tokens}") + log(f"Max model len: {args.max_model_len}") log(f"Num runs: {args.num_runs}") log(f"Modes to test: {modes_to_test}") log("=" * 70) @@ -277,7 +281,8 @@ def main(): # Start server current_proc = start_vllm_server( args.model, args.port, args.gpu, - mode=mode, log_file=f"benchmark_{mode}.log" + mode=mode, max_model_len=args.max_model_len, + log_file=f"benchmark_{mode}.log" ) # Wait for ready