@@ -2440,121 +2440,6 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem):
2440
2440
f'used_mem:{ format_bytes (total_mem )} ' )
2441
2441
logger .info (msg )
2442
2442
2443
- < << << << HEAD
2444
- == == == =
2445
- def warmup_scenario (self ,
2446
- batch_size ,
2447
- seq_or_block ,
2448
- num_blocks ,
2449
- is_prompt ,
2450
- kv_caches ,
2451
- num_iters = 3 ,
2452
- is_pt_profiler_run = True ,
2453
- align_worker = False ,
2454
- is_dummy_run = False ) -> None :
2455
- """Dummy warmup run for memory usage and graph compilation."""
2456
-
2457
- query_seq_len = seq_or_block if is_prompt else 1
2458
- input_ids = torch .zeros ((batch_size , query_seq_len ),
2459
- dtype = torch .int32 ,
2460
- device = 'cpu' )
2461
- position_ids = torch .zeros ((batch_size , query_seq_len ),
2462
- dtype = torch .int32 ,
2463
- device = 'cpu' )
2464
- slot_mapping = torch .zeros ((batch_size , query_seq_len ),
2465
- dtype = torch .int64 ,
2466
- device = 'cpu' )
2467
-
2468
- input_ids_device = _async_h2d_tensor_copy (input_ids , self .device )
2469
- position_ids_device = _async_h2d_tensor_copy (position_ids , self .device )
2470
- slot_mapping_device = _async_h2d_tensor_copy (slot_mapping , self .device )
2471
-
2472
- use_graphs = is_dummy_run or self ._use_graphs ()
2473
- phase = "prompt" if is_prompt else "decode"
2474
- scenario_name = ("warmup_"
2475
- f"{ phase } _"
2476
- f"bs{ batch_size } _"
2477
- f"seq{ query_seq_len } _"
2478
- f"ctx{ num_blocks } _"
2479
- f"graphs{ 'T' if use_graphs else 'F' } " )
2480
- input_ids = torch .zeros ((batch_size , query_seq_len ),
2481
- dtype = torch .int32 ,
2482
- device = 'cpu' )
2483
- position_ids = torch .zeros ((batch_size , query_seq_len ),
2484
- dtype = torch .int32 ,
2485
- device = 'cpu' )
2486
- slot_mapping = torch .zeros ((batch_size , query_seq_len ),
2487
- dtype = torch .int64 ,
2488
- device = 'cpu' )
2489
-
2490
- input_ids_device = _async_h2d_tensor_copy (input_ids , self .device )
2491
- position_ids_device = _async_h2d_tensor_copy (position_ids , self .device )
2492
- slot_mapping_device = _async_h2d_tensor_copy (slot_mapping , self .device )
2493
- self .profiler .start ('internal' , scenario_name )
2494
-
2495
- times = num_iters if use_graphs or is_pt_profiler_run else 1
2496
- for time_index in range (times ):
2497
- if is_prompt :
2498
- seq_lens = torch .zeros ((batch_size ),
2499
- dtype = torch .int32 ,
2500
- device = 'cpu' )
2501
- seq_lens .fill_ (seq_or_block )
2502
- seq_lens_device = _async_h2d_tensor_copy (seq_lens , self .device )
2503
- block_list_device = None
2504
- if num_blocks :
2505
- prefix_block_tables = torch .ones (
2506
- (batch_size , num_blocks ),
2507
- dtype = torch .int32 ,
2508
- device = 'cpu' ) * self ._PAD_BLOCK_ID
2509
- block_list_device = _async_h2d_tensor_copy (
2510
- prefix_block_tables .flatten (), self .device )
2511
- attn_metadata = \
2512
- HPUAttentionMetadataV1 .make_prefill_metadata (
2513
- attn_bias = None ,
2514
- seq_lens_tensor = seq_lens_device ,
2515
- context_lens_tensor = seq_lens_device ,
2516
- slot_mapping = slot_mapping_device ,
2517
- block_list = block_list_device ,
2518
- block_size = self .block_size )
2519
- else :
2520
- block_tables = [
2521
- x .tolist ()
2522
- for x in np .array_split (np .arange (num_blocks ), batch_size )
2523
- ]
2524
- block_list , block_groups , block_usage = \
2525
- self .get_habana_paged_attn_buffers (
2526
- slot_mapping = slot_mapping ,
2527
- block_tables = block_tables ,
2528
- batch_size = batch_size )
2529
- block_list_device = _async_h2d_tensor_copy (
2530
- block_list , self .device )
2531
- block_usage_device = _async_h2d_tensor_copy (
2532
- block_usage , self .device )
2533
- block_groups_device = _async_h2d_tensor_copy (
2534
- block_groups , self .device )
2535
- attn_metadata = HPUAttentionMetadataV1 .make_decode_metadata (
2536
- block_list = block_list_device ,
2537
- block_usage = block_usage_device ,
2538
- block_groups = block_groups_device ,
2539
- num_decode_tokens = batch_size ,
2540
- input_positions = None ,
2541
- slot_mapping = slot_mapping_device ,
2542
- block_size = self .block_size )
2543
-
2544
- logits_indices = torch .arange (0 , batch_size , device = 'cpu' )
2545
- logits_indices_device = _async_h2d_tensor_copy (logits_indices ,
2546
- self .device )
2547
- # Dummy run.
2548
- htorch .core .mark_step ()
2549
- _ = self ._execute_model_generic (input_ids_device , position_ids_device ,
2550
- attn_metadata , logits_indices_device ,
2551
- kv_caches , True )
2552
- # TODO: do sampling on logits, warmup sampler and prefill joiner
2553
- htorch .core .mark_step ()
2554
- self .profiler .end ()
2555
- return None
2556
-
2557
- >> >> >> > 68 ee934 (fix )
2558
2443
def log_warmup (self , phase , i , max_i , batch_size , seq_len , num_blocks ):
2559
2444
free_mem = format_bytes (
2560
2445
HabanaMemoryProfiler .current_free_device_memory ())
0 commit comments