|
| 1 | +#--------------------------------------------- |
| 2 | +# Part 1: system basic config setting |
| 3 | +distributed: True |
| 4 | +device: Ascend |
| 5 | +mode: 0 # 0: graph, 1: pynative |
| 6 | +work_root: &work_root ./work_dir/ |
| 7 | +log_level: info |
| 8 | +amp_level: O0 |
| 9 | + |
| 10 | +# --------------------------------------------- |
| 11 | +# Part2: module setting |
| 12 | +loss_manager: |
| 13 | +# type: fixed # dynamic or |
| 14 | +# scale_sense: 1024 |
| 15 | + loss_scaler: |
| 16 | + type: dynamic |
| 17 | + grad_clip: False |
| 18 | + |
| 19 | +optimizer: |
| 20 | + type: segment_anything.optim.optimizer.AdamW |
| 21 | + weight_decay: 1e-4 |
| 22 | + group_param: |
| 23 | + |
| 24 | + lr_scheduler: |
| 25 | + type: segment_anything.optim.scheduler.SAMDynamicDecayLR |
| 26 | + learning_rate: 8e-6 |
| 27 | + warmup_steps: 250 |
| 28 | + decay_steps: [ 60000, 86666 ] |
| 29 | + decay_factor: 10 |
| 30 | + |
| 31 | + |
| 32 | +network: |
| 33 | + model: |
| 34 | + type: vit_h |
| 35 | + checkpoint: /home/ma-user/modelarts/user-job-dir/segment-anything/models/sam_vit_h-c72f8ba1.ckpt |
| 36 | + enable_text_encoder: True # do not build text encoder to improve performance if set False |
| 37 | + text_encoder: |
| 38 | + type: blip2_stage1_classification |
| 39 | + freeze: |
| 40 | + image_encoder: True |
| 41 | + prompt_encoder: |
| 42 | + filter_prefix: text_ # NOTE: text position embedding won't freeze |
| 43 | + text_encoder: True # text encoder is always frozen, freeze grad. |
| 44 | + |
| 45 | + loss: |
| 46 | + type: segment_anything.modeling.loss.SAMLoss |
| 47 | + |
| 48 | + |
| 49 | +train_loader: |
| 50 | + dataset: |
| 51 | + type: segment_anything.dataset.dataset.SA1BDataset |
| 52 | + data_dir: ./datasets/sa-1b/ |
| 53 | + transform_pipeline: |
| 54 | + # extrack image patch before image resize and pad |
| 55 | + - type: segment_anything.dataset.transform.ImagePatchFromBoxMask |
| 56 | + - type: segment_anything.dataset.transform.ImagePatchPreprocess # norm, resize and to chw |
| 57 | + model: blip2_stage1_classification |
| 58 | + - type: segment_anything.dataset.transform.ImageResizeAndPad |
| 59 | + target_size: 1024 |
| 60 | + - type: segment_anything.dataset.transform.ImageNorm # norm and to chw |
| 61 | + hwc2chw: True |
| 62 | + - type: segment_anything.dataset.transform.LabelPad |
| 63 | + gt_size: 20 |
| 64 | + output_column: ['image', 'masks', 'image_patches', 'valid_boxes'] |
| 65 | + |
| 66 | + model_column: ['image', 'image_patches'] # columns for model cell input |
| 67 | + loss_column: ['masks', 'valid_boxes'] # columns for loss function input |
| 68 | + |
| 69 | + shuffle: True |
| 70 | + batch_size: 1 |
| 71 | + epoch_size: 20 |
| 72 | + drop_remainder: True |
| 73 | + num_workers: 2 |
| 74 | + max_rowsize: 64 # 24M space for dataloader |
| 75 | + |
| 76 | + |
| 77 | +callback: |
| 78 | + - type: segment_anything.utils.callbacks.TrainStatusLog |
| 79 | + loss_item: ['focal_loss', 'dice_loss', 'mse_loss'] # for log |
| 80 | + interval: 100 |
| 81 | + - type: segment_anything.utils.callbacks.SaveCkpt |
| 82 | + work_root: *work_root |
| 83 | + interval: 1 # in epoch |
| 84 | + - type: segment_anything.utils.callbacks.FreezeChildCell # freeze training behavior like dropout and bn |
| 85 | + child_cells: ['text_encoder'] |
0 commit comments