Skip to content

Commit 8099334

Browse files
author
Mark-ZhouWX
committed
add cloud training config
1 parent 240b068 commit 8099334

File tree

2 files changed

+171
-0
lines changed

2 files changed

+171
-0
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#---------------------------------------------
2+
# Part 1: system basic config setting
3+
distributed: True
4+
device: Ascend
5+
mode: 0 # 0: graph, 1: pynative
6+
work_root: &work_root ./work_dir/
7+
log_level: info
8+
amp_level: O0
9+
10+
# ---------------------------------------------
11+
# Part2: module setting
12+
loss_manager:
13+
# type: fixed # dynamic or
14+
# scale_sense: 1024
15+
loss_scaler:
16+
type: dynamic
17+
grad_clip: False
18+
19+
optimizer:
20+
type: segment_anything.optim.optimizer.AdamW
21+
weight_decay: 1e-4
22+
group_param:
23+
24+
lr_scheduler:
25+
type: segment_anything.optim.scheduler.SAMDynamicDecayLR
26+
learning_rate: 8e-6
27+
warmup_steps: 250
28+
decay_steps: [ 60000, 86666 ]
29+
decay_factor: 10
30+
31+
32+
network:
33+
model:
34+
type: vit_h
35+
checkpoint: /home/ma-user/modelarts/user-job-dir/segment-anything/models/sam_vit_h-c72f8ba1.ckpt
36+
enable_text_encoder: True # do not build text encoder to improve performance if set False
37+
text_encoder:
38+
type: blip2_stage1_classification
39+
freeze:
40+
image_encoder: True
41+
prompt_encoder:
42+
filter_prefix: text_ # NOTE: text position embedding won't freeze
43+
text_encoder: True # text encoder is always frozen, freeze grad.
44+
45+
loss:
46+
type: segment_anything.modeling.loss.SAMLoss
47+
48+
49+
train_loader:
50+
dataset:
51+
type: segment_anything.dataset.dataset.SA1BDataset
52+
data_dir: ./datasets/sa-1b/
53+
transform_pipeline:
54+
# extrack image patch before image resize and pad
55+
- type: segment_anything.dataset.transform.ImagePatchFromBoxMask
56+
- type: segment_anything.dataset.transform.ImagePatchPreprocess # norm, resize and to chw
57+
model: blip2_stage1_classification
58+
- type: segment_anything.dataset.transform.ImageResizeAndPad
59+
target_size: 1024
60+
- type: segment_anything.dataset.transform.ImageNorm # norm and to chw
61+
hwc2chw: True
62+
- type: segment_anything.dataset.transform.LabelPad
63+
gt_size: 20
64+
output_column: ['image', 'masks', 'image_patches', 'valid_boxes']
65+
66+
model_column: ['image', 'image_patches'] # columns for model cell input
67+
loss_column: ['masks', 'valid_boxes'] # columns for loss function input
68+
69+
shuffle: True
70+
batch_size: 1
71+
epoch_size: 20
72+
drop_remainder: True
73+
num_workers: 2
74+
max_rowsize: 64 # 24M space for dataloader
75+
76+
77+
callback:
78+
- type: segment_anything.utils.callbacks.TrainStatusLog
79+
loss_item: ['focal_loss', 'dice_loss', 'mse_loss'] # for log
80+
interval: 100
81+
- type: segment_anything.utils.callbacks.SaveCkpt
82+
work_root: *work_root
83+
interval: 1 # in epoch
84+
- type: segment_anything.utils.callbacks.FreezeChildCell # freeze training behavior like dropout and bn
85+
child_cells: ['text_encoder']
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#---------------------------------------------
2+
# Part 1: system basic config setting
3+
distributed: True
4+
device: Ascend
5+
mode: 0 # 0: graph, 1: pynative
6+
work_root: &work_root ./work_dir/
7+
log_level: info
8+
amp_level: O0
9+
10+
# ---------------------------------------------
11+
# Part2: module setting
12+
loss_manager:
13+
# type: fixed # dynamic or
14+
# scale_sense: 1024
15+
loss_scaler:
16+
type: dynamic
17+
grad_clip: False
18+
19+
optimizer:
20+
type: segment_anything.optim.optimizer.AdamW
21+
weight_decay: 1e-4
22+
group_param:
23+
24+
lr_scheduler:
25+
type: segment_anything.optim.scheduler.SAMDynamicDecayLR
26+
learning_rate: 8e-6
27+
warmup_steps: 250
28+
decay_steps: [ 60000, 86666 ]
29+
decay_factor: 10
30+
31+
32+
network:
33+
model:
34+
type: vit_h
35+
checkpoint: /home/ma-user/modelarts/user-job-dir/segment-anything/models/sam_vit_h-c72f8ba1.ckpt
36+
enable_text_encoder: True # do not build text encoder to improve performance if set False
37+
text_encoder:
38+
type: clip_vit_l_14@336
39+
feature_dim: 768
40+
freeze:
41+
image_encoder: True
42+
prompt_encoder:
43+
filter_prefix: text_ # NOTE: text position embedding won't freeze
44+
text_encoder: True # text encoder(blip2) is always frozen
45+
46+
loss:
47+
type: segment_anything.modeling.loss.SAMLoss
48+
49+
50+
train_loader:
51+
dataset:
52+
type: segment_anything.dataset.dataset.SA1BDataset
53+
data_dir: ./datasets/sa-1b/
54+
transform_pipeline:
55+
# extrack image patch before image resize and pad
56+
- type: segment_anything.dataset.transform.ImagePatchFromBoxMask
57+
- type: segment_anything.dataset.transform.ImagePatchPreprocess # norm, resize and to chw
58+
model: clip_vit_l_14@336
59+
- type: segment_anything.dataset.transform.ImageResizeAndPad
60+
target_size: 1024
61+
- type: segment_anything.dataset.transform.ImageNorm # norm and to chw
62+
hwc2chw: True
63+
- type: segment_anything.dataset.transform.LabelPad
64+
gt_size: 20
65+
output_column: ['image', 'masks', 'image_patches', 'valid_boxes']
66+
67+
model_column: ['image', 'image_patches'] # columns for model cell input
68+
loss_column: ['masks', 'valid_boxes'] # columns for loss function input
69+
70+
shuffle: True
71+
batch_size: 1
72+
epoch_size: 20
73+
drop_remainder: True
74+
num_workers: 2
75+
max_rowsize: 64 # 24M space for dataloader
76+
77+
78+
callback:
79+
- type: segment_anything.utils.callbacks.TrainStatusLog
80+
loss_item: ['focal_loss', 'dice_loss', 'mse_loss'] # for log
81+
interval: 100
82+
- type: segment_anything.utils.callbacks.SaveCkpt
83+
work_root: *work_root
84+
interval: 1 # in epoch
85+
- type: segment_anything.utils.callbacks.FreezeChildCell
86+
child_cells: ['text_encoder']

0 commit comments

Comments
 (0)