diff --git a/configs/exps/alvaro/10k-training.yaml b/configs/exps/alvaro/10k-training.yaml new file mode 100644 index 0000000000..471ded838d --- /dev/null +++ b/configs/exps/alvaro/10k-training.yaml @@ -0,0 +1,47 @@ +# MODIFY THIS ONE FOR RUNS + +job: + mem: 32GB + cpus: 4 + gres: gpu:rtx8000:1 + partition: long + time: 15:00:00 + +default: + wandb_name: alvaro-carbonero-math + wandb_project: ocp-alvaro + test_ri: True + mode: train + graph_rewiring: remove-tag-0 + optim: + cp_data_to_tmpdir: true + wandb-tags: 'best-config-??' # Insert what model you're running if running one by one. + frame_averaging: 2D + fa_frames: se3-random + model: + mp_type: updownscale + phys_embeds: True + tag_hidden_channels: 32 + pg_hidden_channels: 64 + energy_head: weighted-av-final-embeds + complex_mp: False + graph_norm: True + hidden_channels: 352 + num_filters: 448 + num_gaussians: 99 + num_interactions: 6 + second_layer_MLP: True + skip_co: concat + edge_embed_type: all_rij + optim: + lr_initial: 0.001 + scheduler: LinearWarmupCosineAnnealingLR + max_epochs: 20 + eval_every: 0.4 + batch_size: 256 + eval_batch_size: 256 + +runs: + - config: schnet-is2re-10k + + - config: gemnet_oc-is2re-10k diff --git a/configs/exps/alvaro/all-training.yaml b/configs/exps/alvaro/all-training.yaml new file mode 100644 index 0000000000..f124a63c66 --- /dev/null +++ b/configs/exps/alvaro/all-training.yaml @@ -0,0 +1,47 @@ +# MODIFY THIS ONE FOR RUNS + +job: + mem: 32GB + cpus: 4 + gres: gpu:rtx8000:1 + partition: long + time: 15:00:00 + +default: + wandb_name: alvaro-carbonero-math + wandb_project: ocp-alvaro + test_ri: True + mode: train + graph_rewiring: remove-tag-0 + optim: + cp_data_to_tmpdir: true + wandb-tags: 'best-config-??' # Insert what model you're running if running one by one. + frame_averaging: 2D + fa_frames: se3-random + model: + mp_type: updownscale + phys_embeds: True + tag_hidden_channels: 32 + pg_hidden_channels: 64 + energy_head: weighted-av-final-embeds + complex_mp: False + graph_norm: True + hidden_channels: 352 + num_filters: 448 + num_gaussians: 99 + num_interactions: 6 + second_layer_MLP: True + skip_co: concat + edge_embed_type: all_rij + optim: + lr_initial: 0.0005 + scheduler: LinearWarmupCosineAnnealingLR + max_epochs: 20 + eval_every: 0.4 + batch_size: 256 + eval_batch_size: 256 + +runs: + - config: afaenet-is2re-all + model: + afaenet_gat_mode: v1 diff --git a/configs/exps/alvaro/best-configs-all.yaml b/configs/exps/alvaro/best-configs-all.yaml new file mode 100644 index 0000000000..106a6b263d --- /dev/null +++ b/configs/exps/alvaro/best-configs-all.yaml @@ -0,0 +1,45 @@ +# DON'T MODIFY THIS + +job: + mem: 32GB + cpus: 4 + gres: gpu:rtx8000:1 + partition: long + time: 15:00:00 + +default: + wandb_name: alvaro-carbonero-math + wandb_project: ocp-alvaro + test_ri: True + mode: train + graph_rewiring: remove-tag-0 + model: + edge_embed_type: all_rij + wandb_tags: 'best-config' + optim: + batch_size: 256 + eval_batch_size: 256 + cp_data_to_tmpdir: true + config: faenet-is2re-all + note: 'best-config-??' # Insert what model you're running if running one by one. + frame_averaging: 2D + fa_frames: se3-random + model: + mp_type: updownscale + phys_embeds: False + tag_hidden_channels: 32 + pg_hidden_channels: 64 + energy_head: weighted-av-final-embeds + complex_mp: False + graph_norm: True + hidden_channels: 352 + num_filters: 448 + num_gaussians: 99 + num_interactions: 6 + second_layer_MLP: True + skip_co: concat + optim: + lr_initial: 0.0019 + scheduler: LinearWarmupCosineAnnealingLR + max_epochs: 20 + eval_every: 0.4 diff --git a/configs/exps/alvaro/dpp-config.yaml b/configs/exps/alvaro/dpp-config.yaml new file mode 100644 index 0000000000..1edbb37d80 --- /dev/null +++ b/configs/exps/alvaro/dpp-config.yaml @@ -0,0 +1,62 @@ +# MODIFY THIS ONE FOR RUNS + +job: + mem: 32GB + cpus: 4 + gres: gpu:rtx8000:1 + partition: long + time: 15:00:00 + +default: + wandb_name: alvaro-carbonero-math + wandb_project: ocp-alvaro + test_ri: True + mode: train + graph_rewiring: remove-tag-0 + optim: + batch_size: 16 + eval_batch_size: 16 + +runs: + # - config: dpp-is2re-10k + + # - config: dpp-is2re-10k + # is_disconnected: True + + # - config: depdpp-is2re-all + + - config: inddpp-is2re-all + note: so that cat get old dimensions + model: + hidden_channels: 256 + num_spherical: 7 + num_radial: 6 + out_emb_channels: 192 + + - config: inddpp-is2re-all + note: dimensions both smaller + model: + hidden_channels: 128 + num_spherical: 4 + num_radial: 3 + out_emb_channels: 96 + + - config: inddpp-is2re-all + note: so that ads get old dimensions + model: + hidden_channels: 512 + num_spherical: 14 + num_radial: 12 + out_emb_channels: 384 + + - config: inddpp-is2re-all + note: so that their average is old dimensions + model: + hidden_channels: 340 + num_spherical: 9 + num_radial: 8 + out_emb_channels: 256 + + # - config: adpp-is2re-10k + # model: + # gat_mode: v1 diff --git a/configs/exps/alvaro/faenet-top-config.yaml b/configs/exps/alvaro/faenet-top-config.yaml new file mode 100644 index 0000000000..2ac480c7fc --- /dev/null +++ b/configs/exps/alvaro/faenet-top-config.yaml @@ -0,0 +1,81 @@ +job: + mem: 32GB + cpus: 4 + gres: gpu:rtx8000:1 + partition: long + time: 15:00:00 + +default: + wandb_name: alvaro-carbonero-math + wandb_project: ocp-alvaro + wandb_tags: "best-config" + test_ri: True + mode: train + graph_rewiring: remove-tag-0 + note: "top-run" + frame_averaging: 2D + fa_method: se3-random + cp_data_to_tmpdir: True + model: + edge_embed_type: all_rij + mp_type: updownscale_base + phys_embeds: True + tag_hidden_channels: 32 + pg_hidden_channels: 96 + energy_head: weighted-av-final-embeds + complex_mp: True + graph_norm: True + hidden_channels: 352 + num_filters: 288 + num_gaussians: 68 + num_interactions: 5 + second_layer_MLP: False + skip_co: concat + cutoff: 4.0 + optim: + batch_size: 256 + eval_batch_size: 256 + lr_initial: 0.002 + scheduler: LinearWarmupCosineAnnealingLR + max_epochs: 9 + eval_every: 0.4 + +runs: + # - config: faenet-is2re-10k + + # - config: faenet-is2re-10k + # is_disconnected: True + + - config: indfaenet-is2re-all + note: so that cat get old dimensions + model: + hidden_channels: 352 + num_gaussians: 99 + num_filters: 448 + + - config: indfaenet-is2re-all + note: dimensions of both smaller + model: + hidden_channels: 176 + num_gaussians: 50 + num_filters: 224 + + - config: indfaenet-is2re-all + note: so that ads get old dimensions + model: + hidden_channels: 704 + num_gaussians: 200 + num_filters: 896 + + - config: indfaenet-is2re-all + note: so that their average is old dimension + model: + hidden_channels: 468 + num_gaussians: 132 + num_filters: 596 + + # - config: indfaenet-is2re-10k + + # - config: afaenet-is2re-all + # model: + # afaenet_gat_mode: v1 diff --git a/configs/exps/alvaro/faenet-training.yaml b/configs/exps/alvaro/faenet-training.yaml new file mode 100644 index 0000000000..62f56573f2 --- /dev/null +++ b/configs/exps/alvaro/faenet-training.yaml @@ -0,0 +1,46 @@ +# MODIFY THIS ONE FOR RUNS + +job: + mem: 32GB + cpus: 4 + gres: gpu:rtx8000:1 + partition: long + time: 15:00:00 + +default: + wandb_name: alvaro-carbonero-math + wandb_project: ocp-alvaro + test_ri: True + mode: train + graph_rewiring: remove-tag-0 + cp_data_to_tmpdir: true + wandb-tags: 'best-config-??' # Insert what model you're running if running one by one. + frame_averaging: 2D + fa_frames: se3-random + model: + mp_type: updownscale + phys_embeds: True + tag_hidden_channels: 32 + pg_hidden_channels: 64 + energy_head: weighted-av-final-embeds + complex_mp: False + graph_norm: True + hidden_channels: 352 + num_filters: 448 + num_gaussians: 99 + num_interactions: 6 + second_layer_MLP: True + skip_co: concat + edge_embed_type: all_rij + optim: + lr_initial: 0.0005 + scheduler: LinearWarmupCosineAnnealingLR + max_epochs: 20 + eval_every: 0.4 + batch_size: 256 + eval_batch_size: 256 + +runs: + - config: afaenet-is2re-all + model: + afaenet_gat_mode: v1 diff --git a/configs/exps/alvaro/gemnet-config.yaml b/configs/exps/alvaro/gemnet-config.yaml new file mode 100644 index 0000000000..782368994a --- /dev/null +++ b/configs/exps/alvaro/gemnet-config.yaml @@ -0,0 +1,35 @@ +job: + mem: 32GB + cpus: 4 + gres: gpu:rtx8000:1 + partition: long + time: 18:00:00 + +default: + wandb_name: alvaro-carbonero-math + wandb_project: ocp-alvaro + graph_rewiring: remove-tag-0 + model: + tag_hidden_channels: 32 + pg_hidden_channels: 32 + phys_embeds: True + otf_graph: False + max_num_neighbors: 40 + hidden_channels: 142 + graph_rewiring: remove-0-tag + optim: + batch_size: 32 + eval_batch_size: 32 + max_epochs: 30 + +runs: + - config: gemnet_t-is2re-all + + - config: gemnet_t-is2re-all + is_disconnected: True + + #- config: depgemnet_t-is2re-all + + #- config: indgemnet_t-is2re-all + + #- config: agemnet_t-is2re-all diff --git a/configs/exps/alvaro/gflownet.yaml b/configs/exps/alvaro/gflownet.yaml new file mode 100644 index 0000000000..2432f47339 --- /dev/null +++ b/configs/exps/alvaro/gflownet.yaml @@ -0,0 +1,143 @@ +job: + mem: 32GB + cpus: 4 + gres: gpu:rtx8000:1 + partition: long + time: 15:00:00 + +default: + # wandb_name: alvaro-carbonero-math + wandb_project: ocp-alvaro + wandb_tags: "gflownet-model" + test_ri: True + mode: train + # graph_rewiring: remove-tag-0 + graph_rewiring: "" + frame_averaging: 2D + fa_method: se3-random + cp_data_to_tmpdir: True + is_disconnected: true + model: + edge_embed_type: all_rij + mp_type: updownscale_base + phys_embeds: True + tag_hidden_channels: 0 + pg_hidden_channels: 96 + energy_head: weighted-av-final-embeds + complex_mp: True + graph_norm: True + hidden_channels: 352 + num_filters: 288 + num_gaussians: 68 + num_interactions: 5 + second_layer_MLP: False + skip_co: concat + cutoff: 4.0 + optim: + batch_size: 256 + eval_batch_size: 256 + lr_initial: 0.002 + scheduler: LinearWarmupCosineAnnealingLR + max_epochs: 9 + eval_every: 0.4 + +runs: + + # - config: faenet-is2re-all + # note: baseline faenet + + # - config: depfaenet-is2re-all + # note: depfaenet baseline + + # - config: depfaenet-is2re-all + # note: depfaenet per-adsorbate + # adsorbates: {'*O', '*OH', '*OH2', '*H'} + + # - config: depfaenet-is2re-all + # note: depfaenet per-adsorbate long string + # adsorbates: '*O, *OH, *OH2, *H' + + # - config: depfaenet-is2re-all + # note: depfaenet per-adsorbate string of a list + # adsorbates: "*O, *OH, *OH2, *H" + + # - config: depfaenet-is2re-all + # note: Trained on selected adsorbate more epochs + # adsorbates: "*O, *OH, *OH2, *H" + # optim: + # max_epochs: 10 + + # - config: depfaenet-is2re-all + # note: depfaenet full data + + # - config: depfaenet-is2re-all + # note: To be used for continue from dir + + # - config: depfaenet-is2re-all + # note: Fine-tune on per-ads-dataset 4 epoch + # continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4023244 + # adsorbates: "*O, *OH, *OH2, *H" + # optim: + # max_epochs: 4 + # lr_initial: 0.00015 + + # - config: depfaenet-is2re-all + # note: Fine-tune on per-ads-dataset 10 epoch + # continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4023244 + # adsorbates: "*O, *OH, *OH2, *H" + # optim: + # max_epochs: 10 + # lr_initial: 0.00015 + + - config: depfaenet-is2re-all + note: Fine-tune on per-ads-dataset 10 epoch + continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4023244 + adsorbates: "*O, *OH, *OH2, *H" + optim: + max_epochs: 20 + lr_initial: 0.0001 + + - config: depfaenet-is2re-all + note: Fine-tune on per-ads-dataset 20 epoch + continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4023244 + adsorbates: "*O, *OH, *OH2, *H" + optim: + max_epochs: 20 + lr_initial: 0.00015 + + - config: depfaenet-is2re-all + note: Fine-tune on per-ads-dataset 15 epoch + continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4023244 + adsorbates: "*O, *OH, *OH2, *H" + optim: + max_epochs: 15 + lr_initial: 0.0002 + + - config: depfaenet-is2re-all + note: Fine-tune on per-ads-dataset 10 epoch + continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4023244 + adsorbates: "*O, *OH, *OH2, *H" + optim: + max_epochs: 10 + lr_initial: 0.0001 + + - config: depfaenet-is2re-all + note: Fine-tune on per-ads-dataset starting from fine-tuned model + continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4071859 + adsorbates: "*O, *OH, *OH2, *H" + optim: + max_epochs: 10 + lr_initial: 0.0001 + + - config: depfaenet-is2re-all + note: Trained on selected adsorbate + adsorbates: "*O, *OH, *OH2, *H" + optim: + max_epochs: 25 + lr_initial: 0.0001 + + - config: depfaenet-is2re-all + note: Trained on selected adsorbate + adsorbates: "*O, *OH, *OH2, *H" + optim: + max_epochs: 25 diff --git a/configs/exps/alvaro/oldgemnet-config.yaml b/configs/exps/alvaro/oldgemnet-config.yaml new file mode 100644 index 0000000000..d9524c34bb --- /dev/null +++ b/configs/exps/alvaro/oldgemnet-config.yaml @@ -0,0 +1,35 @@ +job: + mem: 40GB + cpus: 4 + gres: gpu:rtx8000:1 + partition: long + time: 18:00:00 + +default: + wandb_name: alvaro-carbonero-math + wandb_project: ocp-alvaro + model: + tag_hidden_channels: 32 + pg_hidden_channels: 32 + phys_embeds: True + otf_graph: False + max_num_neighbors: 40 + hidden_channels: 142 + regress_forces: True + graph_rewiring: remove-0-tag + optim: + batch_size: 32 + eval_batch_size: 32 + max_epochs: 30 + +runs: + - config: gemnet_oc-is2re-all + + - config: depgemnet_oc-is2re-all + + #- config: indgemnet_oc-is2re-all + + - config: gemnet_oc-is2re-all + is_disconnected: True + + #- config: agemnet_oc-is2re-all diff --git a/configs/exps/alvaro/reproduce-configs.yaml b/configs/exps/alvaro/reproduce-configs.yaml new file mode 100644 index 0000000000..c4c834585c --- /dev/null +++ b/configs/exps/alvaro/reproduce-configs.yaml @@ -0,0 +1,75 @@ +job: + mem: 32GB + cpus: 4 + gres: gpu:rtx8000:1 + partition: long + time: 15:00:00 + +default: + # wandb_name: alvaro-carbonero-math + wandb_project: ocp-alvaro + wandb_tags: "reproduce-best-config" + test_ri: True + mode: train + graph_rewiring: remove-tag-0 + note: "repoduce-top-run" + frame_averaging: 2D + fa_method: se3-random + cp_data_to_tmpdir: True + is_disconnected: true + model: + edge_embed_type: all_rij + mp_type: updownscale_base + phys_embeds: True + tag_hidden_channels: 32 + pg_hidden_channels: 96 + energy_head: weighted-av-final-embeds + complex_mp: True + graph_norm: True + hidden_channels: 352 + num_filters: 288 + num_gaussians: 68 + num_interactions: 5 + second_layer_MLP: False + skip_co: concat + cutoff: 4.0 + optim: + batch_size: 256 + eval_batch_size: 256 + lr_initial: 0.002 + scheduler: LinearWarmupCosineAnnealingLR + max_epochs: 9 + eval_every: 0.4 + +runs: + + - config: faenet-is2re-all + note: baseline faenet + + - config: indfaenet-is2re-all + note: baseline with top configs + + - config: indfaenet-is2re-all + note: baseline with runs' configs + model: + tag_hidden_channels: 32 + pg_hidden_channels: 96 + energy_head: weighted-av-final-embeds + complex_mp: True + graph_norm: True + hidden_channels: 528 + num_filters: 672 + num_gaussians: 148 + num_interactions: 5 + second_layer_MLP: False + skip_co: concat + + - config: depfaenet-is2re-all + note: baseline with top configs + + - config: indfaenet-is2re-all + note: so that ads get old dimensions + model: + hidden_channels: 704 + num_gaussians: 200 + num_filters: 896 \ No newline at end of file diff --git a/configs/exps/alvaro/schnet-config.yaml b/configs/exps/alvaro/schnet-config.yaml new file mode 100644 index 0000000000..c9ac15dbc6 --- /dev/null +++ b/configs/exps/alvaro/schnet-config.yaml @@ -0,0 +1,61 @@ +job: + mem: 32GB + cpus: 4 + gres: gpu:rtx8000:1 + partition: long + time: 15:00:00 + +default: + wandb_name: alvaro-carbonero-math + wandb_project: ocp-alvaro + energy_head: false + num_targets: 15 + graph_rewiring: remove-tag-0 + model: + otf_graph: false + max_num_neighbors: 40 + optim: + num_workers: 4 + max_epochs: 17 + warmup_factor: 0.2 + lr_initial: 0.0005 + +runs: + # - config: schnet-is2re-10k + + # - config: schnet-is2re-10k + # is_disconnected: True + + # - config: depschnet-is2re-all + + - config: indschnet-is2re-all + note: so that cat get old dimensions + model: + hidden_channels: 256 + num_filters: 128 + num_gaussians: 100 + + - config: indschnet-is2re-all + note: dimensions both smaller + model: + hidden_channels: 126 + num_filters: 64 + num_gaussians: 50 + + - config: indschnet-is2re-all + note: so that ads get old dimensions + model: + hidden_channels: 512 + num_filters: 256 + num_gaussians: 200 + + - config: indschnet-is2re-all + note: so that their average is old dimensions + model: + hidden_channels: 340 + num_filters: 170 + num_gaussians: 132 + + # - config: aschnet-is2re-10k + # model: + # gat_mode: v1 diff --git a/configs/exps/alvaro/standard-faenet.yaml b/configs/exps/alvaro/standard-faenet.yaml new file mode 100644 index 0000000000..256b3d001a --- /dev/null +++ b/configs/exps/alvaro/standard-faenet.yaml @@ -0,0 +1,37 @@ +# MODIFY THIS ONE FOR RUNS + +job: + mem: 32GB + cpus: 4 + gres: gpu:rtx8000:1 + partition: long + time: 15:00:00 + +default: + wandb_name: alvaro-carbonero-math + wandb_project: ocp-alvaro + test_ri: True + mode: train + graph_rewiring: remove-tag-0 + cp_data_to_tmpdir: true + frame_averaging: 2D + fa_frames: se3-random + model: + phys_embeds: True + tag_hidden_channels: 32 + pg_hidden_channels: 32 + energy_head: weighted-av-final-embeds + skip_co: concat + edge_embed_type: all_rij + optim: + lr_initial: 0.0005 + scheduler: LinearWarmupCosineAnnealingLR + max_epochs: 20 + eval_every: 0.4 + batch_size: 256 + eval_batch_size: 256 + +runs: + - config: faenet-is2re-all + model: + afaenet_gat_mode: v1 diff --git a/configs/models/adpp.yaml b/configs/models/adpp.yaml new file mode 100644 index 0000000000..045ffed1b6 --- /dev/null +++ b/configs/models/adpp.yaml @@ -0,0 +1,209 @@ +default: + model: + name: adpp + hidden_channels: 256 + out_emb_channels: 192 + num_blocks: 3 + cutoff: 6.0 + num_radial: 6 + num_spherical: 7 + num_before_skip: 1 + num_after_skip: 2 + num_output_layers: 3 + regress_forces: False + use_pbc: True + basis_emb_size: 8 + envelope_exponent: 5 + act: swish + int_emb_size: 64 + # drlab attributes: + tag_hidden_channels: 0 # 64 + pg_hidden_channels: 0 # 32 -> period & group embedding hidden channels + phys_embeds: False # True + phys_hidden_channels: 0 + energy_head: False # can be {False, weighted-av-initial-embeds, weighted-av-final-embeds, pooling, graclus, random} + optim: + batch_size: 4 + eval_batch_size: 4 + num_workers: 4 + lr_gamma: 0.1 + warmup_factor: 0.2 + +# ------------------- +# ----- IS2RE ----- +# ------------------- + +is2re: + # *** Important note *** + # The total number of gpus used for this run was 1. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + 10k: + optim: + lr_initial: 0.0001 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 20000 + - 40000 + - 60000 + warmup_steps: 10000 + max_epochs: 20 + batch_size: 16 + eval_batch_size: 16 + + 100k: + optim: + lr_initial: 0.0001 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 200000 + - 400000 + - 600000 + warmup_steps: 100000 + max_epochs: 15 + batch_size: 16 + eval_batch_size: 16 + + all: + optim: + lr_initial: 0.0001 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 115082 + - 230164 + - 345246 + warmup_steps: 57541 + max_epochs: 8 + batch_size: 16 + eval_batch_size: 16 + +# ------------------ +# ----- S2EF ----- +# ------------------ + +s2ef: + default: + model: + regress_forces: "from_energy" + optim: + num_workers: 8 + eval_every: 10000 + + 200k: + optim: + # *** Important note *** + # The total number of gpus used for this run was 4. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + batch_size: 48 + eval_batch_size: 48 + lr_initial: 0.00001 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 5208 + - 8333 + - 10416 + warmup_steps: 3125 + max_epochs: 10 + force_coefficient: 50 + + 2M: + optim: + batch_size: 96 + eval_batch_size: 96 + eval_every: 10000 + num_workers: 8 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 20833 + - 31250 + - 41666 + warmup_steps: 10416 + warmup_factor: 0.2 + max_epochs: 15 + force_coefficient: 50 + model: + hidden_channels: 192 + out_emb_channels: 192 + num_blocks: 3 + cutoff: 6.0 + num_radial: 6 + num_spherical: 7 + num_before_skip: 1 + num_after_skip: 2 + num_output_layers: 3 + regress_forces: True + use_pbc: True + + 20M: + optim: + # *** Important note *** + # The total number of gpus used for this run was 64. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + batch_size: 12 + eval_batch_size: 12 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 78125 + - 130208 + - 208333 + warmup_steps: 52083 + max_epochs: 15 + force_coefficient: 50 + + all: + optim: + # *** Important note *** + # The total number of gpus used for this run was 256. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + batch_size: 8 + eval_batch_size: 8 + lr_initial: 0.0001 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 130794 + - 196192 + - 261589 + warmup_steps: 130794 + max_epochs: 7 + force_coefficient: 50 + +qm9: + default: + model: + num_blocks: 6 + hidden_channels: 128 + optim: + # *** Important note *** + # The total number of gpus used for this run was 4. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + lr_initial: 0.001 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 2000000 + - 4000000 + - 6000000 + warmup_steps: 3000 + lr_gamma: 0.1 + batch_size: 128 + max_epochs: 600 + + 10k: {} + all: {} + +qm7x: + default: + optim: + # *** Important note *** + # The total number of gpus used for this run was 4. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + lr_initial: 0.0001 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 115082 + - 230164 + - 345246 + warmup_steps: 57541 + max_epochs: 8 + + all: {} + 1k: {} diff --git a/configs/models/afaenet.yaml b/configs/models/afaenet.yaml new file mode 100644 index 0000000000..bb24eabf95 --- /dev/null +++ b/configs/models/afaenet.yaml @@ -0,0 +1,271 @@ +default: + model: + name: afaenet + act: swish + hidden_channels: 128 + num_filters: 100 + num_interactions: 3 + num_gaussians: 100 + cutoff: 6.0 + use_pbc: True + regress_forces: False + # drlab attributes: + tag_hidden_channels: 0 # 32 + pg_hidden_channels: 0 # 32 -> period & group embedding hidden channels + phys_embeds: False # True + phys_hidden_channels: 0 + energy_head: False # can be {False, weighted-av-initial-embeds, weighted-av-final-embeds, pooling, graclus, random} + # faenet new features + skip_co: False # output skip connections {False, "add", "concat"} + second_layer_MLP: False # in EmbeddingBlock + complex_mp: False + edge_embed_type: rij # {'rij','all_rij','sh', 'all'}) + mp_type: base # {'base', 'simple', 'updownscale', 'att', 'base_with_att', 'local_env'} + graph_norm: False # bool + att_heads: 1 # int + force_decoder_type: "mlp" # can be {"" or "simple"} | only used if regress_forces is True + force_decoder_model_config: + simple: + hidden_channels: 128 + norm: batch1d # batch1d, layer or null + mlp: + hidden_channels: 256 + norm: batch1d # batch1d, layer or null + res: + hidden_channels: 128 + norm: batch1d # batch1d, layer or null + res_updown: + hidden_channels: 128 + norm: batch1d # batch1d, layer or null + optim: + batch_size: 64 + eval_batch_size: 64 + num_workers: 4 + lr_gamma: 0.1 + lr_initial: 0.001 + warmup_factor: 0.2 + max_epochs: 20 + energy_grad_coefficient: 10 + force_coefficient: 30 + energy_coefficient: 1 + + frame_averaging: False # 2D, 3D, da, False + fa_frames: False # can be {None, full, random, det, e3, e3-random, e3-det} + +# ------------------- +# ----- IS2RE ----- +# ------------------- + +is2re: + # *** Important note *** + # The total number of gpus used for this run was 1. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + 10k: + optim: + lr_initial: 0.005 + lr_milestones: # epochs at which lr_initial <- lr_initial * lr_gamma + - 1562 + - 2343 + - 3125 + warmup_steps: 468 + max_epochs: 20 + + 100k: + model: + hidden_channels: 256 + optim: + lr_initial: 0.005 + lr_milestones: # epochs at which lr_initial <- lr_initial * lr_gamma + - 1562 + - 2343 + - 3125 + warmup_steps: 468 + max_epochs: 20 + + all: + model: + hidden_channels: 384 + num_interactions: 4 + optim: + batch_size: 256 + eval_batch_size: 256 + lr_initial: 0.001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 18000 + - 27000 + - 37000 + warmup_steps: 6000 + max_epochs: 20 + +# ------------------ +# ----- S2EF ----- +# ------------------ + +# For 2 GPUs + +s2ef: + default: + model: + num_interactions: 4 + hidden_channels: 750 + num_gaussians: 200 + num_filters: 256 + regress_forces: "direct" + force_coefficient: 30 + energy_grad_coefficient: 10 + optim: + batch_size: 96 + eval_batch_size: 96 + warmup_factor: 0.2 + lr_gamma: 0.1 + lr_initial: 0.0001 + max_epochs: 15 + warmup_steps: 30000 + lr_milestones: + - 55000 + - 75000 + - 10000 + + 200k: {} + + # 1 gpus + 2M: + model: + num_interactions: 5 + hidden_channels: 1024 + num_gaussians: 200 + num_filters: 256 + optim: + batch_size: 192 + eval_batch_size: 192 + + 20M: {} + + all: {} + +qm9: + default: + model: + act: swish + att_heads: 1 + complex_mp: true + cutoff: 6.0 + edge_embed_type: all_rij + energy_head: '' + graph_norm: true + graph_rewiring: null + hidden_channels: 400 + max_num_neighbors: 30 + mp_type: updownscale_base + num_filters: 480 + num_gaussians: 100 + num_interactions: 5 + otf_graph: false + pg_hidden_channels: 32 + phys_embeds: false + phys_hidden_channels: 0 + regress_forces: '' + second_layer_MLP: true + skip_co: true + tag_hidden_channels: 0 + use_pbc: false + + optim: + batch_size: 64 + es_min_abs_change: 1.0e-06 + es_patience: 20 + es_warmup_epochs: 600 + eval_batch_size: 64 + factor: 0.9 + lr_initial: 0.0003 + loss_energy: mse + lr_gamma: 0.1 + lr_initial: 0.001 + max_epochs: 1500 + min_lr: 1.0e-06 + mode: min + optimizer: AdamW + patience: 15 + scheduler: ReduceLROnPlateau + threshold: 0.0001 + threshold_mode: abs + verbose: true + warmup_factor: 0.2 + warmup_steps: 3000 + + 10k: {} + all: {} + +qm7x: + default: + model: # SOTA settings + act: swish + att_heads: 1 + complex_mp: true + cutoff: 5.0 + edge_embed_type: all_rij + energy_head: false + force_decoder_model_config: + mlp: + hidden_channels: 256 + norm: batch1d + res: + hidden_channels: 128 + norm: batch1d + res_updown: + hidden_channels: 128 + norm: layer + simple: + hidden_channels: 128 + norm: batch1d + force_decoder_type: res_updown + graph_norm: false + hidden_channels: 500 + max_num_neighbors: 40 + mp_type: updownscale_base + num_filters: 400 + num_gaussians: 50 + num_interactions: 5 + otf_graph: false + pg_hidden_channels: 32 + phys_embeds: true + phys_hidden_channels: 0 + regress_forces: direct_with_gradient_target + second_layer_MLP: true + skip_co: false + tag_hidden_channels: 0 + use_pbc: false + + optim: + batch_size: 100 + energy_grad_coefficient: 5 + eval_batch_size: 100 + eval_every: 0.34 + factor: 0.75 + force_coefficient: 75 + loss_energy: mae + loss_force: mse + lr_gamma: 0.1 + lr_initial: 0.000193 + max_steps: 4000000 + min_lr: 1.0e-06 + mode: min + optimizer: AdamW + scheduler: ReduceLROnPlateau + threshold: 0.001 + threshold_mode: abs + verbose: true + warmup_factor: 0.2 + warmup_steps: 3000 + + all: {} + 1k: {} + +qm9: + default: + model: + use_pbc: False + all: {} + 10k: {} diff --git a/configs/models/agemnet_oc.yaml b/configs/models/agemnet_oc.yaml new file mode 100644 index 0000000000..5374c5da68 --- /dev/null +++ b/configs/models/agemnet_oc.yaml @@ -0,0 +1,102 @@ +default: + model: + name: agemnet_oc + num_spherical: 7 + num_radial: 128 + num_blocks: 4 + emb_size_atom: 256 + emb_size_edge: 512 + emb_size_trip_in: 64 + emb_size_trip_out: 64 + emb_size_quad_in: 32 + emb_size_quad_out: 32 + emb_size_aint_in: 64 + emb_size_aint_out: 64 + emb_size_rbf: 16 + emb_size_cbf: 16 + emb_size_sbf: 32 + num_before_skip: 2 + num_after_skip: 2 + num_concat: 1 + num_atom: 3 + num_output_afteratom: 3 + cutoff: 12.0 + cutoff_qint: 12.0 + cutoff_aeaint: 12.0 + cutoff_aint: 12.0 + max_neighbors: 30 + max_neighbors_qint: 8 + max_neighbors_aeaint: 20 + max_neighbors_aint: 1000 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + sbf: + name: legendre_outer + extensive: True + output_init: HeOrthogonal + activation: silu + scale_file: configs/models/scaling_factors/gemnet-oc.pt + + regress_forces: True + direct_forces: True + forces_coupled: False + + quad_interaction: True + atom_edge_interaction: True + edge_atom_interaction: True + atom_interaction: True + + num_atom_emb_layers: 2 + num_global_out_layers: 2 + qint_tags: [1, 2] + + # PhAST + tag_hidden_channels: 0 # 64 + pg_hidden_channels: 0 # 32 -> period & group embedding hidden channels + phys_embeds: False # True + phys_hidden_channels: 0 + energy_head: False # can be {False, weighted-av-initial-embeds, weighted-av-final-embeds, pooling, graclus, random} + + optim: + batch_size: 16 + eval_batch_size: 16 + load_balancing: atoms + eval_every: 5000 + num_workers: 2 + lr_initial: 5.e-4 + optimizer: AdamW + optimizer_params: {"amsgrad": True} + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 + loss_energy: mae + loss_force: l2mae + weight_decay: 0 + +is2re: + default: + model: + regress_forces: False + num_targets: 1 + 10k: {} + all: {} + +s2ef: + default: + model: + num_targets: 1 + 200k: {} + 2M: {} + 20M: {} + all: {} diff --git a/configs/models/aschnet.yaml b/configs/models/aschnet.yaml new file mode 100644 index 0000000000..23d8db1496 --- /dev/null +++ b/configs/models/aschnet.yaml @@ -0,0 +1,225 @@ +default: + model: + name: aschnet + num_filters: 128 + num_gaussians: 100 + hidden_channels: 256 + num_interactions: 3 + cutoff: 6.0 + use_pbc: True + regress_forces: False + readout: add + atomref: null + # drlab attributes: + tag_hidden_channels: 0 # 32 + pg_hidden_channels: 0 # 32 -> period & group embedding hidden channels + phys_embeds: False # True + phys_hidden_channels: 0 + energy_head: False # can be {False, weighted-av-initial-embeds, weighted-av-final-embeds, pooling, graclus, random} + optim: + batch_size: 64 + eval_batch_size: 64 + num_workers: 4 + lr_gamma: 0.1 + warmup_factor: 0.2 + +# ------------------- +# ----- IS2RE ----- +# ------------------- + +is2re: + # *** Important note *** + # The total number of gpus used for this run was 1. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + 10k: + model: + hidden_channels: 256 + num_interactions: 3 + optim: + lr_initial: 0.005 + max_epochs: 20 + lr_milestones: + - 1562 + - 2343 + - 3125 + warmup_steps: 468 + batch_size: 256 + eval_batch_size: 256 + + 100k: + model: + hidden_channels: 384 + num_interactions: 4 + optim: + lr_initial: 0.0005 + max_epochs: 25 + lr_milestones: + - 15625 + - 31250 + - 46875 + warmup_steps: 9375 + batch_size: 256 + eval_batch_size: 256 + + all: + model: + hidden_channels: 384 + num_interactions: 4 + optim: + lr_initial: 0.001 + max_epochs: 17 + lr_gamma: 0.1 + lr_milestones: + - 17981 + - 26972 + - 35963 + warmup_steps: 5394 + batch_size: 256 + eval_batch_size: 256 + +# ------------------ +# ----- S2EF ----- +# ------------------ + +s2ef: + default: + model: + regress_forces: "from_energy" + hidden_channels: 1024 + num_filters: 256 + num_interactions: 5 + num_gaussians: 200 + optim: + # *** Important note *** + # The total number of gpus used for this run was 1. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + batch_size: 192 + eval_batch_size: 192 + num_workers: 16 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: + - 52083 + - 83333 + - 104166 + warmup_steps: 31250 + max_epochs: 15 + force_coefficient: 100 + + 200k: + model: + hidden_channels: 1024 + num_filters: 256 + num_interactions: 3 + num_gaussians: 200 + optim: + batch_size: 128 + eval_batch_size: 128 + num_workers: 16 + lr_initial: 0.0005 + lr_gamma: 0.1 + lr_milestones: + - 7812 + - 12500 + - 15625 + warmup_steps: 4687 + max_epochs: 30 + force_coefficient: 100 + + 2M: {} + + 20M: + model: + hidden_channels: 1024 + num_filters: 256 + num_interactions: 5 + num_gaussians: 200 + optim: + # *** Important note *** + # The total number of gpus used for this run was 48. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + batch_size: 24 + eval_batch_size: 24 + num_workers: 16 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: + - 86805 + - 138888 + - 173611 + warmup_steps: 52083 + max_epochs: 30 + force_coefficient: 50 + + all: + model: + hidden_channels: 1024 + num_filters: 256 + num_interactions: 5 + num_gaussians: 200 + optim: + # *** Important note *** + # The total number of gpus used for this run was 64. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + batch_size: 20 + eval_batch_size: 20 + num_workers: 16 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: + - 313907 + - 523179 + - 732451 + warmup_steps: 209271 + max_epochs: 15 + force_coefficient: 30 + +qm9: + default: + model: + hidden_channels: 128 + num_gaussians: 100 + num_filters: 128 + num_interactions: 6 + cutoff: 5.0 + optim: + batch_size: 1024 + lr_initial: 0.001 + max_epochs: 1000 + decay_steps: 125000 + decay_rate: 0.01 + ema_decay: 0.999 + lr_gamma: 0.25 + lr_milestones: + - 17981 + - 26972 + - 35963 + - 52000 + - 100000 + warmup_steps: 1000 + + 10k: {} + all: {} + +qm7x: + default: + model: + hidden_channels: 384 + num_interactions: 4 + optim: + batch_size: 128 + lr_initial: 0.001 + max_epochs: 25 + lr_gamma: 0.1 + lr_milestones: + - 17981 + - 26972 + - 35963 + warmup_steps: 15000 + + all: {} + 1k: {} diff --git a/configs/models/depdpp.yaml b/configs/models/depdpp.yaml new file mode 100644 index 0000000000..3f04d06209 --- /dev/null +++ b/configs/models/depdpp.yaml @@ -0,0 +1,209 @@ +default: + model: + name: depdpp + hidden_channels: 256 + out_emb_channels: 192 + num_blocks: 3 + cutoff: 6.0 + num_radial: 6 + num_spherical: 7 + num_before_skip: 1 + num_after_skip: 2 + num_output_layers: 3 + regress_forces: False + use_pbc: True + basis_emb_size: 8 + envelope_exponent: 5 + act: swish + int_emb_size: 64 + # drlab attributes: + tag_hidden_channels: 0 # 64 + pg_hidden_channels: 0 # 32 -> period & group embedding hidden channels + phys_embeds: False # True + phys_hidden_channels: 0 + energy_head: False # can be {False, weighted-av-initial-embeds, weighted-av-final-embeds, pooling, graclus, random} + optim: + batch_size: 4 + eval_batch_size: 4 + num_workers: 4 + lr_gamma: 0.1 + warmup_factor: 0.2 + +# ------------------- +# ----- IS2RE ----- +# ------------------- + +is2re: + # *** Important note *** + # The total number of gpus used for this run was 1. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + 10k: + optim: + lr_initial: 0.0001 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 20000 + - 40000 + - 60000 + warmup_steps: 10000 + max_epochs: 20 + batch_size: 16 + eval_batch_size: 16 + + 100k: + optim: + lr_initial: 0.0001 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 200000 + - 400000 + - 600000 + warmup_steps: 100000 + max_epochs: 15 + batch_size: 16 + eval_batch_size: 16 + + all: + optim: + lr_initial: 0.0001 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 115082 + - 230164 + - 345246 + warmup_steps: 57541 + max_epochs: 8 + batch_size: 16 + eval_batch_size: 16 + +# ------------------ +# ----- S2EF ----- +# ------------------ + +s2ef: + default: + model: + regress_forces: "from_energy" + optim: + num_workers: 8 + eval_every: 10000 + + 200k: + optim: + # *** Important note *** + # The total number of gpus used for this run was 4. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + batch_size: 48 + eval_batch_size: 48 + lr_initial: 0.00001 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 5208 + - 8333 + - 10416 + warmup_steps: 3125 + max_epochs: 10 + force_coefficient: 50 + + 2M: + optim: + batch_size: 96 + eval_batch_size: 96 + eval_every: 10000 + num_workers: 8 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 20833 + - 31250 + - 41666 + warmup_steps: 10416 + warmup_factor: 0.2 + max_epochs: 15 + force_coefficient: 50 + model: + hidden_channels: 192 + out_emb_channels: 192 + num_blocks: 3 + cutoff: 6.0 + num_radial: 6 + num_spherical: 7 + num_before_skip: 1 + num_after_skip: 2 + num_output_layers: 3 + regress_forces: True + use_pbc: True + + 20M: + optim: + # *** Important note *** + # The total number of gpus used for this run was 64. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + batch_size: 12 + eval_batch_size: 12 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 78125 + - 130208 + - 208333 + warmup_steps: 52083 + max_epochs: 15 + force_coefficient: 50 + + all: + optim: + # *** Important note *** + # The total number of gpus used for this run was 256. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + batch_size: 8 + eval_batch_size: 8 + lr_initial: 0.0001 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 130794 + - 196192 + - 261589 + warmup_steps: 130794 + max_epochs: 7 + force_coefficient: 50 + +qm9: + default: + model: + num_blocks: 6 + hidden_channels: 128 + optim: + # *** Important note *** + # The total number of gpus used for this run was 4. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + lr_initial: 0.001 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 2000000 + - 4000000 + - 6000000 + warmup_steps: 3000 + lr_gamma: 0.1 + batch_size: 128 + max_epochs: 600 + + 10k: {} + all: {} + +qm7x: + default: + optim: + # *** Important note *** + # The total number of gpus used for this run was 4. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + lr_initial: 0.0001 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 115082 + - 230164 + - 345246 + warmup_steps: 57541 + max_epochs: 8 + + all: {} + 1k: {} diff --git a/configs/models/depfaenet.yaml b/configs/models/depfaenet.yaml new file mode 100644 index 0000000000..852ebc3bfd --- /dev/null +++ b/configs/models/depfaenet.yaml @@ -0,0 +1,271 @@ +default: + model: + name: depfaenet + act: swish + hidden_channels: 128 + num_filters: 100 + num_interactions: 3 + num_gaussians: 100 + cutoff: 6.0 + use_pbc: True + regress_forces: False + # drlab attributes: + tag_hidden_channels: 0 # 32 + pg_hidden_channels: 0 # 32 -> period & group embedding hidden channels + phys_embeds: False # True + phys_hidden_channels: 0 + energy_head: False # can be {False, weighted-av-initial-embeds, weighted-av-final-embeds, pooling, graclus, random} + # faenet new features + skip_co: False # output skip connections {False, "add", "concat"} + second_layer_MLP: False # in EmbeddingBlock + complex_mp: False + edge_embed_type: rij # {'rij','all_rij','sh', 'all'}) + mp_type: base # {'base', 'simple', 'updownscale', 'att', 'base_with_att', 'local_env'} + graph_norm: False # bool + att_heads: 1 # int + force_decoder_type: "mlp" # can be {"" or "simple"} | only used if regress_forces is True + force_decoder_model_config: + simple: + hidden_channels: 128 + norm: batch1d # batch1d, layer or null + mlp: + hidden_channels: 256 + norm: batch1d # batch1d, layer or null + res: + hidden_channels: 128 + norm: batch1d # batch1d, layer or null + res_updown: + hidden_channels: 128 + norm: batch1d # batch1d, layer or null + optim: + batch_size: 64 + eval_batch_size: 64 + num_workers: 4 + lr_gamma: 0.1 + lr_initial: 0.001 + warmup_factor: 0.2 + max_epochs: 20 + energy_grad_coefficient: 10 + force_coefficient: 30 + energy_coefficient: 1 + + frame_averaging: False # 2D, 3D, da, False + fa_frames: False # can be {None, full, random, det, e3, e3-random, e3-det} + +# ------------------- +# ----- IS2RE ----- +# ------------------- + +is2re: + # *** Important note *** + # The total number of gpus used for this run was 1. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + 10k: + optim: + lr_initial: 0.005 + lr_milestones: # epochs at which lr_initial <- lr_initial * lr_gamma + - 1562 + - 2343 + - 3125 + warmup_steps: 468 + max_epochs: 20 + + 100k: + model: + hidden_channels: 256 + optim: + lr_initial: 0.005 + lr_milestones: # epochs at which lr_initial <- lr_initial * lr_gamma + - 1562 + - 2343 + - 3125 + warmup_steps: 468 + max_epochs: 20 + + all: + model: + hidden_channels: 384 + num_interactions: 4 + optim: + batch_size: 256 + eval_batch_size: 256 + lr_initial: 0.001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 18000 + - 27000 + - 37000 + warmup_steps: 6000 + max_epochs: 20 + +# ------------------ +# ----- S2EF ----- +# ------------------ + +# For 2 GPUs + +s2ef: + default: + model: + num_interactions: 4 + hidden_channels: 750 + num_gaussians: 200 + num_filters: 256 + regress_forces: "direct" + force_coefficient: 30 + energy_grad_coefficient: 10 + optim: + batch_size: 96 + eval_batch_size: 96 + warmup_factor: 0.2 + lr_gamma: 0.1 + lr_initial: 0.0001 + max_epochs: 15 + warmup_steps: 30000 + lr_milestones: + - 55000 + - 75000 + - 10000 + + 200k: {} + + # 1 gpus + 2M: + model: + num_interactions: 5 + hidden_channels: 1024 + num_gaussians: 200 + num_filters: 256 + optim: + batch_size: 192 + eval_batch_size: 192 + + 20M: {} + + all: {} + +qm9: + default: + model: + act: swish + att_heads: 1 + complex_mp: true + cutoff: 6.0 + edge_embed_type: all_rij + energy_head: '' + graph_norm: true + graph_rewiring: null + hidden_channels: 400 + max_num_neighbors: 30 + mp_type: updownscale_base + num_filters: 480 + num_gaussians: 100 + num_interactions: 5 + otf_graph: false + pg_hidden_channels: 32 + phys_embeds: false + phys_hidden_channels: 0 + regress_forces: '' + second_layer_MLP: true + skip_co: true + tag_hidden_channels: 0 + use_pbc: false + + optim: + batch_size: 64 + es_min_abs_change: 1.0e-06 + es_patience: 20 + es_warmup_epochs: 600 + eval_batch_size: 64 + factor: 0.9 + lr_initial: 0.0003 + loss_energy: mse + lr_gamma: 0.1 + lr_initial: 0.001 + max_epochs: 1500 + min_lr: 1.0e-06 + mode: min + optimizer: AdamW + patience: 15 + scheduler: ReduceLROnPlateau + threshold: 0.0001 + threshold_mode: abs + verbose: true + warmup_factor: 0.2 + warmup_steps: 3000 + + 10k: {} + all: {} + +qm7x: + default: + model: # SOTA settings + act: swish + att_heads: 1 + complex_mp: true + cutoff: 5.0 + edge_embed_type: all_rij + energy_head: false + force_decoder_model_config: + mlp: + hidden_channels: 256 + norm: batch1d + res: + hidden_channels: 128 + norm: batch1d + res_updown: + hidden_channels: 128 + norm: layer + simple: + hidden_channels: 128 + norm: batch1d + force_decoder_type: res_updown + graph_norm: false + hidden_channels: 500 + max_num_neighbors: 40 + mp_type: updownscale_base + num_filters: 400 + num_gaussians: 50 + num_interactions: 5 + otf_graph: false + pg_hidden_channels: 32 + phys_embeds: true + phys_hidden_channels: 0 + regress_forces: direct_with_gradient_target + second_layer_MLP: true + skip_co: false + tag_hidden_channels: 0 + use_pbc: false + + optim: + batch_size: 100 + energy_grad_coefficient: 5 + eval_batch_size: 100 + eval_every: 0.34 + factor: 0.75 + force_coefficient: 75 + loss_energy: mae + loss_force: mse + lr_gamma: 0.1 + lr_initial: 0.000193 + max_steps: 4000000 + min_lr: 1.0e-06 + mode: min + optimizer: AdamW + scheduler: ReduceLROnPlateau + threshold: 0.001 + threshold_mode: abs + verbose: true + warmup_factor: 0.2 + warmup_steps: 3000 + + all: {} + 1k: {} + +qm9: + default: + model: + use_pbc: False + all: {} + 10k: {} diff --git a/configs/models/depgemnet_oc.yaml b/configs/models/depgemnet_oc.yaml new file mode 100644 index 0000000000..d3fb79cabc --- /dev/null +++ b/configs/models/depgemnet_oc.yaml @@ -0,0 +1,102 @@ +default: + model: + name: depgemnet_oc + num_spherical: 7 + num_radial: 128 + num_blocks: 4 + emb_size_atom: 256 + emb_size_edge: 512 + emb_size_trip_in: 64 + emb_size_trip_out: 64 + emb_size_quad_in: 32 + emb_size_quad_out: 32 + emb_size_aint_in: 64 + emb_size_aint_out: 64 + emb_size_rbf: 16 + emb_size_cbf: 16 + emb_size_sbf: 32 + num_before_skip: 2 + num_after_skip: 2 + num_concat: 1 + num_atom: 3 + num_output_afteratom: 3 + cutoff: 12.0 + cutoff_qint: 12.0 + cutoff_aeaint: 12.0 + cutoff_aint: 12.0 + max_neighbors: 30 + max_neighbors_qint: 8 + max_neighbors_aeaint: 20 + max_neighbors_aint: 1000 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + sbf: + name: legendre_outer + extensive: True + output_init: HeOrthogonal + activation: silu + scale_file: configs/models/scaling_factors/gemnet-oc.pt + + regress_forces: True + direct_forces: True + forces_coupled: False + + quad_interaction: True + atom_edge_interaction: True + edge_atom_interaction: True + atom_interaction: True + + num_atom_emb_layers: 2 + num_global_out_layers: 2 + qint_tags: [1, 2] + + # PhAST + tag_hidden_channels: 0 # 64 + pg_hidden_channels: 0 # 32 -> period & group embedding hidden channels + phys_embeds: False # True + phys_hidden_channels: 0 + energy_head: False # can be {False, weighted-av-initial-embeds, weighted-av-final-embeds, pooling, graclus, random} + + optim: + batch_size: 16 + eval_batch_size: 16 + load_balancing: atoms + eval_every: 5000 + num_workers: 2 + lr_initial: 5.e-4 + optimizer: AdamW + optimizer_params: {"amsgrad": True} + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 + loss_energy: mae + loss_force: l2mae + weight_decay: 0 + +is2re: + default: + model: + regress_forces: False + num_targets: 1 + 10k: {} + all: {} + +s2ef: + default: + model: + num_targets: 1 + 200k: {} + 2M: {} + 20M: {} + all: {} diff --git a/configs/models/depgemnet_t.yaml b/configs/models/depgemnet_t.yaml new file mode 100644 index 0000000000..2523da00e8 --- /dev/null +++ b/configs/models/depgemnet_t.yaml @@ -0,0 +1,113 @@ +# From OCP original repo -> https://github.com/Open-Catalyst-Project/ocp/blob/d16de9ee6f26d8661be5b9171e8c73c80237a82f/configs/oc22/is2re/gemnet-dT/gemnet-dT.yml +# Run this on 1 GPU -- so with an effective batch size of 8. + +default: + model: + name: depgemnet_t + use_pbc: true + num_spherical: 7 + num_radial: 64 + num_blocks: 5 + emb_size_atom: 256 + emb_size_edge: 512 + emb_size_trip: 64 + emb_size_rbf: 64 + emb_size_cbf: 16 + emb_size_bil_trip: 64 + num_before_skip: 1 + num_after_skip: 2 + num_concat: 1 + num_atom: 3 + cutoff: 6.0 + max_neighbors: 50 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + extensive: True + otf_graph: False + output_init: HeOrthogonal + activation: silu + scale_file: configs/models/scaling_factors/gemnet-dT_c12.json + regress_forces: False + # PhAST + tag_hidden_channels: 0 # 64 + pg_hidden_channels: 0 # 32 -> period & group embedding hidden channels + phys_hidden_channels: 0 # 32 -> physical properties embedding hidden channels + phys_embeds: False # True + optim: + batch_size: 8 + eval_batch_size: 8 + num_workers: 2 + lr_initial: 1.e-4 + optimizer: AdamW + optimizer_params: { "amsgrad": True } + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 + loss_energy: mae + +# ------------------- +# ----- IS2RE ----- +# ------------------- + +is2re: + 10k: {} + + 100k: {} + + all: {} +# ------------------ +# ----- S2EF ----- +# ------------------ + +s2ef: + default: + model: + cutoff: 6.0 + scale_file: configs/models/scaling_factors/gemnet-dT.json + regress_forces: "direct" + otf_graph: False + max_neighbors: 50 + num_radial: 128 + num_blocks: 3 + emb_size_atom: 512 + emb_size_trip: 64 + emb_size_rbf: 16 + optim: + clip_grad_norm: 10 + loss_force: l2mae + batch_size: 32 + eval_batch_size: 32 + lr_initial: 5.e-4 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + + 200k: {} + + 2M: {} + + 20M: {} + + all: {} + +qm9: + default: {} + + 10k: {} + all: {} + +qm7x: + default: {} + + all: {} + 1k: {} diff --git a/configs/models/depschnet.yaml b/configs/models/depschnet.yaml new file mode 100644 index 0000000000..65fcb15037 --- /dev/null +++ b/configs/models/depschnet.yaml @@ -0,0 +1,225 @@ +default: + model: + name: depschnet + num_filters: 128 + num_gaussians: 100 + hidden_channels: 256 + num_interactions: 3 + cutoff: 6.0 + use_pbc: True + regress_forces: False + readout: add + atomref: null + # drlab attributes: + tag_hidden_channels: 0 # 32 + pg_hidden_channels: 0 # 32 -> period & group embedding hidden channels + phys_embeds: False # True + phys_hidden_channels: 0 + energy_head: False # can be {False, weighted-av-initial-embeds, weighted-av-final-embeds, pooling, graclus, random} + optim: + batch_size: 64 + eval_batch_size: 64 + num_workers: 4 + lr_gamma: 0.1 + warmup_factor: 0.2 + +# ------------------- +# ----- IS2RE ----- +# ------------------- + +is2re: + # *** Important note *** + # The total number of gpus used for this run was 1. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + 10k: + model: + hidden_channels: 256 + num_interactions: 3 + optim: + lr_initial: 0.005 + max_epochs: 20 + lr_milestones: + - 1562 + - 2343 + - 3125 + warmup_steps: 468 + batch_size: 256 + eval_batch_size: 256 + + 100k: + model: + hidden_channels: 384 + num_interactions: 4 + optim: + lr_initial: 0.0005 + max_epochs: 25 + lr_milestones: + - 15625 + - 31250 + - 46875 + warmup_steps: 9375 + batch_size: 256 + eval_batch_size: 256 + + all: + model: + hidden_channels: 384 + num_interactions: 4 + optim: + lr_initial: 0.001 + max_epochs: 17 + lr_gamma: 0.1 + lr_milestones: + - 17981 + - 26972 + - 35963 + warmup_steps: 5394 + batch_size: 256 + eval_batch_size: 256 + +# ------------------ +# ----- S2EF ----- +# ------------------ + +s2ef: + default: + model: + regress_forces: "from_energy" + hidden_channels: 1024 + num_filters: 256 + num_interactions: 5 + num_gaussians: 200 + optim: + # *** Important note *** + # The total number of gpus used for this run was 1. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + batch_size: 192 + eval_batch_size: 192 + num_workers: 16 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: + - 52083 + - 83333 + - 104166 + warmup_steps: 31250 + max_epochs: 15 + force_coefficient: 100 + + 200k: + model: + hidden_channels: 1024 + num_filters: 256 + num_interactions: 3 + num_gaussians: 200 + optim: + batch_size: 128 + eval_batch_size: 128 + num_workers: 16 + lr_initial: 0.0005 + lr_gamma: 0.1 + lr_milestones: + - 7812 + - 12500 + - 15625 + warmup_steps: 4687 + max_epochs: 30 + force_coefficient: 100 + + 2M: {} + + 20M: + model: + hidden_channels: 1024 + num_filters: 256 + num_interactions: 5 + num_gaussians: 200 + optim: + # *** Important note *** + # The total number of gpus used for this run was 48. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + batch_size: 24 + eval_batch_size: 24 + num_workers: 16 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: + - 86805 + - 138888 + - 173611 + warmup_steps: 52083 + max_epochs: 30 + force_coefficient: 50 + + all: + model: + hidden_channels: 1024 + num_filters: 256 + num_interactions: 5 + num_gaussians: 200 + optim: + # *** Important note *** + # The total number of gpus used for this run was 64. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + batch_size: 20 + eval_batch_size: 20 + num_workers: 16 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: + - 313907 + - 523179 + - 732451 + warmup_steps: 209271 + max_epochs: 15 + force_coefficient: 30 + +qm9: + default: + model: + hidden_channels: 128 + num_gaussians: 100 + num_filters: 128 + num_interactions: 6 + cutoff: 5.0 + optim: + batch_size: 1024 + lr_initial: 0.001 + max_epochs: 1000 + decay_steps: 125000 + decay_rate: 0.01 + ema_decay: 0.999 + lr_gamma: 0.25 + lr_milestones: + - 17981 + - 26972 + - 35963 + - 52000 + - 100000 + warmup_steps: 1000 + + 10k: {} + all: {} + +qm7x: + default: + model: + hidden_channels: 384 + num_interactions: 4 + optim: + batch_size: 128 + lr_initial: 0.001 + max_epochs: 25 + lr_gamma: 0.1 + lr_milestones: + - 17981 + - 26972 + - 35963 + warmup_steps: 15000 + + all: {} + 1k: {} diff --git a/configs/models/inddpp.yaml b/configs/models/inddpp.yaml new file mode 100644 index 0000000000..aae9ffb0f3 --- /dev/null +++ b/configs/models/inddpp.yaml @@ -0,0 +1,209 @@ +default: + model: + name: inddpp + hidden_channels: 256 + out_emb_channels: 192 + num_blocks: 3 + cutoff: 6.0 + num_radial: 6 + num_spherical: 7 + num_before_skip: 1 + num_after_skip: 2 + num_output_layers: 3 + regress_forces: False + use_pbc: True + basis_emb_size: 8 + envelope_exponent: 5 + act: swish + int_emb_size: 64 + # drlab attributes: + tag_hidden_channels: 0 # 64 + pg_hidden_channels: 0 # 32 -> period & group embedding hidden channels + phys_embeds: False # True + phys_hidden_channels: 0 + energy_head: False # can be {False, weighted-av-initial-embeds, weighted-av-final-embeds, pooling, graclus, random} + optim: + batch_size: 4 + eval_batch_size: 4 + num_workers: 4 + lr_gamma: 0.1 + warmup_factor: 0.2 + +# ------------------- +# ----- IS2RE ----- +# ------------------- + +is2re: + # *** Important note *** + # The total number of gpus used for this run was 1. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + 10k: + optim: + lr_initial: 0.0001 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 20000 + - 40000 + - 60000 + warmup_steps: 10000 + max_epochs: 20 + batch_size: 16 + eval_batch_size: 16 + + 100k: + optim: + lr_initial: 0.0001 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 200000 + - 400000 + - 600000 + warmup_steps: 100000 + max_epochs: 15 + batch_size: 16 + eval_batch_size: 16 + + all: + optim: + lr_initial: 0.0001 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 115082 + - 230164 + - 345246 + warmup_steps: 57541 + max_epochs: 8 + batch_size: 16 + eval_batch_size: 16 + +# ------------------ +# ----- S2EF ----- +# ------------------ + +s2ef: + default: + model: + regress_forces: "from_energy" + optim: + num_workers: 8 + eval_every: 10000 + + 200k: + optim: + # *** Important note *** + # The total number of gpus used for this run was 4. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + batch_size: 48 + eval_batch_size: 48 + lr_initial: 0.00001 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 5208 + - 8333 + - 10416 + warmup_steps: 3125 + max_epochs: 10 + force_coefficient: 50 + + 2M: + optim: + batch_size: 96 + eval_batch_size: 96 + eval_every: 10000 + num_workers: 8 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 20833 + - 31250 + - 41666 + warmup_steps: 10416 + warmup_factor: 0.2 + max_epochs: 15 + force_coefficient: 50 + model: + hidden_channels: 192 + out_emb_channels: 192 + num_blocks: 3 + cutoff: 6.0 + num_radial: 6 + num_spherical: 7 + num_before_skip: 1 + num_after_skip: 2 + num_output_layers: 3 + regress_forces: True + use_pbc: True + + 20M: + optim: + # *** Important note *** + # The total number of gpus used for this run was 64. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + batch_size: 12 + eval_batch_size: 12 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 78125 + - 130208 + - 208333 + warmup_steps: 52083 + max_epochs: 15 + force_coefficient: 50 + + all: + optim: + # *** Important note *** + # The total number of gpus used for this run was 256. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + batch_size: 8 + eval_batch_size: 8 + lr_initial: 0.0001 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 130794 + - 196192 + - 261589 + warmup_steps: 130794 + max_epochs: 7 + force_coefficient: 50 + +qm9: + default: + model: + num_blocks: 6 + hidden_channels: 128 + optim: + # *** Important note *** + # The total number of gpus used for this run was 4. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + lr_initial: 0.001 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 2000000 + - 4000000 + - 6000000 + warmup_steps: 3000 + lr_gamma: 0.1 + batch_size: 128 + max_epochs: 600 + + 10k: {} + all: {} + +qm7x: + default: + optim: + # *** Important note *** + # The total number of gpus used for this run was 4. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + lr_initial: 0.0001 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 115082 + - 230164 + - 345246 + warmup_steps: 57541 + max_epochs: 8 + + all: {} + 1k: {} diff --git a/configs/models/indfaenet.yaml b/configs/models/indfaenet.yaml new file mode 100644 index 0000000000..acfb22166f --- /dev/null +++ b/configs/models/indfaenet.yaml @@ -0,0 +1,271 @@ +default: + model: + name: indfaenet + act: swish + hidden_channels: 128 + num_filters: 100 + num_interactions: 3 + num_gaussians: 100 + cutoff: 6.0 + use_pbc: True + regress_forces: False + # drlab attributes: + tag_hidden_channels: 0 # 32 + pg_hidden_channels: 0 # 32 -> period & group embedding hidden channels + phys_embeds: False # True + phys_hidden_channels: 0 + energy_head: False # can be {False, weighted-av-initial-embeds, weighted-av-final-embeds, pooling, graclus, random} + # faenet new features + skip_co: False # output skip connections {False, "add", "concat"} + second_layer_MLP: False # in EmbeddingBlock + complex_mp: False + edge_embed_type: rij # {'rij','all_rij','sh', 'all'}) + mp_type: base # {'base', 'simple', 'updownscale', 'att', 'base_with_att', 'local_env'} + graph_norm: False # bool + att_heads: 1 # int + force_decoder_type: "mlp" # can be {"" or "simple"} | only used if regress_forces is True + force_decoder_model_config: + simple: + hidden_channels: 128 + norm: batch1d # batch1d, layer or null + mlp: + hidden_channels: 256 + norm: batch1d # batch1d, layer or null + res: + hidden_channels: 128 + norm: batch1d # batch1d, layer or null + res_updown: + hidden_channels: 128 + norm: batch1d # batch1d, layer or null + optim: + batch_size: 64 + eval_batch_size: 64 + num_workers: 4 + lr_gamma: 0.1 + lr_initial: 0.001 + warmup_factor: 0.2 + max_epochs: 20 + energy_grad_coefficient: 10 + force_coefficient: 30 + energy_coefficient: 1 + + frame_averaging: False # 2D, 3D, da, False + fa_frames: False # can be {None, full, random, det, e3, e3-random, e3-det} + +# ------------------- +# ----- IS2RE ----- +# ------------------- + +is2re: + # *** Important note *** + # The total number of gpus used for this run was 1. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + 10k: + optim: + lr_initial: 0.005 + lr_milestones: # epochs at which lr_initial <- lr_initial * lr_gamma + - 1562 + - 2343 + - 3125 + warmup_steps: 468 + max_epochs: 20 + + 100k: + model: + hidden_channels: 256 + optim: + lr_initial: 0.005 + lr_milestones: # epochs at which lr_initial <- lr_initial * lr_gamma + - 1562 + - 2343 + - 3125 + warmup_steps: 468 + max_epochs: 20 + + all: + model: + hidden_channels: 384 + num_interactions: 4 + optim: + batch_size: 256 + eval_batch_size: 256 + lr_initial: 0.001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 18000 + - 27000 + - 37000 + warmup_steps: 6000 + max_epochs: 20 + +# ------------------ +# ----- S2EF ----- +# ------------------ + +# For 2 GPUs + +s2ef: + default: + model: + num_interactions: 4 + hidden_channels: 750 + num_gaussians: 200 + num_filters: 256 + regress_forces: "direct" + force_coefficient: 30 + energy_grad_coefficient: 10 + optim: + batch_size: 96 + eval_batch_size: 96 + warmup_factor: 0.2 + lr_gamma: 0.1 + lr_initial: 0.0001 + max_epochs: 15 + warmup_steps: 30000 + lr_milestones: + - 55000 + - 75000 + - 10000 + + 200k: {} + + # 1 gpus + 2M: + model: + num_interactions: 5 + hidden_channels: 1024 + num_gaussians: 200 + num_filters: 256 + optim: + batch_size: 192 + eval_batch_size: 192 + + 20M: {} + + all: {} + +qm9: + default: + model: + act: swish + att_heads: 1 + complex_mp: true + cutoff: 6.0 + edge_embed_type: all_rij + energy_head: '' + graph_norm: true + graph_rewiring: null + hidden_channels: 400 + max_num_neighbors: 30 + mp_type: updownscale_base + num_filters: 480 + num_gaussians: 100 + num_interactions: 5 + otf_graph: false + pg_hidden_channels: 32 + phys_embeds: false + phys_hidden_channels: 0 + regress_forces: '' + second_layer_MLP: true + skip_co: true + tag_hidden_channels: 0 + use_pbc: false + + optim: + batch_size: 64 + es_min_abs_change: 1.0e-06 + es_patience: 20 + es_warmup_epochs: 600 + eval_batch_size: 64 + factor: 0.9 + lr_initial: 0.0003 + loss_energy: mse + lr_gamma: 0.1 + lr_initial: 0.001 + max_epochs: 1500 + min_lr: 1.0e-06 + mode: min + optimizer: AdamW + patience: 15 + scheduler: ReduceLROnPlateau + threshold: 0.0001 + threshold_mode: abs + verbose: true + warmup_factor: 0.2 + warmup_steps: 3000 + + 10k: {} + all: {} + +qm7x: + default: + model: # SOTA settings + act: swish + att_heads: 1 + complex_mp: true + cutoff: 5.0 + edge_embed_type: all_rij + energy_head: false + force_decoder_model_config: + mlp: + hidden_channels: 256 + norm: batch1d + res: + hidden_channels: 128 + norm: batch1d + res_updown: + hidden_channels: 128 + norm: layer + simple: + hidden_channels: 128 + norm: batch1d + force_decoder_type: res_updown + graph_norm: false + hidden_channels: 500 + max_num_neighbors: 40 + mp_type: updownscale_base + num_filters: 400 + num_gaussians: 50 + num_interactions: 5 + otf_graph: false + pg_hidden_channels: 32 + phys_embeds: true + phys_hidden_channels: 0 + regress_forces: direct_with_gradient_target + second_layer_MLP: true + skip_co: false + tag_hidden_channels: 0 + use_pbc: false + + optim: + batch_size: 100 + energy_grad_coefficient: 5 + eval_batch_size: 100 + eval_every: 0.34 + factor: 0.75 + force_coefficient: 75 + loss_energy: mae + loss_force: mse + lr_gamma: 0.1 + lr_initial: 0.000193 + max_steps: 4000000 + min_lr: 1.0e-06 + mode: min + optimizer: AdamW + scheduler: ReduceLROnPlateau + threshold: 0.001 + threshold_mode: abs + verbose: true + warmup_factor: 0.2 + warmup_steps: 3000 + + all: {} + 1k: {} + +qm9: + default: + model: + use_pbc: False + all: {} + 10k: {} diff --git a/configs/models/indgemnet_oc.yaml b/configs/models/indgemnet_oc.yaml new file mode 100644 index 0000000000..8cf5c66d17 --- /dev/null +++ b/configs/models/indgemnet_oc.yaml @@ -0,0 +1,102 @@ +default: + model: + name: indgemnet_oc + num_spherical: 7 + num_radial: 128 + num_blocks: 4 + emb_size_atom: 256 + emb_size_edge: 512 + emb_size_trip_in: 64 + emb_size_trip_out: 64 + emb_size_quad_in: 32 + emb_size_quad_out: 32 + emb_size_aint_in: 64 + emb_size_aint_out: 64 + emb_size_rbf: 16 + emb_size_cbf: 16 + emb_size_sbf: 32 + num_before_skip: 2 + num_after_skip: 2 + num_concat: 1 + num_atom: 3 + num_output_afteratom: 3 + cutoff: 12.0 + cutoff_qint: 12.0 + cutoff_aeaint: 12.0 + cutoff_aint: 12.0 + max_neighbors: 30 + max_neighbors_qint: 8 + max_neighbors_aeaint: 20 + max_neighbors_aint: 1000 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + sbf: + name: legendre_outer + extensive: True + output_init: HeOrthogonal + activation: silu + scale_file: configs/models/scaling_factors/gemnet-oc.pt + + regress_forces: True + direct_forces: True + forces_coupled: False + + quad_interaction: True + atom_edge_interaction: True + edge_atom_interaction: True + atom_interaction: True + + num_atom_emb_layers: 2 + num_global_out_layers: 2 + qint_tags: [1, 2] + + # PhAST + tag_hidden_channels: 0 # 64 + pg_hidden_channels: 0 # 32 -> period & group embedding hidden channels + phys_embeds: False # True + phys_hidden_channels: 0 + energy_head: False # can be {False, weighted-av-initial-embeds, weighted-av-final-embeds, pooling, graclus, random} + + optim: + batch_size: 16 + eval_batch_size: 16 + load_balancing: atoms + eval_every: 5000 + num_workers: 2 + lr_initial: 5.e-4 + optimizer: AdamW + optimizer_params: {"amsgrad": True} + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 + loss_energy: mae + loss_force: l2mae + weight_decay: 0 + +is2re: + default: + model: + regress_forces: False + num_targets: 1 + 10k: {} + all: {} + +s2ef: + default: + model: + num_targets: 1 + 200k: {} + 2M: {} + 20M: {} + all: {} diff --git a/configs/models/indgemnet_t.yaml b/configs/models/indgemnet_t.yaml new file mode 100644 index 0000000000..a68b95f7b4 --- /dev/null +++ b/configs/models/indgemnet_t.yaml @@ -0,0 +1,113 @@ +# From OCP original repo -> https://github.com/Open-Catalyst-Project/ocp/blob/d16de9ee6f26d8661be5b9171e8c73c80237a82f/configs/oc22/is2re/gemnet-dT/gemnet-dT.yml +# Run this on 1 GPU -- so with an effective batch size of 8. + +default: + model: + name: indgemnet_t + use_pbc: true + num_spherical: 7 + num_radial: 64 + num_blocks: 5 + emb_size_atom: 256 + emb_size_edge: 512 + emb_size_trip: 64 + emb_size_rbf: 64 + emb_size_cbf: 16 + emb_size_bil_trip: 64 + num_before_skip: 1 + num_after_skip: 2 + num_concat: 1 + num_atom: 3 + cutoff: 6.0 + max_neighbors: 50 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + extensive: True + otf_graph: False + output_init: HeOrthogonal + activation: silu + scale_file: configs/models/scaling_factors/gemnet-dT_c12.json + regress_forces: False + # PhAST + tag_hidden_channels: 0 # 64 + pg_hidden_channels: 0 # 32 -> period & group embedding hidden channels + phys_hidden_channels: 0 # 32 -> physical properties embedding hidden channels + phys_embeds: False # True + optim: + batch_size: 8 + eval_batch_size: 8 + num_workers: 2 + lr_initial: 1.e-4 + optimizer: AdamW + optimizer_params: { "amsgrad": True } + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 + loss_energy: mae + +# ------------------- +# ----- IS2RE ----- +# ------------------- + +is2re: + 10k: {} + + 100k: {} + + all: {} +# ------------------ +# ----- S2EF ----- +# ------------------ + +s2ef: + default: + model: + cutoff: 6.0 + scale_file: configs/models/scaling_factors/gemnet-dT.json + regress_forces: "direct" + otf_graph: False + max_neighbors: 50 + num_radial: 128 + num_blocks: 3 + emb_size_atom: 512 + emb_size_trip: 64 + emb_size_rbf: 16 + optim: + clip_grad_norm: 10 + loss_force: l2mae + batch_size: 32 + eval_batch_size: 32 + lr_initial: 5.e-4 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + + 200k: {} + + 2M: {} + + 20M: {} + + all: {} + +qm9: + default: {} + + 10k: {} + all: {} + +qm7x: + default: {} + + all: {} + 1k: {} diff --git a/configs/models/indschnet.yaml b/configs/models/indschnet.yaml new file mode 100644 index 0000000000..d8acf62ba2 --- /dev/null +++ b/configs/models/indschnet.yaml @@ -0,0 +1,225 @@ +default: + model: + name: indschnet + num_filters: 128 + num_gaussians: 100 + hidden_channels: 256 + num_interactions: 3 + cutoff: 6.0 + use_pbc: True + regress_forces: False + readout: add + atomref: null + # drlab attributes: + tag_hidden_channels: 0 # 32 + pg_hidden_channels: 0 # 32 -> period & group embedding hidden channels + phys_embeds: False # True + phys_hidden_channels: 0 + energy_head: False # can be {False, weighted-av-initial-embeds, weighted-av-final-embeds, pooling, graclus, random} + optim: + batch_size: 64 + eval_batch_size: 64 + num_workers: 4 + lr_gamma: 0.1 + warmup_factor: 0.2 + +# ------------------- +# ----- IS2RE ----- +# ------------------- + +is2re: + # *** Important note *** + # The total number of gpus used for this run was 1. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + 10k: + model: + hidden_channels: 256 + num_interactions: 3 + optim: + lr_initial: 0.005 + max_epochs: 20 + lr_milestones: + - 1562 + - 2343 + - 3125 + warmup_steps: 468 + batch_size: 256 + eval_batch_size: 256 + + 100k: + model: + hidden_channels: 384 + num_interactions: 4 + optim: + lr_initial: 0.0005 + max_epochs: 25 + lr_milestones: + - 15625 + - 31250 + - 46875 + warmup_steps: 9375 + batch_size: 256 + eval_batch_size: 256 + + all: + model: + hidden_channels: 384 + num_interactions: 4 + optim: + lr_initial: 0.001 + max_epochs: 17 + lr_gamma: 0.1 + lr_milestones: + - 17981 + - 26972 + - 35963 + warmup_steps: 5394 + batch_size: 256 + eval_batch_size: 256 + +# ------------------ +# ----- S2EF ----- +# ------------------ + +s2ef: + default: + model: + regress_forces: "from_energy" + hidden_channels: 1024 + num_filters: 256 + num_interactions: 5 + num_gaussians: 200 + optim: + # *** Important note *** + # The total number of gpus used for this run was 1. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + batch_size: 192 + eval_batch_size: 192 + num_workers: 16 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: + - 52083 + - 83333 + - 104166 + warmup_steps: 31250 + max_epochs: 15 + force_coefficient: 100 + + 200k: + model: + hidden_channels: 1024 + num_filters: 256 + num_interactions: 3 + num_gaussians: 200 + optim: + batch_size: 128 + eval_batch_size: 128 + num_workers: 16 + lr_initial: 0.0005 + lr_gamma: 0.1 + lr_milestones: + - 7812 + - 12500 + - 15625 + warmup_steps: 4687 + max_epochs: 30 + force_coefficient: 100 + + 2M: {} + + 20M: + model: + hidden_channels: 1024 + num_filters: 256 + num_interactions: 5 + num_gaussians: 200 + optim: + # *** Important note *** + # The total number of gpus used for this run was 48. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + batch_size: 24 + eval_batch_size: 24 + num_workers: 16 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: + - 86805 + - 138888 + - 173611 + warmup_steps: 52083 + max_epochs: 30 + force_coefficient: 50 + + all: + model: + hidden_channels: 1024 + num_filters: 256 + num_interactions: 5 + num_gaussians: 200 + optim: + # *** Important note *** + # The total number of gpus used for this run was 64. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + batch_size: 20 + eval_batch_size: 20 + num_workers: 16 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: + - 313907 + - 523179 + - 732451 + warmup_steps: 209271 + max_epochs: 15 + force_coefficient: 30 + +qm9: + default: + model: + hidden_channels: 128 + num_gaussians: 100 + num_filters: 128 + num_interactions: 6 + cutoff: 5.0 + optim: + batch_size: 1024 + lr_initial: 0.001 + max_epochs: 1000 + decay_steps: 125000 + decay_rate: 0.01 + ema_decay: 0.999 + lr_gamma: 0.25 + lr_milestones: + - 17981 + - 26972 + - 35963 + - 52000 + - 100000 + warmup_steps: 1000 + + 10k: {} + all: {} + +qm7x: + default: + model: + hidden_channels: 384 + num_interactions: 4 + optim: + batch_size: 128 + lr_initial: 0.001 + max_epochs: 25 + lr_gamma: 0.1 + lr_milestones: + - 17981 + - 26972 + - 35963 + warmup_steps: 15000 + + all: {} + 1k: {} diff --git a/configs/models/painn.yaml b/configs/models/painn.yaml index 2c0abac112..c138652a81 100644 --- a/configs/models/painn.yaml +++ b/configs/models/painn.yaml @@ -2,6 +2,9 @@ default: model: name: painn use_pbc: True + optim: + num_workers: 4 + eval_batch_size: 64 # ------------------- # ----- IS2RE ----- diff --git a/configs/models/tasks/is2re.yaml b/configs/models/tasks/is2re.yaml index cf47f159de..787e20295f 100644 --- a/configs/models/tasks/is2re.yaml +++ b/configs/models/tasks/is2re.yaml @@ -16,6 +16,8 @@ default: otf_graph: False max_num_neighbors: 40 mode: train + adsorbates: all # {"*O", "*OH", "*OH2", "*H"} + adsorbates_ref_dir: /network/scratch/s/schmidtv/ocp/datasets/ocp/per_ads dataset: default_val: val_id train: diff --git a/debug.py b/debug.py new file mode 100644 index 0000000000..0787829521 --- /dev/null +++ b/debug.py @@ -0,0 +1,200 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import logging +import os +import time +import traceback +import sys +import torch +from yaml import dump + +from ocpmodels.common import dist_utils +from ocpmodels.common.flags import flags +from ocpmodels.common.registry import registry +from ocpmodels.common.utils import ( + JOB_ID, + auto_note, + build_config, + merge_dicts, + move_lmdb_data_to_slurm_tmpdir, + resolve, + setup_imports, + setup_logging, + update_from_sbatch_py_vars, + set_min_hidden_channels, +) +from ocpmodels.common.orion_utils import ( + continue_orion_exp, + load_orion_exp, + sample_orion_hparams, +) +from ocpmodels.trainers import BaseTrainer + +# os.environ["CUDA_LAUNCH_BLOCKING"] = "1" +torch.multiprocessing.set_sharing_strategy("file_system") + + +def print_warnings(): + warnings = [ + "`max_num_neighbors` is set to 40. This should be tuned per model.", + "`tag_specific_weights` is not handled for " + + "`regress_forces: direct_with_gradient_target` in compute_loss()", + ] + print("\n" + "-" * 80 + "\n") + print("🛑 OCP-DR-Lab Warnings (nota benes):") + for warning in warnings: + print(f" • {warning}") + print("Remove warnings when they are fixed in the code/configs.") + print("\n" + "-" * 80 + "\n") + + +def wrap_up(args, start_time, error=None, signal=None, trainer=None): + total_time = time.time() - start_time + logging.info(f"Total time taken: {total_time}") + if trainer and trainer.logger is not None: + trainer.logger.log({"Total time": total_time}) + + if args.distributed: + print( + "\nWaiting for all processes to finish with dist_utils.cleanup()...", + end="", + ) + dist_utils.cleanup() + print("Done!") + + if "interactive" not in os.popen(f"squeue -hj {JOB_ID}").read(): + print("\nSelf-canceling SLURM job in 32s", JOB_ID) + os.popen(f"sleep 32 && scancel {JOB_ID}") + + if trainer and trainer.logger: + trainer.logger.finish(error or signal) + + +if __name__ == "__main__": + error = signal = orion_exp = orion_trial = trainer = None + orion_race_condition = False + hparams = {} + + setup_logging() + + parser = flags.get_parser() + args, override_args = parser.parse_known_args() + args = update_from_sbatch_py_vars(args) + if args.logdir: + args.logdir = resolve(args.logdir) + + # -- Build config + + args.wandb_name = "alvaro-carbonero-math" + args.wandb_project = "ocp-alvaro" + args.config = "inddpp-is2re-all" + + args.graph_rewiring = "remove-tag-0" + + trainer_config = build_config(args, override_args) + + if dist_utils.is_master(): + trainer_config = move_lmdb_data_to_slurm_tmpdir(trainer_config) + dist_utils.synchronize() + + trainer_config["dataset"] = dist_utils.broadcast_from_master( + trainer_config["dataset"] + ) + + # trainer_config["optim"]["batch_size"] = 32 + # trainer_config["optim"]["eval_batch_size"] = 32 + # trainer_config["optim"]["max_epochs"] = 30 + # trainer_config["optim"]["es_patience"] = 5 + trainer_config["optim"]["num_workers"] = 0 + # trainer_config["model"]["regress_forces"] = False + + # -- Initial setup + + setup_imports() + print("\n🚩 All things imported.\n") + start_time = time.time() + + try: + # -- Orion + + if args.orion_exp_config_path and dist_utils.is_master(): + orion_exp = load_orion_exp(args) + hparams, orion_trial = sample_orion_hparams(orion_exp, trainer_config) + + if hparams.get("orion_race_condition"): + logging.warning("\n\n ⛔️ Orion race condition. Stopping here.\n\n") + wrap_up(args, start_time, error, signal) + sys.exit() + + hparams = dist_utils.broadcast_from_master(hparams) + if hparams: + print("\n💎 Received hyper-parameters from Orion:") + print(dump(hparams), end="\n") + trainer_config = merge_dicts(trainer_config, hparams) + + # -- Setup trainer + trainer_config = continue_orion_exp(trainer_config) + trainer_config = auto_note(trainer_config) + trainer_config = set_min_hidden_channels(trainer_config) + + try: + cls = registry.get_trainer_class(trainer_config["trainer"]) + trainer: BaseTrainer = cls(**trainer_config) + except Exception as e: + traceback.print_exc() + logging.warning(f"\n💀 Error in trainer initialization: {e}\n") + signal = "trainer_init_error" + + if signal is None: + task = registry.get_task_class(trainer_config["mode"])(trainer_config) + task.setup(trainer) + print_warnings() + + # -- Start Training + + signal = task.run() + + # -- End of training + + # handle job preemption / time limit + if signal == "SIGTERM": + print("\nJob was preempted. Wrapping up...\n") + if trainer: + trainer.close_datasets() + + dist_utils.synchronize() + + objective = dist_utils.broadcast_from_master( + trainer.objective if trainer else None + ) + + if orion_exp is not None: + if objective is None: + if signal == "loss_is_nan": + objective = 1e12 + print("Received NaN objective from worker. Setting to 1e12.") + if signal == "trainer_init_error": + objective = 1e12 + print( + "Received trainer_init_error from worker.", + "Setting objective to 1e12.", + ) + if objective is not None: + orion_exp.observe( + orion_trial, + [{"type": "objective", "name": "energy_mae", "value": objective}], + ) + else: + print("Received None objective from worker. Skipping observation.") + + except Exception: + error = True + print(traceback.format_exc()) + + finally: + wrap_up(args, start_time, error, signal, trainer=trainer) diff --git a/debug_faenet.py b/debug_faenet.py new file mode 100644 index 0000000000..56d79c3d68 --- /dev/null +++ b/debug_faenet.py @@ -0,0 +1,222 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import logging +import os +import time +import traceback +import sys +import torch +from yaml import dump + +from ocpmodels.common import dist_utils +from ocpmodels.common.flags import flags +from ocpmodels.common.registry import registry +from ocpmodels.common.utils import ( + JOB_ID, + auto_note, + build_config, + merge_dicts, + move_lmdb_data_to_slurm_tmpdir, + resolve, + setup_imports, + setup_logging, + update_from_sbatch_py_vars, + set_min_hidden_channels, +) +from ocpmodels.common.orion_utils import ( + continue_orion_exp, + load_orion_exp, + sample_orion_hparams, +) +from ocpmodels.trainers import BaseTrainer + +# os.environ["CUDA_LAUNCH_BLOCKING"] = "1" +torch.multiprocessing.set_sharing_strategy("file_system") + + +def print_warnings(): + warnings = [ + "`max_num_neighbors` is set to 40. This should be tuned per model.", + "`tag_specific_weights` is not handled for " + + "`regress_forces: direct_with_gradient_target` in compute_loss()", + ] + print("\n" + "-" * 80 + "\n") + print("🛑 OCP-DR-Lab Warnings (nota benes):") + for warning in warnings: + print(f" • {warning}") + print("Remove warnings when they are fixed in the code/configs.") + print("\n" + "-" * 80 + "\n") + + +def wrap_up(args, start_time, error=None, signal=None, trainer=None): + total_time = time.time() - start_time + logging.info(f"Total time taken: {total_time}") + if trainer and trainer.logger is not None: + trainer.logger.log({"Total time": total_time}) + + if args.distributed: + print( + "\nWaiting for all processes to finish with dist_utils.cleanup()...", + end="", + ) + dist_utils.cleanup() + print("Done!") + + if "interactive" not in os.popen(f"squeue -hj {JOB_ID}").read(): + print("\nSelf-canceling SLURM job in 32s", JOB_ID) + os.popen(f"sleep 32 && scancel {JOB_ID}") + + if trainer and trainer.logger: + trainer.logger.finish(error or signal) + + +if __name__ == "__main__": + error = signal = orion_exp = orion_trial = trainer = None + orion_race_condition = False + hparams = {} + + setup_logging() + + parser = flags.get_parser() + args, override_args = parser.parse_known_args() + args = update_from_sbatch_py_vars(args) + if args.logdir: + args.logdir = resolve(args.logdir) + + # -- Build config + + args.wandb_name = "alvaro-carbonero-math" + args.wandb_project = "ocp-alvaro" + args.test_ri = True + args.mode = "train" + args.graph_rewiring = "remove-tag-0" + args.cp_data_to_tmpdir = True + args.config = "indfaenet-is2re-10k" + args.frame_averaging = "2D" + args.fa_frames = "se3-random" + + trainer_config = build_config(args, override_args) + + if dist_utils.is_master(): + trainer_config = move_lmdb_data_to_slurm_tmpdir(trainer_config) + dist_utils.synchronize() + + trainer_config["dataset"] = dist_utils.broadcast_from_master( + trainer_config["dataset"] + ) + + trainer_config["model"]["edge_embed_type"] = "all_rij" + trainer_config["model"]["mp_type"] = "updownscale" + trainer_config["model"]["phys_embeds"] = True + trainer_config["model"]["tag_hidden_channels"] = 32 + trainer_config["model"]["pg_hidden_channels"] = 64 + trainer_config["model"]["energy_head"] = "weighted-av-final-embeds" + trainer_config["model"]["complex_mp"] = False + trainer_config["model"]["graph_norm"] = True + trainer_config["model"]["hidden_channels"] = 352 + trainer_config["model"]["num_filters"] = 448 + trainer_config["model"]["num_gaussians"] = 99 + trainer_config["model"]["num_interactions"] = 6 + trainer_config["model"]["second_layer_MLP"] = True + trainer_config["model"]["skip_co"] = "concat" + # trainer_config["model"]["transformer_out"] = False + trainer_config["model"]["afaenet_gat_mode"] = "v1" + # trainer_config["model"]["disconnected_mlp"] = True + + # trainer_config["optim"]["batch_sizes"] = 256 + # trainer_config["optim"]["eval_batch_sizes"] = 256 + trainer_config["optim"]["lr_initial"] = 0.0019 + trainer_config["optim"]["scheduler"] = "LinearWarmupCosineAnnealingLR" + trainer_config["optim"]["max_epochs"] = 20 + trainer_config["optim"]["eval_every"] = 0.4 + + # -- Initial setup + + setup_imports() + print("\n🚩 All things imported.\n") + start_time = time.time() + + try: + # -- Orion + + if args.orion_exp_config_path and dist_utils.is_master(): + orion_exp = load_orion_exp(args) + hparams, orion_trial = sample_orion_hparams(orion_exp, trainer_config) + + if hparams.get("orion_race_condition"): + logging.warning("\n\n ⛔️ Orion race condition. Stopping here.\n\n") + wrap_up(args, start_time, error, signal) + sys.exit() + + hparams = dist_utils.broadcast_from_master(hparams) + if hparams: + print("\n💎 Received hyper-parameters from Orion:") + print(dump(hparams), end="\n") + trainer_config = merge_dicts(trainer_config, hparams) + + # -- Setup trainer + trainer_config = continue_orion_exp(trainer_config) + trainer_config = auto_note(trainer_config) + trainer_config = set_min_hidden_channels(trainer_config) + + try: + cls = registry.get_trainer_class(trainer_config["trainer"]) + trainer: BaseTrainer = cls(**trainer_config) + except Exception as e: + traceback.print_exc() + logging.warning(f"\n💀 Error in trainer initialization: {e}\n") + signal = "trainer_init_error" + + if signal is None: + task = registry.get_task_class(trainer_config["mode"])(trainer_config) + task.setup(trainer) + print_warnings() + + # -- Start Training + + signal = task.run() + + # -- End of training + + # handle job preemption / time limit + if signal == "SIGTERM": + print("\nJob was preempted. Wrapping up...\n") + if trainer: + trainer.close_datasets() + + dist_utils.synchronize() + + objective = dist_utils.broadcast_from_master( + trainer.objective if trainer else None + ) + + if orion_exp is not None: + if objective is None: + if signal == "loss_is_nan": + objective = 1e12 + print("Received NaN objective from worker. Setting to 1e12.") + if signal == "trainer_init_error": + objective = 1e12 + print( + "Received trainer_init_error from worker.", + "Setting objective to 1e12.", + ) + if objective is not None: + orion_exp.observe( + orion_trial, + [{"type": "objective", "name": "energy_mae", "value": objective}], + ) + else: + print("Received None objective from worker. Skipping observation.") + + except Exception: + error = True + print(traceback.format_exc()) + + finally: + wrap_up(args, start_time, error, signal, trainer=trainer) diff --git a/debug_schnet.py b/debug_schnet.py new file mode 100644 index 0000000000..b1fec82ddb --- /dev/null +++ b/debug_schnet.py @@ -0,0 +1,209 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import logging +import os +import time +import traceback +import sys +import torch +from yaml import dump + +from ocpmodels.common import dist_utils +from ocpmodels.common.flags import flags +from ocpmodels.common.registry import registry +from ocpmodels.common.utils import ( + JOB_ID, + auto_note, + build_config, + merge_dicts, + move_lmdb_data_to_slurm_tmpdir, + resolve, + setup_imports, + setup_logging, + update_from_sbatch_py_vars, + set_min_hidden_channels, +) +from ocpmodels.common.orion_utils import ( + continue_orion_exp, + load_orion_exp, + sample_orion_hparams, +) +from ocpmodels.trainers import BaseTrainer + +# os.environ["CUDA_LAUNCH_BLOCKING"] = "1" +torch.multiprocessing.set_sharing_strategy("file_system") + + +def print_warnings(): + warnings = [ + "`max_num_neighbors` is set to 40. This should be tuned per model.", + "`tag_specific_weights` is not handled for " + + "`regress_forces: direct_with_gradient_target` in compute_loss()", + ] + print("\n" + "-" * 80 + "\n") + print("🛑 OCP-DR-Lab Warnings (nota benes):") + for warning in warnings: + print(f" • {warning}") + print("Remove warnings when they are fixed in the code/configs.") + print("\n" + "-" * 80 + "\n") + + +def wrap_up(args, start_time, error=None, signal=None, trainer=None): + total_time = time.time() - start_time + logging.info(f"Total time taken: {total_time}") + if trainer and trainer.logger is not None: + trainer.logger.log({"Total time": total_time}) + + if args.distributed: + print( + "\nWaiting for all processes to finish with dist_utils.cleanup()...", + end="", + ) + dist_utils.cleanup() + print("Done!") + + if "interactive" not in os.popen(f"squeue -hj {JOB_ID}").read(): + print("\nSelf-canceling SLURM job in 32s", JOB_ID) + os.popen(f"sleep 32 && scancel {JOB_ID}") + + if trainer and trainer.logger: + trainer.logger.finish(error or signal) + + +if __name__ == "__main__": + error = signal = orion_exp = orion_trial = trainer = None + orion_race_condition = False + hparams = {} + + setup_logging() + + parser = flags.get_parser() + args, override_args = parser.parse_known_args() + args = update_from_sbatch_py_vars(args) + if args.logdir: + args.logdir = resolve(args.logdir) + + # -- Build config + + args.wandb_name = "alvaro-carbonero-math" + args.wandb_project = "ocp-alvaro" + args.tag_hidden_channels: 32 + args.pg_hidden_channels: 32 + args.phys_embeds = True + args.phys_hidden_channels = 0 + args.energy_head = False + args.num_targets = 1 + args.otf_graph = False + args.max_num_neighbors = 40 + args.hidden_channels = 142 + args.graph_rewiring = "remove-tag-0" + args.config = "indschnet-is2re-10k" + + trainer_config = build_config(args, override_args) + + if dist_utils.is_master(): + trainer_config = move_lmdb_data_to_slurm_tmpdir(trainer_config) + dist_utils.synchronize() + + trainer_config["dataset"] = dist_utils.broadcast_from_master( + trainer_config["dataset"] + ) + + trainer_config["optim"]["batch_size"] = 64 + trainer_config["optim"]["eval_batch_size"] = 64 + trainer_config["optim"]["lr_initial"] = 0.0005 + trainer_config["optim"]["max_epochs"] = 30 + trainer_config["optim"]["es_patience"] = 5 + + trainer_config["model"]["gat_mode"] = "v1" + + # -- Initial setup + + setup_imports() + print("\n🚩 All things imported.\n") + start_time = time.time() + + try: + # -- Orion + + if args.orion_exp_config_path and dist_utils.is_master(): + orion_exp = load_orion_exp(args) + hparams, orion_trial = sample_orion_hparams(orion_exp, trainer_config) + + if hparams.get("orion_race_condition"): + logging.warning("\n\n ⛔️ Orion race condition. Stopping here.\n\n") + wrap_up(args, start_time, error, signal) + sys.exit() + + hparams = dist_utils.broadcast_from_master(hparams) + if hparams: + print("\n💎 Received hyper-parameters from Orion:") + print(dump(hparams), end="\n") + trainer_config = merge_dicts(trainer_config, hparams) + + # -- Setup trainer + trainer_config = continue_orion_exp(trainer_config) + trainer_config = auto_note(trainer_config) + trainer_config = set_min_hidden_channels(trainer_config) + + try: + cls = registry.get_trainer_class(trainer_config["trainer"]) + trainer: BaseTrainer = cls(**trainer_config) + except Exception as e: + traceback.print_exc() + logging.warning(f"\n💀 Error in trainer initialization: {e}\n") + signal = "trainer_init_error" + + if signal is None: + task = registry.get_task_class(trainer_config["mode"])(trainer_config) + task.setup(trainer) + print_warnings() + + # -- Start Training + + signal = task.run() + + # -- End of training + + # handle job preemption / time limit + if signal == "SIGTERM": + print("\nJob was preempted. Wrapping up...\n") + if trainer: + trainer.close_datasets() + + dist_utils.synchronize() + + objective = dist_utils.broadcast_from_master( + trainer.objective if trainer else None + ) + + if orion_exp is not None: + if objective is None: + if signal == "loss_is_nan": + objective = 1e12 + print("Received NaN objective from worker. Setting to 1e12.") + if signal == "trainer_init_error": + objective = 1e12 + print( + "Received trainer_init_error from worker.", + "Setting objective to 1e12.", + ) + if objective is not None: + orion_exp.observe( + orion_trial, + [{"type": "objective", "name": "energy_mae", "value": objective}], + ) + else: + print("Received None objective from worker. Skipping observation.") + + except Exception: + error = True + print(traceback.format_exc()) + + finally: + wrap_up(args, start_time, error, signal, trainer=trainer) diff --git a/mila/launch_exp.py b/mila/launch_exp.py index 8bd00e7c9c..dec6bde850 100644 --- a/mila/launch_exp.py +++ b/mila/launch_exp.py @@ -1,3 +1,4 @@ +import copy import os import re import subprocess @@ -5,10 +6,8 @@ from pathlib import Path from minydra import resolved_args -from yaml import safe_load, dump - from sbatch import now -import copy +from yaml import dump, safe_load ROOT = Path(__file__).resolve().parent.parent @@ -143,14 +142,16 @@ def cli_arg(args, key=""): s += cli_arg(v, key=f"{parent}{k}") else: if " " in str(v) or "," in str(v) or isinstance(v, str): - if "'" in str(v) and '"' in str(v): - v = str(v).replace("'", "\\'") + if '"' in str(v): + v = str(v).replace('"', '\\"') v = f"'{v}'" elif "'" in str(v): - v = f'"{v}"' + v = f'\\"{v}\\"' else: v = f"'{v}'" s += f" --{parent}{k}={v}" + if "ads" in k: + print(s.split(" --")[-1]) return s @@ -175,10 +176,15 @@ def get_args_or_exp(key, args, exp): n_jobs = None args = resolved_args() assert "exp" in args - regex = args.get("match", ".*") + + regex = args.pop("match", ".*") + exp_name = args.pop("exp").replace(".yml", "").replace(".yaml", "") + no_confirm = args.pop("no_confirm", False) + + sbatch_overrides = args.to_dict() + ts = now() - exp_name = args.exp.replace(".yml", "").replace(".yaml", "") exp_file = find_exp(exp_name) exp = safe_load(exp_file.open("r")) @@ -231,6 +237,8 @@ def get_args_or_exp(key, args, exp): else: params["wandb_tags"] = exp_name + job = merge_dicts(job, sbatch_overrides) + py_args = f'py_args="{cli_arg(params).strip()}"' sbatch_args = " ".join( @@ -253,7 +261,7 @@ def get_args_or_exp(key, args, exp): text += "\n<><><> Experiment config:\n\n-----" + exp_file.read_text() + "-----" text += "\n<><><> Experiment runs:\n\n • " + "\n\n • ".join(commands) + separator - confirm = args.no_confirm or "y" in input("\n🚦 Confirm? [y/n] : ") + confirm = no_confirm or "y" in input("\n🚦 Confirm? [y/n] : ") if confirm: try: @@ -267,6 +275,10 @@ def get_args_or_exp(key, args, exp): for c, command in enumerate(commands): print(f"Launching job {c+1:3}", end="\r") outputs.append(os.popen(command).read().strip()) + if "Aborting" in outputs[-1]: + print("\nError submitting job", c + 1, ":", command) + print(outputs[-1].replace("Error while launching job:\n", "")) + print("\n") if " verbose=true" in command.lower(): print(outputs[-1]) except KeyboardInterrupt: @@ -283,6 +295,8 @@ def get_args_or_exp(key, args, exp): if is_interrupted: print("\n💀 Interrupted. Kill jobs with:\n$ scancel" + " ".join(jobs)) + elif not jobs: + print("\n❌ No jobs launched") else: text += f"{separator}All jobs launched: {' '.join(jobs)}" with outfile.open("w") as f: diff --git a/mila/sbatch.py b/mila/sbatch.py index de82809f8b..a4b24095c2 100644 --- a/mila/sbatch.py +++ b/mila/sbatch.py @@ -1,12 +1,13 @@ -from minydra import resolved_args, MinyDict -from pathlib import Path -from datetime import datetime import os +import re import subprocess -from shutil import copyfile import sys -import re +from datetime import datetime +from pathlib import Path +from shutil import copyfile + import yaml +from minydra import MinyDict, resolved_args IS_DRAC = ( "narval.calcul.quebec" in os.environ.get("HOSTNAME", "") @@ -24,13 +25,13 @@ # git commit: {git_commit} # cwd: {cwd} -{git_checkout} {sbatch_py_vars} export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4)) echo "Master port $MASTER_PORT" cd {code_loc} +{git_checkout} {modules} @@ -41,7 +42,7 @@ conda activate {env} fi {wandb_offline} -srun --output={output} {python_command} +srun --gpus-per-task=1 --output={output} {python_command} """ @@ -247,7 +248,6 @@ def load_sbatch_args_from_dir(dir): "cpus": int(sbatch_args["cpus-per-task"]), "mem": sbatch_args["mem"], "gres": sbatch_args["gres"], - "output": sbatch_args["output"], } return args @@ -417,7 +417,17 @@ def load_sbatch_args_from_dir(dir): print("\nDev mode: not actually executing the command 🤓\n") else: # not dev mode: run the command, make directories - out = subprocess.check_output(command.split(" ")).decode("utf-8").strip() + try: + out = ( + subprocess.check_output(command.split(" "), stderr=subprocess.STDOUT) + .decode("utf-8") + .strip() + ) + except subprocess.CalledProcessError as error: + print("Error while launching job:\n```") + print(error.output.decode("utf-8").strip()) + print("```\nAborting...") + sys.exit(1) jobid = out.split(" job ")[-1].strip() success = out.startswith("Submitted batch job") diff --git a/ocdata/LiFePO4.cif b/ocdata/LiFePO4.cif new file mode 100644 index 0000000000..2b01776980 --- /dev/null +++ b/ocdata/LiFePO4.cif @@ -0,0 +1,54 @@ +# generated using pymatgen +data_LiFePO4 +_symmetry_space_group_name_H-M 'P 1' +_cell_length_a 4.74644100 +_cell_length_b 10.44373000 +_cell_length_c 6.09022600 +_cell_angle_alpha 89.99726981 +_cell_angle_beta 90.00071024 +_cell_angle_gamma 90.00075935 +_symmetry_Int_Tables_number 1 +_chemical_formula_structural LiFePO4 +_chemical_formula_sum 'Li4 Fe4 P4 O16' +_cell_volume 301.89584168 +_cell_formula_units_Z 4 +loop_ + _symmetry_equiv_pos_site_id + _symmetry_equiv_pos_as_xyz + 1 'x, y, z' +loop_ + _atom_site_type_symbol + _atom_site_label + _atom_site_symmetry_multiplicity + _atom_site_fract_x + _atom_site_fract_y + _atom_site_fract_z + _atom_site_occupancy + Li Li0 1 0.00000100 0.00001200 0.00003300 1 + Li Li1 1 0.50000600 0.50000900 0.00003100 1 + Li Li2 1 0.50000300 0.50001200 0.49996900 1 + Li Li3 1 0.00000300 0.00001400 0.49996700 1 + Fe Fe4 1 0.47524900 0.21803400 0.75000400 1 + Fe Fe5 1 0.02475600 0.71803500 0.75000300 1 + Fe Fe6 1 0.97510000 0.28190400 0.25000000 1 + Fe Fe7 1 0.52491600 0.78190200 0.25000000 1 + P P8 1 0.41781800 0.09476500 0.24999700 1 + P P9 1 0.91789000 0.40522300 0.75000500 1 + P P10 1 0.08218100 0.59476800 0.24999700 1 + P P11 1 0.58210900 0.90522200 0.75000500 1 + O O12 1 0.74186800 0.09673300 0.25000400 1 + O O13 1 0.24192700 0.40323600 0.74999600 1 + O O14 1 0.75813200 0.59672800 0.25000200 1 + O O15 1 0.25807400 0.90323500 0.74999400 1 + O O16 1 0.20694800 0.45707900 0.25000000 1 + O O17 1 0.70682400 0.04293300 0.74999900 1 + O O18 1 0.29304600 0.95707800 0.25000400 1 + O O19 1 0.79318000 0.54293300 0.75000400 1 + O O20 1 0.28451400 0.16549800 0.04703500 1 + O O21 1 0.78450500 0.33453800 0.95297200 1 + O O22 1 0.78447400 0.33453000 0.54706100 1 + O O23 1 0.28449100 0.16550500 0.45292700 1 + O O24 1 0.21550200 0.66551200 0.45292200 1 + O O25 1 0.71552100 0.83452400 0.54706500 1 + O O26 1 0.71548600 0.83453300 0.95297000 1 + O O27 1 0.21547700 0.66550400 0.04703700 1 diff --git a/ocdata/__init__.py b/ocdata/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ocdata/adsorbates.py b/ocdata/adsorbates.py new file mode 100644 index 0000000000..316ab2e739 --- /dev/null +++ b/ocdata/adsorbates.py @@ -0,0 +1,64 @@ +import numpy as np +import pickle +import os + + +class Adsorbate: + """ + This class handles all things with the adsorbate. + Selects one (either specified or random), and stores info as an object + + Attributes + ---------- + atoms : Atoms + actual atoms of the adsorbate + smiles : str + SMILES representation of the adsorbate + bond_indices : list + indices of the atoms meant to be bonded to the surface + adsorbate_sampling_str : str + string capturing the adsorbate index and total possible adsorbates + """ + + def __init__( + self, adsorbate_database=None, specified_index=None, adsorbate_atoms=None + ): + if adsorbate_atoms is None: + assert adsorbate_database is not None + self.choose_adsorbate_pkl(adsorbate_database, specified_index) + else: + ( + self.adsorbate_sampling_str, + self.atoms, + self.smiles, + self.bond_indices, + ) = adsorbate_atoms + + def choose_adsorbate_pkl(self, adsorbate_database, specified_index=None): + """ + Chooses an adsorbate from our pkl based inverted index at random. + + Args: + adsorbate_database: A string pointing to the a pkl file that contains + an inverted index over different adsorbates. + specified_index: adsorbate index to choose instead of choosing a random one + Sets: + atoms `ase.Atoms` object of the adsorbate + smiles SMILES-formatted representation of the adsorbate + bond_indices list of integers indicating the indices of the atoms in + the adsorbate that are meant to be bonded to the surface + adsorbate_sampling_str Enum string specifying the sample, [index] + adsorbate_db_fname filename denoting which version was used to sample + """ + with open(adsorbate_database, "rb") as f: + inv_index = pickle.load(f) + + if specified_index is not None: + element = specified_index + else: + element = np.random.choice(len(inv_index)) + print(f"args.actions.adsorbate_id is None, choosing {element}") + + self.adsorbate_sampling_str = str(element) + self.atoms, self.smiles, self.bond_indices = inv_index[element] + self.adsorbate_db_fname = os.path.basename(adsorbate_database) diff --git a/ocdata/base_atoms/__init__.py b/ocdata/base_atoms/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ocdata/base_atoms/ase_dbs/__init__.py b/ocdata/base_atoms/ase_dbs/__init__.py new file mode 100644 index 0000000000..525320ee30 --- /dev/null +++ b/ocdata/base_atoms/ase_dbs/__init__.py @@ -0,0 +1,5 @@ +import os + + +BULK_DB = os.path.join(__path__[0], 'bulks.db') +ADSORBATE_DB = os.path.join(__path__[0], 'adsorbates.db') \ No newline at end of file diff --git a/ocdata/base_atoms/ase_dbs/adsorbates.db b/ocdata/base_atoms/ase_dbs/adsorbates.db new file mode 100644 index 0000000000..6ebf41a736 Binary files /dev/null and b/ocdata/base_atoms/ase_dbs/adsorbates.db differ diff --git a/ocdata/base_atoms/ase_dbs/bulks.db b/ocdata/base_atoms/ase_dbs/bulks.db new file mode 100644 index 0000000000..bc13ef5011 Binary files /dev/null and b/ocdata/base_atoms/ase_dbs/bulks.db differ diff --git a/ocdata/base_atoms/pkls/__init__.py b/ocdata/base_atoms/pkls/__init__.py new file mode 100644 index 0000000000..fdf425a977 --- /dev/null +++ b/ocdata/base_atoms/pkls/__init__.py @@ -0,0 +1,7 @@ +import os + + +BULK_PKL = os.path.join(__path__[0], 'bulks.pkl') +MAY12_BULK_PKL = os.path.join(__path__[0], 'bulks_may12.pkl') +ADSORBATE_PKL = os.path.join(__path__[0], 'adsorbates.pkl') +MAY12_SURFACE_ENUM_PKL = os.path.join(__path__[0], 'for_surface_enumeration_bulk_may12.pkl') diff --git a/ocdata/base_atoms/pkls/convert_db_to_pkl.py b/ocdata/base_atoms/pkls/convert_db_to_pkl.py new file mode 100644 index 0000000000..9f7fcc3d39 --- /dev/null +++ b/ocdata/base_atoms/pkls/convert_db_to_pkl.py @@ -0,0 +1,130 @@ +''' +Helper script convert db files to pkl files. +''' + +__author__ = 'Siddharth Goyal' + +import ase +import ase.db +import pickle + + +def get_bulk_inverted_index_1(input_bulk_database, max_num_elements): + ''' + Converts an input ASE.db to an inverted index to efficiently sample bulks + ''' + assert max_num_elements > 0 + db = ase.db.connect(input_bulk_database) + + index = {} + total_entries = 0 + for i in range(1, max_num_elements + 1): + index[i] = [] + rows = list(db.select(n_elements=i)) + print(len(rows)) + for r in range(len(rows)): + index[i].append((rows[r].toatoms(), rows[r].mpid)) + total_entries += 1 + + return index, total_entries + +def get_bulk_inverted_index_2(input_bulk_database, max_num_elements): + ''' + Converts an input ASE.db to an inverted index to efficiently sample bulks + ''' + assert max_num_elements > 0 + db = ase.db.connect(input_bulk_database) + rows = list(db.select()) + + index = {} + total_entries = 0 + for r in range(len(rows)): + bulk = rows[r].toatoms() + mpid = rows[r].mpid + formula_str = str(bulk.symbols) + num_ele = sum(1 for c in formula_str if c.isupper()) + if num_ele > max_num_elements: + continue + if num_ele not in index: + index[num_ele] = [] + index[num_ele].append((bulk, mpid)) + total_entries += 1 + + return index, total_entries + + +# handling 2 dbs +def convert_bulk(bulk_path1, bulk_path2, max_num_elements, output_pkl, precompute_pkl_for_surface_enumeration): + + index1, total_entries1 = get_bulk_inverted_index_1(bulk_path1, max_num_elements) + index2, total_entries2 = get_bulk_inverted_index_2(bulk_path2, max_num_elements) + + # As of bulk.db file from Kevin on 01 May 2020 + assert total_entries1 == 11010 + assert total_entries2 == 491 + + combined_total_entries = total_entries1 + total_entries2 + lst_for_surface_enumeration = [] + combined_index = {} + all_index_counter = 0 + + # Handle first db elements + for i in range(1, max_num_elements + 1): + combined_index[i] = [] + + for j in range(len(index1[i])): + sampling_str = str(j) + "/" + str(len(index1[i])) + "_" + str(all_index_counter) + "/11010" + bulk, mpid = index1[i][j] + current_obj = (bulk, mpid, sampling_str, all_index_counter) + print(current_obj) + combined_index[i].append(current_obj) + all_index_counter += 1 + lst_for_surface_enumeration.append(current_obj) + + # Handle second db elements + for i in range(1, max_num_elements + 1): + for j in range(len(index2[i])): + sampling_str = str(j + len(index1[i])) + "/" + str(len(index1[i]) + len(index2[i])) + "_" + str(all_index_counter) + "/" + str(combined_total_entries) + bulk, mpid = index2[i][j] + current_obj = (bulk, mpid, sampling_str, all_index_counter) + print(current_obj) + combined_index[i].append(current_obj) + all_index_counter += 1 + lst_for_surface_enumeration.append(current_obj) + + with open(output_pkl, 'wb') as f: + pickle.dump(combined_index, f) + + with open(precompute_pkl_for_surface_enumeration, 'wb') as g: + pickle.dump(lst_for_surface_enumeration, g) + + +def convert_adsorbate(input_adsorbate_database, output_pkl): + ''' + Converts an input ASE.db to an inverted index to efficiently sample adsorbates + ''' + db = ase.db.connect(input_adsorbate_database) + + index = {} + + for i, row in enumerate(db.select()): + atoms = row.toatoms() + data = row.data + smiles = data['SMILE'] + bond_indices = data['bond_idx'] + index[i] = (atoms, smiles, bond_indices) + + with open(output_pkl, 'wb') as f: + pickle.dump(index, f) + + # As of adsorbates.db file in master on April 28 2020 + assert len(index) == 82 + + +def main(): + convert_bulk("../ase_dbs/bulks.db", "../ase_dbs/new_bulks.db", 3, "bulks_may12.pkl", "for_surface_enumeration_bulk_may12.pkl") +# convert_adsorbate("../ase_dbs/adsorbates.db", "adsorbates.pkl") + + +if __name__ == "__main__": + main() diff --git a/ocdata/bulk_obj.py b/ocdata/bulk_obj.py new file mode 100644 index 0000000000..a1974f25f4 --- /dev/null +++ b/ocdata/bulk_obj.py @@ -0,0 +1,343 @@ +import math +import os +import pickle + +import numpy as np +from pymatgen.core.surface import ( + SlabGenerator, + get_symmetrically_distinct_miller_indices, +) +from pymatgen.io.ase import AseAtomsAdaptor +from pymatgen.symmetry.analyzer import SpacegroupAnalyzer + +from .constants import COVALENT_MATERIALS_MPIDS, MAX_MILLER + + +class Bulk: + """ + This class handles all things with the bulk. + It also provides possible surfaces, later used to create a Surface object. + + Attributes + ---------- + precomputed_structures : str + root dir of precomputed structures + bulk_atoms : Atoms + actual atoms of the bulk + mpid : str + mpid of the bulk + bulk_sampling_str : str + string capturing the bulk index and number of possible bulks + index_of_bulk_atoms : int + index of bulk in the db + n_elems : int + number of elements of the bulk + elem_sampling_str : str + string capturing n_elems and the max possible elements + + Public methods + -------------- + get_possible_surfaces() + returns a list of possible surfaces for this bulk instance + """ + + def __init__( + self, bulk_database=None, precomputed_structures=None, bulk_index=None, max_elems=3 + ): + """ + Initializes the object by choosing or sampling from the bulk database + + Args: + bulk_database: either a list of dict of bulks + precomputed_structures: Root directory of precomputed structures for + surface enumeration + bulk_index: index of bulk to select if not doing a random sample + max_elems: max number of elements for any bulk + """ + self.precomputed_structures = precomputed_structures + if bulk_database is not None: + self.choose_bulk_pkl(bulk_database, bulk_index, max_elems) + + def choose_bulk_pkl(self, bulk_db, bulk_index, max_elems): + """ + Chooses a bulk from our pkl file at random as long as the bulk contains + the specified number of elements in any composition. + + Args: + bulk_db Unpickled dict or list of bulks + bulk_index Index of which bulk to select. If None, randomly sample one. + max_elems Max elems for any bulk structure. Currently it is 3 by default. + + Sets as class attributes: + bulk_atoms `ase.Atoms` of the chosen bulk structure. + mpid A string indicating which MPID the bulk is + bulk_sampling_str A string to enumerate the sampled structure + index_of_bulk_atoms Index of the chosen bulk in the array (should match + bulk_index if provided) + """ + + try: + if bulk_index is not None: + assert ( + len(bulk_db) > max_elems + ), f"Bulk db only has {len(bulk_db)} entries. Did you pass in the correct bulk database?" + assert isinstance(bulk_db[bulk_index], tuple) + + ( + self.bulk_atoms, + self.mpid, + self.bulk_sampling_str, + self.index_of_bulk_atoms, + ) = bulk_db[bulk_index] + self.bulk_sampling_str = f"{self.index_of_bulk_atoms}" + self.n_elems = len(set(self.bulk_atoms.symbols)) # 1, 2, or 3 + self.elem_sampling_str = f"{self.n_elems}" + + else: + self.sample_n_elems() + assert isinstance( + bulk_db, dict + ), "Did you pass in the correct bulk database?" + assert ( + self.n_elems in bulk_db.keys() + ), f"Bulk db does not have bulks of {self.n_elems} elements" + assert isinstance( + bulk_db[self.n_elems], list + ), "Did you pass in the correct bulk database?" + + total_elements_for_key = len(bulk_db[self.n_elems]) + row_bulk_index = np.random.choice(total_elements_for_key) + ( + self.bulk_atoms, + self.mpid, + self.bulk_sampling_str, + self.index_of_bulk_atoms, + ) = bulk_db[self.n_elems][row_bulk_index] + + except IndexError: + raise ValueError( + "Randomly chose to look for a %i-component material, " + "but no such materials exist. Please add one " + "to the database or change the weights to exclude " + "this number of components." % self.n_elems + ) + + def sample_n_elems(self, n_cat_elems_weights={1: 0.05, 2: 0.65, 3: 0.3}): + """ + Chooses the number of species we should look for in this sample. + + Arg: + n_cat_elems_weights A dictionary whose keys are integers containing the + number of species you want to consider and whose + values are the probabilities of selecting this + number. The probabilities must sum to 1. + Sets: + n_elems An integer showing how many species have been chosen. + elem_sampling_str Enum string of [chosen n_elems]/[total number of choices] + """ + + possible_n_elems = list(n_cat_elems_weights.keys()) + weights = list(n_cat_elems_weights.values()) + assert math.isclose(sum(weights), 1) + + self.n_elems = np.random.choice(possible_n_elems, p=weights) + self.elem_sampling_str = str(self.n_elems) + "/" + str(len(possible_n_elems)) + + def get_possible_surfaces(self): + """ + Returns a list of possible surfaces for this bulk instance. + This can be later used to iterate through all surfaces, + or select one at random, to make a Surface object. + """ + if self.precomputed_structures: + surfaces_info = self.read_from_precomputed_enumerations( + self.index_of_bulk_atoms + ) + else: + surfaces_info = self.enumerate_surfaces() + return surfaces_info + + def read_from_precomputed_enumerations(self, index): + """ + Loads relevant pickle of precomputed surfaces. + + Args: + index: bulk index + Returns: + surfaces_info: a list of surface_info tuples (atoms, miller, shift, top) + """ + with open( + os.path.join(self.precomputed_structures, str(index) + ".pkl"), "rb" + ) as f: + surfaces_info = pickle.load(f) + return surfaces_info + + def enumerate_surfaces( + self, miller_indices=None, sample_miller_indices=False, max_miller=MAX_MILLER + ): + """ + Enumerate all the symmetrically distinct surfaces of a bulk structure. It + will not enumerate surfaces with Miller indices above the `max_miller` + argument. Note that we also look at the bottoms of surfaces if they are + distinct from the top. If they are distinct, we flip the surface so the bottom + is pointing upwards. + + Args: + bulk_atoms `ase.Atoms` object of the bulk you want to enumerate + surfaces from. + miller_indices: A tuple of Miller indices as tuples you want to enumerate. (victor) + sample_miller_indices: whether to select a Miller indices tuple from + get_symmetrically_distinct_miller_indices (victor) + max_miller An integer indicating the maximum Miller index of the surfaces + you are willing to enumerate. Increasing this argument will + increase the number of surfaces, but the surfaces will + generally become larger. + Returns: + all_slabs_info A list of 4-tuples containing: `pymatgen.Structure` + objects for surfaces we have enumerated, the Miller + indices, floats for the shifts, and Booleans for "top". + """ + bulk_struct = self.standardize_bulk(self.bulk_atoms) + + all_millers = [] + + if miller_indices: + assert isinstance(miller_indices, tuple) + assert len(miller_indices) == 3 + if sample_miller_indices: + print( + "Warning: sample_miller_indices is True, but miller_indices is not None. Ignoring sample_miller_indices." + ) + all_millers = [miller_indices] + else: + all_millers = get_symmetrically_distinct_miller_indices( + bulk_struct, MAX_MILLER + ) + if sample_miller_indices: + all_millers = [np.random.choice(all_millers)] + + all_slabs_info = [] + for millers in all_millers: + slab_gen = SlabGenerator( + initial_structure=bulk_struct, + miller_index=millers, + min_slab_size=7.0, + min_vacuum_size=20.0, + lll_reduce=False, + center_slab=True, + primitive=True, + max_normal_search=1, + ) + slabs = slab_gen.get_slabs( + tol=0.3, bonds=None, max_broken_bonds=0, symmetrize=False + ) + + # Additional filtering for the 2D materials' slabs + if self.mpid in COVALENT_MATERIALS_MPIDS: + slabs = [ + slab for slab in slabs if self.is_2D_slab_reasonsable(slab) is True + ] + + # If the bottoms of the slabs are different than the tops, then we want + # to consider them, too + if len(slabs) != 0: + flipped_slabs_info = [ + (self.flip_struct(slab), millers, slab.shift, False) + for slab in slabs + if self.is_structure_invertible(slab) is False + ] + + # Concatenate all the results together + slabs_info = [(slab, millers, slab.shift, True) for slab in slabs] + all_slabs_info.extend(slabs_info + flipped_slabs_info) + return all_slabs_info + + def is_2D_slab_reasonsable(self, struct): + """ + There are 400+ 2D bulk materials whose slabs generated by pymaten require + additional filtering: some slabs are cleaved where one or more surface atoms + have no bonds with other atoms on the slab. + + Arg: + struct `pymatgen.Structure` object of a slab + Returns: + A boolean indicating whether or not the slab is + reasonable. + """ + for site in struct: + if len(struct.get_neighbors(site, 3)) == 0: + return False + return True + + def standardize_bulk(self, atoms): + """ + There are many ways to define a bulk unit cell. If you change the unit cell + itself but also change the locations of the atoms within the unit cell, you + can get effectively the same bulk structure. To address this, there is a + standardization method used to reduce the degrees of freedom such that each + unit cell only has one "true" configuration. This function will align a + unit cell you give it to fit within this standardization. + + Args: + atoms: `ase.Atoms` object of the bulk you want to standardize + Returns: + standardized_struct: `pymatgen.Structure` of the standardized bulk + """ + struct = AseAtomsAdaptor.get_structure(atoms) + sga = SpacegroupAnalyzer(struct, symprec=0.1) + standardized_struct = sga.get_conventional_standard_structure() + return standardized_struct + + def flip_struct(self, struct): + """ + Flips an atoms object upside down. Normally used to flip surfaces. + + Arg: + struct `pymatgen.Structure` object + Returns: + flipped_struct: The same `ase.Atoms` object that was fed as an + argument, but flipped upside down. + """ + atoms = AseAtomsAdaptor.get_atoms(struct) + + # This is black magic wizardry to me. Good look figuring it out. + atoms.wrap() + atoms.rotate(180, "x", rotate_cell=True, center="COM") + if atoms.cell[2][2] < 0.0: + atoms.cell[2] = -atoms.cell[2] + if np.cross(atoms.cell[0], atoms.cell[1])[2] < 0.0: + atoms.cell[1] = -atoms.cell[1] + atoms.center() + atoms.wrap() + + flipped_struct = AseAtomsAdaptor.get_structure(atoms) + return flipped_struct + + def is_structure_invertible(self, structure): + """ + This function figures out whether or not an `pymatgen.Structure` object has + symmetricity. In this function, the affine matrix is a rotation matrix that + is multiplied with the XYZ positions of the crystal. If the z,z component + of that is negative, it means symmetry operation exist, it could be a + mirror operation, or one that involves multiple rotations/etc. Regardless, + it means that the top becomes the bottom and vice-versa, and the structure + is the symmetric. i.e. structure_XYZ = structure_XYZ*M. + + In short: If this function returns `False`, then the input structure can + be flipped in the z-direction to create a new structure. + + Arg: + structure: A `pymatgen.Structure` object. + Returns + A boolean indicating whether or not your `ase.Atoms` object is + symmetric in z-direction (i.e. symmetric with respect to x-y plane). + """ + # If any of the operations involve a transformation in the z-direction, + # then the structure is invertible. + sga = SpacegroupAnalyzer(structure, symprec=0.1) + for operation in sga.get_symmetry_operations(): + xform_matrix = operation.affine_matrix + z_xform = xform_matrix[2, 2] + if z_xform == -1: + return True + return False diff --git a/ocdata/bulks.py b/ocdata/bulks.py new file mode 100644 index 0000000000..01b15af3e2 --- /dev/null +++ b/ocdata/bulks.py @@ -0,0 +1,40 @@ +''' +This submodule contains the scripts that the Ulissi group used to pull the +relaxed bulk structures from our database. +''' + +__author__ = 'Kevin Tran' +__email__ = 'ktran@andrew.cmu.edu' + +import warnings +from tqdm import tqdm +import ase.db +from gaspy.gasdb import get_mongo_collection +from gaspy.mongo import make_atoms_from_doc + + +with get_mongo_collection('atoms') as collection: + docs = list(tqdm(collection.find({'fwname.calculation_type': 'unit cell optimization', + 'fwname.vasp_settings.gga': 'RP', + 'fwname.vasp_settings.pp': 'PBE', + 'fwname.vasp_settings.xc': {'$exists': False}, + 'fwname.vasp_settings.pp_version': '5.4', + 'fwname.vasp_settings.encut': 500, + 'fwname.vasp_settings.isym': 0}), + desc='pulling from FireWorks')) + +mpids = set() +db = ase.db.connect('bulks.db') +for doc in tqdm(docs, desc='writing to database'): + atoms = make_atoms_from_doc(doc) + n_elements = len(set(atoms.symbols)) + + if n_elements <= 3: + mpid = doc['fwname']['mpid'] + if mpid not in mpids: + mpids.add(mpid) + _ = db.write(atoms, mpid=doc['fwname']['mpid'], n_elements=n_elements) + + else: + warnings.warn('Found a duplicate MPID: %s; adding the first one' % mpid, + RuntimeWarning) diff --git a/ocdata/combined.py b/ocdata/combined.py new file mode 100644 index 0000000000..58b7cbb2b5 --- /dev/null +++ b/ocdata/combined.py @@ -0,0 +1,366 @@ +import warnings + +import catkit +import numpy as np +from ase import neighborlist +from ase.neighborlist import natural_cutoffs +from pymatgen.analysis.local_env import VoronoiNN +from pymatgen.io.ase import AseAtomsAdaptor + +from .constants import COVALENT_RADIUS +from .loader import Loader +from .surfaces import constrain_surface + + +class Combined: + """ + This class handles all things with the adsorbate placed on a surface + Needs one adsorbate and one surface to create this class. + + Attributes + ---------- + adsorbate : Adsorbate + object representing the adsorbate + surface : Surface + object representing the surface + enumerate_all_configs : boolean + whether to enumerate all adslab placements instead of choosing one random + adsorbed_surface_atoms : list + `Atoms` objects containing both the adsorbate and surface for all desired placements + adsorbed_surface_sampling_strs : list + list of strings capturing the config index for each adslab placement + constrained_adsorbed_surfaces : list + list of all constrained adslab atoms + all_sites : list + list of binding coordinates for all the adslab configs + + Public methods + -------------- + get_adsorbed_bulk_dict(ind) + returns a dict of info for the adsorbate+surface of the specified config index + """ + + def __init__( + self, + adsorbate, + surface, + enumerate_all_configs, + animate=False, + no_loader=False, + index=-1, + early_init=False, + ): + """ + Adds adsorbate to surface, does the constraining, and aggregates all data necessary to write out. + Can either pick a random configuration or store all possible ones. + + Args: + adsorbate: the `Adsorbate` object + surface: the `Surface` object + enumerate_all_configs: whether to enumerate all adslab placements instead of choosing one random + index: list of adsorbstion site indices (-1, list or int) (victor) + early_init: whether to skip the actual work and just initialize the class + """ + self.adsorbate = adsorbate + self.surface = surface + self.animate = animate + self.no_loader = no_loader + self.enumerate_all_configs = enumerate_all_configs + self.index = index + self.early_init = early_init + + if early_init: + return + + self.add_adsorbate_onto_surface( + self.adsorbate.atoms, + self.surface.surface_atoms, + self.adsorbate.bond_indices, + self.index, + ) + + self.constrained_adsorbed_surfaces = [] + self.all_sites = [] + for a, atoms in enumerate(self.adsorbed_surface_atoms): + # Add appropriate constraints + self.constrained_adsorbed_surfaces.append(constrain_surface(atoms)) + + # Do the hashing + self.all_sites.append( + self.find_sites( + self.surface.constrained_surface, + self.constrained_adsorbed_surfaces[-1], + self.adsorbate.bond_indices, + ) + ) + + def add_adsorbate_onto_surface(self, adsorbate, surface, bond_indices, index=-1): + """ + There are a lot of small details that need to be considered when adding an + adsorbate onto a surface. This function will take care of those details for + you. + + Args: + adsorbate: An `ase.Atoms` object of the adsorbate + surface: An `ase.Atoms` object of the surface + bond_indices: A list of integers indicating the indices of the + binding atoms of the adsorbate + index: list of adsorbstion site indices (-1, list or int) (victor) ; + if None, choose the first "reasonable" binding site according to + `is_config_reasonable` + Sets these values: + adsorbed_surface_atoms: An `ase graphic Atoms` object containing the adsorbate and + surface. The bulk atoms will be tagged with `0`; the + surface atoms will be tagged with `1`, and the the + adsorbate atoms will be tagged with `2` or above. + adsorbed_surface_sampling_strs: String specifying the sample, [index]/[total] + of reasonable adsorbed surfaces + """ + iterate_index = False + if index is None: + index = 0 + iterate_index = True + + with Loader( + " [AAOS] make adsorbate_gratoms", + animate=self.animate, + ignore=self.no_loader, + ): + # convert surface atoms into graphic atoms object + surface_gratoms = catkit.Gratoms(surface) + surface_atom_indices = [ + i for i, atom in enumerate(surface) if atom.tag == 1 + ] + surface_gratoms.set_surface_atoms(surface_atom_indices) + surface_gratoms.pbc = np.array([True, True, False]) + + # set up the adsorbate into graphic atoms object + # with its connectivity matrix + adsorbate_gratoms = self.convert_adsorbate_atoms_to_gratoms( + adsorbate, bond_indices + ) + + # generate all possible adsorption configurations on that surface. + # The "bonds" argument automatically take care of mono vs. + # bidentate adsorption configuration. + with Loader( + " [AAOS] make all adsorbed_surfaces", + animate=self.animate, + ignore=self.no_loader, + ): + builder = catkit.gen.adsorption.Builder(surface_gratoms) + with warnings.catch_warnings(): # suppress potential square root warnings + warnings.simplefilter("ignore") + adsorbed_surfaces = builder.add_adsorbate( + adsorbate_gratoms, bonds=bond_indices, index=index + ) + if not isinstance(adsorbed_surfaces, list): + # account for the case where an int index is given + adsorbed_surfaces = [adsorbed_surfaces] + + print(">>> Total adsorbed_surfaces:", len(adsorbed_surfaces)) + + with Loader( + " [AAOS] filter reasonable_adsorbed_surfaces", + animate=self.animate, + ignore=self.no_loader, + ): + # Filter out unreasonable structures. + # Then pick one from the reasonable configurations list as an output. + reasonable_adsorbed_surfaces = [ + surface + for surface in adsorbed_surfaces + if self.is_config_reasonable(surface) + ] + + # if index was None and index 0 is not a valid adsorption site, iterate until a valid site is found + while not reasonable_adsorbed_surfaces and iterate_index: + index += 1 + adsorbed_surfaces = [ + builder.add_adsorbate( + adsorbate_gratoms, bonds=bond_indices, index=index + ) + ] + reasonable_adsorbed_surfaces = [ + surface + for surface in adsorbed_surfaces + if self.is_config_reasonable(surface) + ] + if iterate_index: + print(f">>> Chosen adsorption site: {index}") + print( + f">>> Reasonable configs: {len(reasonable_adsorbed_surfaces)}/{len(adsorbed_surfaces)}" + ) + + self.adsorbed_surface_atoms = [] + self.adsorbed_surface_sampling_strs = [] + if self.enumerate_all_configs: + self.num_configs = len(reasonable_adsorbed_surfaces) + for ind, reasonable_config in enumerate(reasonable_adsorbed_surfaces): + self.adsorbed_surface_atoms.append(reasonable_config) + self.adsorbed_surface_sampling_strs.append( + str(ind) + "/" + str(len(reasonable_adsorbed_surfaces)) + ) + else: + self.num_configs = 1 + reasonable_adsorbed_surface_index = np.random.choice( + len(reasonable_adsorbed_surfaces) + ) + self.adsorbed_surface_atoms.append( + reasonable_adsorbed_surfaces[reasonable_adsorbed_surface_index] + ) + self.adsorbed_surface_sampling_strs.append( + str(reasonable_adsorbed_surface_index) + + "/" + + str(len(reasonable_adsorbed_surfaces)) + ) + + def convert_adsorbate_atoms_to_gratoms(self, adsorbate, bond_indices): + """ + Convert adsorbate atoms object into graphic atoms object, + so the adsorbate can be placed onto the surface with optimal + configuration. Set tags for adsorbate atoms to 2, to distinguish + them from surface atoms. + + Args: + adsorbate An `ase.Atoms` object of the adsorbate + bond_indices A list of integers indicating the indices of the + binding atoms of the adsorbate + + Returns: + adsorbate_gratoms An graphic atoms object of the adsorbate. + """ + connectivity = self.get_connectivity(adsorbate) + adsorbate_gratoms = catkit.Gratoms(adsorbate, edges=connectivity) + # tag adsorbate atoms: non-binding atoms as 2, the binding atom(s) as 3 for now to + # track adsorption site for analyzing if adslab configuration is reasonable. + adsorbate_gratoms.set_tags( + [3 if idx in bond_indices else 2 for idx in range(len(adsorbate_gratoms))] + ) + return adsorbate_gratoms + + def get_connectivity(self, adsorbate): + """ + Generate the connectivity of an adsorbate atoms obj. + + Args: + adsorbate An `ase.Atoms` object of the adsorbate + + Returns: + matrix The connectivity matrix of the adsorbate. + """ + cutoff = natural_cutoffs(adsorbate) + neighborList = neighborlist.NeighborList( + cutoff, self_interaction=False, bothways=True + ) + neighborList.update(adsorbate) + matrix = neighborlist.get_connectivity_matrix(neighborList.nl).toarray() + return matrix + + def is_config_reasonable(self, adslab): + """ + Function that check whether the adsorbate placement is reasonable. + Two criteria are: 1. The adsorbate should be placed on the slab: + the fractional coordinates of the adsorption site is bounded by the unit cell. + 2. The adsorbate should not be buried into the surface: for any atom + in the adsorbate, if the distance between the atom and slab atoms + are closer than 80% of their expected covalent bond, we reject that placement. + + Args: + adslab An `ase.Atoms` object of the adsorbate+slab complex. + + Returns: + A boolean indicating whether or not the adsorbate placement is + reasonable. + """ + vnn = VoronoiNN(allow_pathological=True, tol=0.2, cutoff=10) + adsorbate_indices = [atom.index for atom in adslab if atom.tag >= 2] + adsorbate_bond_indices = [atom.index for atom in adslab if atom.tag == 3] + structure = AseAtomsAdaptor.get_structure(adslab) + slab_lattice = structure.lattice + + # Check to see if the fractional coordinates of the adsorption site is bounded + # by the slab unit cell. We loosen the threshold to -0.01 and 1.01 + # to not wrongly exclude reasonable edge adsorption site. + for idx in adsorbate_bond_indices: + coord = slab_lattice.get_fractional_coords(structure[idx].coords) + if np.any((coord < -0.01) | (coord > 1.01)): + return False + + # Then, check the covalent radius between each adsorbate atoms + # and its nearest neighbors that are slab atoms + # to make sure adsorbate is not buried into the surface + for idx in adsorbate_indices: + try: + nearneighbors = vnn.get_nn_info(structure, n=idx) + except ValueError: + return False + + slab_nn = [ + nn for nn in nearneighbors if nn["site_index"] not in adsorbate_indices + ] + for nn in slab_nn: + ads_elem = structure[idx].species_string + nn_elem = structure[nn["site_index"]].species_string + cov_bond_thres = ( + 0.8 * (COVALENT_RADIUS[ads_elem] + COVALENT_RADIUS[nn_elem]) / 100 + ) + actual_dist = adslab.get_distance(idx, nn["site_index"], mic=True) + if actual_dist < cov_bond_thres: + return False + + # If the structure is reasonable, change tags of adsorbate atoms from 2 and 3 to 2 only + # for ML model compatibility and data cleanliness of the output adslab configurations + old_tags = adslab.get_tags() + adslab.set_tags(np.where(old_tags == 3, 2, old_tags)) + return True + + def find_sites(self, surface, adsorbed_surface, bond_indices): + """ + Finds the Cartesian coordinates of the bonding atoms of the adsorbate. + + Args: + surface `ase.Atoms` of the chosen surface + adsorbed_surface An `ase graphic Atoms` object containing the + adsorbate and surface. + bond_indices A list of integers indicating the indices of the + binding atoms of the adsorbate + Returns: + sites A tuple of 3-tuples containing the Cartesian coordinates of + each of the binding atoms + """ + sites = [] + for idx in bond_indices: + binding_atom_index = len(surface) + idx + atom = adsorbed_surface[binding_atom_index] + positions = tuple(round(coord, 2) for coord in atom.position) + sites.append(positions) + + return tuple(sites) + + def get_adsorbed_bulk_dict(self, ind): + """ + Returns an organized dict for writing to files. + All info is already processed and stored in class variables. + """ + ads_sampling_str = ( + self.adsorbate.adsorbate_sampling_str + + "_" + + self.adsorbed_surface_sampling_strs[ind] + ) + + return { + "adsorbed_bulk_atomsobject": self.constrained_adsorbed_surfaces[ind], + "adsorbed_bulk_metadata": ( + self.surface.bulk_object.mpid, + self.surface.millers, + round(self.surface.shift, 3), + self.surface.top, + self.adsorbate.smiles, + self.all_sites[ind], + ), + "adsorbed_bulk_samplingstr": self.surface.overall_sampling_str + + "_" + + ads_sampling_str, + "adsorbed_db_version": self.adsorbate.adsorbate_db_fname, + } diff --git a/ocdata/constants.py b/ocdata/constants.py new file mode 100644 index 0000000000..990b0481e4 --- /dev/null +++ b/ocdata/constants.py @@ -0,0 +1,113 @@ + +# unused? +ELEMENTS = {1: 'H', 2: 'He', 3: 'Li', 4: 'Be', 5: 'B', 6: 'C', 7: 'N', 8: 'O', + 9: 'F', 10: 'Ne', 11: 'Na', 12: 'Mg', 13: 'Al', 14: 'Si', 15: 'P', + 16: 'S', 17: 'Cl', 18: 'Ar', 19: 'K', 20: 'Ca', 21: 'Sc', 22: 'Ti', + 23: 'V', 24: 'Cr', 25: 'Mn', 26: 'Fe', 27: 'Co', 28: 'Ni', 29: + 'Cu', 30: 'Zn', 31: 'Ga', 32: 'Ge', 33: 'As', 34: 'Se', 35: 'Br', + 36: 'Kr', 37: 'Rb', 38: 'Sr', 39: 'Y', 40: 'Zr', 41: 'Nb', 42: + 'Mo', 43: 'Tc', 44: 'Ru', 45: 'Rh', 46: 'Pd', 47: 'Ag', 48: 'Cd', + 49: 'In', 50: 'Sn', 51: 'Sb', 52: 'Te', 53: 'I', 54: 'Xe', 55: + 'Cs', 56: 'Ba', 57: 'La', 58: 'Ce', 59: 'Pr', 60: 'Nd', 61: 'Pm', + 62: 'Sm', 63: 'Eu', 64: 'Gd', 65: 'Tb', 66: 'Dy', 67: 'Ho', 68: + 'Er', 69: 'Tm', 70: 'Yb', 71: 'Lu', 72: 'Hf', 73: 'Ta', 74: 'W', + 75: 'Re', 76: 'Os', 77: 'Ir', 78: 'Pt', 79: 'Au', 80: 'Hg', 81: + 'Tl', 82: 'Pb', 83: 'Bi', 84: 'Po', 85: 'At', 86: 'Rn', 87: 'Fr', + 88: 'Ra', 89: 'Ac', 90: 'Th', 91: 'Pa', 92: 'U', 93: 'Np', 94: + 'Pu', 95: 'Am', 96: 'Cm', 97: 'Bk', 98: 'Cf', 99: 'Es', 100: 'Fm', + 101: 'Md', 102: 'No', 103: 'Lr', 104: 'Rf', 105: 'Db', 106: 'Sg', + 107: 'Bh', 108: 'Hs', 109: 'Mt', 110: 'Ds', 111: 'Rg', 112: 'Cn', + 113: 'Nh', 114: 'Fl', 115: 'Mc', 116: 'Lv', 117: 'Ts', 118: 'Og'} + +# Covalent radius of elements (unit is pm, 1pm=0.01 angstrom) +# Value are taken from https://github.com/lmmentel/mendeleev +COVALENT_RADIUS = {'H': 32.0, 'He': 46.0, 'O': 63.0, 'F': 64.0, 'Ne': 67.0, 'N': 71.0, + 'C': 75.0, 'B': 85.0, 'Ar': 96.0, 'Cl': 99.0, 'Be': 102.0, 'S': 103.0, + 'Ni': 110.0, 'P': 111.0, 'Co': 111.0, 'Cu': 112.0, 'Br': 114.0, + 'Si': 116.0, 'Fe': 116.0, 'Se': 116.0, 'Kr': 117.0, 'Zn': 118.0, + 'Mn': 119.0, 'Pd': 120.0, 'Ge': 121.0, 'As': 121.0, 'Rg': 121.0, + 'Cr': 122.0, 'Ir': 122.0, 'Cn': 122.0, 'Pt': 123.0, 'Ga': 124.0, + 'Au': 124.0, 'Ru': 125.0, 'Rh': 125.0, 'Al': 126.0, 'Tc': 128.0, + 'Ag': 128.0, 'Ds': 128.0, 'Os': 129.0, 'Mt': 129.0, 'Xe': 131.0, 'Re': 131.0, + 'Li': 133.0, 'I': 133.0, 'Hg': 133.0, 'V': 134.0, 'Hs': 134.0, 'Ti': 136.0, + 'Cd': 136.0, 'Te': 136.0, 'Nh': 136.0, 'W': 137.0, 'Mo': 138.0, 'Mg': 139.0, + 'Sn': 140.0, 'Sb': 140.0, 'Bh': 141.0, 'In': 142.0, 'Rn': 142.0, 'Sg': 143.0, + 'Fl': 143.0, 'Tl': 144.0, 'Pb': 144.0, 'Po': 145.0, 'Ta': 146.0, 'Nb': 147.0, + 'At': 147.0, 'Sc': 148.0, 'Db': 149.0, 'Bi': 151.0, 'Hf': 152.0, 'Zr': 154.0, + 'Na': 155.0, 'Rf': 157.0, 'Og': 157.0, 'Lr': 161.0, 'Lu': 162.0, 'Mc': 162.0, + 'Y': 163.0, 'Ce': 163.0, 'Tm': 164.0, 'Er': 165.0, 'Es': 165.0, 'Ts': 165.0, + 'Ho': 166.0, 'Am': 166.0, 'Cm': 166.0, 'Dy': 167.0, 'Fm': 167.0, 'Eu': 168.0, + 'Tb': 168.0, 'Bk': 168.0, 'Cf': 168.0, 'Gd': 169.0, 'Pa': 169.0, 'Yb': 170.0, + 'U': 170.0, 'Ca': 171.0, 'Np': 171.0, 'Sm': 172.0, 'Pu': 172.0, 'Pm': 173.0, + 'Md': 173.0, 'Nd': 174.0, 'Th': 175.0, 'Lv': 175.0, 'Pr': 176.0, 'No': 176.0, + 'La': 180.0, 'Sr': 185.0, 'Ac': 186.0, 'K': 196.0, 'Ba': 196.0, 'Ra': 201.0, + 'Rb': 210.0, 'Fr': 223.0, 'Cs': 232.0} + +# We will enumerate surfaces with Miller indices <= MAX_MILLER +MAX_MILLER = 2 + +# We will create surfaces that are at least MIN_XY Angstroms wide. GASpy uses +# 4.5, but our larger adsorbates here can be up to 3.6 Angstroms long. So 4.5 + +# 3.6 ~= 8 Angstroms +MIN_XY = 8. + +COVALENT_MATERIALS_MPIDS = ['mp-104', 'mp-79', 'mp-94', 'mp-11', 'mp-48', 'mp-23152', 'mp-1014111', 'mp-567409', + 'mp-157', 'mp-10021', 'mp-140', 'mp-1094075', 'mp-571550', 'mp-1067758', 'mp-570875', + 'mp-1371', 'mp-2128', 'mp-995193', 'mp-2294', 'mp-31053', 'mp-556225', 'mp-568971', + 'mp-1009834', 'mp-909', 'mp-9548', 'mp-1863', 'mp-570325', 'mp-17524', 'mp-978553', + 'mp-2793', 'mp-998972', 'mp-641', 'mp-15700', 'mp-1009581', 'mp-580226', 'mp-604910', + 'mp-542640', 'mp-20311', 'mp-505531', 'mp-1634', 'mp-700', 'mp-2242', 'mp-628773', + 'mp-21405', 'mp-1379', 'mp-2194', 'mp-1115', 'mp-762', 'mp-572758', 'mp-630528', + 'mp-850131', 'mp-1943', 'mp-2418', 'mp-9889', 'mp-2160', 'mp-2231', 'mp-2156', + 'mp-541582', 'mp-2815', 'mp-604914', 'mp-22691', 'mp-1017565', 'mp-665', 'mp-13682', + 'mp-22375', 'mp-21296', 'mp-1821', 'mp-9920', 'mp-691', 'mp-694', 'mp-1013525', 'mp-1170', + 'mp-2809', 'mp-755263', 'mp-224', 'mp-571033', 'mp-19932', 'mp-2507', 'mp-30485', 'mp-2330', + 'mp-2686', 'mp-525', 'mp-562100', 'mp-7597', 'mp-850083', 'mp-570356', 'mp-684898', + 'mp-10033', 'mp-604908', 'mp-1023900', 'mp-672372', 'mp-1070580', 'mp-1057015', 'mp-30500', + 'mp-1080586', 'mp-1095294', 'mp-1071032', 'mp-1096986', 'mp-1078500', 'mp-1091375', 'mp-28919', + 'mp-7459', 'mp-27666', 'mp-1025459', 'mp-29652', 'mp-9922', 'mp-25469', 'mp-20050', 'mp-1018150', + 'mp-11693', 'mp-570122', 'mp-27507', 'mp-567279', 'mp-1018891', 'mp-1907', 'mp-985829', 'mp-9983', + 'mp-782', 'mp-684690', 'mp-676241', 'mp-22693', 'mp-28233', 'mp-28116', 'mp-1984', 'mp-27455', + 'mp-542449', 'mp-541885', 'mp-1245', 'mp-9996', 'mp-568328', 'mp-10009', 'mp-27164', + 'mp-11675', 'mp-2285', 'mp-13683', 'mp-2798', 'mp-11687', 'mp-21273', 'mp-2430', 'mp-1168', + 'mp-945077', 'mp-1063670', 'mp-1008626', 'mp-22853', 'mp-27628', 'mp-632403', 'mp-500', + 'mp-20826', 'mp-601823', 'mp-1079574', 'mp-27770', 'mp-985831', 'mp-1078443', 'mp-28117', + 'mp-1009641', 'mp-1017540', 'mp-542634', 'mp-23240', 'mp-938', 'mp-1100795', 'mp-9481', + 'mp-9897', 'mp-27513', 'mp-1683', 'mp-1007758', 'mp-1078708', 'mp-23162', 'mp-1186', 'mp-570858', + 'mp-1662', 'mp-602', 'mp-1080459', 'mp-1068510', 'mp-228', 'mp-1025402', 'mp-1967', 'mp-582549', + 'mp-22881', 'mp-9254', 'mp-22877', 'mp-542495', 'mp-605', 'mp-7541', 'mp-542812', 'mp-540922', + 'mp-540884', 'mp-865373', 'mp-570451', 'mp-23309', 'mp-23229', 'mp-484', 'mp-29772', 'mp-9921', + 'mp-2089', 'mp-1019322', 'mp-399', 'mp-972889', 'mp-568746', 'mp-541837', 'mp-22856', + 'mp-570197', 'mp-569581', 'mp-542615', 'mp-1078313', 'mp-27902', 'mp-486', 'mp-34202', + 'mp-23174', 'mp-27396', 'mp-23164', 'mp-2578', 'mp-27411', 'mp-556516', 'mp-10264', + 'mp-1018020', 'mp-35835', 'mp-3532', 'mp-20757', 'mp-29249', 'mp-990091', 'mp-555269', + 'mp-624190', 'mp-1025340', 'mp-8976', 'mp-20331', 'mp-20793', 'mp-560370', 'mp-5045', + 'mp-6959', 'mp-4468', 'mp-19885', 'mp-10412', 'mp-22253', 'mp-27532', 'mp-560262', + 'mp-627601', 'mp-21365', 'mp-504564', 'mp-560806', 'mp-541937', 'mp-542644', 'mp-985304', + 'mp-674328', 'mp-569662', 'mp-637614', 'mp-5807', 'mp-675326', 'mp-34289', 'mp-10232', + 'mp-866941', 'mp-28019', 'mp-20612', 'mp-675290', 'mp-675367', 'mp-675066', 'mp-1024076', + 'mp-12433', 'mp-9378', 'mp-7263', 'mp-23918', 'mp-21096', 'mp-1018658', 'mp-22152', 'mp-1071623', + 'mp-22035', 'mp-8147', 'mp-1029479', 'mp-1029779', 'mp-4384', 'mp-1068653', 'mp-1080466', + 'mp-1078896', 'mp-24428', 'mp-1094008', 'mp-19727', 'mp-1077470', 'mp-27749', 'mp-8435', 'mp-8848', + 'mp-20422', 'mp-13963', 'mp-540997', 'mp-1029316', 'mp-1029309', 'mp-14790', 'mp-571471', + 'mp-989651', 'mp-29300', 'mp-28846', 'mp-9379', 'mp-28557', 'mp-989586', 'mp-13962', 'mp-28580', + 'mp-1078140', 'mp-998560', 'mp-13542', 'mp-7505', 'mp-3208', 'mp-998512', 'mp-542096', + 'mp-605028', 'mp-14815', 'mp-574169', 'mp-3006', 'mp-28866', 'mp-8211', 'mp-8695', 'mp-628726', + 'mp-16765', 'mp-19917', 'mp-3779', 'mp-14791', 'mp-3342', 'mp-36381', 'mp-1095516', + 'mp-24081', 'mp-27178', 'mp-1079752', 'mp-568592', 'mp-674984', 'mp-27656', 'mp-13923', + 'mp-676437', 'mp-1094079', 'mp-27449', 'mp-541312', 'mp-28220', 'mp-620190', 'mp-567931', + 'mp-28487', 'mp-28480', 'mp-30971', 'mp-29072', 'mp-18279', 'mp-29607', 'mp-505164', 'mp-570340', + 'mp-31220', 'mp-27171', 'mp-17287', 'mp-1024958', 'mp-1078645', 'mp-30979', 'mp-567817', + 'mp-29022', 'mp-675163', 'mp-28224', 'mp-9797', 'mp-4628', 'mp-8436', 'mp-8768', 'mp-998233', + 'mp-2977', 'mp-28189', 'mp-18625', 'mp-541449', 'mp-7280', 'mp-29419', 'mp-675801', + 'mp-22945', 'mp-28178', 'mp-27361', 'mp-1079559', 'mp-541911', 'mp-17801', 'mp-12743', + 'mp-8613', 'mp-8677', 'mp-768680', 'mp-504630', 'mp-12527', 'mp-3123', 'mp-23396', 'mp-1013900', + 'mp-28361', 'mp-23434', 'mp-616481', 'mp-14474', 'mp-28126', 'mp-977371', 'mp-571661', + 'mp-29073', 'mp-505206', 'mp-541149', 'mp-672273', 'mp-22982', 'mp-20235', 'mp-27850', + 'mp-675543', 'mp-554921', 'mp-977592', 'mp-570930', 'mp-9622', 'mp-29666', 'mp-3534', + 'mp-9010', 'mp-23472', 'mp-38605', 'mp-8190', 'mp-20242', 'mp-3525', 'mp-9251', 'mp-15121', + 'mp-14242', 'mp-27947', 'mp-540818', 'mp-1025457', 'mp-504957', 'mp-27948', 'mp-19810', 'mp-580748', + 'mp-8612', 'mp-27195', 'mp-769218', 'mp-9272', 'mp-9391', 'mp-31406', 'mp-4988', 'mp-7277', 'mp-3849', + 'mp-14241', 'mp-31507', 'mp-12307', 'mp-1079754', 'mp-30183', 'mp-20506', 'mp-7038', 'mp-540687', + 'mp-569044', 'mp-541487', 'mp-17945', 'mp-1201', 'mp-130', 'mp-611219', 'mp-1057273', 'mp-158', + 'mp-9798', 'mp-676250', 'mp-998787', 'mp-21413', 'mp-27910'] \ No newline at end of file diff --git a/ocdata/loader.py b/ocdata/loader.py new file mode 100644 index 0000000000..db7c0f2442 --- /dev/null +++ b/ocdata/loader.py @@ -0,0 +1,81 @@ +from itertools import cycle +from shutil import get_terminal_size +from threading import Thread +from time import sleep, time + + +class Loader: + def __init__( + self, + desc="Loading...", + end="Done!", + timeout=0.1, + timer=True, + erase=False, + animate=True, + ignore=False, + out=None, + ): + """ + A loader-like context manager + + Args: + desc (str, optional): The loader's description. Defaults to "Loading...". + end (str, optional): Final print. Defaults to "Done!". + timeout (float, optional): Sleep time between prints. Defaults to 0.1. + """ + self.desc = desc + self.end = end + self.timeout = timeout + self.timer = timer + self.erase = erase + self.out = out + self.duration = None + + self._thread = Thread(target=self._animate, daemon=True) + self.steps = ["⢿", "⣻", "⣽", "⣾", "⣷", "⣯", "⣟", "⡿"] + self.done = (not animate) or ignore + self.ignore = ignore + + def start(self): + self._start_time = time() + self._thread.start() + return self + + def _animate(self): + for c in cycle(self.steps): + if self.done: + break + print(f"\r{self.desc} {c}", flush=True, end="") + sleep(self.timeout) + + def __enter__(self): + self.start() + return self + + def stop(self): + if self.ignore: + return + self.done = True + cols = get_terminal_size((80, 20)).columns + + if self.erase: + end = f"\r{self.end}" + else: + end = self.desc + " | " + self.end + + if self.timer: + end_time = time() + self.duration = end_time - self._start_time + if isinstance(self.out, list): + self.out.append(self.duration) + elif isinstance(self.out, dict): + self.out[self.desc].append(self.duration) + end += f" ({self.duration:.2f}s)" + + print("\r" + " " * cols, end="\r", flush=True) + print(end, flush=True) + + def __exit__(self, exc_type, exc_value, tb): + # handle exceptions with those variables ^ + self.stop() diff --git a/ocdata/precompute_sample_structures.py b/ocdata/precompute_sample_structures.py new file mode 100644 index 0000000000..507b64c61f --- /dev/null +++ b/ocdata/precompute_sample_structures.py @@ -0,0 +1,174 @@ +''' +This submodule contains the scripts that the we used to sample the adsorption +structures. + +Note that some of these scripts were taken from +[GASpy](https://github.com/ulissigroup/GASpy) with permission of author. +''' + +__authors__ = ['Kevin Tran', 'Aini Palizhati', 'Siddharth Goyal', 'Zachary Ulissi'] +__email__ = ['ktran@andrew.cmu.edu'] + +import math +from collections import defaultdict +import random +import pickle +import numpy as np +import catkit +import ase +import ase.db +from ase import neighborlist +from ase.constraints import FixAtoms +from ase.neighborlist import natural_cutoffs +from pymatgen.io.ase import AseAtomsAdaptor +from pymatgen.core.surface import SlabGenerator, get_symmetrically_distinct_miller_indices +from pymatgen.symmetry.analyzer import SpacegroupAnalyzer +from pymatgen.analysis.local_env import VoronoiNN +from .base_atoms.pkls import BULK_PKL, ADSORBATE_PKL +from .constants import MAX_MILLER +import sys +import time + +def enumerate_surfaces_for_saving(bulk_atoms, max_miller=MAX_MILLER): + ''' + Enumerate all the symmetrically distinct surfaces of a bulk structure. It + will not enumerate surfaces with Miller indices above the `max_miller` + argument. Note that we also look at the bottoms of surfaces if they are + distinct from the top. If they are distinct, we flip the surface so the bottom + is pointing upwards. + + Args: + bulk_atoms `ase.Atoms` object of the bulk you want to enumerate + surfaces from. + max_miller An integer indicating the maximum Miller index of the surfaces + you are willing to enumerate. Increasing this argument will + increase the number of surfaces, but the surfaces will + generally become larger. + Returns: + all_slabs_info A list of 4-tuples containing: `pymatgen.Structure` + objects for surfaces we have enumerated, the Miller + indices, floats for the shifts, and Booleans for "top". + ''' + bulk_struct = standardize_bulk(bulk_atoms) + + all_slabs_info = [] + for millers in get_symmetrically_distinct_miller_indices(bulk_struct, MAX_MILLER): + slab_gen = SlabGenerator(initial_structure=bulk_struct, + miller_index=millers, + min_slab_size=7., + min_vacuum_size=20., + lll_reduce=False, + center_slab=True, + primitive=True, + max_normal_search=1) + slabs = slab_gen.get_slabs(tol=0.3, + bonds=None, + max_broken_bonds=0, + symmetrize=False) + + # If the bottoms of the slabs are different than the tops, then we want + # to consider them, too + flipped_slabs_info = [(flip_struct(slab), millers, slab.shift, False) + for slab in slabs if is_structure_invertible(slab) is False] + + # Concatenate all the results together + slabs_info = [(slab, millers, slab.shift, True) for slab in slabs] + all_slabs_info.extend(slabs_info + flipped_slabs_info) + return all_slabs_info + + +def standardize_bulk(atoms): + ''' + There are many ways to define a bulk unit cell. If you change the unit cell + itself but also change the locations of the atoms within the unit cell, you + can get effectively the same bulk structure. To address this, there is a + standardization method used to reduce the degrees of freedom such that each + unit cell only has one "true" configuration. This function will align a + unit cell you give it to fit within this standardization. + + Arg: + atoms `ase.Atoms` object of the bulk you want to standardize + Returns: + standardized_struct `pymatgen.Structure` of the standardized bulk + ''' + struct = AseAtomsAdaptor.get_structure(atoms) + sga = SpacegroupAnalyzer(struct, symprec=0.1) + standardized_struct = sga.get_conventional_standard_structure() + return standardized_struct + + +def is_structure_invertible(structure): + ''' + This function figures out whether or not an `pymatgen.Structure` object has + symmetricity. In this function, the affine matrix is a rotation matrix that + is multiplied with the XYZ positions of the crystal. If the z,z component + of that is negative, it means symmetry operation exist, it could be a + mirror operation, or one that involves multiple rotations/etc. Regardless, + it means that the top becomes the bottom and vice-versa, and the structure + is the symmetric. i.e. structure_XYZ = structure_XYZ*M. + + In short: If this function returns `False`, then the input structure can + be flipped in the z-direction to create a new structure. + + Arg: + structure A `pymatgen.Structure` object. + Returns + A boolean indicating whether or not your `ase.Atoms` object is + symmetric in z-direction (i.e. symmetric with respect to x-y plane). + ''' + # If any of the operations involve a transformation in the z-direction, + # then the structure is invertible. + sga = SpacegroupAnalyzer(structure, symprec=0.1) + for operation in sga.get_symmetry_operations(): + xform_matrix = operation.affine_matrix + z_xform = xform_matrix[2, 2] + if z_xform == -1: + return True + return False + + +def flip_struct(struct): + ''' + Flips an atoms object upside down. Normally used to flip surfaces. + + Arg: + atoms `pymatgen.Structure` object + Returns: + flipped_struct The same `ase.Atoms` object that was fed as an + argument, but flipped upside down. + ''' + atoms = AseAtomsAdaptor.get_atoms(struct) + + # This is black magic wizardry to me. Good look figuring it out. + atoms.wrap() + atoms.rotate(180, 'x', rotate_cell=True, center='COM') + if atoms.cell[2][2] < 0.: + atoms.cell[2] = -atoms.cell[2] + if np.cross(atoms.cell[0], atoms.cell[1])[2] < 0.0: + atoms.cell[1] = -atoms.cell[1] + atoms.wrap() + + flipped_struct = AseAtomsAdaptor.get_structure(atoms) + return flipped_struct + + +def precompute_enumerate_surface(bulk_database, bulk_index, opfile): + + with open(bulk_database, 'rb') as f: + inv_index = pickle.load(f) + flatten = inv_index[1] + inv_index[2] + inv_index[3] + assert bulk_index < len(flatten) + + bulk, mpid = flatten[bulk_index] + + print(bulk, mpid) + surfaces_info = enumerate_surfaces_for_saving(bulk) + + with open(opfile, 'wb') as g: + pickle.dump(surfaces_info, g) + +if __name__ == "__main__": + s = time.time() + precompute_enumerate_surface(BULK_PKL, int(sys.argv[1]), sys.argv[2]) + e = time.time() + print(sys.argv[1], "Done in", e - s ) diff --git a/ocdata/structure_sampler.py b/ocdata/structure_sampler.py new file mode 100644 index 0000000000..4b0e3893e6 --- /dev/null +++ b/ocdata/structure_sampler.py @@ -0,0 +1,187 @@ + +from ocdata.vasp import write_vasp_input_files +from ocdata.adsorbates import Adsorbate +from ocdata.bulk_obj import Bulk +from ocdata.surfaces import Surface +from ocdata.combined import Combined + +import logging +import numpy as np +import os +import pickle +import time + +class StructureSampler(): + ''' + A class that creates adsorbate/bulk/surface objects and + writes vasp input files for one of the following options: + - one random adsorbate/bulk/surface/config, based on a specified random seed + - one specified adsorbate, n specified bulks, and all possible surfaces and configs + - one specified adsorbate, n specified bulks, one specified surface, and all possible configs + + The output directory structure will look like the following: + - For sampling a random structure, the directories will be `random{seed}/surface` and + `random{seed}/adslab` for the surface alone and the adsorbate+surface, respectively. + - For enumerating all structures, the directories will be `{adsorbate}_{bulk}_{surface}/surface` + and `{adsorbate}_{bulk}_{surface}/adslab{config}`, where everything in braces are the + respective indices. + + Attributes + ---------- + args : argparse.Namespace + contains all command line args + logger : logging.RootLogger + logging class to print info + adsorbate : Adsorbate + the selected adsorbate object + all_bulks : list + list of `Bulk` objects + bulk_indices_list : list + list of specified bulk indices (ints) that we want to select + + Public methods + -------------- + run() + selects the appropriate materials and writes to files + ''' + + def __init__(self, args): + ''' + Set up args from argparse, random seed, and logging. + ''' + self.args = args + + self.logger = logging.getLogger() + logging.basicConfig(format='[%(asctime)s] %(levelname)s: %(message)s', + datefmt='%H:%M:%S') + self.logger.setLevel(logging.INFO if self.args.verbose else logging.WARNING) + + if self.args.enumerate_all_structures: + self.bulk_indices_list = [int(ind) for ind in args.bulk_indices.split(',')] + self.logger.info(f'Enumerating all surfaces/configs for adsorbate {self.args.adsorbate_index} and bulks {self.bulk_indices_list}') + else: + self.logger.info('Sampling one random structure') + np.random.seed(self.args.seed) + + def run(self): + ''' + Runs the entire job: generates adsorbate/bulk/surface objects and writes to files. + ''' + start = time.time() + + if self.args.enumerate_all_structures: + self.adsorbate = Adsorbate(self.args.adsorbate_db, self.args.adsorbate_index) + self._load_bulks() + self._load_and_write_surfaces() + + end = time.time() + self.logger.info(f'Done! ({round(end - start, 2)}s)') + + def _load_bulks(self): + ''' + Loads bulk structures (one random or a list of specified ones) + and stores them in self.all_bulks + ''' + self.all_bulks = [] + with open(self.args.bulk_db, 'rb') as f: + bulk_db_lookup = pickle.load(f) + + if self.args.enumerate_all_structures: + for ind in self.bulk_indices_list: + self.all_bulks.append(Bulk(bulk_db_lookup, self.args.precomputed_structures, ind)) + else: + self.all_bulks.append(Bulk(bulk_db_lookup, self.args.precomputed_structures)) + + def _load_and_write_surfaces(self): + ''' + Loops through all bulks and chooses one random or all possible surfaces; + writes info for that surface and combined surface+adsorbate + ''' + for bulk_ind, bulk in enumerate(self.all_bulks): + possible_surfaces = bulk.get_possible_surfaces() + if self.args.enumerate_all_structures: + if self.args.surface_index is not None: + assert 0 <= self.args.surface_index < len(possible_surfaces), 'Invalid surface index provided' + self.logger.info(f'Loading only surface {self.args.surface_index} for bulk {self.bulk_indices_list[bulk_ind]}') + included_surface_indices = [self.args.surface_index] + else: + self.logger.info(f'Enumerating all {len(possible_surfaces)} surfaces for bulk {self.bulk_indices_list[bulk_ind]}') + included_surface_indices = range(len(possible_surfaces)) + + for cur_surface_ind in included_surface_indices: + surface_info = possible_surfaces[cur_surface_ind] + surface = Surface(bulk, surface_info, cur_surface_ind, len(possible_surfaces)) + self._combine_and_write(surface, self.bulk_indices_list[bulk_ind], cur_surface_ind) + else: + surface_info_index = np.random.choice(len(possible_surfaces)) + surface = Surface(bulk, possible_surfaces[surface_info_index], surface_info_index, len(possible_surfaces)) + self.adsorbate = Adsorbate(self.args.adsorbate_db) + self._combine_and_write(surface) + + + def _combine_and_write(self, surface, cur_bulk_index=None, cur_surface_index=None): + ''' + Add the adsorbate onto a given surface in a Combined object. + Writes output files for the surface itself and the combined surface+adsorbate + + Args: + surface: a Surface object to combine with self.adsorbate + cur_bulk_index: current bulk index from self.bulk_indices_list + cur_surface_index: current surface index if enumerating all + ''' + if self.args.enumerate_all_structures: + output_name_template = f'{self.args.adsorbate_index}_{cur_bulk_index}_{cur_surface_index}' + else: + output_name_template = f'random{self.args.seed}' + + self._write_surface(surface, output_name_template) + + combined = Combined(self.adsorbate, surface, self.args.enumerate_all_structures) + self._write_adsorbed_surface(combined, output_name_template) + + def _write_surface(self, surface, output_name_template): + ''' + Write VASP input files and metadata for the surface alone. + + Args: + surface: the Surface object to write info for + output_name_template: parent directory name for output files + ''' + bulk_dict = surface.get_bulk_dict() + bulk_dir = os.path.join(self.args.output_dir, output_name_template, 'surface') + write_vasp_input_files(bulk_dict['bulk_atomsobject'], bulk_dir) + self._write_metadata_pkl(bulk_dict, os.path.join(bulk_dir, 'metadata.pkl')) + self.logger.info(f"wrote surface ({bulk_dict['bulk_samplingstr']}) to {bulk_dir}") + + def _write_adsorbed_surface(self, combined, output_name_template): + ''' + Write VASP input files and metadata for the adsorbate placed on surface. + + Args: + combined: the Combined object to write info for, containing any number of adslabs + output_name_template: parent directory name for output files + ''' + self.logger.info(f'Writing {combined.num_configs} adslab configs') + for config_ind in range(combined.num_configs): + if self.args.enumerate_all_structures: + adsorbed_bulk_dir = os.path.join(self.args.output_dir, output_name_template, f'adslab{config_ind}') + else: + adsorbed_bulk_dir = os.path.join(self.args.output_dir, output_name_template, 'adslab') + adsorbed_bulk_dict = combined.get_adsorbed_bulk_dict(config_ind) + write_vasp_input_files(adsorbed_bulk_dict['adsorbed_bulk_atomsobject'], adsorbed_bulk_dir) + self._write_metadata_pkl(adsorbed_bulk_dict, os.path.join(adsorbed_bulk_dir, 'metadata.pkl')) + if config_ind == 0: + self.logger.info(f"wrote adsorbed surface ({adsorbed_bulk_dict['adsorbed_bulk_samplingstr']}) to {adsorbed_bulk_dir}") + + def _write_metadata_pkl(self, dict_to_write, path): + ''' + Writes a dict as a metadata pickle + + Args: + dict_to_write: dict containing all info to dump as file + path: output file path + ''' + file_path = os.path.join(path, 'metadata.pkl') + with open(path, 'wb') as f: + pickle.dump(dict_to_write, f) + diff --git a/ocdata/surfaces.py b/ocdata/surfaces.py new file mode 100644 index 0000000000..b96f571e81 --- /dev/null +++ b/ocdata/surfaces.py @@ -0,0 +1,348 @@ +import math +import os +import pickle +from collections import defaultdict + +import numpy as np +from ase import neighborlist +from ase.constraints import FixAtoms +from pymatgen.analysis.local_env import VoronoiNN +from pymatgen.core import Composition +from pymatgen.io.ase import AseAtomsAdaptor +from pymatgen.symmetry.analyzer import SpacegroupAnalyzer + +from .constants import MIN_XY +from .loader import Loader + + +def constrain_surface(atoms): + """ + This function fixes sub-surface atoms of a surface. Also works on systems + that have surface + adsorbate(s), as long as the bulk atoms are tagged with + `0`, surface atoms are tagged with `1`, and the adsorbate atoms are tagged + with `2` or above. + + This function is used for both surface atoms and the combined surface+adsorbate + + Inputs: + atoms `ase.Atoms` class of the surface system. The tags of + these atoms must be set such that any bulk/surface + atoms are tagged with `0` or `1`, resectively, and any + adsorbate atom is tagged with a 2 or above. + Returns: + atoms A deep copy of the `atoms` argument, but where the appropriate + atoms are constrained. + """ + # Work on a copy so that we don't modify the original + atoms = atoms.copy() + + # We'll be making a `mask` list to feed to the `FixAtoms` class. This list + # should contain a `True` if we want an atom to be constrained, and `False` + # otherwise + mask = [True if atom.tag == 0 else False for atom in atoms] + atoms.constraints += [FixAtoms(mask=mask)] + return atoms + + +class Surface: + """ + This class handles all things with a surface. + Create one with a bulk and one of its selected surfaces + + Attributes + ---------- + bulk_object : Bulk + bulk object that the surface comes from + surface_sampling_str : str + string capturing the surface index and total possible surfaces + surface_atoms : Atoms + actual atoms of the surface + constrained_surface : Atoms + constrained version of surface_atoms + millers : tuple + miller indices of the surface + shift : float + shift applied in the c-direction of bulk unit cell to get a termination + top : boolean + indicates the top or bottom termination of the pymatgen generated slab + + Public methods + -------------- + get_bulk_dict() + returns a dict containing info about the surface + """ + + def __init__( + self, + bulk_object, + surface_info, + surface_index, + total_surfaces_possible, + no_loader=True, + ): + """ + Initialize the surface object, tag atoms, and constrain the surface. + + Args: + bulk_object: `Bulk()` object of the corresponding bulk + surface_info: tuple containing atoms, millers, shift, top + surface_index: index of surface out of all possible ones for the bulk + total_surfaces_possible: number of possible surfaces from this bulk + """ + self.bulk_object = bulk_object + self.no_loader = no_loader + surface_struct, self.millers, self.shift, self.top = surface_info + self.surface_sampling_str = ( + str(surface_index) + "/" + str(total_surfaces_possible) + ) + + unit_surface_atoms = AseAtomsAdaptor.get_atoms(surface_struct) + self.surface_atoms = self.tile_atoms(unit_surface_atoms) + + # verify that the bulk and surface elements and stoichiometry match: + assert ( + Composition(self.surface_atoms.get_chemical_formula()).reduced_formula + == Composition( + bulk_object.bulk_atoms.get_chemical_formula() + ).reduced_formula + ), "Mismatched bulk and surface" + + self.tag_surface_atoms(self.bulk_object.bulk_atoms, self.surface_atoms) + self.constrained_surface = constrain_surface(self.surface_atoms) + + def tile_atoms(self, atoms): + """ + This function will repeat an atoms structure in the x and y direction until + the x and y dimensions are at least as wide as the MIN_XY constant. + + Args: + atoms `ase.Atoms` object of the structure that you want to tile + Returns: + atoms_tiled An `ase.Atoms` object that's just a tiled version of + the `atoms` argument. + """ + x_length = np.linalg.norm(atoms.cell[0]) + y_length = np.linalg.norm(atoms.cell[1]) + nx = int(math.ceil(MIN_XY / x_length)) + ny = int(math.ceil(MIN_XY / y_length)) + n_xyz = (nx, ny, 1) + atoms_tiled = atoms.repeat(n_xyz) + return atoms_tiled + + def tag_surface_atoms(self, bulk_atoms, surface_atoms): + """ + Sets the tags of an `ase.Atoms` object. Any atom that we consider a "bulk" + atom will have a tag of 0, and any atom that we consider a "surface" atom + will have a tag of 1. We use a combination of Voronoi neighbor algorithms + (adapted from from `pymatgen.core.surface.Slab.get_surface_sites`; see + https://pymatgen.org/pymatgen.core.surface.html) and a distance cutoff. + + Arg: + bulk_atoms `ase.Atoms` format of the respective bulk structure + surface_atoms The surface where you are trying to find surface sites in + `ase.Atoms` format + """ + with Loader( + " [surface][tag_surface_atoms] _find_surface_atoms_with_voronoi", + animate=False, + ignore=self.no_loader, + ) as loader: + voronoi_tags = self._find_surface_atoms_with_voronoi( + bulk_atoms, surface_atoms + ) + + height_tags = self._find_surface_atoms_by_height(surface_atoms) + # If either of the methods consider an atom a "surface atom", then tag it as such. + tags = [max(v_tag, h_tag) for v_tag, h_tag in zip(voronoi_tags, height_tags)] + surface_atoms.set_tags(tags) + + def _find_surface_atoms_with_voronoi(self, bulk_atoms, surface_atoms): + """ + Labels atoms as surface or bulk atoms according to their coordination + relative to their bulk structure. If an atom's coordination is less than it + normally is in a bulk, then we consider it a surface atom. We calculate the + coordination using pymatgen's Voronoi algorithms. + + Note that if a single element has different sites within a bulk and these + sites have different coordinations, then we consider slab atoms + "under-coordinated" only if they are less coordinated than the most under + undercoordinated bulk atom. For example: Say we have a bulk with two Cu + sites. One site has a coordination of 12 and another a coordination of 9. + If a slab atom has a coordination of 10, we will consider it a bulk atom. + + Args: + bulk_atoms `ase.Atoms` of the bulk structure the surface was cut + from. + surface_atoms `ase.Atoms` of the surface + Returns: + tags A list of 0's and 1's whose indices align with the atoms in + `surface_atoms`. 0's indicate a bulk atom and 1 indicates a + surface atom. + """ + # Initializations + surface_struct = AseAtomsAdaptor.get_structure(surface_atoms) + center_of_mass = self.calculate_center_of_mass(surface_struct) + bulk_cn_dict = self.calculate_coordination_of_bulk_atoms(bulk_atoms) + voronoi_nn = VoronoiNN(tol=0.1) # 0.1 chosen for better detection + default_cutoff = voronoi_nn.cutoff + + tags = [] + for idx, site in enumerate(surface_struct): + # Tag as surface atom only if it's above the center of mass + if site.frac_coords[2] > center_of_mass[2]: + # Run the voronoi tesselation with increasing cutoffs until it's + # possible to compute the coordination number + cutoff = default_cutoff + max_cutoff = ( + surface_struct.lattice.a**2 + + surface_struct.lattice.b**2 + + surface_struct.lattice.c**2 + ) ** 0.5 + while True: + try: + # Tag as surface if atom is under-coordinated + voronoi_nn.cutoff = cutoff + cn = voronoi_nn.get_cn(surface_struct, idx, use_weights=True) + cn = round(cn, 5) + if cn < min(bulk_cn_dict[site.species_string]): + tags.append(1) + else: + tags.append(0) + break + + # Tag as surface if we get a pathological error + except RuntimeError: + tags.append(1) + break + + # A ValueError can occur if the cutoff is too small. + except ValueError: + # Increase cutoff if max_cutoff has not been reached. Tag atom + # at surface otherwise + if cutoff < max_cutoff: + cutoff = min(cutoff * 2, max_cutoff) + else: + tags.append(1) + break + + # Tag as bulk otherwise + else: + tags.append(0) + return tags + + def calculate_center_of_mass(self, struct): + """ + Determine the surface atoms indices from here + """ + weights = [site.species.weight for site in struct] + center_of_mass = np.average(struct.frac_coords, weights=weights, axis=0) + return center_of_mass + + def calculate_coordination_of_bulk_atoms(self, bulk_atoms): + """ + Finds all unique atoms in a bulk structure and then determines their + coordination number. Then parses these coordination numbers into a + dictionary whose keys are the elements of the atoms and whose values are + their possible coordination numbers. + For example: `bulk_cns = {'Pt': {3., 12.}, 'Pd': {12.}}` + + Arg: + bulk_atoms An `ase.Atoms` object of the bulk structure. + Returns: + bulk_cn_dict A defaultdict whose keys are the elements within + `bulk_atoms` and whose values are a set of integers of the + coordination numbers of that element. + """ + voronoi_nn = VoronoiNN(tol=0.1) # 0.1 chosen for better detection + default_cutoff = voronoi_nn.cutoff + + # Object type conversion so we can use Voronoi + bulk_struct = AseAtomsAdaptor.get_structure(bulk_atoms) + sga = SpacegroupAnalyzer(bulk_struct) + sym_struct = sga.get_symmetrized_structure() + + # We'll only loop over the symmetrically distinct sites for speed's sake + bulk_cn_dict = defaultdict(set) + for idx in sym_struct.equivalent_indices: + site = sym_struct[idx[0]] + + # Run the voronoi tesselation with increasing cutoffs until it's + # possible to compute the coordination number + cutoff = default_cutoff + max_cutoff = ( + bulk_struct.lattice.a**2 + + bulk_struct.lattice.b**2 + + bulk_struct.lattice.c**2 + ) ** 0.5 + while True: + try: + voronoi_nn.cutoff = cutoff + cn = voronoi_nn.get_cn(bulk_struct, idx[0], use_weights=True) + cn = round(cn, 5) + break + + # A ValueError can occur if the cutoff is too small. + except ValueError: + # Increase cutoff if max_cutoff has not been reached. + if cutoff < max_cutoff: + cutoff = min(cutoff * 2, max_cutoff) + else: + raise RuntimeError("No neighbor found even with max cutoff.") + + bulk_cn_dict[site.species_string].add(cn) + return bulk_cn_dict + + def _find_surface_atoms_by_height(self, surface_atoms): + """ + As discussed in the docstring for `_find_surface_atoms_with_voronoi`, + sometimes we might accidentally tag a surface atom as a bulk atom if there + are multiple coordination environments for that atom type within the bulk. + One heuristic that we use to address this is to simply figure out if an + atom is close to the surface. This function will figure that out. + + Specifically: We consider an atom a surface atom if it is within 2 + Angstroms of the heighest atom in the z-direction (or more accurately, the + direction of the 3rd unit cell vector). + + Arg: + surface_atoms The surface where you are trying to find surface sites in + `ase.Atoms` format + Returns: + tags A list that contains the indices of + the surface atoms + """ + unit_cell_height = np.linalg.norm(surface_atoms.cell[2]) + scaled_positions = surface_atoms.get_scaled_positions() + scaled_max_height = max( + scaled_position[2] for scaled_position in scaled_positions + ) + scaled_threshold = scaled_max_height - 2.0 / unit_cell_height + + tags = [ + 0 if scaled_position[2] < scaled_threshold else 1 + for scaled_position in scaled_positions + ] + return tags + + def get_bulk_dict(self): + """ + Returns an organized dict for writing to files. + All info is already processed and stored in class variables. + """ + self.overall_sampling_str = ( + self.bulk_object.elem_sampling_str + + "_" + + self.bulk_object.bulk_sampling_str + + "_" + + self.surface_sampling_str + ) + return { + "bulk_atomsobject": self.constrained_surface, + "bulk_metadata": ( + self.bulk_object.mpid, + self.millers, + round(self.shift, 3), + self.top, + ), + "bulk_samplingstr": self.overall_sampling_str, + } diff --git a/ocdata/vasp.py b/ocdata/vasp.py new file mode 100644 index 0000000000..59fba9097c --- /dev/null +++ b/ocdata/vasp.py @@ -0,0 +1,230 @@ +''' +This submodule contains the scripts that the we used to run VASP. + +Note that some of these scripts were taken and modified from +[GASpy](https://github.com/ulissigroup/GASpy) with permission of authors. +''' + +__author__ = 'Kevin Tran' +__email__ = 'ktran@andrew.cmu.edu' + +import os +import numpy as np +import ase.io +from ase.io.trajectory import TrajectoryWriter +from ase.calculators.vasp import Vasp2 +from ase.calculators.singlepoint import SinglePointCalculator as SPC + +# NOTE: this is the setting for slab and adslab +VASP_FLAGS = {'ibrion': 2, + 'nsw': 2000, + 'isif': 0, + 'isym': 0, + 'lreal': 'Auto', + 'ediffg': -0.03, + 'symprec': 1e-10, + 'encut': 350., + 'laechg': True, + 'lwave': False, + 'ncore': 4, + 'gga': 'RP', + 'pp': 'PBE', + 'xc': 'PBE'} + +# This is the setting for bulk optmization. +# Only use when expanding the bulk_db with other crystal structures. +BULK_VASP_FLAGS = {'ibrion': 1, + 'nsw': 100, + 'isif': 7, + 'isym':0, + 'ediffg': 1e-08, + 'encut': 500., + 'kpts': (10, 10, 10), + 'prec':'Accurate', + 'gga': 'RP', + 'pp': 'PBE', + 'lwave':False, + 'lcharg':False} + +def run_vasp(atoms, vasp_flags=None): + ''' + Will relax the input atoms given the VASP flag inputs. + + Args: + atoms `ase.Atoms` object that we want to relax. + vasp_flags A dictionary of settings we want to pass to the `Vasp2` + calculator. Defaults to a standerd set of values if `None` + Returns: + trajectory A list of `ase.Atoms` objects where each element represents + each step during the relaxation. + ''' + if vasp_flags is None: # Immutable default + vasp_flags = VASP_FLAGS.copy() + + atoms, vasp_flags = _clean_up_inputs(atoms, vasp_flags) + vasp_flags = _set_vasp_command(vasp_flags) + trajectory = relax_atoms(atoms, vasp_flags) + return trajectory + + +def _clean_up_inputs(atoms, vasp_flags): + ''' + Parses the inputs and makes sure some things are straightened out. + + Arg: + atoms `ase.Atoms` object of the structure we want to relax + vasp_flags A dictionary of settings we want to pass to the `Vasp2` + calculator + Returns: + atoms `ase.Atoms` object of the structure we want to relax, but + with the unit vectors fixed (if needed) + vasp_flags A modified version of the 'vasp_flags' argument + ''' + # Check that the unit vectors obey the right-hand rule, (X x Y points in + # Z). If not, then flip the order of X and Y to enforce this so that VASP + # is happy. + if np.dot(np.cross(atoms.cell[0], atoms.cell[1]), atoms.cell[2]) < 0: + atoms.set_cell(atoms.cell[[1, 0, 2], :]) + + # Calculate and set the k points + if 'kpts' not in vasp_flags.keys(): + k_pts = calculate_surface_k_points(atoms) + vasp_flags['kpts'] = k_pts + + return atoms, vasp_flags + + +def calculate_surface_k_points(atoms): + ''' + For surface calculations, it's a good practice to calculate the k-point + mesh given the unit cell size. We do that on-the-spot here. + + Arg: + atoms `ase.Atoms` object of the structure we want to relax + Returns: + k_pts A 3-tuple of integers indicating the k-point mesh to use + ''' + cell = atoms.get_cell() + order = np.inf + a0 = np.linalg.norm(cell[0], ord=order) + b0 = np.linalg.norm(cell[1], ord=order) + multiplier = 40 + k_pts = (max(1, int(round(multiplier/a0))), + max(1, int(round(multiplier/b0))), + 1) + return k_pts + + +def _set_vasp_command(n_processors=16, vasp_executable='vasp_std'): + ''' + This function assigns the appropriate call to VASP to the `$VASP_COMMAND` + variable. + ''' + # TODO: Sid and/or Caleb to figure out what exactly to put here to make + # things work. Here are some examples: + # https://github.com/ulissigroup/GASpy/blob/master/gaspy/vasp_functions.py#L167 + # https://github.com/ulissigroup/GASpy/blob/master/gaspy/vasp_functions.py#L200 + command = 'srun -n %d %s' % (n_processors, vasp_executable) + os.environ['VASP_COMMAND'] = command + raise NotImplementedError + + +def relax_atoms(atoms, vasp_flags): + ''' + Perform a DFT relaxation with VASP and then write the trajectory to the + 'relaxation.traj' file. + + Args: + atoms `ase.Atoms` object of the structure we want to relax + vasp_flags A dictionary of settings we want to pass to the `Vasp2` + calculator + Returns: + images A list of `ase.Atoms` that comprise the relaxation + trajectory + ''' + # Run the calculation + calc = Vasp2(**vasp_flags) + atoms.set_calculator(calc) + atoms.get_potential_energy() + + # Read the trajectory from the output file + images = [] + for atoms in ase.io.read('vasprun.xml', ':'): + image = atoms.copy() + image = image[calc.resort] + image.set_calculator(SPC(image, + energy=atoms.get_potential_energy(), + forces=atoms.get_forces()[calc.resort])) + images += [image] + + # Write the trajectory + with TrajectoryWriter('relaxation.traj', 'a') as writer: + for atoms in images: + writer.write(atoms) + return images + + +def write_vasp_input_files(atoms, outdir='.', vasp_flags=None): + ''' + Effectively goes through the same motions as the `run_vasp` function, + except it only writes the input files instead of running. + + Args: + atoms `ase.Atoms` object that we want to relax. + outdir A string indicating where you want to save the input files. + Defaults to '.' + vasp_flags A dictionary of settings we want to pass to the `Vasp2` + calculator. Defaults to a standerd set of values if `None` + ''' + if vasp_flags is None: # Immutable default + vasp_flags = VASP_FLAGS.copy() + + atoms, vasp_flags = _clean_up_inputs(atoms, vasp_flags) + calc = Vasp2(directory=outdir, **vasp_flags) + calc.write_input(atoms) + + +def xml_to_tuples(xml='vasprun.xml'): + ''' + Converts an XML file into both a trajectory file while also returning the + trajectory as a list of `ase.Atoms` objects + + Args: + xml String indicating the XML file to read from + Returns: + images A list of 5-tuples for each images in the trajectory. The + tuples include a list of symbols for each atom; the positions + of the atoms; the forces each atom sees; the unit cell + dimensions; and the potential energy of the whole system. + ''' + traj = xml_to_traj(xml) + + images = [] + for atoms in traj: + symbols = atoms.get_chemical_symbols() + positions = atoms.get_positions() + forces = atoms.get_forces() + cell = np.array(atoms.get_cell()) + energy = atoms.get_potential_energy() + atoms_tuple = (symbols, positions, forces, cell, energy) + images.append(atoms_tuple) + + return images + + +def xml_to_traj(xml='vasprun.xml'): + ''' + Converts an XML file into both a trajectory file while also returning the + trajectory as a list of `ase.Atoms` objects + + Args: + xml String indicating the XML file to read from + Returns: + traj A list of `ase.Atoms` objects + ''' + traj = ase.io.read(xml, ':') + for atoms in traj: + atoms.set_calculator(SPC(atoms, + energy=atoms.get_potential_energy(), + forces=atoms.get_forces())) + return traj diff --git a/ocpmodels/common/data_parallel.py b/ocpmodels/common/data_parallel.py index 6f0ceca86b..98f19984b8 100644 --- a/ocpmodels/common/data_parallel.py +++ b/ocpmodels/common/data_parallel.py @@ -13,6 +13,7 @@ import numpy as np import torch from torch.utils.data import BatchSampler, DistributedSampler, Sampler +from torch_geometric.data import Data from ocpmodels.common import dist_utils from ocpmodels.datasets import data_list_collater @@ -53,6 +54,14 @@ def forward(self, batch_list, **kwargs): return self.module(batch_list[0], **kwargs) if len(self.device_ids) == 1: + if type(batch_list[0]) is list: + return self.module( + [ + batch_list[0][0].to(f"cuda:{self.device_ids[0]}"), + batch_list[0][1].to(f"cuda:{self.device_ids[0]}"), + ], + **kwargs, + ) return self.module(batch_list[0].to(f"cuda:{self.device_ids[0]}"), **kwargs) for t in chain(self.module.parameters(), self.module.buffers()): diff --git a/ocpmodels/common/flags.py b/ocpmodels/common/flags.py index 8389fa373f..a32e44c20a 100644 --- a/ocpmodels/common/flags.py +++ b/ocpmodels/common/flags.py @@ -87,12 +87,14 @@ def add_core_args(self): "--checkpoint", type=str, help="Model checkpoint to load" ) self.parser.add_argument( - "--continue_from_dir", type=str, help="Run to continue, loading its config" + "--continue_from_dir", + type=str, + help="Continue an existing run, loading its config and overwriting desired arguments", ) self.parser.add_argument( "--restart_from_dir", type=str, - help="Run to restart, loading its config and overwriting " + help="Restart training from an existing run, loading its config and overwriting args" + "from the command-line", ) self.parser.add_argument( @@ -287,6 +289,18 @@ def add_core_args(self): help="Number of validation loops to run in order to collect inference" + " timing stats", ) + self.parser.add_argument( + "--is_disconnected", + type=bool, + default=False, + help="Eliminates edges between catalyst and adsorbate.", + ) + self.parser.add_argument( + "--lowest_energy_only", + type=bool, + default=False, + help="Makes trainer use the lowest energy data point for every (catalyst, adsorbate, cell) tuple. ONLY USE WITH ALL DATASET", + ) flags = Flags() diff --git a/ocpmodels/common/gfn.py b/ocpmodels/common/gfn.py new file mode 100644 index 0000000000..00ecd0cd38 --- /dev/null +++ b/ocpmodels/common/gfn.py @@ -0,0 +1,296 @@ +import os +from copy import deepcopy +from pathlib import Path +from typing import Callable, List, Union + +import torch.nn as nn +from torch_geometric.data.batch import Batch +from torch_geometric.data.data import Data + +from ocpmodels.common.utils import make_trainer_from_dir, resolve +from ocpmodels.datasets.data_transforms import get_transforms +from ocpmodels.models.faenet import FAENet + + +class FAENetWrapper(nn.Module): + def __init__( + self, + faenet: FAENet, + transform: Callable = None, + frame_averaging: str = None, + trainer_config: dict = None, + ): + """ + `FAENetWrapper` is a wrapper class for the FAENet model. It is used to perform + a forward pass of the model when frame averaging is applied. + + Args: + faenet (FAENet, optional): The FAENet model to use. Defaults to None. + transform (Transform, optional): The data transform to use. Defaults to None. + frame_averaging (str, optional): The frame averaging method to use. + trainer_config (dict, optional): The trainer config used to create the model. + Defaults to None. + """ + super().__init__() + + self.faenet = faenet + self.transform = transform + self.frame_averaging = frame_averaging + self.trainer_config = trainer_config + self._is_frozen = None + + @property + def frozen(self): + """ + Returns whether or not the model is frozen. A model is frozen if all of its + parameters are set to not require gradients. + + This is a lazy property, meaning that it is only computed once and then cached. + + Returns: + bool: Whether or not the model is frozen. + """ + if self._is_frozen is None: + frozen = True + for param in self.parameters(): + if param.requires_grad: + frozen = False + break + self._is_frozen = frozen + return self._is_frozen + + def preprocess(self, batch: Union[Batch, Data, List[Data], List[Batch]]): + """ + Preprocess a batch of graphs using the data transform. + + * if batch is a list with one element: + * it could be a batch from the FAENet data loader which produces + lists of Batch with 1 element (because of multi-GPU features) + * if the single element is a Batch, extract it (`batch=batch[0]`) + * if batch is a Data instance, it is a single graph and we turn + it back into a list of 1 element (`batch=[batch]`) + * if it is a Batch instance, it is a collection of graphs and we turn it + into a list of Data graphs (`batch=batch.to_data_list()`) + + Finally we transform the list of Data graphs with the pre-processing transforms + and collate them into a Batch. + + .. code-block:: python + + In [7]: %timeit wrapper.preprocess(batch) + The slowest run took 4.94 times longer than the fastest. + This could mean that an intermediate result is being cached. + 67.1 ms ± 58.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) + + In [8]: %timeit wrapper.preprocess(batch) + 43.8 ms ± 1.66 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) + + Args: + batch (List?[Data, Batch]): The batch of graphs to transform + + Returns: + torch_geometric.Batch: The transformed batch. If frame averaging is + disabled, this is the same as the input batch. + """ + if isinstance(batch, list): + if len(batch) == 1 and isinstance(batch[0], Batch): + batch = batch[0] + if isinstance(batch, Data): + batch = [batch] + if isinstance(batch, Batch): + batch = batch.to_data_list() + + return Batch.from_data_list([self.transform(b) for b in batch]) + + def forward( + self, + batch: Union[Batch, Data, List[Data], List[Batch]], + preprocess: bool = True, + ): + """Perform a forward pass of the model when frame averaging is applied. + + Adapted from + ocmpodels.trainers.single_point_trainer.SingleTrainer.model_forward + + This implementation assumes only the energy is being predicted, and only + frame-averages this prediction. + + Args: + batch (List?[Data, Batch]): The batch of graphs to predict on. + preprocess (bool, optional): Whether or not to apply the data transforms. + Defaults to True. + + Returns: + (dict): model predictions tensor for "energy" and "forces". + """ + if preprocess: + batch = self.preprocess(batch) + if not self.frozen: + raise RuntimeError( + "FAENetWrapper must be frozen before calling forward." + + " Use .freeze() to freeze it." + ) + # Distinguish frame averaging from base case. + if self.frame_averaging and self.frame_averaging != "DA": + original_pos = batch[0].pos + original_cell = batch[0].cell + e_all = [] + + # Compute model prediction for each frame + for i in range(len(batch[0].fa_pos)): + batch[0].pos = batch[0].fa_pos[i] + batch[0].cell = batch[0].fa_cell[i] + + # forward pass + preds = self.faenet( + deepcopy(batch), + mode="inference", + regress_forces=False, + q=None, + ) + e_all.append(preds["energy"]) + + batch[0].pos = original_pos + batch[0].cell = original_cell + + # Average predictions over frames + preds["energy"] = sum(e_all) / len(e_all) + else: + preds = self.faenet(batch) + + if preds["energy"].shape[-1] == 1: + preds["energy"] = preds["energy"].view(-1) + + return preds["energy"] # denormalize? + + def freeze(self): + """Freeze the model parameters.""" + for param in self.parameters(): + param.requires_grad = False + + +def parse_loc() -> str: + """ + Parse the current location from the environment variables. If the location is a + number, assume it is a SLURM job ID and return "mila". Otherwise, return the + location name. + + Returns: + str: Where the current job is running, typically Mila or DRAC or laptop. + """ + loc = os.environ.get( + "SLURM_CLUSTER_NAME", os.environ.get("SLURM_JOB_ID", os.environ["USER"]) + ) + if all(s.isdigit() for s in loc): + loc = "mila" + return loc + + +def find_ckpt(ckpt_paths: dict, release: str) -> Path: + """ + Finds a checkpoint in a dictionary of paths, based on the current cluster name and + release. If the path is a file, use it directly. Otherwise, look for a single + checkpoint file in a ${release}/sub-fodler. E.g.: + ckpt_paths = {"mila": "/path/to/ckpt_dir"} release = v2.3_graph_phys + find_ckpt(ckpt_paths, release) -> /path/to/ckpt_dir/v2.3_graph_phys/name.ckpt + + ckpt_paths = {"mila": "/path/to/ckpt_dir/file.ckpt"} release = v2.3_graph_phys + find_ckpt(ckpt_paths, release) -> /path/to/ckpt_dir/file.ckpt + + Args: + ckpt_paths (dict): Where to look for the checkpoints. + Maps cluster names to paths. + + Raises: + ValueError: The current location is not in the checkpoint path dict. + ValueError: The checkpoint path does not exist. ValueError: The checkpoint path + is a directory and contains no .ckpt file. ValueError: The checkpoint path is a + directory and contains >1 .ckpt files. + + Returns: + Path: Path to the checkpoint for that release on this host. + """ + loc = parse_loc() + if loc not in ckpt_paths: + raise ValueError(f"FAENet proxy checkpoint path not found for location {loc}.") + path = resolve(ckpt_paths[loc]) + if not path.exists(): + raise ValueError(f"FAENet proxy checkpoint not found at {str(path)}.") + if path.is_file(): + return path + path = path / release + ckpts = list(path.glob("**/*.ckpt")) + if len(ckpts) == 0: + raise ValueError(f"No FAENet proxy checkpoint found at {str(path)}.") + if len(ckpts) > 1: + raise ValueError( + f"Multiple FAENet proxy checkpoints found at {str(path)}. " + "Please specify the checkpoint explicitly." + ) + return ckpts[0] + + +def prepare_for_gfn(ckpt_paths: dict, release: str) -> tuple: + """ + Prepare a FAENet model for use in GFN. Loads the checkpoint for the given release + on the current host, and wraps it in a FAENetWrapper. + + Example ckpt_paths: + + ckpt_paths = { + "mila": "/path/to/releases_dir", + "drac": "/path/to/releases_dir", + "laptop": "/path/to/releases_dir", + } + + The loaded model is frozen (all parameters are set to not require gradients). + + Args: + ckpt_paths (dict): Where to look for the checkpoints as {loc: path}. + release (str): Which release to load. + + Returns: + tuple: (model, loaders) where loaders is a dict of loaders for the model. + """ + ckpt_path = find_ckpt(ckpt_paths, release) + assert ckpt_path.exists(), f"Path {ckpt_path} does not exist." + trainer = make_trainer_from_dir( + ckpt_path, + mode="continue", + overrides={ + "is_debug": True, + "silent": True, + "cp_data_to_tmpdir": False, + }, + silent=True, + ) + + wrapper = FAENetWrapper( + faenet=trainer.model, + transform=get_transforms(trainer.config), + frame_averaging=trainer.config.get("frame_averaging", ""), + trainer_config=trainer.config, + ) + wrapper.freeze() + loaders = trainer.loaders + + return wrapper, loaders + + +if __name__ == "__main__": + # for instance in ipython: + # In [1]: run ocpmodels/common/gfn.py + # + from ocpmodels.common.gfn import prepare_for_gfn + + ckpt_paths = {"mila": "/path/to/releases_dir"} + release = "v2.3_graph_phys" + # or + ckpt_paths = { + "mila": "/network/scratch/a/alexandre.duval/ocp/runs/3785941/checkpoints/best_checkpoint.pt" + } + release = None + wrapper, loaders = prepare_for_gfn(ckpt_paths, release) + data_gen = iter(loaders["train"]) + batch = next(data_gen) + preds = wrapper(batch) diff --git a/ocpmodels/common/utils.py b/ocpmodels/common/utils.py index dfa48f51d4..05974fa44e 100644 --- a/ocpmodels/common/utils.py +++ b/ocpmodels/common/utils.py @@ -33,12 +33,12 @@ from matplotlib.figure import Figure from torch_geometric.data import Data from torch_geometric.utils import remove_self_loops -from torch_scatter import segment_coo, segment_csr, scatter +from torch_scatter import scatter, segment_coo, segment_csr import ocpmodels -from ocpmodels.common.flags import flags -from ocpmodels.common.registry import registry import ocpmodels.common.dist_utils as dist_utils +from ocpmodels.common.flags import Flags, flags +from ocpmodels.common.registry import registry class Cluster: @@ -759,6 +759,9 @@ def setup_imports(): # manual model imports importlib.import_module("ocpmodels.models.gemnet_oc.gemnet_oc") + importlib.import_module("ocpmodels.models.gemnet_oc.depgemnet_oc") + importlib.import_module("ocpmodels.models.gemnet_oc.indgemnet_oc") + importlib.import_module("ocpmodels.models.gemnet_oc.agemnet_oc") experimental_folder = os.path.join(root_folder, "../experimental/") if os.path.exists(experimental_folder): @@ -898,6 +901,37 @@ def set_cpus_to_workers(config, silent=False): return config +def set_dataset_split(config): + """ + Set the split for all datasets in the config to the one specified in the + config's name. + + Resulting dict: + { + "dataset": { + "train": { + "split": "all" + ... + }, + ... + } + } + + Args: + config (dict): The full trainer config dict + + Returns: + dict: The updated config dict + """ + split = config["config"].split("-")[-1] + for d, dataset in config["dataset"].items(): + if d == "default_val": + continue + assert isinstance(dataset, dict) + config["dataset"][d]["split"] = split + return config + + def check_regress_forces(config): if "regress_forces" in config["model"]: if config["model"]["regress_forces"] == "": @@ -973,7 +1007,7 @@ def load_config(config_str): return config -def build_config(args, args_override, silent=False): +def build_config(args, args_override=[], silent=None): config, overrides, loaded_config = {}, {}, {} if hasattr(args, "config_yml") and args.config_yml: @@ -997,10 +1031,11 @@ def build_config(args, args_override, silent=False): if args.continue_from_dir else resolve(args.restart_from_dir) ) + already_ckpt = load_dir.exists() and load_dir.is_file() # find configs: from checkpoints first, from the dropped config file # otherwise ckpts = list(load_dir.glob("checkpoints/checkpoint-*.pt")) - if not ckpts: + if not ckpts and not already_ckpt: print(f"💥 Could not find checkpoints in {str(load_dir)}.") configs = list(load_dir.glob("config-*.y*ml")) if not configs: @@ -1011,11 +1046,14 @@ def build_config(args, args_override, silent=False): loaded_config = yaml.safe_load(configs[0].read_text()) load_path = str(configs[0]) else: - latest_ckpt = str( - sorted(ckpts, key=lambda c: float(c.stem.split("-")[-1]))[-1] - ) + if already_ckpt: + latest_ckpt = load_dir + else: + latest_ckpt = str( + sorted(ckpts, key=lambda c: float(c.stem.split("-")[-1]))[-1] + ) load_path = latest_ckpt - loaded_config = torch.load((latest_ckpt), map_location="cpu")["config"] + loaded_config = torch.load(latest_ckpt, map_location="cpu")["config"] # config has been found. We need to prune/modify it depending on whether # we're restarting or continuing. @@ -1036,7 +1074,7 @@ def build_config(args, args_override, silent=False): loaded_config["checkpoint"] = str(latest_ckpt) loaded_config["job_ids"] = loaded_config["job_ids"] + f", {JOB_ID}" loaded_config["job_id"] = JOB_ID - loaded_config["local_rank"] = config["local_rank"] + loaded_config["local_rank"] = config.get("local_rank", 0) else: # restarting from scratch keep_keys = [ @@ -1044,7 +1082,7 @@ def build_config(args, args_override, silent=False): "config", "dataset", "energy_head", - "fa_frames", + "fa_method", "frame_averaging", "graph_rewiring", "model", @@ -1054,6 +1092,7 @@ def build_config(args, args_override, silent=False): "test_ri", "use_pbc", "wandb_project", + "grad_fine_tune", ] loaded_config = { k: loaded_config[k] for k in keep_keys if k in loaded_config @@ -1174,6 +1213,7 @@ def build_config(args, args_override, silent=False): config = override_drac_paths(config) config = continue_from_slurm_job_id(config) config = read_slurm_env(config) + config = set_dataset_split(config) config["optim"]["eval_batch_size"] = config["optim"]["batch_size"] dist_utils.setup(config) @@ -1281,15 +1321,24 @@ def get_pbc_distances( def radius_graph_pbc(data, radius, max_num_neighbors_threshold): - device = data.pos.device - batch_size = len(data.natoms) - - # position of the atoms atom_pos = data.pos + natoms = data.natoms + cell = data.cell + + return radius_graph_pbc_inputs( + atom_pos, natoms, cell, radius, max_num_neighbors_threshold + ) + + +def radius_graph_pbc_inputs( + atom_pos, natoms, cell, radius, max_num_neighbors_threshold +): + device = atom_pos.device + batch_size = len(natoms) # Before computing the pairwise distances between atoms, first create a list # of atom indices to compare for the entire batch - num_atoms_per_image = data.natoms + num_atoms_per_image = natoms num_atoms_per_image_sqr = (num_atoms_per_image**2).long() # index offset between images @@ -1335,22 +1384,22 @@ def radius_graph_pbc(data, radius, max_num_neighbors_threshold): # Note that the unit cell volume V = a1 * (a2 x a3) and that # (a2 x a3) / V is also the reciprocal primitive vector # (crystallographer's definition). - cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1) - cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True) + cross_a2a3 = torch.cross(cell[:, 1], cell[:, 2], dim=-1) + cell_vol = torch.sum(cell[:, 0] * cross_a2a3, dim=-1, keepdim=True) inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1) rep_a1 = torch.ceil(radius * inv_min_dist_a1) - cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1) + cross_a3a1 = torch.cross(cell[:, 2], cell[:, 0], dim=-1) inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1) rep_a2 = torch.ceil(radius * inv_min_dist_a2) if radius >= 20: # Cutoff larger than the vacuum layer of 20A - cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1) + cross_a1a2 = torch.cross(cell[:, 0], cell[:, 1], dim=-1) inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1) rep_a3 = torch.ceil(radius * inv_min_dist_a3) else: - rep_a3 = data.cell.new_zeros(1) + rep_a3 = cell.new_zeros(1) # Take the max over all images for uniformity. This is essentially padding. # Note that this can significantly increase the number of computed distances # if the required repetitions are very different between images @@ -1371,7 +1420,7 @@ def radius_graph_pbc(data, radius, max_num_neighbors_threshold): unit_cell_batch = unit_cell.view(1, 3, num_cells).expand(batch_size, -1, -1) # Compute the x, y, z positional offsets for each cell in each image - data_cell = torch.transpose(data.cell, 1, 2) + data_cell = torch.transpose(cell, 1, 2) pbc_offsets = torch.bmm(data_cell, unit_cell_batch) pbc_offsets_per_atom = torch.repeat_interleave( pbc_offsets, num_atoms_per_image_sqr, dim=0 @@ -1403,7 +1452,7 @@ def radius_graph_pbc(data, radius, max_num_neighbors_threshold): atom_distance_sqr = torch.masked_select(atom_distance_sqr, mask) mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask( - natoms=data.natoms, + natoms=natoms, index=index1, atom_distance=atom_distance_sqr, max_num_neighbors_threshold=max_num_neighbors_threshold, @@ -1430,6 +1479,7 @@ def get_max_neighbors_mask(natoms, index, atom_distance, max_num_neighbors_thres `max_num_neighbors_threshold` neighbors. Assumes that `index` is sorted. """ + device = natoms.device num_atoms = natoms.sum() @@ -1744,3 +1794,64 @@ def scatter_det(*args, **kwargs): torch.use_deterministic_algorithms(mode=False) return out + + +def make_config_from_dir(path, mode, overrides={}, silent=None): + """ + Make a config from a directory. This is useful when restarting or continuing from a + previous run. + + Args: + path (str): Where to load the config from. mode (str): Either 'continue' or + 'restart'. overrides (dict, optional): Dictionary to update the config with . + Defaults to {}. silent (bool, optional): Whether or not to print loading + status. Defaults to None. + + Returns: + dict: The loaded and overridden config. + """ + path = resolve(path) + assert path.exists() + assert mode in { + "continue", + "restart", + }, f"Invalid mode: {mode}. Expected 'continue' or 'restart'." + assert isinstance( + overrides, dict + ), f"Overrides must be a dict. Received {overrides}" + + argv = deepcopy(sys.argv) + sys.argv[1:] = [] + default_args = Flags().get_parser().parse_args() + sys.argv = argv + + if mode == "continue": + default_args.continue_from_dir = str(path) + else: + default_args.restart_from_dir = str(path) + + config = build_config(default_args, silent=silent) + config = merge_dicts(config, overrides) + + setup_imports() + return config + + +def make_trainer_from_dir(path, mode, overrides={}, silent=None): + """ + Make a trainer from a directory. + + Load a config with `make_config_from_dir` and then make a trainer from it. + + Args: + path (str): Where to load the config from. + mode (str): Either 'continue' or 'restart'. + overrides (dict, optional): Dictionary to update the config with. + Defaults to {}. + silent (bool, optional): _description_. Defaults to None. + + Returns: + Trainer: The loaded trainer. + """ + config = make_config_from_dir(path, mode, overrides, silent) + return registry.get_trainer_class(config["trainer"])(**config) diff --git a/ocpmodels/datasets/data_transforms.py b/ocpmodels/datasets/data_transforms.py index 64556a0038..1b3e0a37c9 100644 --- a/ocpmodels/datasets/data_transforms.py +++ b/ocpmodels/datasets/data_transforms.py @@ -92,14 +92,45 @@ def __init__(self, rewiring_type=None) -> None: def __call__(self, data): if self.inactive: return data - - data.batch = torch.zeros(data.num_nodes, dtype=torch.long) - data.natoms = torch.tensor([data.natoms]) - data.ptr = torch.tensor([0, data.natoms]) + if not hasattr(data, "batch") or data.batch is None: + data.batch = torch.zeros(data.num_nodes, dtype=torch.long) + if isinstance(data.natoms, int) or data.natoms.ndim == 0: + data.natoms = torch.tensor([data.natoms]) + if not hasattr(data, "ptr") or data.ptr is None: + data.ptr = torch.tensor([0, data.natoms]) return self.rewiring_func(data) +class Disconnected(Transform): + def __init__(self, is_disconnected=False) -> None: + self.inactive = not is_disconnected + + def edge_classifier(self, edge_index, tags): + edges_with_tags = tags[ + edge_index.type(torch.long) + ] # Tensor with shape=edge_index.shape where every entry is a tag + filt1 = edges_with_tags[0] == edges_with_tags[1] + filt2 = (edges_with_tags[0] != 2) * (edges_with_tags[1] != 2) + + # Edge is removed if tags are different (R1), and at least one end has tag 2 (R2). We want ~(R1*R2) = ~R1+~R2. + # filt1 = ~R1. Let L1 be that head has tag 2, and L2 is that tail has tag 2. Then R2 = L1+L2, so ~R2 = ~L1*~L2 = filt2. + + return filt1 + filt2 + + def __call__(self, data): + if self.inactive: + return data + + values = self.edge_classifier(data.edge_index, data.tags) + + data.edge_index = data.edge_index[:, values] + data.cell_offsets = data.cell_offsets[values, :] + data.distances = data.distances[values] + + return data + + class Compose: # https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#Compose def __init__(self, transforms): @@ -140,5 +171,6 @@ def get_transforms(trainer_config): AddAttributes(), GraphRewiring(trainer_config.get("graph_rewiring")), FrameAveraging(trainer_config["frame_averaging"], trainer_config["fa_frames"]), + Disconnected(trainer_config["is_disconnected"]), ] return Compose(transforms) diff --git a/ocpmodels/datasets/lmdb_dataset.py b/ocpmodels/datasets/lmdb_dataset.py index 2eaef01200..c07ae1773a 100644 --- a/ocpmodels/datasets/lmdb_dataset.py +++ b/ocpmodels/datasets/lmdb_dataset.py @@ -6,6 +6,7 @@ """ import bisect +import json import logging import pickle import time @@ -36,15 +37,37 @@ class LmdbDataset(Dataset): config (dict): Dataset configuration transform (callable, optional): Data transform function. (default: :obj:`None`) + fa_frames (str, optional): type of frame averaging method applied, if any. + adsorbates (str, optional): comma-separated list of adsorbates to filter. + If None or "all", no filtering is applied. + (default: None) + adsorbates_ref_dir: where metadata files for adsorbates are stored. + (default: "/network/scratch/s/schmidtv/ocp/datasets/ocp/per_ads") """ - def __init__(self, config, transform=None, fa_frames=None): - super(LmdbDataset, self).__init__() + def __init__( + self, + config, + transform=None, + fa_frames=None, + lmdb_glob=None, + adsorbates=None, + adsorbates_ref_dir=None, + silent=False, + ): + super().__init__() self.config = config + self.adsorbates = adsorbates + self.adsorbates_ref_dir = adsorbates_ref_dir + self.silent = silent self.path = Path(self.config["src"]) if not self.path.is_file(): db_paths = sorted(self.path.glob("*.lmdb")) + if lmdb_glob: + db_paths = [ + p for p in db_paths if any(lg in p.stem for lg in lmdb_glob) + ] assert len(db_paths) > 0, f"No LMDBs found in '{self.path}'" self.metadata_path = self.path / "metadata.npz" @@ -58,7 +81,7 @@ def __init__(self, config, transform=None, fa_frames=None): else: length = self.envs[-1].stat()["entries"] assert length is not None, f"Could not find length of LMDB {db_path}" - self._keys.append(list(range(length))) + self._keys.append([str(i).encode("ascii") for i in range(length)]) keylens = [len(k) for k in self._keys] self._keylen_cumulative = np.cumsum(keylens).tolist() @@ -71,14 +94,99 @@ def __init__(self, config, transform=None, fa_frames=None): ] self.num_samples = len(self._keys) + self.filter_per_adsorbates() self.transform = transform - self.fa_frames = fa_frames + self.fa_method = fa_frames + + def filter_per_adsorbates(self): + """Filter the dataset to only include structures with a specific + adsorbate. + """ + # no adsorbates specified, or asked for all: return + if not self.adsorbates or self.adsorbates == "all": + return + + # val_ood_ads and val_ood_both don't have targeted adsorbates + if Path(self.config["src"]).parts[-1] in {"val_ood_ads", "val_ood_both"}: + return + + # make set of adsorbates from a list or a string. If a string, split on comma. + ads = [] + if isinstance(self.adsorbates, str): + if "," in self.adsorbates: + ads = [a.strip() for a in self.adsorbates.split(",")] + else: + ads = [self.adsorbates] + else: + ads = self.adsorbates + ads = set(ads) + + # find reference file for this dataset + ref_path = self.adsorbates_ref_dir + if not ref_path: + print("No adsorbate reference directory provided as `adsorbate_ref_dir`.") + return + ref_path = Path(ref_path) + if not ref_path.is_dir(): + print(f"Adsorbate reference directory {ref_path} does not exist.") + return + pattern = f"{self.config['split']}-{self.path.parts[-1]}" + candidates = list(ref_path.glob(f"*{pattern}*.json")) + if not candidates: + print( + f"No adsorbate reference files found for {self.path.name}.:" + + "\n".join( + [ + str(p) + for p in [ + ref_path, + pattern, + list(ref_path.glob(f"*{pattern}*.json")), + list(ref_path.glob("*")), + ] + ] + ) + ) + return + if len(candidates) > 1: + print( + f"Multiple adsorbate reference files found for {self.path.name}." + "Using the first one." + ) + ref = json.loads(candidates[0].read_text()) + + # find dataset indices with the appropriate adsorbates + allowed_idxs = set( + str(i).encode("ascii") + for i, a in zip(ref["ds_idx"], ref["ads_symbols"]) + if a in ads + ) + + previous_samples = self.num_samples + + # filter the dataset indices + if isinstance(self._keys[0], bytes): + self._keys = [i for i in self._keys if i in allowed_idxs] + self.num_samples = len(self._keys) + else: + assert isinstance(self._keys[0], list) + self._keys = [[i for i in k if i in allowed_idxs] for k in self._keys] + keylens = [len(k) for k in self._keys] + self._keylen_cumulative = np.cumsum(keylens).tolist() + self.num_samples = sum(keylens) + + if not self.silent: + print( + f"Filtered dataset {pattern} from {previous_samples} to", + f"{self.num_samples} samples. (adsorbates: {ads})", + ) + + assert self.num_samples > 0, f"No samples found for adsorbates {ads}." def __len__(self): return self.num_samples - def __getitem__(self, idx): - t0 = time.time_ns() + def get_pickled_from_db(self, idx): if not self.path.is_file(): # Figure out which db this should be indexed from. db_idx = bisect.bisect(self._keylen_cumulative, idx) @@ -89,16 +197,20 @@ def __getitem__(self, idx): assert el_idx >= 0 # Return features. - datapoint_pickled = ( - self.envs[db_idx] - .begin() - .get(f"{self._keys[db_idx][el_idx]}".encode("ascii")) + return ( + f"{db_idx}_{el_idx}", + self.envs[db_idx].begin().get(self._keys[db_idx][el_idx]), ) - data_object = pyg2_data_transform(pickle.loads(datapoint_pickled)) - data_object.id = f"{db_idx}_{el_idx}" - else: - datapoint_pickled = self.env.begin().get(self._keys[idx]) - data_object = pyg2_data_transform(pickle.loads(datapoint_pickled)) + + return None, self.env.begin().get(self._keys[idx]) + + def __getitem__(self, idx): + t0 = time.time_ns() + + el_id, datapoint_pickled = self.get_pickled_from_db(idx) + data_object = pyg2_data_transform(pickle.loads(datapoint_pickled)) + if el_id: + data_object.id = el_id t1 = time.time_ns() if self.transform is not None: @@ -112,6 +224,7 @@ def __getitem__(self, idx): data_object.load_time = load_time data_object.transform_time = transform_time data_object.total_get_time = total_get_time + data_object.idx_in_dataset = idx return data_object @@ -137,6 +250,30 @@ def close_db(self): self.env.close() +@registry.register_dataset("deup_lmdb") +class DeupDataset(LmdbDataset): + def __init__(self, all_datasets_configs, deup_split, transform=None, silent=False): + # ! WARNING: this does not (yet?) handle adsorbate filtering + super().__init__( + all_datasets_configs[deup_split], + lmdb_glob=deup_split.replace("deup-", "").split("-"), + silent=silent, + ) + ocp_splits = deup_split.split("-")[1:] + self.ocp_datasets = { + d: LmdbDataset(all_datasets_configs[d], transform, silent=silent) + for d in ocp_splits + } + + def __getitem__(self, idx): + _, datapoint_pickled = self.get_pickled_from_db(idx) + deup_sample = pickle.loads(datapoint_pickled) + ocp_sample = self.ocp_datasets[deup_sample["ds"]][deup_sample["idx_in_dataset"]] + for k, v in deup_sample.items(): + setattr(ocp_sample, f"deup_{k}", v) + return ocp_sample + + class SinglePointLmdbDataset(LmdbDataset): def __init__(self, config, transform=None): super(SinglePointLmdbDataset, self).__init__(config, transform) diff --git a/ocpmodels/datasets/other_datasets.py b/ocpmodels/datasets/other_datasets.py new file mode 100644 index 0000000000..5f615a75a0 --- /dev/null +++ b/ocpmodels/datasets/other_datasets.py @@ -0,0 +1,177 @@ +import bisect +import logging +import pickle +import time +from pathlib import Path + +import torch +from torch_geometric.data import Data, HeteroData + +from ocpmodels.datasets.lmdb_dataset import LmdbDataset +from ocpmodels.common.registry import registry +from ocpmodels.common.utils import pyg2_data_transform + + +# This is a function that receives an adsorbate/catalyst system and returns +# each of these parts separately. +def graph_splitter(graph): + edge_index = graph.edge_index + pos = graph.pos + cell = graph.cell + atomic_numbers = graph.atomic_numbers + natoms = graph.natoms + cell_offsets = graph.cell_offsets + force = graph.force + distances = graph.distances + fixed = graph.fixed + tags = graph.tags + y_init = graph.y_init + y_relaxed = graph.y_relaxed + pos_relaxed = graph.pos_relaxed + id = graph.id + + # Make masks to filter most data we need + adsorbate_v_mask = tags == 2 + catalyst_v_mask = ~adsorbate_v_mask + + adsorbate_e_mask = (tags[edge_index][0] == 2) * (tags[edge_index][1] == 2) + catalyst_e_mask = (tags[edge_index][0] != 2) * (tags[edge_index][1] != 2) + + # Reindex the edge indices. + device = graph.edge_index.device + + ads_assoc = torch.full((natoms,), -1, dtype=torch.long, device=device) + cat_assoc = torch.full((natoms,), -1, dtype=torch.long, device=device) + + ads_natoms = adsorbate_v_mask.sum() + cat_natoms = catalyst_v_mask.sum() + + ads_assoc[adsorbate_v_mask] = torch.arange(ads_natoms, device=device) + cat_assoc[catalyst_v_mask] = torch.arange(cat_natoms, device=device) + + ads_edge_index = ads_assoc[edge_index[:, adsorbate_e_mask]] + cat_edge_index = cat_assoc[edge_index[:, catalyst_e_mask]] + + # Create the graphs + adsorbate = Data( + edge_index=ads_edge_index, + pos=pos[adsorbate_v_mask, :], + cell=cell, + atomic_numbers=atomic_numbers[adsorbate_v_mask], + natoms=ads_natoms, + cell_offsets=cell_offsets[adsorbate_e_mask, :], + force=force[adsorbate_v_mask, :], + tags=tags[adsorbate_v_mask], + y_init=y_init, + y_relaxed=y_relaxed, + pos_relaxed=pos_relaxed[adsorbate_v_mask, :], + id=id, + mode="adsorbate", + ) + + catalyst = Data( + edge_index=cat_edge_index, + pos=pos[catalyst_v_mask, :], + cell=cell, + atomic_numbers=atomic_numbers[catalyst_v_mask], + natoms=cat_natoms, + cell_offsets=cell_offsets[catalyst_e_mask, :], + force=force[catalyst_v_mask, :], + tags=tags[catalyst_v_mask], + y_init=y_init, + y_relaxed=y_relaxed, + pos_relaxed=pos_relaxed[catalyst_v_mask, :], + id=id, + mode="catalyst", + ) + + return adsorbate, catalyst + + +# This dataset class sends back a tuple with the adsorbate and catalyst. +@registry.register_dataset("separate") +class SeparateLmdbDataset( + LmdbDataset +): # Check that the dataset works as intended, with an specific example. + def __getitem__(self, idx): + t0 = time.time_ns() + if not self.path.is_file(): + # Figure out which db this should be indexed from. + db_idx = bisect.bisect(self._keylen_cumulative, idx) + # Extract index of element within that db. + el_idx = idx + if db_idx != 0: + el_idx = idx - self._keylen_cumulative[db_idx - 1] + assert el_idx >= 0 + + # Return features. + datapoint_pickled = ( + self.envs[db_idx] + .begin() + .get(f"{self._keys[db_idx][el_idx]}".encode("ascii")) + ) + data_object = pyg2_data_transform(pickle.loads(datapoint_pickled)) + data_object.id = f"{db_idx}_{el_idx}" + else: + datapoint_pickled = self.env.begin().get(self._keys[idx]) + data_object = pyg2_data_transform(pickle.loads(datapoint_pickled)) + + # We separate the graphs + adsorbate, catalyst = graph_splitter(data_object) + + t1 = time.time_ns() + if self.transform is not None: + adsorbate = self.transform(adsorbate) + catalyst = self.transform(catalyst) + t2 = time.time_ns() + + load_time = (t1 - t0) * 1e-9 # time in s + transform_time = (t2 - t1) * 1e-9 # time in s + total_get_time = (t2 - t0) * 1e-9 # time in s + + adsorbate.load_time = load_time + adsorbate.transform_time = transform_time + adsorbate.total_get_time = total_get_time + + catalyst.load_time = load_time + catalyst.transform_time = transform_time + catalyst.total_get_time = total_get_time + + return (adsorbate, catalyst) + + +@registry.register_dataset("heterogeneous") +class HeterogeneousDataset(SeparateLmdbDataset): + def __getitem__(self, idx): + # We start by separating the adsorbate and catalyst + adsorbate, catalyst = super().__getitem__(idx) + + # We save each into the heterogeneous graph + reaction = HeteroData() + for graph in [adsorbate, catalyst]: + mode = graph.mode + for key in graph.keys: + if key == "edge_index": + continue + reaction[mode][key] = graph[key] + + reaction[mode, "is_close", mode].edge_index = graph.edge_index + + # We create the edges between both parts of the graph. + sender = torch.repeat_interleave( + torch.arange(catalyst.natoms.item()), adsorbate.natoms.item() + ) + receiver = torch.arange(0, adsorbate.natoms.item()).repeat( + catalyst.natoms.item() + ) + reaction["catalyst", "is_disc", "adsorbate"].edge_index = torch.stack( + [sender, receiver] + ) + reaction[ + "catalyst", "is_disc", "adsorbate" + ].edge_weight = torch.repeat_interleave( + reaction["catalyst"].pos[:, 2], + adsorbate.natoms.item(), + ) + + return reaction diff --git a/ocpmodels/models/__init__.py b/ocpmodels/models/__init__.py index a722f78170..417433ac0e 100644 --- a/ocpmodels/models/__init__.py +++ b/ocpmodels/models/__init__.py @@ -8,6 +8,8 @@ from .dimenet import DimeNet # noqa: F401 from .faenet import FAENet # noqa: F401 from .gemnet.gemnet import GemNetT # noqa: F401 +from .gemnet.depgemnet_t import depGemNetT # noqa: F401 +from .gemnet.indgemnet_t import indGemNetT # noqa: F401 from .dimenet_plus_plus import DimeNetPlusPlus # noqa: F401 from .forcenet import ForceNet # noqa: F401 from .schnet import SchNet # noqa: F401 diff --git a/ocpmodels/models/adpp.py b/ocpmodels/models/adpp.py new file mode 100644 index 0000000000..725d51bd54 --- /dev/null +++ b/ocpmodels/models/adpp.py @@ -0,0 +1,964 @@ +from math import pi as PI +from math import sqrt + +import torch +from torch import nn +from torch.nn import Embedding, Linear +from torch_geometric.nn import radius_graph +from torch_geometric.nn.inits import glorot_orthogonal +from torch_geometric.nn.models.dimenet import ( + Envelope, + ResidualLayer, + SphericalBasisLayer, +) +from torch_scatter import scatter +from torch_sparse import SparseTensor + +from ocpmodels.common.registry import registry +from ocpmodels.common.utils import ( + conditional_grad, + get_pbc_distances, + radius_graph_pbc_inputs, +) +from ocpmodels.models.base_model import BaseModel +from ocpmodels.models.utils.pos_encodings import PositionalEncoding +from ocpmodels.modules.phys_embeddings import PhysEmbedding +from ocpmodels.modules.pooling import Graclus, Hierarchical_Pooling +from ocpmodels.models.utils.activations import swish +from ocpmodels.models.afaenet import GATInteraction, GaussianSmearing + + +try: + import sympy as sym +except ImportError: + sym = None + +NUM_CLUSTERS = 20 +NUM_POOLING_LAYERS = 1 + + +class BesselBasisLayer(torch.nn.Module): + def __init__(self, num_radial, cutoff=5.0, envelope_exponent=5): + super().__init__() + self.cutoff = cutoff + self.envelope = Envelope(envelope_exponent) + + self.freq = torch.nn.Parameter(torch.Tensor(num_radial)) + + self.reset_parameters() + + @torch.no_grad() + def reset_parameters(self): + torch.arange(1, self.freq.numel() + 1, out=self.freq).mul_(PI) + + def forward(self, dist): + dist = dist.unsqueeze(-1) / self.cutoff + return self.envelope(dist) * (self.freq * dist).sin() + + +class EmbeddingBlock(torch.nn.Module): + def __init__(self, num_radial, hidden_channels, act=swish): + super().__init__() + self.act = act + + self.emb = Embedding(85, hidden_channels) + self.lin_rbf = Linear(num_radial, hidden_channels) + self.lin = Linear(3 * hidden_channels, hidden_channels) + + self.reset_parameters() + + def reset_parameters(self): + self.emb.weight.data.uniform_(-sqrt(3), sqrt(3)) + self.lin_rbf.reset_parameters() + self.lin.reset_parameters() + + def forward(self, x, rbf, i, j, tags=None, subnodes=None): + x = self.emb(x) + rbf = self.act(self.lin_rbf(rbf)) + return self.act(self.lin(torch.cat([x[i], x[j], rbf], dim=-1))) + + +class AdvancedEmbeddingBlock(torch.nn.Module): + def __init__( + self, + num_radial, + hidden_channels, + tag_hidden_channels, + pg_hidden_channels, + phys_hidden_channels, + phys_embeds, + graph_rewiring, + act=swish, + ): + super().__init__() + self.act = act + self.use_tag = tag_hidden_channels > 0 + self.use_pg = pg_hidden_channels > 0 + self.use_mlp_phys = phys_hidden_channels > 0 + self.use_positional_embeds = graph_rewiring in { + "one-supernode-per-graph", + "one-supernode-per-atom-type", + "one-supernode-per-atom-type-dist", + } + # self.use_positional_embeds = False + + # Phys embeddings + self.phys_emb = PhysEmbedding(props=phys_embeds, pg=self.use_pg) + # With MLP + if self.use_mlp_phys: + self.phys_lin = Linear(self.phys_emb.n_properties, phys_hidden_channels) + else: + phys_hidden_channels = self.phys_emb.n_properties + # Period + group embeddings + if self.use_pg: + self.period_embedding = Embedding( + self.phys_emb.period_size, pg_hidden_channels + ) + self.group_embedding = Embedding( + self.phys_emb.group_size, pg_hidden_channels + ) + # Tag embedding + if tag_hidden_channels: + self.tag = Embedding(3, tag_hidden_channels) + + # Position encoding + if self.use_positional_embeds: + self.pe = PositionalEncoding(hidden_channels, 210) + + # Main embedding + self.emb = Embedding( + 85, + hidden_channels + - tag_hidden_channels + - phys_hidden_channels + - 2 * pg_hidden_channels, + ) + + self.lin_rbf = Linear(num_radial, hidden_channels) + self.lin = Linear(3 * hidden_channels, hidden_channels) + + self.reset_parameters() + + def reset_parameters(self): + self.emb.weight.data.uniform_(-sqrt(3), sqrt(3)) + if self.use_mlp_phys: + self.phys_lin.reset_parameters() + if self.use_tag: + self.tag.weight.data.uniform_(-sqrt(3), sqrt(3)) + if self.use_pg: + self.period_embedding.weight.data.uniform_(-sqrt(3), sqrt(3)) + self.group_embedding.weight.data.uniform_(-sqrt(3), sqrt(3)) + self.lin_rbf.reset_parameters() + self.lin.reset_parameters() + + def forward(self, x, rbf, i, j, tag=None, subnodes=None): + x_ = self.emb(x) + rbf = self.act(self.lin_rbf(rbf)) + + if self.phys_emb.device != x.device: + self.phys_emb = self.phys_emb.to(x.device) + + if self.use_tag: + x_tag = self.tag(tag) + x_ = torch.cat((x_, x_tag), dim=1) + + if self.phys_emb.n_properties > 0: + x_phys = self.phys_emb.properties[x] + if self.use_mlp_phys: + x_phys = self.phys_lin(x_phys) + x_ = torch.cat((x_, x_phys), dim=1) + + if self.use_pg: + x_period = self.period_embedding(self.phys_emb.period[x]) + x_group = self.group_embedding(self.phys_emb.group[x]) + x_ = torch.cat((x_, x_period, x_group), dim=1) + + if self.use_positional_embeds: + idx_of_non_zero_val = (tag == 0).nonzero().T.squeeze(0) + x_pos = torch.zeros_like(x_, device=x_.device) + x_pos[idx_of_non_zero_val, :] = self.pe(subnodes).to(device=x_pos.device) + x_ += x_pos + + return self.act( + self.lin( + torch.cat( + [ + x_[i], + x_[j], + rbf, + ], + dim=-1, + ) + ) + ) + + +class InteractionPPBlock(torch.nn.Module): + def __init__( + self, + hidden_channels, + int_emb_size, + basis_emb_size, + num_spherical, + num_radial, + num_before_skip, + num_after_skip, + act=swish, + ): + super(InteractionPPBlock, self).__init__() + self.act = act + + # Transformations of Bessel and spherical basis representations. + self.lin_rbf1 = nn.Linear(num_radial, basis_emb_size, bias=False) + self.lin_rbf2 = nn.Linear(basis_emb_size, hidden_channels, bias=False) + self.lin_sbf1 = nn.Linear( + num_spherical * num_radial, basis_emb_size, bias=False + ) + self.lin_sbf2 = nn.Linear(basis_emb_size, int_emb_size, bias=False) + + # Dense transformations of input messages. + self.lin_kj = nn.Linear(hidden_channels, hidden_channels) + self.lin_ji = nn.Linear(hidden_channels, hidden_channels) + + # Embedding projections for interaction triplets. + self.lin_down = nn.Linear(hidden_channels, int_emb_size, bias=False) + self.lin_up = nn.Linear(int_emb_size, hidden_channels, bias=False) + + # Residual layers before and after skip connection. + self.layers_before_skip = torch.nn.ModuleList( + [ResidualLayer(hidden_channels, act) for _ in range(num_before_skip)] + ) + self.lin = nn.Linear(hidden_channels, hidden_channels) + self.layers_after_skip = torch.nn.ModuleList( + [ResidualLayer(hidden_channels, act) for _ in range(num_after_skip)] + ) + + self.reset_parameters() + + def reset_parameters(self): + glorot_orthogonal(self.lin_rbf1.weight, scale=2.0) + glorot_orthogonal(self.lin_rbf2.weight, scale=2.0) + glorot_orthogonal(self.lin_sbf1.weight, scale=2.0) + glorot_orthogonal(self.lin_sbf2.weight, scale=2.0) + + glorot_orthogonal(self.lin_kj.weight, scale=2.0) + self.lin_kj.bias.data.fill_(0) + glorot_orthogonal(self.lin_ji.weight, scale=2.0) + self.lin_ji.bias.data.fill_(0) + + glorot_orthogonal(self.lin_down.weight, scale=2.0) + glorot_orthogonal(self.lin_up.weight, scale=2.0) + + for res_layer in self.layers_before_skip: + res_layer.reset_parameters() + glorot_orthogonal(self.lin.weight, scale=2.0) + self.lin.bias.data.fill_(0) + for res_layer in self.layers_after_skip: + res_layer.reset_parameters() + + def forward(self, x, rbf, sbf, idx_kj, idx_ji): + # Initial transformations. + x_ji = self.act(self.lin_ji(x)) + x_kj = self.act(self.lin_kj(x)) + + # Transformation via Bessel basis. + rbf = self.lin_rbf1(rbf) + rbf = self.lin_rbf2(rbf) + x_kj = x_kj * rbf + + # Down-project embeddings and generate interaction triplet embeddings. + x_kj = self.act(self.lin_down(x_kj)) + + # Transform via 2D spherical basis. + sbf = self.lin_sbf1(sbf) + sbf = self.lin_sbf2(sbf) + x_kj = x_kj[idx_kj] * sbf + + # Aggregate interactions and up-project embeddings. + x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x.size(0)) + x_kj = self.act(self.lin_up(x_kj)) + + h = x_ji + x_kj + for layer in self.layers_before_skip: + h = layer(h) + h = self.act(self.lin(h)) + x + for layer in self.layers_after_skip: + h = layer(h) + + return h + + +class EHOutputPPBlock(torch.nn.Module): + def __init__( + self, + num_radial, + hidden_channels, + out_emb_channels, + out_channels, + num_layers, + energy_head, + act=swish, + ): + super(EHOutputPPBlock, self).__init__() + self.act = act + self.energy_head = energy_head + + self.lin_rbf = nn.Linear(num_radial, hidden_channels, bias=False) + self.lin_up = nn.Linear(hidden_channels, out_emb_channels, bias=True) + self.lins = torch.nn.ModuleList() + for _ in range(num_layers): + self.lins.append(nn.Linear(out_emb_channels, out_emb_channels)) + self.lin = nn.Linear(out_emb_channels, out_channels, bias=False) + + # weighted average & pooling + if self.energy_head in {"pooling", "random"}: + self.hierarchical_pooling = Hierarchical_Pooling( + hidden_channels, + self.act, + NUM_POOLING_LAYERS, + NUM_CLUSTERS, + self.energy_head, + ) + elif self.energy_head == "graclus": + self.graclus = Graclus(hidden_channels, self.act) + elif self.energy_head == "weighted-av-final-embeds": + self.w_lin = Linear(hidden_channels, 1) + + self.reset_parameters() + + def reset_parameters(self): + glorot_orthogonal(self.lin_rbf.weight, scale=2.0) + glorot_orthogonal(self.lin_up.weight, scale=2.0) + for lin in self.lins: + glorot_orthogonal(lin.weight, scale=2.0) + lin.bias.data.fill_(0) + self.lin.weight.data.fill_(0) + if self.energy_head == "weighted-av-final-embeds": + self.w_lin.bias.data.fill_(0) + torch.nn.init.xavier_uniform_(self.w_lin.weight) + + def forward(self, x, rbf, i, edge_index, edge_weight, batch, num_nodes=None): + x = self.lin_rbf(rbf) * x + x = scatter(x, i, dim=0, dim_size=num_nodes) + + pooling_loss = None + if self.energy_head == "weighted-av-final-embeds": + alpha = self.w_lin(x) + elif self.energy_head == "graclus": + x, batch = self.graclus(x, edge_index, edge_weight, batch) + elif self.energy_head in {"pooling", "random"}: + x, batch, pooling_loss = self.hierarchical_pooling( + x, edge_index, edge_weight, batch + ) + + x = self.lin_up(x) + for lin in self.lins: + x = self.act(lin(x)) + x = self.lin(x) + + if self.energy_head == "weighted-av-final-embeds": + x = x * alpha + + return x, pooling_loss, batch + + +class OutputPPBlock(torch.nn.Module): + def __init__( + self, + num_radial, + hidden_channels, + out_emb_channels, + out_channels, + num_layers, + act=swish, + ): + super(OutputPPBlock, self).__init__() + self.act = act + + self.lin_rbf = nn.Linear(num_radial, hidden_channels, bias=False) + self.lin_up = nn.Linear(hidden_channels, out_emb_channels, bias=True) + self.lins = torch.nn.ModuleList() + for _ in range(num_layers): + self.lins.append(nn.Linear(out_emb_channels, out_emb_channels)) + self.lin = nn.Linear(out_emb_channels, out_channels, bias=False) + + self.reset_parameters() + + def reset_parameters(self): + glorot_orthogonal(self.lin_rbf.weight, scale=2.0) + glorot_orthogonal(self.lin_up.weight, scale=2.0) + for lin in self.lins: + glorot_orthogonal(lin.weight, scale=2.0) + lin.bias.data.fill_(0) + self.lin.weight.data.fill_(0) + + def forward(self, x, rbf, i, num_nodes=None): + x = self.lin_rbf(rbf) * x + x = scatter(x, i, dim=0, dim_size=num_nodes) + x = self.lin_up(x) + for lin in self.lins: + x = self.act(lin(x)) + return self.lin(x) + + +@registry.register_model("adpp") +class ADPP(BaseModel): + r"""DimeNet++ implementation based on https://github.com/klicperajo/dimenet. + + Args: + hidden_channels (int): Hidden embedding size. + tag_hidden_channels (int): tag embedding size + pg_hidden_channels (int): period & group embedding size + phys_hidden_channels (int): MLP hidden size for physics embedding + phys_embeds (bool): whether we use physics embeddings or not + graph_rewiring (str): name of rewiring method. Default=False. + out_channels (int): Size of each output sample. + num_blocks (int): Number of building blocks. + int_emb_size (int): Embedding size used for interaction triplets + basis_emb_size (int): Embedding size used in the basis transformation + out_emb_channels(int): Embedding size used for atoms in the output block + num_spherical (int): Number of spherical harmonics. + num_radial (int): Number of radial basis functions. + cutoff: (float, optional): Cutoff distance for interatomic + interactions. (default: :obj:`5.0`) + use_pbc (bool, optional): Use of periodic boundary conditions. + (default: true) + otf_graph (bool, optional): Recompute radius graph. + (default: false) + envelope_exponent (int, optional): Shape of the smooth cutoff. + (default: :obj:`5`) + num_before_skip: (int, optional): Number of residual layers in the + interaction blocks before the skip connection. (default: :obj:`1`) + num_after_skip: (int, optional): Number of residual layers in the + interaction blocks after the skip connection. (default: :obj:`2`) + num_output_layers: (int, optional): Number of linear layers for the + output blocks. (default: :obj:`3`) + act: (function, optional): The activation function. + (default: :obj:`swish`) + regress_forces: (bool, optional): Compute atom forces from energy. + (default: false). + """ + + url = "https://github.com/klicperajo/dimenet/raw/master/pretrained" + + def __init__(self, **kwargs): + super().__init__() + + kwargs["num_targets"] = kwargs["hidden_channels"] // 2 + self.act = swish + + self.cutoff = kwargs["cutoff"] + self.use_pbc = kwargs["use_pbc"] + self.otf_graph = kwargs["otf_graph"] + self.regress_forces = kwargs["regress_forces"] + self.energy_head = kwargs["energy_head"] + use_tag = kwargs["tag_hidden_channels"] > 0 + use_pg = kwargs["pg_hidden_channels"] > 0 + act = ( + getattr(nn.functional, kwargs["act"]) if kwargs["act"] != "swish" else swish + ) + + assert ( + kwargs["tag_hidden_channels"] + 2 * kwargs["pg_hidden_channels"] + 16 + < kwargs["hidden_channels"] + ) + if sym is None: + raise ImportError("Package `sympy` could not be found.") + + self.rbf_ads = BesselBasisLayer( + kwargs["num_radial"], self.cutoff, kwargs["envelope_exponent"] + ) + self.rbf_cat = BesselBasisLayer( + kwargs["num_radial"], self.cutoff, kwargs["envelope_exponent"] + ) + self.sbf_ads = SphericalBasisLayer( + kwargs["num_spherical"], + kwargs["num_radial"], + self.cutoff, + kwargs["envelope_exponent"], + ) + self.sbf_cat = SphericalBasisLayer( + kwargs["num_spherical"], + kwargs["num_radial"], + self.cutoff, + kwargs["envelope_exponent"], + ) + # Disconnected interaction embedding + self.distance_expansion_disc = GaussianSmearing(0.0, 20.0, 100) + self.disc_edge_embed = Linear(100, kwargs["hidden_channels"]) + + if use_tag or use_pg or kwargs["phys_embeds"] or kwargs["graph_rewiring"]: + self.emb_ads = AdvancedEmbeddingBlock( + kwargs["num_radial"], + kwargs["hidden_channels"], + kwargs["tag_hidden_channels"], + kwargs["pg_hidden_channels"], + kwargs["phys_hidden_channels"], + kwargs["phys_embeds"], + kwargs["graph_rewiring"], + act, + ) + self.emb_cat = AdvancedEmbeddingBlock( + kwargs["num_radial"], + kwargs["hidden_channels"], + kwargs["tag_hidden_channels"], + kwargs["pg_hidden_channels"], + kwargs["phys_hidden_channels"], + kwargs["phys_embeds"], + kwargs["graph_rewiring"], + act, + ) + else: + self.emb_ads = EmbeddingBlock( + kwargs["num_radial"], kwargs["hidden_channels"], act + ) + self.emb_cat = EmbeddingBlock( + kwargs["num_radial"], kwargs["hidden_channels"], act + ) + + if self.energy_head: + self.output_blocks_ads = torch.nn.ModuleList( + [ + EHOutputPPBlock( + kwargs["num_radial"], + kwargs["hidden_channels"], + kwargs["out_emb_channels"], + kwargs["num_targets"], + kwargs["num_output_layers"], + self.energy_head, + act, + ) + for _ in range(kwargs["num_blocks"] + 1) + ] + ) + self.output_blocks_cat = torch.nn.ModuleList( + [ + EHOutputPPBlock( + kwargs["num_radial"], + kwargs["hidden_channels"], + kwargs["out_emb_channels"], + kwargs["num_targets"], + kwargs["num_output_layers"], + self.energy_head, + act, + ) + for _ in range(kwargs["num_blocks"] + 1) + ] + ) + else: + self.output_blocks_ads = torch.nn.ModuleList( + [ + OutputPPBlock( + kwargs["num_radial"], + kwargs["hidden_channels"], + kwargs["out_emb_channels"], + kwargs["num_targets"], + kwargs["num_output_layers"], + act, + ) + for _ in range(kwargs["num_blocks"] + 1) + ] + ) + self.output_blocks_cat = torch.nn.ModuleList( + [ + OutputPPBlock( + kwargs["num_radial"], + kwargs["hidden_channels"], + kwargs["out_emb_channels"], + kwargs["num_targets"], + kwargs["num_output_layers"], + act, + ) + for _ in range(kwargs["num_blocks"] + 1) + ] + ) + + self.interaction_blocks_ads = torch.nn.ModuleList( + [ + InteractionPPBlock( + kwargs["hidden_channels"], + kwargs["int_emb_size"], + kwargs["basis_emb_size"], + kwargs["num_spherical"], + kwargs["num_radial"], + kwargs["num_before_skip"], + kwargs["num_after_skip"], + act, + ) + for _ in range(kwargs["num_blocks"]) + ] + ) + self.interaction_blocks_cat = torch.nn.ModuleList( + [ + InteractionPPBlock( + kwargs["hidden_channels"], + kwargs["int_emb_size"], + kwargs["basis_emb_size"], + kwargs["num_spherical"], + kwargs["num_radial"], + kwargs["num_before_skip"], + kwargs["num_after_skip"], + act, + ) + for _ in range(kwargs["num_blocks"]) + ] + ) + self.inter_interactions = torch.nn.ModuleList( + [ + GATInteraction( + kwargs["hidden_channels"], + kwargs["gat_mode"], + kwargs["hidden_channels"], + ) + for _ in range(kwargs["num_blocks"]) + ] + ) + + if self.energy_head == "weighted-av-initial-embeds": + self.w_lin_ads = Linear(kwargs["hidden_channels"], 1) + self.w_lin_cat = Linear(kwargs["hidden_channels"], 1) + + self.task = kwargs["task_name"] + + self.combination = nn.Sequential( + Linear(kwargs["hidden_channels"] // 2 * 2, kwargs["hidden_channels"] // 2), + self.act, + Linear(kwargs["hidden_channels"] // 2, 1), + ) + + self.reset_parameters() + + def reset_parameters(self): + self.rbf_ads.reset_parameters() + self.rbf_cat.reset_parameters() + self.emb_ads.reset_parameters() + self.emb_cat.reset_parameters() + for out in self.output_blocks_ads: + out.reset_parameters() + for out in self.output_blocks_cat: + out.reset_parameters() + for interaction in self.interaction_blocks_ads: + interaction.reset_parameters() + for interaction in self.interaction_blocks_cat: + interaction.reset_parameters() + if self.energy_head == "weighted-av-initial-embeds": + self.w_lin_ads.bias.data.fill_(0) + self.w_lin_cat.bias.data.fill_(0) + torch.nn.init.xavier_uniform_(self.w_lin.weight) + + def triplets(self, edge_index, cell_offsets, num_nodes): + row, col = edge_index # j->i + + value = torch.arange(row.size(0), device=row.device) + adj_t = SparseTensor( + row=col, col=row, value=value, sparse_sizes=(num_nodes, num_nodes) + ) + adj_t_row = adj_t[row] + num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long) + + # Node indices (k->j->i) for triplets. + idx_i = col.repeat_interleave(num_triplets) + idx_j = row.repeat_interleave(num_triplets) + idx_k = adj_t_row.storage.col() + + # Edge indices (k->j, j->i) for triplets. + idx_kj = adj_t_row.storage.value() + idx_ji = adj_t_row.storage.row() + + # Remove self-loop triplets d->b->d + # Check atom as well as cell offset + cell_offset_kji = cell_offsets[idx_kj] + cell_offsets[idx_ji] + mask = (idx_i != idx_k) | torch.any(cell_offset_kji != 0, dim=-1) + + idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask] + idx_kj, idx_ji = idx_kj[mask], idx_ji[mask] + + return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji + + @conditional_grad(torch.enable_grad()) + def energy_forward(self, data): + ( + pos_ads, + edge_index_ads, + cell_ads, + cell_offsets_ads, + neighbors_ads, + batch_ads, + atomic_numbers_ads, + tags_ads, + ) = ( + data["adsorbate"].pos, + data["adsorbate", "is_close", "adsorbate"].edge_index, + data["adsorbate"].cell, + data["adsorbate"].cell_offsets, + data["adsorbate"].neighbors, + data["adsorbate"].batch, + data["adsorbate"].atomic_numbers, + data["adsorbate"].tags, + ) + ( + pos_cat, + edge_index_cat, + cell_cat, + cell_offsets_cat, + neighbors_cat, + batch_cat, + atomic_numbers_cat, + tags_cat, + ) = ( + data["catalyst"].pos, + data["catalyst", "is_close", "catalyst"].edge_index, + data["catalyst"].cell, + data["catalyst"].cell_offsets, + data["catalyst"].neighbors, + data["catalyst"].batch, + data["catalyst"].atomic_numbers, + data["catalyst"].tags, + ) + + if self.otf_graph: # NOT IMPLEMENTED!! + edge_index, cell_offsets, neighbors = radius_graph_pbc_inputs( + pos, natoms, cell, self.cutoff, 50 + ) + data.edge_index = edge_index + data.cell_offsets = cell_offsets + data.neighbors = neighbors + + # Rewire the graph + subnodes = False + + if self.use_pbc: + out = get_pbc_distances( + pos_ads, + edge_index_ads, + cell_ads, + cell_offsets_ads, + neighbors_ads, + return_offsets=True, + ) + + edge_index_ads = out["edge_index"] + dist_ads = out["distances"] + offsets_ads = out["offsets"] + + j_ads, i_ads = edge_index_ads + + out = get_pbc_distances( + pos_cat, + edge_index_cat, + cell_cat, + cell_offsets_cat, + neighbors_cat, + return_offsets=True, + ) + + edge_index_cat = out["edge_index"] + dist_cat = out["distances"] + offsets_cat = out["offsets"] + + j_cat, i_cat = edge_index_cat + else: # NOT IMPLEMENTED + edge_index = radius_graph(pos, r=self.cutoff, batch=batch) + j, i = edge_index + dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt() + + _, _, idx_i_ads, idx_j_ads, idx_k_ads, idx_kj_ads, idx_ji_ads = self.triplets( + edge_index_ads, + cell_offsets_ads, + num_nodes=atomic_numbers_ads.size(0), + ) + _, _, idx_i_cat, idx_j_cat, idx_k_cat, idx_kj_cat, idx_ji_cat = self.triplets( + edge_index_cat, + cell_offsets_cat, + num_nodes=atomic_numbers_cat.size(0), + ) + + # Calculate angles. + pos_i_ads = pos_ads[idx_i_ads].detach() + pos_j_ads = pos_ads[idx_j_ads].detach() + + pos_i_cat = pos_cat[idx_i_cat].detach() + pos_j_cat = pos_cat[idx_j_cat].detach() + if self.use_pbc: + pos_ji_ads, pos_kj_ads = ( + pos_ads[idx_j_ads].detach() - pos_i_ads + offsets_ads[idx_ji_ads], + pos_ads[idx_k_ads].detach() - pos_j_ads + offsets_ads[idx_kj_ads], + ) + pos_ji_cat, pos_kj_cat = ( + pos_cat[idx_j_cat].detach() - pos_i_cat + offsets_cat[idx_ji_cat], + pos_cat[idx_k_cat].detach() - pos_j_cat + offsets_cat[idx_kj_cat], + ) + else: # NOT IMPLEMENTED + pos_ji, pos_kj = ( + pos[idx_j].detach() - pos_i, + pos[idx_k].detach() - pos_j, + ) + + a_ads = (pos_ji_ads * pos_kj_ads).sum(dim=-1) + b_ads = torch.cross(pos_ji_ads, pos_kj_ads).norm(dim=-1) + angle_ads = torch.atan2(b_ads, a_ads) + + a_cat = (pos_ji_cat * pos_kj_cat).sum(dim=-1) + b_cat = torch.cross(pos_ji_cat, pos_kj_cat).norm(dim=-1) + angle_cat = torch.atan2(b_cat, a_cat) + + rbf_ads = self.rbf_ads(dist_ads) + sbf_ads = self.sbf_ads(dist_ads, angle_ads, idx_kj_ads) + + rbf_cat = self.rbf_cat(dist_cat) + sbf_cat = self.sbf_cat(dist_cat, angle_cat, idx_kj_cat) + + pooling_loss = None # deal with pooling loss + + # Embedding block. + x_ads = self.emb_ads( + atomic_numbers_ads.long(), rbf_ads, i_ads, j_ads, tags_ads, subnodes + ) + if self.energy_head: + P_ads, pooling_loss, batch_ads = self.output_blocks_ads[0]( + x_ads, + rbf_ads, + i_ads, + edge_index_ads, + dist_ads, + batch_ads, + num_nodes=pos_ads.size(0), + ) + else: + P_ads = self.output_blocks_ads[0]( + x_ads, rbf_ads, i_ads, num_nodes=pos_ads.size(0) + ) + + if self.energy_head == "weighted-av-initial-embeds": + alpha_ads = self.w_lin_ads( + scatter(x_ads, i_ads, dim=0, dim_size=pos_ads.size(0)) + ) + + x_cat = self.emb_cat( + atomic_numbers_cat.long(), rbf_cat, i_cat, j_cat, tags_cat, subnodes + ) + if self.energy_head: + P_cat, pooling_loss, batch_cat = self.output_blocks_cat[0]( + x_cat, + rbf_cat, + i_cat, + edge_index_cat, + dist_cat, + batch_cat, + num_nodes=pos_cat.size(0), + ) + else: + P_cat = self.output_blocks_cat[0]( + x_cat, rbf_cat, i_cat, num_nodes=pos_cat.size(0) + ) + + if self.energy_head == "weighted-av-initial-embeds": + alpha_cat = self.w_lin_cat( + scatter(x_cat, i_cat, dim=0, dim_size=pos_cat.size(0)) + ) + + edge_weights = self.distance_expansion_disc(data["is_disc"].edge_weight) + edge_weights = self.disc_edge_embed(edge_weights) + + # Interaction blocks. + energy_Ps_ads = [] + energy_Ps_cat = [] + + for ( + interaction_block_ads, + interaction_block_cat, + output_block_ads, + output_block_cat, + disc_interaction, + ) in zip( + self.interaction_blocks_ads, + self.interaction_blocks_cat, + self.output_blocks_ads[1:], + self.output_blocks_cat[1:], + self.inter_interactions, + ): + intra_ads = interaction_block_ads( + x_ads, rbf_ads, sbf_ads, idx_kj_ads, idx_ji_ads + ) + intra_cat = interaction_block_cat( + x_cat, rbf_cat, sbf_cat, idx_kj_cat, idx_ji_cat + ) + + inter_ads, inter_cat = disc_interaction( + intra_ads, intra_cat, data["is_disc"].edge_index, edge_weights + ) + + x_ads, x_cat = x_ads + inter_ads, x_cat + inter_cat + x_ads, x_cat = nn.functional.normalize(x_ads), nn.functional.normalize( + x_cat + ) + + if self.energy_head: + P_bis_ads, pooling_loss_bis_ads, _ = output_block_ads( + x_ads, + rbf_ads, + i_ads, + edge_index_ads, + dist_ads, + batch_ads, + num_nodes=pos_ads.size(0), + ) + energy_Ps_ads.append( + P_bis_ads.sum(0) / len(P) + if batch_ads is None + else scatter(P_bis_ads, batch_ads, dim=0) + ) + if pooling_loss_bis_ads is not None: + pooling_loss += pooling_loss_bis_ads + + P_bis_cat, pooling_loss_bis_cat, _ = output_block_cat( + x_cat, + rbf_cat, + i_cat, + edge_index_cat, + dist_cat, + batch_cat, + num_nodes=pos_cat.size(0), + ) + energy_Ps_cat.append( + P_bis_cat.sum(0) / len(P) + if batch_cat is None + else scatter(P_bis_cat, batch_cat, dim=0) + ) + if pooling_loss_bis_cat is not None: + pooling_loss += pooling_loss_bis_cat + else: + P_ads += output_block_ads( + x_ads, rbf_ads, i_ads, num_nodes=pos_ads.size(0) + ) + P_cat += output_block_cat( + x_cat, rbf_cat, i_cat, num_nodes=pos_cat.size(0) + ) + + if self.energy_head == "weighted-av-initial-embeds": + P = P * alpha + + # Output + # scatter + energy_ads = self.scattering(batch_ads, P_ads) + energy_cat = self.scattering(batch_cat, P_cat) + energy = torch.cat([energy_ads, energy_cat], dim=1) + energy = self.combination(energy) + + return { + "energy": energy, + "pooling_loss": pooling_loss, + } + + def scattering(self, batch, P, P_bis=0): + energy = scatter(P, batch, dim=0, reduce="add") + + return energy + + @conditional_grad(torch.enable_grad()) + def forces_forward(self, preds): + return + + @property + def num_params(self): + return sum(p.numel() for p in self.parameters()) diff --git a/ocpmodels/models/afaenet.py b/ocpmodels/models/afaenet.py new file mode 100644 index 0000000000..1aa4f6f7e4 --- /dev/null +++ b/ocpmodels/models/afaenet.py @@ -0,0 +1,438 @@ +import torch +import math +from torch import nn +from torch.nn import Linear, Transformer, Softmax + +from torch_geometric.data import Batch +from torch_geometric.nn import radius_graph, GATConv, GATv2Conv + +from torch_sparse import SparseTensor, spspmm +from torch_sparse import transpose as transpose_sparse +from scipy import sparse + +from ocpmodels.models.faenet import ( + GaussianSmearing, + EmbeddingBlock, + InteractionBlock, + OutputBlock, +) +from ocpmodels.models.indfaenet import PositionalEncoding +from ocpmodels.common.registry import registry +from ocpmodels.models.base_model import BaseModel +from ocpmodels.common.utils import conditional_grad, get_pbc_distances +from ocpmodels.models.utils.activations import swish + + +class GATInteraction(nn.Module): + def __init__(self, d_model, version, edge_dim, dropout=0.1): + super(GATInteraction, self).__init__() + + if version not in {"v1", "v2"}: + raise ValueError( + f"Invalid GAT version. Received {version}, available: v1, v2." + ) + + # Not quite sure what is the impact of increasing or decreasing the number of heads + if version == "v1": + self.interaction = GATConv( + in_channels=d_model, + out_channels=d_model, + heads=3, + concat=False, + edge_dim=edge_dim, + dropout=dropout, + ) + else: + self.interaction = GATv2Conv( + in_channels=d_model, + out_channels=d_model, + head=3, + concat=False, + edge_dim=edge_dim, + dropout=dropout, + ) + + def forward(self, h_ads, h_cat, bipartite_edges, bipartite_weights): + # We first do the message passing + separation_pt = h_ads.shape[0] + combined = torch.concat([h_ads, h_cat], dim=0) + combined = self.interaction(combined, bipartite_edges, bipartite_weights) + + # We separate again and we return + ads, cat = combined[:separation_pt], combined[separation_pt:] + # QUESTION: Should normalization happen before separating them? + # ads, cat = nn.functional.normalize(ads), nn.functional.normalize(cat) + # ads, cat = ads + h_ads, cat + h_cat + + return ads, cat + + +@registry.register_model("afaenet") +class AFaenet(BaseModel): + def __init__(self, **kwargs): + super(AFaenet, self).__init__() + + self.cutoff = kwargs["cutoff"] + self.energy_head = kwargs["energy_head"] + self.regress_forces = kwargs["regress_forces"] + self.use_pbc = kwargs["use_pbc"] + self.max_num_neighbors = kwargs["max_num_neighbors"] + self.edge_embed_type = kwargs["edge_embed_type"] + self.skip_co = kwargs["skip_co"] + if kwargs["mp_type"] == "sfarinet": + kwargs["num_filters"] = kwargs["hidden_channels"] + self.hidden_channels = kwargs["hidden_channels"] + + self.act = ( + getattr(nn.functional, kwargs["act"]) if kwargs["act"] != "swish" else swish + ) + self.use_positional_embeds = kwargs["graph_rewiring"] in { + "one-supernode-per-graph", + "one-supernode-per-atom-type", + "one-supernode-per-atom-type-dist", + } + + # Gaussian Basis + self.distance_expansion_ads = GaussianSmearing( + 0.0, self.cutoff, kwargs["num_gaussians"] + ) + self.distance_expansion_cat = GaussianSmearing( + 0.0, self.cutoff, kwargs["num_gaussians"] + ) + self.distance_expansion_disc = GaussianSmearing( + 0.0, 20.0, kwargs["num_gaussians"] + ) + # Set the second parameter as the highest possible z-axis value + + # Embedding block + self.embed_block_ads = EmbeddingBlock( + kwargs["num_gaussians"], + kwargs["num_filters"], + kwargs["hidden_channels"], + kwargs["tag_hidden_channels"], + kwargs["pg_hidden_channels"], + kwargs["phys_hidden_channels"], + kwargs["phys_embeds"], + kwargs["graph_rewiring"], + self.act, + kwargs["second_layer_MLP"], + kwargs["edge_embed_type"], + ) + self.embed_block_cat = EmbeddingBlock( + kwargs["num_gaussians"], + kwargs["num_filters"], + kwargs["hidden_channels"], + kwargs["tag_hidden_channels"], + kwargs["pg_hidden_channels"], + kwargs["phys_hidden_channels"], + kwargs["phys_embeds"], + kwargs["graph_rewiring"], + self.act, + kwargs["second_layer_MLP"], + kwargs["edge_embed_type"], + ) + self.disc_edge_embed = Linear(kwargs["num_gaussians"], kwargs["num_filters"]) + + # Interaction block + self.interaction_blocks_ads = nn.ModuleList( + [ + InteractionBlock( + kwargs["hidden_channels"], + kwargs["num_filters"], + self.act, + kwargs["mp_type"], + kwargs["complex_mp"], + kwargs["att_heads"], + kwargs["graph_norm"], + ) + for _ in range(kwargs["num_interactions"]) + ] + ) + self.interaction_blocks_cat = nn.ModuleList( + [ + InteractionBlock( + kwargs["hidden_channels"], + kwargs["num_filters"], + self.act, + kwargs["mp_type"], + kwargs["complex_mp"], + kwargs["att_heads"], + kwargs["graph_norm"], + ) + for _ in range(kwargs["num_interactions"]) + ] + ) + + assert ( + "afaenet_gat_mode" in kwargs + ), "GAT version needs to be specified. Options: v1, v2" + # Inter Interaction + self.inter_interactions = nn.ModuleList( + [ + GATInteraction( + kwargs["hidden_channels"], + kwargs["afaenet_gat_mode"], + kwargs["num_filters"], + ) + for _ in range(kwargs["num_interactions"]) + ] + ) + + # Output blocks + self.output_block_ads = OutputBlock( + self.energy_head, kwargs["hidden_channels"], self.act, kwargs["model_name"] + ) + self.output_block_cat = OutputBlock( + self.energy_head, kwargs["hidden_channels"], self.act, kwargs["model_name"] + ) + + # Energy head + if self.energy_head == "weighted-av-initial-embeds": + self.w_lin_ads = Linear(kwargs["hidden_channels"], 1) + self.w_lin_cat = Linear(kwargs["hidden_channels"], 1) + + # Skip co + if ( + self.skip_co == "concat" + ): # for the implementation of independent faenet, make sure the input is large enough + self.mlp_skip_co_ads = Linear( + (kwargs["num_interactions"] + 1) * kwargs["hidden_channels"] // 2, + kwargs["hidden_channels"] // 2, + ) + self.mlp_skip_co_cat = Linear( + (kwargs["num_interactions"] + 1) * kwargs["hidden_channels"] // 2, + kwargs["hidden_channels"] // 2, + ) + + elif self.skip_co == "concat_atom": + self.mlp_skip_co = Linear( + ((kwargs["num_interactions"] + 1) * kwargs["hidden_channels"]), + kwargs["hidden_channels"], + ) + + self.transformer_out = kwargs.get("transformer_out", False) + if self.transformer_out: + self.combination = Transformer( + d_model=kwargs["hidden_channels"] // 2, + nhead=2, + num_encoder_layers=2, + num_decoder_layers=2, + dim_feedforward=kwargs["hidden_channels"], + batch_first=True, + ) + self.positional_encoding = PositionalEncoding( + kwargs["hidden_channels"] // 2, + dropout=0.1, + max_len=5, + ) + self.query_pos = nn.Parameter(torch.rand(kwargs["hidden_channels"] // 2)) + self.transformer_lin = Linear(kwargs["hidden_channels"] // 2, 1) + else: + self.combination = nn.Sequential( + Linear(kwargs["hidden_channels"], kwargs["hidden_channels"] // 2), + swish, + Linear(kwargs["hidden_channels"] // 2, 1), + ) + + @conditional_grad(torch.enable_grad()) + def energy_forward(self, data): + batch_size = len(data) + batch_ads = data["adsorbate"]["batch"] + batch_cat = data["catalyst"]["batch"] + + # Graph rewiring + ads_rewiring, cat_rewiring = self.graph_rewiring(data, batch_ads, batch_cat) + edge_index_ads, edge_weight_ads, rel_pos_ads, edge_attr_ads = ads_rewiring + edge_index_cat, edge_weight_cat, rel_pos_cat, edge_attr_cat = cat_rewiring + + # Embedding + h_ads, e_ads = self.embedding( + data["adsorbate"].atomic_numbers.long(), + edge_weight_ads, + rel_pos_ads, + edge_attr_ads, + data["adsorbate"].tags, + self.embed_block_ads, + ) + h_cat, e_cat = self.embedding( + data["catalyst"].atomic_numbers.long(), + edge_weight_cat, + rel_pos_cat, + edge_attr_cat, + data["catalyst"].tags, + self.embed_block_cat, + ) + + # Compute atom weights for late energy head + if self.energy_head == "weighted-av-initial-embeds": + alpha_ads = self.w_lin_ads(h_ads) + alpha_cat = self.w_lin_cat(h_cat) + else: + alpha_ads = None + alpha_cat = None + + # Edge embeddings of the complete bipartite graph. + edge_weights = self.distance_expansion_disc(data["is_disc"].edge_weight) + edge_weights = self.disc_edge_embed(edge_weights) + + # Now we do interactions. + energy_skip_co_ads = [] + energy_skip_co_cat = [] + for interaction_ads, interaction_cat, inter_interaction in zip( + self.interaction_blocks_ads, + self.interaction_blocks_cat, + self.inter_interactions, + ): + if self.skip_co == "concat_atom": + energy_skip_co_ads.append(h_ads) + energy_skip_co_cat.append(h_cat) + elif self.skip_co: + energy_skip_co_ads.append( + self.output_block_ads( + h_ads, edge_index_ads, edge_weight_ads, batch_ads, alpha_ads + ) + ) + energy_skip_co_cat.append( + self.output_block_cat( + h_cat, edge_index_cat, edge_weight_cat, batch_cat, alpha_cat + ) + ) + # First we do intra interaction + intra_ads = interaction_ads(h_ads, edge_index_ads, e_ads) + intra_cat = interaction_cat(h_cat, edge_index_cat, e_cat) + + # Then we do inter interaction + inter_ads, inter_cat = inter_interaction( + intra_ads, + intra_cat, + data["is_disc"].edge_index, + edge_weights, + ) + # QUESTION: Can we do both simultaneously? + + h_ads, h_cat = h_ads + inter_ads, h_cat + inter_cat + h_ads, h_cat = nn.functional.normalize(h_ads), nn.functional.normalize( + h_cat + ) + + # Atom skip-co + if self.skip_co == "concat_atom": + energy_skip_co_ads.append(h_ads) + energy_skip_co_cat.append(h_cat) + + h_ads = self.act(self.mlp_skip_co_ads(torch.cat(energy_skip_co_ads, dim=1))) + h_cat = self.act(self.mlp_skip_co_cat(torch.cat(energy_skip_co_cat, dim=1))) + + energy_ads = self.output_block_ads( + h_ads, edge_index_ads, edge_weight_ads, batch_ads, alpha_ads + ) + energy_cat = self.output_block_cat( + h_cat, edge_index_cat, edge_weight_cat, batch_cat, alpha_cat + ) + + # Skip-connection + energy_skip_co_ads.append(energy_ads) + energy_skip_co_cat.append(energy_cat) + if self.skip_co == "concat": + energy_ads = self.mlp_skip_co_ads(torch.cat(energy_skip_co_ads, dim=1)) + energy_cat = self.mlp_skip_co_cat(torch.cat(energy_skip_co_cat, dim=1)) + elif self.skip_co == "add": + energy_ads = sum(energy_skip_co_ads) + energy_cat = sum(energy_skip_co_cat) + + # Combining hidden representations + if self.transformer_out: + batch_size = energy_ads.shape[0] + + fake_target_sequence = ( + self.query_pos.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) + ) + system_energy = torch.cat( + [energy_ads.unsqueeze(1), energy_cat.unsqueeze(1)], dim=1 + ) + + system_energy = self.positional_encoding(system_energy) + + system_energy = self.combination( + system_energy, fake_target_sequence + ).squeeze(1) + system_energy = self.transformer_lin(system_energy) + else: + system_energy = torch.cat([energy_ads, energy_cat], dim=1) + system_energy = self.combination(system_energy) + + # We combine predictions and return them + pred_system = { + "energy": system_energy, + "pooling_loss": None, # This might break something. + "hidden_state": torch.cat([energy_ads, energy_cat], dim=1), + } + + return pred_system + + @conditional_grad(torch.enable_grad()) + def embedding(self, z, edge_weight, rel_pos, edge_attr, tags, embed_func): + # Normalize and squash to [0,1] for gaussian basis + rel_pos_normalized = None + if self.edge_embed_type in {"sh", "all_rij", "all"}: + rel_pos_normalized = (rel_pos / edge_weight.view(-1, 1) + 1) / 2.0 + + pooling_loss = None # deal with pooling loss + + # Embedding block + h, e = embed_func(z, rel_pos, edge_attr, tags, rel_pos_normalized) + + return h, e + + @conditional_grad(torch.enable_grad()) + def graph_rewiring(self, data, batch_ads, batch_cat): + z = data["adsorbate"].atomic_numbers.long() + + # Use periodic boundary conditions + results = [] + if self.use_pbc: + assert z.dim() == 1 and z.dtype == torch.long + + for mode in ["adsorbate", "catalyst"]: + out = get_pbc_distances( + data[mode].pos, + data[mode, "is_close", mode].edge_index, + data[mode].cell, + data[mode].cell_offsets, + data[mode].neighbors, + return_distance_vec=True, + ) + + edge_index = out["edge_index"] + edge_weight = out["distances"] + rel_pos = out["distance_vec"] + if mode == "adsorbate": + distance_expansion = self.distance_expansion_ads + else: + distance_expansion = self.distance_expansion_cat + edge_attr = distance_expansion(edge_weight) + results.append([edge_index, edge_weight, rel_pos, edge_attr]) + else: + for mode in ["adsorbate", "catalyst"]: + edge_index = radius_graph( + data[mode].pos, + r=self.cutoff, + batch=batch_ads if mode == "adsorbate" else batch_cat, + max_num_neighbors=self.max_num_neighbors, + ) + # edge_index = data.edge_index + row, col = edge_index + rel_pos = data[mode].pos[row] - data[mode].pos[col] + edge_weight = rel_pos.norm(dim=-1) + if mode == "adsorbate": + distance_expansion = self.distance_expansion_ads + else: + distance_expansion = self.distance_expansion_cat + edge_attr = distance_expansion(edge_weight) + results.append([edge_index, edge_weight, rel_pos, edge_attr]) + + return results + + @conditional_grad(torch.enable_grad()) + def forces_forward(self, preds): + pass diff --git a/ocpmodels/models/aschnet.py b/ocpmodels/models/aschnet.py new file mode 100644 index 0000000000..d0dcd5d6d9 --- /dev/null +++ b/ocpmodels/models/aschnet.py @@ -0,0 +1,469 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" +from math import pi as PI + +import torch +import torch.nn.functional as F +from torch.nn import Embedding, Linear, ModuleList, Sequential +from torch import nn +from torch_geometric.nn import MessagePassing, radius_graph +from torch_scatter import scatter + +from ocpmodels.common.registry import registry +from ocpmodels.common.utils import ( + conditional_grad, + get_pbc_distances, + radius_graph_pbc, +) +from ocpmodels.models.base_model import BaseModel +from ocpmodels.models.utils.pos_encodings import PositionalEncoding +from ocpmodels.modules.phys_embeddings import PhysEmbedding +from ocpmodels.modules.pooling import Graclus, Hierarchical_Pooling +from ocpmodels.models.utils.activations import swish +from ocpmodels.models.schnet import ( + InteractionBlock, + CFConv, + GaussianSmearing, + ShiftedSoftplus, +) +from ocpmodels.models.afaenet import GATInteraction + +NUM_CLUSTERS = 20 +NUM_POOLING_LAYERS = 1 + + +@registry.register_model("aschnet") +class ASchNet(BaseModel): + r"""The continuous-filter convolutional neural network SchNet from the + `"SchNet: A Continuous-filter Convolutional Neural Network for Modeling + Quantum Interactions" `_ paper that uses + the interactions blocks of the form + + .. math:: + \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \odot + h_{\mathbf{\Theta}} ( \exp(-\gamma(\mathbf{e}_{j,i} - \mathbf{\mu}))), + + here :math:`h_{\mathbf{\Theta}}` denotes an MLP and + :math:`\mathbf{e}_{j,i}` denotes the interatomic distances between atoms. + + Args: + cutoff (float, optional): Cutoff distance for interatomic interactions. + (default: :obj:`10.0`) + use_pbc (bool, optional): Use of periodic boundary conditions. + (default: true) + otf_graph (bool, optional): Recompute radius graph. + (default: false) + max_num_neighbors (int, optional): The maximum number of neighbors to + collect for each node within the :attr:`cutoff` distance. + (default: :obj:`32`) + graph_rewiring (str, optional): Method used to create the graph, + among "", remove-tag-0, supernodes. + energy_head (str, optional): Method to compute energy prediction + from atom representations. + hidden_channels (int, optional): Hidden embedding size. + (default: :obj:`128`) + tag_hidden_channels (int, optional): Hidden tag embedding size. + (default: :obj:`32`) + pg_hidden_channels (int, optional): Hidden period and group embed size. + (default: obj:`32`) + phys_embed (bool, optional): Concat fixed physics-aware embeddings. + phys_hidden_channels (int, optional): Hidden size of learnable phys embed. + (default: obj:`32`) + num_filters (int, optional): The number of filters to use. + (default: :obj:`128`) + num_interactions (int, optional): The number of interaction blocks. + (default: :obj:`6`) + num_gaussians (int, optional): The number of gaussians :math:`\mu`. + (default: :obj:`50`) + readout (string, optional): Whether to apply :obj:`"add"` or + :obj:`"mean"` global aggregation. (default: :obj:`"add"`) + atomref (torch.Tensor, optional): The reference of single-atom + properties. + Expects a vector of shape :obj:`(max_atomic_number, )`. + """ + + url = "http://www.quantum-machine.org/datasets/trained_schnet_models.zip" + + def __init__(self, **kwargs): + super().__init__() + + import ase + + self.use_pbc = kwargs["use_pbc"] + self.cutoff = kwargs["cutoff"] + self.otf_graph = kwargs["otf_graph"] + self.scale = None + self.regress_forces = kwargs["regress_forces"] + + self.num_filters = kwargs["num_filters"] + self.num_interactions = kwargs["num_interactions"] + self.num_gaussians = kwargs["num_gaussians"] + self.max_num_neighbors = kwargs["max_num_neighbors"] + self.readout = kwargs["readout"] + self.hidden_channels = kwargs["hidden_channels"] + self.tag_hidden_channels = kwargs["tag_hidden_channels"] + self.use_tag = self.tag_hidden_channels > 0 + self.pg_hidden_channels = kwargs["pg_hidden_channels"] + self.use_pg = self.pg_hidden_channels > 0 + self.phys_hidden_channels = kwargs["phys_hidden_channels"] + self.energy_head = kwargs["energy_head"] + self.use_phys_embeddings = kwargs["phys_embeds"] + self.use_mlp_phys = self.phys_hidden_channels > 0 and kwargs["phys_embeds"] + self.use_positional_embeds = kwargs["graph_rewiring"] in { + "one-supernode-per-graph", + "one-supernode-per-atom-type", + "one-supernode-per-atom-type-dist", + } + + self.register_buffer( + "initial_atomref", + torch.tensor(kwargs["atomref"]) if kwargs["atomref"] is not None else None, + ) + self.atomref = None + if kwargs["atomref"] is not None: + self.atomref = Embedding(100, 1) + self.atomref.weight.data.copy_(torch.tensor(kwargs["atomref"])) + + atomic_mass = torch.from_numpy(ase.data.atomic_masses) + # self.covalent_radii = torch.from_numpy(ase.data.covalent_radii) + # self.vdw_radii = torch.from_numpy(ase.data.vdw_radii) + self.register_buffer("atomic_mass", atomic_mass) + + if self.use_tag: + self.tag_embedding = Embedding(3, self.tag_hidden_channels) + + # Phys embeddings + self.phys_emb = PhysEmbedding(props=kwargs["phys_embeds"], pg=self.use_pg) + if self.use_mlp_phys: + self.phys_lin = Linear( + self.phys_emb.n_properties, self.phys_hidden_channels + ) + else: + self.phys_hidden_channels = self.phys_emb.n_properties + + # Period + group embeddings + if self.use_pg: + self.period_embedding = Embedding( + self.phys_emb.period_size, self.pg_hidden_channels + ) + self.group_embedding = Embedding( + self.phys_emb.group_size, self.pg_hidden_channels + ) + + assert ( + self.tag_hidden_channels + + 2 * self.pg_hidden_channels + + self.phys_hidden_channels + < self.hidden_channels + ) + + # Main embedding + self.embedding_ads = Embedding( + 85, + self.hidden_channels + - self.tag_hidden_channels + - self.phys_hidden_channels + - 2 * self.pg_hidden_channels, + ) + self.embedding_cat = Embedding( + 85, + self.hidden_channels + - self.tag_hidden_channels + - self.phys_hidden_channels + - 2 * self.pg_hidden_channels, + ) + + # Gaussian basis and linear transformation of disc edges + self.distance_expansion_disc = GaussianSmearing(0.0, 20.0, self.num_gaussians) + self.disc_edge_embed = Linear(self.num_gaussians, self.num_filters) + + # Position encoding + if self.use_positional_embeds: + self.pe = PositionalEncoding(self.hidden_channels, 210) + + # Interaction block + self.distance_expansion = GaussianSmearing(0.0, self.cutoff, self.num_gaussians) + + self.interactions_ads = ModuleList() + for _ in range(self.num_interactions): + block = InteractionBlock( + self.hidden_channels, self.num_gaussians, self.num_filters, self.cutoff + ) + self.interactions_ads.append(block) + + self.interactions_cat = ModuleList() + for _ in range(self.num_interactions): + block = InteractionBlock( + self.hidden_channels, self.num_gaussians, self.num_filters, self.cutoff + ) + self.interactions_cat.append(block) + + self.interactions_disc = ModuleList() + assert ( + "gat_mode" in kwargs + ), "GAT version needs to be specified. Options: v1, v2" + for _ in range(self.num_interactions): + block = GATInteraction( + self.hidden_channels, kwargs["gat_mode"], self.num_filters + ) + self.interactions_disc.append(block) + + # Output block + self.lin1_ads = Linear(self.hidden_channels, self.hidden_channels // 2) + self.lin1_cat = Linear(self.hidden_channels, self.hidden_channels // 2) + self.act = ShiftedSoftplus() + self.lin2_ads = Linear(self.hidden_channels // 2, self.hidden_channels // 2) + self.lin2_cat = Linear(self.hidden_channels // 2, self.hidden_channels // 2) + + # weighted average & pooling + if self.energy_head in {"pooling", "random"}: + self.hierarchical_pooling = Hierarchical_Pooling( + self.hidden_channels, + self.act, + NUM_POOLING_LAYERS, + NUM_CLUSTERS, + self.energy_head, + ) + elif self.energy_head == "graclus": + self.graclus = Graclus(self.hidden_channels, self.act) + elif self.energy_head in { + "weighted-av-initial-embeds", + "weighted-av-final-embeds", + }: + self.w_lin = Linear(self.hidden_channels, 1) + + self.combination = nn.Sequential( + Linear(self.hidden_channels, self.hidden_channels // 2), + swish, + Linear(kwargs["hidden_channels"] // 2, 1), + ) + + self.reset_parameters() + + def reset_parameters(self): + self.embedding_ads.reset_parameters() + self.embedding_cat.reset_parameters() + if self.use_mlp_phys: + torch.nn.init.xavier_uniform_(self.phys_lin.weight) + if self.use_tag: + self.tag_embedding.reset_parameters() + if self.use_pg: + self.period_embedding.reset_parameters() + self.group_embedding.reset_parameters() + if self.energy_head in {"weighted-av-init-embeds", "weighted-av-final-embeds"}: + self.w_lin.bias.data.fill_(0) + torch.nn.init.xavier_uniform_(self.w_lin.weight) + for interaction_ads, interaction_cat, interaction_disc in zip( + self.interactions_ads, self.interactions_cat, self.interactions_disc + ): + interaction_ads.reset_parameters() + interaction_cat.reset_parameters() + # interaction_disc.reset_parameters() # need to implement this! + torch.nn.init.xavier_uniform_(self.lin1_ads.weight) + self.lin1_ads.bias.data.fill_(0) + torch.nn.init.xavier_uniform_(self.lin2_ads.weight) + self.lin2_ads.bias.data.fill_(0) + torch.nn.init.xavier_uniform_(self.lin1_cat.weight) + self.lin1_cat.bias.data.fill_(0) + torch.nn.init.xavier_uniform_(self.lin2_cat.weight) + self.lin2_cat.bias.data.fill_(0) + if self.atomref is not None: + self.atomref.weight.data.copy_(self.initial_atomref) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"hidden_channels={self.hidden_channels}, " + f"tag_hidden_channels={self.tag_hidden_channels}, " + f"properties={self.phys_hidden_channels}, " + f"period_hidden_channels={self.pg_hidden_channels}, " + f"group_hidden_channels={self.pg_hidden_channels}, " + f"energy_head={self.energy_head}", + f"num_filters={self.num_filters}, " + f"num_interactions={self.num_interactions}, " + f"num_gaussians={self.num_gaussians}, " + f"cutoff={self.cutoff})", + ) + + @conditional_grad(torch.enable_grad()) + def forces_forward(self, preds): + return + + @conditional_grad(torch.enable_grad()) + def energy_forward(self, data): + """""" + # Re-compute on the fly the graph + if self.otf_graph: + edge_index, cell_offsets, neighbors = radius_graph_pbc_inputs( + data["adsorbate"].pos, + data["adsorbate"].natoms, + data["adsorbate"].cell, + self.cutoff, + 50, + ) + data["adsorbate", "is_close", "adsorbate"].edge_index = edge_index + data["adsorbate"].cell_offsets = cell_offsets + data["adsorbate"].neighbors = neighbors + + edge_index, cell_offsets, neighbors = radius_graph_pbc_inputs( + data["catalyst"].pos, + data["catalyst"].natoms, + data["catalyst"].cell, + self.cutoff, + 50, + ) + data["catalyst", "is_close", "catalyst"].edge_index = edge_index + data["catalyst"].cell_offsets = cell_offsets + data["catalyst"].neighbors = neighbors + + # Rewire the graph + # Use periodic boundary conditions + ads_rewiring, cat_rewiring = self.graph_rewiring( + data, + ) + edge_index_ads, edge_weight_ads, edge_attr_ads = ads_rewiring + edge_index_cat, edge_weight_cat, edge_attr_cat = cat_rewiring + + h_ads = self.embedding_ads(data["adsorbate"].atomic_numbers.long()) + h_cat = self.embedding_cat(data["catalyst"].atomic_numbers.long()) + + edge_weights_disc = self.distance_expansion_disc(data["is_disc"].edge_weight) + edge_weights_disc = self.disc_edge_embed(edge_weights_disc) + + if self.use_tag: # NOT IMPLEMENTED + assert data["adsorbate"].tags is not None + h_tag = self.tag_embedding(data.tags) + h = torch.cat((h, h_tag), dim=1) + + if self.phys_emb.device != data["adsorbate"].batch.device: # NOT IMPLEMENTED + self.phys_emb = self.phys_emb.to(data["adsorbate"].batch.device) + + if self.use_phys_embeddings: # NOT IMPLEMENTED + h_phys = self.phys_emb.properties[z] + if self.use_mlp_phys: + h_phys = self.phys_lin(h_phys) + h = torch.cat((h, h_phys), dim=1) + + if self.use_pg: # NOT IMPLEMENTED + # assert self.phys_emb.period is not None + h_period = self.period_embedding(self.phys_emb.period[z]) + h_group = self.group_embedding(self.phys_emb.group[z]) + h = torch.cat((h, h_period, h_group), dim=1) + + if self.use_positional_embeds: # NOT IMPLEMENTED + idx_of_non_zero_val = (data.tags == 0).nonzero().T.squeeze(0) + h_pos = torch.zeros_like(h, device=h.device) + h_pos[idx_of_non_zero_val, :] = self.pe(data.subnodes).to( + device=h_pos.device + ) + h += h_pos + + if self.energy_head == "weighted-av-initial-embeds": + alpha = self.w_lin(h) + + for interaction_ads, interaction_cat, interaction_disc in zip( + self.interactions_ads, self.interactions_cat, self.interactions_disc + ): + intra_ads = interaction_ads( + h_ads, edge_index_ads, edge_weight_ads, edge_attr_ads + ) + intra_cat = interaction_cat( + h_cat, edge_index_cat, edge_weight_cat, edge_attr_cat + ) + inter_ads, inter_cat = interaction_disc( + intra_ads, intra_cat, data["is_disc"].edge_index, edge_weights_disc + ) + h_ads, h_cat = h_ads + inter_ads, h_cat + inter_cat + h_ads, h_cat = nn.functional.normalize(h_ads), nn.functional.normalize( + h_cat + ) + + pooling_loss = None # deal with pooling loss + + if self.energy_head == "weighted-av-final-embeds": # NOT IMPLEMENTED + alpha = self.w_lin(h) + + elif self.energy_head == "graclus": + h, batch = self.graclus( + h, edge_index, edge_weight, batch + ) # NOT IMPLEMENTED + + if self.energy_head in {"pooling", "random"}: # NOT IMPLEMENTED + h, batch, pooling_loss = self.hierarchical_pooling( + h, edge_index, edge_weight, batch + ) + + # MLP + h_ads = self.lin1_ads(h_ads) + h_ads = self.act(h_ads) + h_ads = self.lin2_ads(h_ads) + + h_cat = self.lin1_cat(h_cat) + h_cat = self.act(h_cat) + h_cat = self.lin2_cat(h_cat) + + if self.energy_head in { # NOT IMPLEMENTED + "weighted-av-initial-embeds", + "weighted-av-final-embeds", + }: + h = h * alpha + + if self.atomref is not None: # NOT IMPLEMENTED + h = h + self.atomref(z) + + # Global pooling + out_ads = self.scattering(h_ads, data["adsorbate"].batch) + out_cat = self.scattering(h_cat, data["catalyst"].batch) + + if self.scale is not None: + out = self.scale * out + + system = torch.concat([out_ads, out_cat], dim=1) + out = self.combination(system) + + return { + "energy": out, + "pooling_loss": pooling_loss, + } + + @conditional_grad(torch.enable_grad()) + def graph_rewiring(self, data): + results = [] + + if self.use_pbc: + for mode in ["adsorbate", "catalyst"]: + out = get_pbc_distances( + data[mode].pos, + data[mode, "is_close", mode].edge_index, + data[mode].cell, + data[mode].cell_offsets, + data[mode].neighbors, + return_distance_vec=True, + ) + + edge_index = out["edge_index"] + edge_weight = out["distances"] + edge_attr = self.distance_expansion(edge_weight) + results.append([edge_index, edge_weight, edge_attr]) + else: + for mode in ["adsorbate", "catalyst"]: + edge_index = radius_graph( + data[mode].pos, + r=self.cutoff, + batch=data[mode].batch, + max_num_neighbors=self.max_num_neighbors, + ) + row, col = edge_index + edge_weight = (pos[row] - pos[col]).norm(dim=-1) + edge_attr = self.distance_expansion(edge_weight) + results.append([edge_index, edge_weight, edge_attr]) + + return results + + @conditional_grad(torch.enable_grad()) + def scattering(self, h, batch): + return scatter(h, batch, dim=0, reduce=self.readout) diff --git a/ocpmodels/models/base_model.py b/ocpmodels/models/base_model.py index 240cc4a7b0..8f92000072 100644 --- a/ocpmodels/models/base_model.py +++ b/ocpmodels/models/base_model.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn +from torch_geometric.data import HeteroData from torch_geometric.nn import radius_graph from ocpmodels.common.utils import ( @@ -39,12 +40,19 @@ def energy_forward(self, data): def forces_forward(self, preds): raise NotImplementedError - def forward(self, data, mode="train"): + def forward(self, data, mode="train", **kwargs): grad_forces = forces = None # energy gradient w.r.t. positions will be computed if mode == "train" or self.regress_forces == "from_energy": - data.pos.requires_grad_(True) + if type(data) is list: + data[0].pos.requires_grad_(True) + data[1].pos.requires_grad_(True) + elif type(data[0]) is HeteroData: + data["adsorbate"].pos.requires_grad_(True) + data["catalyst"].pos.requires_grad_(True) + else: + data.pos.requires_grad_(True) # predict energy preds = self.energy_forward(data) @@ -63,7 +71,12 @@ def forward(self, data, mode="train"): grad_forces = forces else: # compute forces from energy gradient - grad_forces = self.forces_as_energy_grad(data.pos, preds["energy"]) + try: + grad_forces = self.forces_as_energy_grad( + data.pos, preds["energy"] + ) + except: + grad_forces = self.forces_as_energy_grad(data["adsorbate"].pos) if self.regress_forces == "from_energy": # predicted forces are the energy gradient diff --git a/ocpmodels/models/depdpp.py b/ocpmodels/models/depdpp.py new file mode 100644 index 0000000000..2b47114880 --- /dev/null +++ b/ocpmodels/models/depdpp.py @@ -0,0 +1,52 @@ +import torch +from torch import nn +from torch.nn import Linear +from torch_scatter import scatter + +from ocpmodels.models.dimenet_plus_plus import DimeNetPlusPlus +from ocpmodels.common.registry import registry +from ocpmodels.common.utils import conditional_grad +from ocpmodels.models.utils.activations import swish + +from torch_geometric.data import Batch + + +@registry.register_model("depdpp") +class depSchNet(DimeNetPlusPlus): + def __init__(self, **kwargs): + self.hidden_channels = kwargs["hidden_channels"] + + kwargs["num_targets"] = kwargs["hidden_channels"] // 2 + super().__init__(**kwargs) + + self.act = swish + self.combination = nn.Sequential( + Linear(self.hidden_channels // 2 * 2, self.hidden_channels // 2), + self.act, + Linear(self.hidden_channels // 2, 1), + ) + + @conditional_grad(torch.enable_grad()) + def energy_forward(self, data): + # We need to save the tags so this step is necessary. + self.tags_saver(data.tags) + pred = super().energy_forward(data) + + return pred + + def tags_saver(self, tags): + self.current_tags = tags + + @conditional_grad(torch.enable_grad()) + def scattering(self, batch, h, P_bis): + ads = self.current_tags == 2 + cat = ~ads + + ads_out = scatter(h, batch * ads, dim=0) + cat_out = scatter(h, batch * cat, dim=0) + + system = torch.cat([ads_out, cat_out], dim=1) + system = self.combination(system) + system = system + P_bis + + return system diff --git a/ocpmodels/models/depfaenet.py b/ocpmodels/models/depfaenet.py new file mode 100644 index 0000000000..25f6a09683 --- /dev/null +++ b/ocpmodels/models/depfaenet.py @@ -0,0 +1,97 @@ +import torch +from torch.nn import Linear +from torch import nn +from torch_scatter import scatter + +from ocpmodels.models.faenet import FAENet +from ocpmodels.models.faenet import OutputBlock as conOutputBlock +from ocpmodels.common.registry import registry +from ocpmodels.common.utils import conditional_grad +from ocpmodels.models.utils.activations import swish + +from torch_geometric.data import Batch + + +class discOutputBlock(conOutputBlock): + def __init__(self, energy_head, hidden_channels, act, disconnected_mlp=False): + super(discOutputBlock, self).__init__(energy_head, hidden_channels, act) + + # We modify the last output linear function to make the output a vector + self.lin2 = Linear(hidden_channels // 2, hidden_channels // 2) + + self.disconnected_mlp = disconnected_mlp + if self.disconnected_mlp: + self.ads_lin = Linear(hidden_channels // 2, hidden_channels // 2) + self.cat_lin = Linear(hidden_channels // 2, hidden_channels // 2) + + # Combines the hidden representation of each to a scalar. + self.combination = nn.Sequential( + Linear(hidden_channels // 2 * 2, hidden_channels // 2), + swish, + Linear(hidden_channels // 2, 1), + ) + + def tags_saver(self, tags): + self.current_tags = tags + + def forward(self, h, edge_index, edge_weight, batch, alpha): + if ( + self.energy_head == "weighted-av-final-embeds" + ): # Right now, this is the only available option. + alpha = self.w_lin(h) + + elif self.energy_head == "graclus": + h, batch = self.graclus(h, edge_index, edge_weight, batch) + + elif self.energy_head in {"pooling", "random"}: + h, batch, pooling_loss = self.hierarchical_pooling( + h, edge_index, edge_weight, batch + ) + + # MLP + h = self.lin1(h) + h = self.lin2(self.act(h)) + + if self.energy_head in { + "weighted-av-initial-embeds", + "weighted-av-final-embeds", + }: + h = h * alpha + + # We pool separately and then we concatenate. + ads = self.current_tags == 2 + cat = ~ads + + ads_out = scatter(h, batch * ads, dim=0, reduce="add") + cat_out = scatter(h, batch * cat, dim=0, reduce="add") + + if self.disconnected_mlp: + ads_out = self.ads_lin(ads_out) + cat_out = self.cat_lin(cat_out) + + system = torch.cat([ads_out, cat_out], dim=1) + + # Finally, we predict a number. + energy = self.combination(system) + + return energy + + +@registry.register_model("depfaenet") +class depFAENet(FAENet): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # We replace the old output block by the new output block + self.disconnected_mlp = kwargs.get("disconnected_mlp", False) + self.output_block = discOutputBlock( + self.energy_head, kwargs["hidden_channels"], self.act, self.disconnected_mlp + ) + + @conditional_grad(torch.enable_grad()) + def energy_forward(self, data): + # We need to save the tags so this step is necessary. + self.output_block.tags_saver(data.tags) + pred = super().energy_forward(data) + + return pred diff --git a/ocpmodels/models/depschnet.py b/ocpmodels/models/depschnet.py new file mode 100644 index 0000000000..69f9adf83c --- /dev/null +++ b/ocpmodels/models/depschnet.py @@ -0,0 +1,53 @@ +import torch +from torch import nn +from torch.nn import Linear +from torch_scatter import scatter + +from ocpmodels.models.schnet import SchNet +from ocpmodels.models.faenet import OutputBlock as conOutputBlock +from ocpmodels.common.registry import registry +from ocpmodels.common.utils import conditional_grad +from ocpmodels.models.utils.activations import swish + +from torch_geometric.data import Batch + + +@registry.register_model("depschnet") +class depSchNet(SchNet): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # We replace the last linear transform to keep dimentionality + self.lin2 = Linear(self.hidden_channels // 2, self.hidden_channels // 2) + torch.nn.init.xavier_uniform_(self.lin2.weight) + self.lin2.bias.data.fill_(0) + + self.combination = nn.Sequential( + Linear(self.hidden_channels // 2 * 2, self.hidden_channels // 2), + swish, + Linear(self.hidden_channels // 2, 1), + ) + + @conditional_grad(torch.enable_grad()) + def energy_forward(self, data): + # We need to save the tags so this step is necessary. + self.tags_saver(data.tags) + pred = super().energy_forward(data) + + return pred + + def tags_saver(self, tags): + self.current_tags = tags + + @conditional_grad(torch.enable_grad()) + def scattering(self, h, batch): + ads = self.current_tags == 2 + cat = ~ads + + ads_out = scatter(h, batch * ads, dim=0, reduce=self.readout) + cat_out = scatter(h, batch * cat, dim=0, reduce=self.readout) + + system = torch.cat([ads_out, cat_out], dim=1) + system = self.combination(system) + + return system diff --git a/ocpmodels/models/dimenet_plus_plus.py b/ocpmodels/models/dimenet_plus_plus.py index c3dec2989b..03570b3faa 100644 --- a/ocpmodels/models/dimenet_plus_plus.py +++ b/ocpmodels/models/dimenet_plus_plus.py @@ -726,14 +726,19 @@ def energy_forward(self, data): # Output # scatter - energy = P.sum(dim=0) if batch is None else scatter(P, batch, dim=0) - energy = energy + P_bis + energy = self.scattering(batch, P, P_bis) return { "energy": energy, "pooling_loss": pooling_loss, } + def scattering(self, batch, P, P_bis): + energy = P.sum(dim=0) if batch is None else scatter(P, batch, dim=0) + energy = energy + P_bis + + return energy + @conditional_grad(torch.enable_grad()) def forces_forward(self, preds): return diff --git a/ocpmodels/models/faenet.py b/ocpmodels/models/faenet.py index 17fd05754d..cccaf24304 100644 --- a/ocpmodels/models/faenet.py +++ b/ocpmodels/models/faenet.py @@ -153,7 +153,6 @@ def forward( self, z, rel_pos, edge_attr, tag=None, normalised_rel_pos=None, subnodes=None ): # --- Edge embedding -- - if self.edge_embed_type == "rij": e = self.lin_e1(rel_pos) elif self.edge_embed_type == "all_rij": @@ -397,13 +396,19 @@ def message(self, x_j, W, local_env=None): class OutputBlock(nn.Module): - def __init__(self, energy_head, hidden_channels, act): + def __init__(self, energy_head, hidden_channels, act, model_name="faenet"): super().__init__() self.energy_head = energy_head self.act = act self.lin1 = Linear(hidden_channels, hidden_channels // 2) - self.lin2 = Linear(hidden_channels // 2, 1) + if model_name == "faenet": + self.lin2 = Linear(hidden_channels // 2, 1) + elif model_name in { + "indfaenet", + "afaenet", + }: # These are models that output more than one scalar. + self.lin2 = Linear(hidden_channels // 2, hidden_channels // 2) # weighted average & pooling if self.energy_head in {"pooling", "random"}: @@ -527,6 +532,7 @@ def __init__(self, **kwargs): "one-supernode-per-atom-type", "one-supernode-per-atom-type-dist", } + # Gaussian Basis self.distance_expansion = GaussianSmearing( 0.0, self.cutoff, kwargs["num_gaussians"] @@ -565,7 +571,7 @@ def __init__(self, **kwargs): # Output block self.output_block = OutputBlock( - self.energy_head, kwargs["hidden_channels"], self.act + self.energy_head, kwargs["hidden_channels"], self.act, kwargs["model_name"] ) # Energy head @@ -586,7 +592,13 @@ def __init__(self, **kwargs): # Skip co if self.skip_co == "concat": - self.mlp_skip_co = Linear((kwargs["num_interactions"] + 1), 1) + if kwargs["model_name"] in ["faenet", "depfaenet"]: + self.mlp_skip_co = Linear((kwargs["num_interactions"] + 1), 1) + elif kwargs["model_name"] == "indfaenet": + self.mlp_skip_co = Linear( + (kwargs["num_interactions"] + 1) * kwargs["hidden_channels"] // 2, + kwargs["hidden_channels"] // 2, + ) elif self.skip_co == "concat_atom": self.mlp_skip_co = Linear( ((kwargs["num_interactions"] + 1) * kwargs["hidden_channels"]), @@ -600,6 +612,7 @@ def forces_forward(self, preds): @conditional_grad(torch.enable_grad()) def energy_forward(self, data): # Rewire the graph + z = data.atomic_numbers.long() pos = data.pos batch = data.batch diff --git a/ocpmodels/models/gemnet/depgemnet_t.py b/ocpmodels/models/gemnet/depgemnet_t.py new file mode 100644 index 0000000000..8782d1aac4 --- /dev/null +++ b/ocpmodels/models/gemnet/depgemnet_t.py @@ -0,0 +1,46 @@ +import torch +from torch.nn import Linear +from torch_scatter import scatter + +from ocpmodels.models.gemnet.gemnet import GemNetT +from ocpmodels.common.registry import registry +from ocpmodels.common.utils import conditional_grad, scatter_det + +from torch_geometric.data import Batch + + +@registry.register_model("depgemnet_t") +class depGemNetT(GemNetT): + def __init__(self, **kwargs): + self.hidden_channels = kwargs["emb_size_atom"] + + kwargs["num_targets"] = self.hidden_channels // 2 + super().__init__(**kwargs) + + self.sys_lin1 = Linear(self.hidden_channels // 2 * 2, self.hidden_channels // 2) + self.sys_lin2 = Linear(self.hidden_channels // 2, 1) + + @conditional_grad(torch.enable_grad()) + def energy_forward(self, data): + # We need to save the tags so this step is necessary. + self.tags_saver(data.tags) + pred = super().energy_forward(data) + + return pred + + def tags_saver(self, tags): + self.current_tags = tags + + @conditional_grad(torch.enable_grad()) + def scattering(self, E_t, batch, dim, dim_size, reduce="add"): + ads = self.current_tags == 2 + cat = ~ads + + ads_out = scatter_det(src=E_t, index=batch * ads, dim=dim, reduce=reduce) + cat_out = scatter_det(src=E_t, index=batch * cat, dim=dim, reduce=reduce) + + system = torch.cat([ads_out, cat_out], dim=1) + system = self.sys_lin1(system) + system = self.sys_lin2(system) + + return system diff --git a/ocpmodels/models/gemnet/gemnet.py b/ocpmodels/models/gemnet/gemnet.py index a1b1d4f28b..10bb2837c8 100644 --- a/ocpmodels/models/gemnet/gemnet.py +++ b/ocpmodels/models/gemnet/gemnet.py @@ -414,7 +414,7 @@ def select_edges( edge_vector = edge_vector[edge_mask] empty_image = neighbors == 0 - if torch.any(empty_image): + if torch.any(empty_image) and "mode" not in data.keys: raise ValueError( f"An image has no neighbors: id={data.id[empty_image]}, " f"sid={data.sid[empty_image]}, fid={data.fid[empty_image]}" @@ -473,7 +473,13 @@ def generate_interaction_graph(self, data): select_cutoff = None else: select_cutoff = self.cutoff - (edge_index, cell_offsets, neighbors, D_st, V_st,) = self.select_edges( + ( + edge_index, + cell_offsets, + neighbors, + D_st, + V_st, + ) = self.select_edges( data=data, edge_index=edge_index, cell_offsets=cell_offsets, @@ -582,11 +588,11 @@ def energy_forward(self, data): nMolecules = torch.max(batch) + 1 if self.extensive: - E_t = scatter( + E_t = self.scattering( E_t, batch, dim=0, dim_size=nMolecules, reduce="add" ) # (nMolecules, num_targets) else: - E_t = scatter( + E_t = self.scattering( E_t, batch, dim=0, dim_size=nMolecules, reduce="mean" ) # (nMolecules, num_targets) @@ -599,6 +605,10 @@ def energy_forward(self, data): "pos": pos, } + def scattering(self, E_t, batch, dim, dim_size, reduce="add"): + E_t = scatter(E_t, batch, dim=0, dim_size=dim_size, reduce=reduce) + return E_t + @conditional_grad(torch.enable_grad()) def forces_forward(self, preds): F_st = preds["F_st"] diff --git a/ocpmodels/models/gemnet/indgemnet_t.py b/ocpmodels/models/gemnet/indgemnet_t.py new file mode 100644 index 0000000000..2453541d91 --- /dev/null +++ b/ocpmodels/models/gemnet/indgemnet_t.py @@ -0,0 +1,64 @@ +import torch, math +from torch import nn +from torch.nn import Linear + +from ocpmodels.models.gemnet.gemnet import GemNetT +from ocpmodels.models.base_model import BaseModel +from ocpmodels.common.registry import registry +from ocpmodels.models.utils.activations import swish + +from torch_geometric.data import Batch + + +@registry.register_model("indgemnet_t") +class indGemNetT(BaseModel): # Change to make it inherit from base model. + def __init__(self, **kwargs): + super().__init__() + + self.regress_forces = kwargs["regress_forces"] + + kwargs["num_targets"] = kwargs["emb_size_atom"] // 2 + + self.ads_model = GemNetT(**kwargs) + self.cat_model = GemNetT(**kwargs) + + self.act = swish + self.combination = nn.Sequential( + Linear(kwargs["emb_size_atom"] // 2 * 2, kwargs["emb_size_atom"] // 2), + self.act, + Linear(kwargs["emb_size_atom"] // 2, 1), + ) + + def energy_forward( + self, data, mode="train" + ): # PROBLEM TO FIX: THE PREDICTION IS BY AN AVERAGE! + import ipdb + + ipdb.set_trace() + + adsorbates = data[0] + catalysts = data[1] + + # We make predictions for each + pred_ads = self.ads_model(adsorbates, mode) + pred_cat = self.cat_model(catalysts, mode) + + ads_energy = pred_ads["energy"] + cat_energy = pred_cat["energy"] + + # We combine predictions + system_energy = torch.cat([ads_energy, cat_energy], dim=1) + system_energy = self.combination(system_energy) + + # We return them + pred_system = { + "energy": system_energy, + "E_t": pred_ads["E_t"], + "idx_t": pred_ads["idx_t"], + "main_graph": pred_ads["main_graph"], + "num_atoms": pred_ads["num_atoms"], + "pos": pred_ads["pos"], + "F_st": pred_ads["F_st"], + } + + return pred_system diff --git a/ocpmodels/models/gemnet_oc/agemnet_oc.py b/ocpmodels/models/gemnet_oc/agemnet_oc.py new file mode 100644 index 0000000000..faf556e183 --- /dev/null +++ b/ocpmodels/models/gemnet_oc/agemnet_oc.py @@ -0,0 +1,117 @@ +import torch, math +from torch import nn +from torch.nn import Linear +from torch_geometric.data import Data, Batch + +from ocpmodels.models.gemnet_oc.gemnet_oc import GemNetOC +from ocpmodels.models.base_model import BaseModel +from ocpmodels.common.registry import registry +from ocpmodels.models.utils.activations import swish + +from torch_geometric.data import Batch + + +@registry.register_model("agemnet_oc") +class aGemNetOC(BaseModel): # Change to make it inherit from base model. + def __init__(self, **kwargs): + super().__init__() + + self.regress_forces = kwargs["regress_forces"] + self.direct_forces = kwargs["direct_forces"] + + self.regress_forces = kwargs["regress_forces"] + + kwargs["num_targets"] = kwargs["emb_size_atom"] // 2 + + self.ads_model = GemNetOC(**kwargs) + self.cat_model = GemNetOC(**kwargs) + + self.act = swish + self.combination = nn.Sequential( + Linear(kwargs["emb_size_atom"] // 2 * 2, kwargs["emb_size_atom"] // 2), + self.act, + Linear(kwargs["emb_size_atom"] // 2, 1), + ) + + def energy_forward( + self, data, mode="train" + ): # PROBLEM TO FIX: THE PREDICTION IS BY AN AVERAGE! + import ipdb + + ipdb.set_trace() + + bip_edges = data["is_disc"].edge_index + bip_weights = data["is_disc"].edge_weight + + adsorbates, catalysts = [], [] + for i in range(len(data)): + adsorbates.append( + Data( + **data[i]["adsorbate"]._mapping, + edge_index=data[i]["adsorbate", "is_close", "adsorbate"] + ) + ) + catalyst.append( + Data( + **data[i]["catalyst"]._mapping, + edge_index=data[i]["catalyst", "is_close", "catalyst"] + ) + ) + del data + adsorbates = Batch.from_data_list(adsorbates) + catalysts = Batch.from_data_list(catalysts) + + # We make predictions for each + pos_ads = adsorbates.pos + batch_ads = adsorbates.batch + atomic_numbers_ads = adsorbates.atomic_numbers.long() + num_atoms_ads = adsorbates.shape[0] + + pos_cat = catalysts.pos + batch_cat = catalysts.batch + atomic_numbers_cat = catalysts.atomic_numbers.long() + num_atoms_cat = catalysts.shape[0] + + if self.regress_forces and not self.direct_forces: + pos_ads.requires_grad_(True) + pos_cat.requires_grad_(True) + + output_ads = self.ads_model.pre_interaction( + pos_ads, batch_ads, atomic_numbers_ads, num_atoms_ads, adsorbates + ) + output_cat = self.cat_model.pre_interaction( + pos_cat, batch_cat, atomic_numbers_cat, num_atoms_cat, catalysts + ) + + inter_outputs_ads, inter_outputs_cat = self.interactions(output_ads, output_cat) + + ads_energy = pred_ads["energy"] + cat_energy = pred_cat["energy"] + + # We combine predictions + system_energy = torch.cat([ads_energy, cat_energy], dim=1) + system_energy = self.combination(system_energy) + + # We return them + pred_system = { + "energy": system_energy, + "pooling_loss": pred_ads["pooling_loss"] + if pred_ads["pooling_loss"] is None + else pred_ads["pooling_loss"] + pred_cat["pooling_loss"], + } + + return pred_system + + def interactions(self, output_ads, output_cat): + h_ads, m_ads = output_ads["h"], output_ads["m"] + h_cat, m_cat = output_cat["h"], output_cat["m"] + del output_ads["h"] + del output_ads["m"] + del output_cat["h"] + del output_cat["m"] + + # basis_output_ads, idx + + return 1, 2 + + # GOT UP TO HERE. I NEED TO DO INTERACTIONS. HERE. diff --git a/ocpmodels/models/gemnet_oc/depgemnet_oc.py b/ocpmodels/models/gemnet_oc/depgemnet_oc.py new file mode 100644 index 0000000000..741e1b7cb7 --- /dev/null +++ b/ocpmodels/models/gemnet_oc/depgemnet_oc.py @@ -0,0 +1,46 @@ +import torch +from torch.nn import Linear +from torch_scatter import scatter + +from ocpmodels.models.gemnet_oc.gemnet_oc import GemNetOC +from ocpmodels.common.registry import registry +from ocpmodels.common.utils import conditional_grad, scatter_det + +from torch_geometric.data import Batch + + +@registry.register_model("depgemnet_oc") +class depGemNetOC(GemNetOC): + def __init__(self, **kwargs): + self.hidden_channels = kwargs["emb_size_atom"] + + kwargs["num_targets"] = self.hidden_channels // 2 + super().__init__(**kwargs) + + self.sys_lin1 = Linear(self.hidden_channels // 2 * 2, self.hidden_channels // 2) + self.sys_lin2 = Linear(self.hidden_channels // 2, 1) + + @conditional_grad(torch.enable_grad()) + def energy_forward(self, data): + # We need to save the tags so this step is necessary. + self.tags_saver(data.tags) + pred = super().energy_forward(data) + + return pred + + def tags_saver(self, tags): + self.current_tags = tags + + @conditional_grad(torch.enable_grad()) + def scattering(self, E_t, batch, dim, dim_size, reduce="add"): + ads = self.current_tags == 2 + cat = ~ads + + ads_out = scatter_det(src=E_t, index=batch * ads, dim=dim, reduce=reduce) + cat_out = scatter_det(src=E_t, index=batch * cat, dim=dim, reduce=reduce) + + system = torch.cat([ads_out, cat_out], dim=1) + system = self.sys_lin1(system) + system = self.sys_lin2(system) + + return system diff --git a/ocpmodels/models/gemnet_oc/gemnet_oc.py b/ocpmodels/models/gemnet_oc/gemnet_oc.py index da00442486..fdcd882b6f 100644 --- a/ocpmodels/models/gemnet_oc/gemnet_oc.py +++ b/ocpmodels/models/gemnet_oc/gemnet_oc.py @@ -358,6 +358,7 @@ def __init__( for _ in range(num_global_out_layers) ] self.out_mlp_E = torch.nn.Sequential(*out_mlp_E) + self.out_energy = Dense(emb_size_atom, num_targets, bias=False, activation=None) if direct_forces: out_mlp_F = [ @@ -861,6 +862,9 @@ def subselect_edges( empty_image = subgraph["num_neighbors"] == 0 if torch.any(empty_image): + import ipdb + + ipdb.set_trace() raise ValueError( f"An image has no neighbors: id={data.id[empty_image]}, " f"sid={data.sid[empty_image]}, fid={data.fid[empty_image]}" @@ -1212,6 +1216,101 @@ def energy_forward(self, data): if self.regress_forces and not self.direct_forces: pos.requires_grad_(True) + outputs = self.pre_interaction(pos, batch, atomic_numbers, num_atoms, data) + + # h, m, basis_output, idx_t, x_E, x_F, xs_E, xs_F + interaction_outputs = self.interactions(outputs) + + E_t, idx_t, F_st = self.post_interactions( + batch=batch, + **interaction_outputs, + ) + + return { + "energy": E_t.squeeze(1), # (num_molecules) + "E_t": E_t, + "idx_t": idx_t, + "main_graph": outputs["main_graph"], + "num_atoms": num_atoms, + "pos": pos, + "F_st": F_st, + } + + def post_interactions(self, h, m, basis_output, idx_t, x_E, x_F, xs_E, xs_F, batch): + # Global output block for final predictions + x_E = self.out_mlp_E(torch.cat(xs_E, dim=-1)) + if self.direct_forces: + x_F = self.out_mlp_F(torch.cat(xs_F, dim=-1)) + with torch.cuda.amp.autocast(False): + E_t = self.out_energy(x_E.float()) + if self.direct_forces: + F_st = self.out_forces(x_F.float()) + + nMolecules = torch.max(batch) + 1 + + if self.extensive: + E_t = self.scattering( + E_t, batch, dim=0, dim_size=nMolecules, reduce="add" + ) # (nMolecules, num_targets) + else: + E_t = self.scattering( + E_t, batch, dim=0, dim_size=nMolecules, reduce="mean" + ) # (nMolecules, num_targets) + + return E_t, idx_t, F_st + + def interactions(self, outputs): + h, m = outputs["h"], outputs["m"] + # del outputs["h"]; del outputs["m"] + + basis_output, idx_t = outputs["basis_output"], outputs["idx_t"] + # del outputs["basis_output"]; del outputs["idx_t"] + + x_E, x_F = outputs["x_E"], outputs["x_F"] + # del outputs["x_E"]; outputs["x_F"] + + xs_E, xs_F = outputs["xs_E"], outputs["xs_F"] + # del outputs["xs_E"]; del outputs["xs_F"] + + for i in range(self.num_blocks): + # Interaction block + h, m = self.int_blocks[i]( + h=h, + m=m, + bases_qint=outputs["bases_qint"], + bases_e2e=outputs["bases_e2e"], + bases_a2e=outputs["bases_a2e"], + bases_e2a=outputs["bases_e2a"], + basis_a2a_rad=outputs["basis_a2a_rad"], + basis_atom_update=outputs["basis_atom_update"], + edge_index_main=outputs["main_graph"]["edge_index"], + a2ee2a_graph=outputs["a2ee2a_graph"], + a2a_graph=outputs["a2a_graph"], + id_swap=outputs["id_swap"], + trip_idx_e2e=outputs["trip_idx_e2e"], + trip_idx_a2e=outputs["trip_idx_a2e"], + trip_idx_e2a=outputs["trip_idx_e2a"], + quad_idx=outputs["quad_idx"], + ) # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) + + x_E, x_F = self.out_blocks[i + 1](h, m, basis_output, idx_t) + # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) + xs_E.append(x_E) + xs_F.append(x_F) + + interaction_outputs = { + "h": h, + "m": m, + "basis_output": basis_output, + "idx_t": idx_t, + "x_E": x_E, + "x_F": x_F, + "xs_E": xs_E, + "xs_F": xs_F, + } + return interaction_outputs + + def pre_interaction(self, pos, batch, atomic_numbers, num_atoms, data): ( main_graph, a2a_graph, @@ -1256,64 +1355,44 @@ def energy_forward(self, data): # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) xs_E, xs_F = [x_E], [x_F] - for i in range(self.num_blocks): - # Interaction block - h, m = self.int_blocks[i]( - h=h, - m=m, - bases_qint=bases_qint, - bases_e2e=bases_e2e, - bases_a2e=bases_a2e, - bases_e2a=bases_e2a, - basis_a2a_rad=basis_a2a_rad, - basis_atom_update=basis_atom_update, - edge_index_main=main_graph["edge_index"], - a2ee2a_graph=a2ee2a_graph, - a2a_graph=a2a_graph, - id_swap=id_swap, - trip_idx_e2e=trip_idx_e2e, - trip_idx_a2e=trip_idx_a2e, - trip_idx_e2a=trip_idx_e2a, - quad_idx=quad_idx, - ) # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) + outputs = { + "main_graph": main_graph, + "a2a_graph": a2a_graph, + "a2ee2a_graph": a2ee2a_graph, + "id_swap": id_swap, + "trip_idx_e2e": trip_idx_e2e, + "trip_idx_a2e": trip_idx_a2e, + "trip_idx_e2a": trip_idx_e2a, + "quad_idx": quad_idx, + "idx_t": idx_t, + "basis_rad_raw": basis_rad_raw, + "basis_atom_update": basis_atom_update, + "basis_output": basis_output, + "bases_qint": bases_qint, + "bases_e2e": bases_e2e, + "bases_a2e": bases_a2e, + "bases_e2a": bases_e2a, + "basis_a2a_rad": basis_a2a_rad, + "h": h, + "m": m, + "x_E": x_E, + "x_F": x_F, + "xs_E": xs_E, + "xs_F": xs_F, + } - x_E, x_F = self.out_blocks[i + 1](h, m, basis_output, idx_t) - # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) - xs_E.append(x_E) - xs_F.append(x_F) + return outputs - # Global output block for final predictions - x_E = self.out_mlp_E(torch.cat(xs_E, dim=-1)) - if self.direct_forces: - x_F = self.out_mlp_F(torch.cat(xs_F, dim=-1)) - with torch.cuda.amp.autocast(False): - E_t = self.out_energy(x_E.float()) - if self.direct_forces: - F_st = self.out_forces(x_F.float()) - - nMolecules = torch.max(batch) + 1 - if self.extensive: - E_t = scatter_det( - E_t, batch, dim=0, dim_size=nMolecules, reduce="add" - ) # (nMolecules, num_targets) - else: - E_t = scatter_det( - E_t, batch, dim=0, dim_size=nMolecules, reduce="mean" - ) # (nMolecules, num_targets) + @conditional_grad(torch.enable_grad()) + def scattering(self, E_t, batch, dim, dim_size, reduce="add"): + E_t = scatter_det( + src=E_t, index=batch, dim=dim, dim_size=dim_size, reduce=reduce + ) - return { - "energy": E_t.squeeze(1), # (num_molecules) - "E_t": E_t, - "idx_t": idx_t, - "main_graph": main_graph, - "num_atoms": num_atoms, - "pos": pos, - "F_st": F_st, - } + return E_t @conditional_grad(torch.enable_grad()) def forces_forward(self, preds): - idx_t = preds["idx_t"] main_graph = preds["main_graph"] num_atoms = preds["num_atoms"] diff --git a/ocpmodels/models/gemnet_oc/indgemnet_oc.py b/ocpmodels/models/gemnet_oc/indgemnet_oc.py new file mode 100644 index 0000000000..b2f5ff7f4b --- /dev/null +++ b/ocpmodels/models/gemnet_oc/indgemnet_oc.py @@ -0,0 +1,64 @@ +import torch, math +from torch import nn +from torch.nn import Linear + +from ocpmodels.models.gemnet_oc.gemnet_oc import GemNetOC +from ocpmodels.models.base_model import BaseModel +from ocpmodels.common.registry import registry +from ocpmodels.models.utils.activations import swish + +from torch_geometric.data import Batch + + +@registry.register_model("indgemnet_oc") +class indGemNetOC(BaseModel): # Change to make it inherit from base model. + def __init__(self, **kwargs): + super().__init__() + + self.regress_forces = kwargs["regress_forces"] + + kwargs["num_targets"] = kwargs["emb_size_atom"] // 2 + + self.ads_model = GemNetOC(**kwargs) + self.cat_model = GemNetOC(**kwargs) + + self.act = swish + self.combination = nn.Sequential( + Linear(kwargs["emb_size_atom"] // 2 * 2, kwargs["emb_size_atom"] // 2), + self.act, + Linear(kwargs["emb_size_atom"] // 2, 1), + ) + + def energy_forward( + self, data, mode="train" + ): # PROBLEM TO FIX: THE PREDICTION IS BY AN AVERAGE! + import ipdb + + ipdb.set_trace() + + adsorbates = data[0] + catalysts = data[1] + + # We make predictions for each + pred_ads = self.ads_model(adsorbates, mode) + pred_cat = self.cat_model(catalysts, mode) + + ads_energy = pred_ads["energy"] + cat_energy = pred_cat["energy"] + + # We combine predictions + system_energy = torch.cat([ads_energy, cat_energy], dim=1) + system_energy = self.combination(system_energy) + + # We return them + pred_system = { + "energy": system_energy, + "E_t": pred_ads["E_t"], + "idx_t": pred_ads["idx_t"], + "main_graph": pred_ads["main_graph"], + "num_atoms": pred_ads["num_atoms"], + "pos": pred_ads["pos"], + "F_st": pred_ads["F_st"], + } + + return pred_system diff --git a/ocpmodels/models/inddpp.py b/ocpmodels/models/inddpp.py new file mode 100644 index 0000000000..a5130424f6 --- /dev/null +++ b/ocpmodels/models/inddpp.py @@ -0,0 +1,69 @@ +import torch, math +from torch import nn +from torch.nn import Linear, Transformer + +from ocpmodels.models.dimenet_plus_plus import DimeNetPlusPlus +from ocpmodels.models.base_model import BaseModel +from ocpmodels.common.registry import registry +from ocpmodels.models.utils.activations import swish + +from torch_geometric.data import Batch + + +@registry.register_model("inddpp") +class indDimeNetPlusPlus(BaseModel): # Change to make it inherit from base model. + def __init__(self, **kwargs): + super().__init__() + + self.regress_forces = kwargs["regress_forces"] + kwargs["num_targets"] = kwargs["hidden_channels"] // 2 + + self.cat_model = DimeNetPlusPlus(**kwargs) + + old_hc = kwargs["hidden_channels"] + old_sphr = kwargs["num_spherical"] + old_radi = kwargs["num_radial"] + old_out_emb = kwargs["out_emb_channels"] + old_targets = kwargs["num_targets"] + + kwargs["hidden_channels"] = kwargs["hidden_channels"] // 2 + kwargs["num_spherical"] = kwargs["num_spherical"] // 2 + kwargs["num_radial"] = kwargs["num_radial"] // 2 + kwargs["out_emb_channesl"] = kwargs["out_emb_channels"] // 2 + kwargs["num_targets"] = kwargs["num_targets"] // 2 + + self.ads_model = DimeNetPlusPlus(**kwargs) + + self.act = swish + self.combination = nn.Sequential( + Linear(kwargs["num_targets"] + old_targets, kwargs["num_targets"] // 2), + self.act, + Linear(kwargs["num_targets"] // 2, 1), + ) + + def energy_forward( + self, data, mode="train" + ): # PROBLEM TO FIX: THE PREDICTION IS BY AN AVERAGE! + adsorbates = data[0] + catalysts = data[1] + + # We make predictions for each + pred_ads = self.ads_model(adsorbates, mode) + pred_cat = self.cat_model(catalysts, mode) + + ads_energy = pred_ads["energy"] + cat_energy = pred_cat["energy"] + + # We combine predictions + system_energy = torch.cat([ads_energy, cat_energy], dim=1) + system_energy = self.combination(system_energy) + + # We return them + pred_system = { + "energy": system_energy, + "pooling_loss": pred_ads["pooling_loss"] + if pred_ads["pooling_loss"] is None + else pred_ads["pooling_loss"] + pred_cat["pooling_loss"], + } + + return pred_system diff --git a/ocpmodels/models/indfaenet.py b/ocpmodels/models/indfaenet.py new file mode 100644 index 0000000000..56d27ee680 --- /dev/null +++ b/ocpmodels/models/indfaenet.py @@ -0,0 +1,141 @@ +import torch, math +from torch import nn +from torch.nn import Linear, Transformer + +from ocpmodels.models.faenet import FAENet +from ocpmodels.models.faenet import OutputBlock +from ocpmodels.models.base_model import BaseModel +from ocpmodels.common.registry import registry +from ocpmodels.models.utils.activations import swish + +from torch_geometric.data import Batch + + +# Implementation of positional encoding obtained from Harvard's annotated transformer's guide +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, max_len=5): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, x): + x = x + self.pe[:, : x.size(1)].requires_grad_(False) + return self.dropout(x) + + +@registry.register_model("indfaenet") +class indFAENet(BaseModel): # Change to make it inherit from base model. + def __init__(self, **kwargs): + super(indFAENet, self).__init__() + + self.regress_forces = kwargs["regress_forces"] + + old_hc = kwargs["hidden_channels"] + old_gaus = kwargs["num_gaussians"] + old_filt = kwargs["num_filters"] + + self.cat_model = FAENet(**kwargs) + + kwargs["hidden_channels"] = kwargs["hidden_channels"] // 2 + kwargs["num_gaussians"] = kwargs["num_gaussians"] // 2 + kwargs["num_filters"] = kwargs["num_filters"] // 2 + + self.ads_model = FAENet(**kwargs) + + self.act = ( + getattr(nn.functional, kwargs["act"]) if kwargs["act"] != "swish" else swish + ) + + self.disconnected_mlp = kwargs.get("disconnected_mlp", False) + if self.disconnected_mlp: + self.ads_lin = Linear( + kwargs["hidden_channels"] // 2, kwargs["hidden_channels"] // 2 + ) + self.cat_lin = Linear( + kwargs["hidden_channels"] // 2, kwargs["hidden_channels"] // 2 + ) + + self.transformer_out = kwargs.get("transformer_out", False) + if self.transformer_out: + self.combination = Transformer( + d_model=kwargs["hidden_channels"] // 2, + nhead=2, + num_encoder_layers=2, + num_decoder_layers=2, + dim_feedforward=kwargs["hidden_channels"], + batch_first=True, + ) + self.positional_encoding = PositionalEncoding( + kwargs["hidden_channels"] // 2, + dropout=0.1, + max_len=5, + ) + self.query_pos = nn.Parameter(torch.rand(kwargs["hidden_channels"] // 2)) + self.transformer_lin = Linear(kwargs["hidden_channels"] // 2, 1) + else: + self.combination = nn.Sequential( + Linear( + kwargs["hidden_channels"] // 2 + old_hc // 2, + kwargs["hidden_channels"] // 2, + ), + self.act, + Linear(kwargs["hidden_channels"] // 2, 1), + ) + + def energy_forward( + self, data, mode="train" + ): # PROBLEM TO FIX: THE PREDICTION IS BY AN AVERAGE! + adsorbates = data[0] + catalysts = data[1] + + # We make predictions for each + pred_ads = self.ads_model(adsorbates, mode) + pred_cat = self.cat_model(catalysts, mode) + + ads_energy = pred_ads["energy"] + cat_energy = pred_cat["energy"] + if self.disconnected_mlp: + ads_energy = self.ads_lin(ads_energy) + cat_energy = self.cat_lin(cat_energy) + + # We combine predictions + if self.transformer_out: + batch_size = ads_energy.shape[0] + + fake_target_sequence = ( + self.query_pos.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) + ) + system_energy = torch.cat( + [ads_energy.unsqueeze(1), cat_energy.unsqueeze(1)], dim=1 + ) + + system_energy = self.positional_encoding(system_energy) + + system_energy = self.combination( + system_energy, fake_target_sequence + ).squeeze(1) + system_energy = self.transformer_lin(system_energy) + else: + system_energy = torch.cat([ads_energy, cat_energy], dim=1) + system_energy = self.combination(system_energy) + + # We return them + pred_system = { + "energy": system_energy, + "pooling_loss": pred_ads["pooling_loss"] + if pred_ads["pooling_loss"] is None + else pred_ads["pooling_loss"] + pred_cat["pooling_loss"], + "hidden_state": pred_ads["hidden_state"], + } + + return pred_system diff --git a/ocpmodels/models/indschnet.py b/ocpmodels/models/indschnet.py new file mode 100644 index 0000000000..16df76945e --- /dev/null +++ b/ocpmodels/models/indschnet.py @@ -0,0 +1,136 @@ +import torch, math +from torch import nn +from torch.nn import Linear, Transformer + +from ocpmodels.models.schnet import SchNet +from ocpmodels.models.base_model import BaseModel +from ocpmodels.common.registry import registry +from ocpmodels.models.utils.activations import swish + +from torch_geometric.data import Batch + + +# Implementation of positional encoding obtained from Harvard's annotated transformer's guide +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, max_len=5): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, x): + x = x + self.pe[:, : x.size(1)].requires_grad_(False) + return self.dropout(x) + + +@registry.register_model("indschnet") +class indSchNet(BaseModel): # Change to make it inherit from base model. + def __init__(self, **kwargs): + super(indSchNet, self).__init__() + + self.regress_forces = kwargs["regress_forces"] + + self.cat_model = SchNet(**kwargs) + + old_filt = kwargs["num_filters"] + old_gaus = kwargs["num_gaussians"] + old_hc = kwargs["hidden_channels"] + + kwargs["num_filters"] = kwargs["num_filters"] // 2 + kwargs["num_gaussians"] = kwargs["num_gaussians"] // 2 + kwargs["hidden_channels"] = kwargs["hidden_channels"] // 2 + + self.ads_model = SchNet(**kwargs) + + self.disconnected_mlp = kwargs.get("disconnected_mlp", False) + if self.disconnected_mlp: + self.ads_lin = Linear( + kwargs["hidden_channels"] // 2, kwargs["hidden_channels"] // 2 + ) + self.cat_lin = Linear( + kwargs["hidden_channels"] // 2, kwargs["hidden_channels"] // 2 + ) + + self.transformer_out = kwargs.get("transformer_out", False) + self.act = swish + if self.transformer_out: + self.combination = Transformer( + d_model=kwargs["hidden_channels"] // 2, + nhead=2, + num_encoder_layers=2, + num_decoder_layers=2, + dim_feedforward=kwargs["hidden_channels"], + batch_first=True, + ) + self.positional_encoding = PositionalEncoding( + kwargs["hidden_channels"] // 2, + dropout=0.1, + max_len=5, + ) + self.query_pos = nn.Parameter(torch.rand(kwargs["hidden_channels"] // 2)) + self.transformer_lin = Linear(kwargs["hidden_channels"] // 2, 1) + else: + self.combination = nn.Sequential( + Linear( + kwargs["hidden_channels"] // 2 + old_hc // 2, + kwargs["hidden_channels"] // 2, + ), + self.act, + Linear(kwargs["hidden_channels"] // 2, 1), + ) + + def energy_forward( + self, data, mode="train" + ): # PROBLEM TO FIX: THE PREDICTION IS BY AN AVERAGE! + adsorbates = data[0] + catalysts = data[1] + + # We make predictions for each + pred_ads = self.ads_model(adsorbates, mode) + pred_cat = self.cat_model(catalysts, mode) + + ads_energy = pred_ads["energy"] + cat_energy = pred_cat["energy"] + if self.disconnected_mlp: + ads_energy = self.ads_lin(ads_energy) + cat_energy = self.cat_lin(cat_energy) + + # We combine predictions + if self.transformer_out: + batch_size = ads_energy.shape[0] + + fake_target_sequence = ( + self.query_pos.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) + ) + system_energy = torch.cat( + [ads_energy.unsqueeze(1), cat_energy.unsqueeze(1)], dim=1 + ) + + system_energy = self.positional_encoding(system_energy) + + system_energy = self.combination( + system_energy, fake_target_sequence + ).squeeze(1) + system_energy = self.transformer_lin(system_energy) + else: + system_energy = torch.cat([ads_energy, cat_energy], dim=1) + system_energy = self.combination(system_energy) + + # We return them + pred_system = { + "energy": system_energy, + "pooling_loss": pred_ads["pooling_loss"] + if pred_ads["pooling_loss"] is None + else pred_ads["pooling_loss"] + pred_cat["pooling_loss"], + } + + return pred_system diff --git a/ocpmodels/models/painn.py b/ocpmodels/models/painn.py index 8b2f5d45c3..0d60ef6e0c 100644 --- a/ocpmodels/models/painn.py +++ b/ocpmodels/models/painn.py @@ -612,6 +612,10 @@ def forces_forward(self, preds): @conditional_grad(torch.enable_grad()) def energy_forward(self, data): + import ipdb + + ipdb.set_trace() + pos = data.pos batch = data.batch z = data.atomic_numbers.long() diff --git a/ocpmodels/models/schnet.py b/ocpmodels/models/schnet.py index 063968473f..60528d4681 100644 --- a/ocpmodels/models/schnet.py +++ b/ocpmodels/models/schnet.py @@ -258,7 +258,10 @@ def __init__(self, **kwargs): # Output block self.lin1 = Linear(self.hidden_channels, self.hidden_channels // 2) self.act = ShiftedSoftplus() - self.lin2 = Linear(self.hidden_channels // 2, 1) + if kwargs["model_name"] in ["indschnet"]: + self.lin2 = Linear(self.hidden_channels // 2, self.hidden_channels // 2) + else: + self.lin2 = Linear(self.hidden_channels // 2, 1) # weighted average & pooling if self.energy_head in {"pooling", "random"}: @@ -427,7 +430,7 @@ def energy_forward(self, data): h = h + self.atomref(z) # Global pooling - out = scatter(h, batch, dim=0, reduce=self.readout) + out = self.scattering(h, batch) if self.scale is not None: out = self.scale * out @@ -436,3 +439,7 @@ def energy_forward(self, data): "energy": out, "pooling_loss": pooling_loss, } + + @conditional_grad(torch.enable_grad()) + def scattering(self, h, batch): + return scatter(h, batch, dim=0, reduce=self.readout) diff --git a/ocpmodels/preprocessing/graph_rewiring.py b/ocpmodels/preprocessing/graph_rewiring.py index 2f3b103a6c..b9115e9077 100644 --- a/ocpmodels/preprocessing/graph_rewiring.py +++ b/ocpmodels/preprocessing/graph_rewiring.py @@ -36,6 +36,11 @@ def remove_tag0_nodes(data): data.tags = data.tags[non_sub] if hasattr(data, "pos_relaxed"): data.pos_relaxed = data.pos_relaxed[non_sub, :] + if hasattr(data, "query"): + data.h = data.h[non_sub, :] + data.query = data.query[non_sub, :] + data.key = data.key[non_sub, :] + data.value = data.value[non_sub, :] # per-edge tensors data.edge_index = data.edge_index[:, neither_is_sub] diff --git a/ocpmodels/trainers/base_trainer.py b/ocpmodels/trainers/base_trainer.py index 43f52d4a25..900f1f2008 100644 --- a/ocpmodels/trainers/base_trainer.py +++ b/ocpmodels/trainers/base_trainer.py @@ -8,11 +8,13 @@ import errno import logging import os +import pickle import random import time from abc import ABC, abstractmethod from collections import defaultdict from copy import deepcopy +from uuid import uuid4 import numpy as np import torch @@ -22,10 +24,10 @@ from rich.console import Console from rich.table import Table from torch.nn.parallel.distributed import DistributedDataParallel -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Subset from torch_geometric.data import Batch from tqdm import tqdm -from uuid import uuid4 + from ocpmodels.common import dist_utils from ocpmodels.common.data_parallel import ( BalancedBatchSampler, @@ -35,7 +37,12 @@ from ocpmodels.common.graph_transforms import RandomReflect, RandomRotate from ocpmodels.common.registry import registry from ocpmodels.common.timer import Times -from ocpmodels.common.utils import JOB_ID, get_commit_hash, save_checkpoint, resolve +from ocpmodels.common.utils import ( + JOB_ID, + get_commit_hash, + resolve, + save_checkpoint, +) from ocpmodels.datasets.data_transforms import FrameAveraging, get_transforms from ocpmodels.modules.evaluator import Evaluator from ocpmodels.modules.exponential_moving_average import ( @@ -50,6 +57,7 @@ class BaseTrainer(ABC): def __init__(self, **kwargs): run_dir = kwargs["run_dir"] + model_name = kwargs["model"].pop( "name", kwargs.get("model_name", "Unknown - base_trainer issue") ) @@ -150,8 +158,56 @@ def __init__(self, **kwargs): "to stop the training after the next validation\n", ) (run_dir / f"config-{JOB_ID}.yaml").write_text(yaml.dump(self.config)) - self.load() + # Here's the models whose edges are removed as a transform + transform_models = [ + "depfaenet", + "depschnet", + "depgemnet_oc", + "depgemnet_t", + "depdpp", + ] + if self.config["is_disconnected"]: + print("\n\nHeads up: cat-ads edges being removed!") + if self.config["model_name"] in transform_models: + if not self.config["is_disconnected"]: + print( + f"\n\nWhen using {self.config['model_name']},", + "the flag 'is_disconnected' should be used! The flag has been turned on.\n", + ) + self.config["is_disconnected"] = True + + # Here's the models whose graphs are disconnected in the dataset + self.separate_models = [ + "indfaenet", + "indschnet", + "indgemnet_oc", + "indgemnet_t", + "inddpp", + ] + self.heterogeneous_models = [ + "afaenet", + "aschnet", + "agemnet_oc", + "agemnet_t", + "adpp", + ] + self.data_mode = "normal" + self.separate_dataset = False + + if self.config["model_name"] in self.separate_models: + self.data_mode = "separate" + print( + "\n\nHeads up: using separate dataset, so ads/cats are separated before transforms.\n" + ) + + elif self.config["model_name"] in self.heterogeneous_models: + self.data_mode = "heterogeneous" + print( + "\n\nHeads up: using heterogeneous dataset, so ads/cats are stored separately in a het graph.\n" + ) + + self.load() self.evaluator = Evaluator( task=self.task_name, model_regresses_forces=self.config["model"].get("regress_forces", ""), @@ -220,6 +276,7 @@ def get_dataloader(self, dataset, sampler): pin_memory=True, batch_sampler=sampler, ) + return loader def load_datasets(self): @@ -239,9 +296,42 @@ def load_datasets(self): if split == "default_val": continue - self.datasets[split] = registry.get_dataset_class( - self.config["task"]["dataset"] - )(ds_conf, transform=transform) + if self.data_mode == "separate": + self.datasets[split] = registry.get_dataset_class("separate")( + ds_conf, + transform=transform, + adsorbates=self.config.get("adsorbates"), + adsorbates_ref_dir=self.config.get("adsorbates_ref_dir"), + silent=self.silent, + ) + + elif self.data_mode == "heterogeneous": + self.datasets[split] = registry.get_dataset_class("heterogeneous")( + ds_conf, + transform=transform, + adsorbates=self.config.get("adsorbates"), + adsorbates_ref_dir=self.config.get("adsorbates_ref_dir"), + ) + + else: + self.datasets[split] = registry.get_dataset_class( + self.config["task"]["dataset"] + )( + ds_conf, + transform=transform, + adsorbates=self.config.get("adsorbates"), + adsorbates_ref_dir=self.config.get("adsorbates_ref_dir"), + ) + + if self.config["lowest_energy_only"]: + with open( + "/network/scratch/a/alvaro.carbonero/lowest_energy.pkl", "rb" + ) as fp: + good_indices = pickle.load(fp) + good_indices = list(good_indices) + + self.real_dataset = self.datasets["train"] + self.datasets["train"] = Subset(self.datasets["train"], good_indices) shuffle = False if split == "train": @@ -364,6 +454,7 @@ def load_model(self): "task_name": self.task_name, }, **self.config["model"], + "model_name": self.config["model_name"], } self.model = registry.get_model_class(self.config["model_name"])( @@ -1044,7 +1135,7 @@ def measure_inference_time(self, loops=1): self.config["model"].get("regress_forces") == "from_energy" ) self.model.eval() - timer = Times(gpu=True) + timer = Times(gpu=torch.cuda.is_available()) # average inference over multiple loops for _ in range(loops): diff --git a/ocpmodels/trainers/single_trainer.py b/ocpmodels/trainers/single_trainer.py index bd4f32d380..fda4bf54f2 100644 --- a/ocpmodels/trainers/single_trainer.py +++ b/ocpmodels/trainers/single_trainer.py @@ -217,6 +217,8 @@ def train(self, disable_eval_tqdm=True, debug_batches=-1): # Calculate start_epoch from step instead of loading the epoch number # to prevent inconsistencies due to different batch size in checkpoint. + if self.config["continue_from_dir"] is not None and self.config["adsorbates"] not in {None, "all"}: + self.step = 0 start_epoch = self.step // n_train timer = Times() epoch_times = [] @@ -449,6 +451,7 @@ def end_of_training( batch = next(iter(self.loaders[self.config["dataset"]["default_val"]])) self.model_forward(batch) self.logger.log({"Batch time": time.time() - start_time}) + self.logger.log( {"Model run time": model_run_time / len(self.loaders["train"])} ) @@ -456,33 +459,77 @@ def end_of_training( self.logger.log({"Epoch time": np.mean(epoch_times)}) # Check respect of symmetries - if self.test_ri and not is_test_env: - symmetry = self.test_model_symmetries(debug_batches=debug_batches) - if symmetry == "SIGTERM": - return "SIGTERM" - if self.logger: - self.logger.log(symmetry) - if not self.silent: - print(symmetry) + if self.data_mode == "normal": + if self.test_ri and not is_test_env: + symmetry = self.test_model_symmetries(debug_batches=debug_batches) + if symmetry == "SIGTERM": + return "SIGTERM" + if self.logger: + self.logger.log(symmetry) + if not self.silent: + print(symmetry) # Close datasets if debug_batches < 0: for ds in self.datasets.values(): - ds.close_db() + try: + ds.close_db() + except: + assert self.config["lowest_energy_only"] == True + self.real_dataset.close_db() def model_forward(self, batch_list, mode="train"): # Distinguish frame averaging from base case. if self.config["frame_averaging"] and self.config["frame_averaging"] != "DA": - original_pos = batch_list[0].pos - if self.task_name in OCP_TASKS: - original_cell = batch_list[0].cell + if self.data_mode == "heterogeneous": + original_pos_ads = batch_list[0]["adsorbate"].pos + original_pos_cat = batch_list[0]["catalyst"].pos + + if self.task_name in OCP_TASKS: + original_cell = batch_list[0]["catalyst"].cell + + fa_pos_length = len(batch_list[0]["adsorbate"].fa_pos) + elif self.data_mode == "separate": + original_pos_ads = batch_list[0][0].pos + original_pos_cat = batch_list[0][1].pos + + if self.task_name in OCP_TASKS: + original_cell = batch_list[0][1].cell + + fa_pos_length = len(batch_list[0][0].fa_pos) + else: + original_pos = batch_list[0].pos + + if self.task_name in OCP_TASKS: + original_cell = batch_list[0].cell + + fa_pos_length = len(batch_list[0].fa_pos) e_all, p_all, f_all, gt_all = [], [], [], [] # Compute model prediction for each frame - for i in range(len(batch_list[0].fa_pos)): - batch_list[0].pos = batch_list[0].fa_pos[i] - if self.task_name in OCP_TASKS: - batch_list[0].cell = batch_list[0].fa_cell[i] + for i in range(fa_pos_length): + if self.data_mode == "heterogeneous": + batch_list[0]["adsorbate"].pos = batch_list[0]["adsorbate"].fa_pos[ + i + ] + batch_list[0]["catalyst"].pos = batch_list[0]["catalyst"].fa_pos[i] + if self.task_name in OCP_TASKS: + batch_list[0]["adsorbate"].cell = batch_list[0][ + "adsorbate" + ].fa_cell[i] + batch_list[0]["catalyst"].cell = batch_list[0][ + "catalyst" + ].fa_cell[i] + elif self.data_mode == "separate": + batch_list[0][0].pos = batch_list[0][0].fa_pos[i] + batch_list[0][1].pos = batch_list[0][1].fa_pos[i] + if self.task_name in OCP_TASKS: + batch_list[0][0].cell = batch_list[0][0].fa_cell[i] + batch_list[0][1].cell = batch_list[0][1].fa_cell[i] + else: + batch_list[0].pos = batch_list[0].fa_pos[i] + if self.task_name in OCP_TASKS: + batch_list[0].cell = batch_list[0].fa_cell[i] # forward pass preds = self.model(deepcopy(batch_list), mode=mode) @@ -522,9 +569,22 @@ def model_forward(self, batch_list, mode="train"): ) gt_all.append(g_grad_target) - batch_list[0].pos = original_pos - if self.task_name in OCP_TASKS: - batch_list[0].cell = original_cell + if self.data_mode == "heterogeneous": + batch_list[0]["adsorbate"].pos = original_pos_ads + batch_list[0]["catalyst"].pos = original_pos_cat + if self.task_name in OCP_TASKS: + batch_list[0]["adsorbate"].cell = original_cell + batch_list[0]["catalyst"].cell = original_cell + elif self.data_mode == "separate": + batch_list[0][0].pos = original_pos_ads + batch_list[0][1].pos = original_pos_cat + if self.task_name in OCP_TASKS: + batch_list[0][0].cell = original_cell + batch_list[0][1].cell = original_cell + else: + batch_list[0].pos = original_pos + if self.task_name in OCP_TASKS: + batch_list[0].cell = original_cell # Average predictions over frames preds["energy"] = sum(e_all) / len(e_all) @@ -546,15 +606,29 @@ def compute_loss(self, preds, batch_list): loss = {"total_loss": []} # Energy loss - energy_target = torch.cat( - [ - batch.y_relaxed.to(self.device) - if self.task_name == "is2re" - else batch.y.to(self.device) - for batch in batch_list - ], - dim=0, - ) + if self.data_mode == "heterogeneous": + energy_target = torch.cat( + [ + batch["adsorbate"].y_relaxed.to(self.device) + if self.task_name == "is2re" + else batch["adsorbate"].y.to(self.device) + for batch in batch_list + ], + dim=0, + ) + + elif self.data_mode == "separate": + energy_target = batch_list[0][0].y_relaxed.to(self.device) + else: + energy_target = torch.cat( + [ + batch.y_relaxed.to(self.device) + if self.task_name == "is2re" + else batch.y.to(self.device) + for batch in batch_list + ], + dim=0, + ) if self.normalizer.get("normalize_labels", False): hofs = None @@ -650,22 +724,40 @@ def compute_loss(self, preds, batch_list): def compute_metrics( self, preds: Dict, batch_list: List[Data], evaluator: Evaluator, metrics={} ): - natoms = torch.cat( - [batch.natoms.to(self.device) for batch in batch_list], dim=0 - ) + if self.data_mode == "heterogeneous": + natoms = batch_list[0]["adsorbate"].natoms.to(self.device) + batch_list[0][ + "catalyst" + ].natoms.to(self.device) + target = { + "energy": batch_list[0]["adsorbate"].y_relaxed.to(self.device), + "natoms": natoms, + } - target = { - "energy": torch.cat( - [ - batch.y_relaxed.to(self.device) - if self.task_name == "is2re" - else batch.y.to(self.device) - for batch in batch_list - ], - dim=0, - ), - "natoms": natoms, - } + elif self.data_mode == "separate": + natoms = batch_list[0][0].natoms.to(self.device) + batch_list[0][ + 1 + ].natoms.to(self.device) + target = { + "energy": batch_list[0][0].y_relaxed.to(self.device), + "natoms": natoms, + } + + else: + natoms = torch.cat( + [batch.natoms.to(self.device) for batch in batch_list], dim=0 + ) + target = { + "energy": torch.cat( + [ + batch.y_relaxed.to(self.device) + if self.task_name == "is2re" + else batch.y.to(self.device) + for batch in batch_list + ], + dim=0, + ), + "natoms": natoms, + } if self.config["model"].get("regress_forces", False): target["forces"] = torch.cat( diff --git a/scripts/gnn_dev.py b/scripts/gnn_dev.py new file mode 100644 index 0000000000..5041231667 --- /dev/null +++ b/scripts/gnn_dev.py @@ -0,0 +1,56 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" +import sys +import warnings +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from ocpmodels.common.utils import make_script_trainer +from ocpmodels.trainers import SingleTrainer + +if __name__ == "__main__": + config = {} + # Customize args + # config["graph_rewiring"] = "remove-tag-0" + # config["frame_averaging"] = "2D" + # config["fa_method"] = "random" # "random" + # config["test_ri"] = False + # config["optim"] = {"max_epochs": 2} + # config["model"] = {"use_pbc": True} + # config["continue_from_dir"] = "/network/scratch/a/alexandre.duval/ocp/runs/4023244" + + checkpoint_path = None + # "checkpoints/2022-04-28-11-42-56-dimenetplusplus/" + "best_checkpoint.pt" + + str_args = sys.argv[1:] + if all("config" not in arg for arg in str_args): + # str_args.append("--is_debug") + # str_args.append("--config=faenet-is2re-all") + str_args.append("--adsorbates='*O, *OH, *OH2, *H'") + str_args.append("--config=depfaenet-is2re-all") + str_args.append("--continue_from_dir=/network/scratch/a/alexandre.duval/ocp/runs/4023244") + str_args.append("--optim.max_epochs=6") + # str_args.append("--is_disconnected=True") + # str_args.append("--silent=0") + warnings.warn( + "No model / mode is given; chosen as default" + f"Using: {str_args[-1]}" + ) + + trainer: SingleTrainer = make_script_trainer(str_args=str_args, overrides=config) + + trainer.train() + + if checkpoint_path: + trainer.load_checkpoint( + checkpoint_path="checkpoints/2022-04-28-11-42-56-dimenetplusplus/" + + "best_checkpoint.pt" + ) + + predictions = trainer.predict( + trainer.val_loader, results_file="is2re_results", disable_tqdm=False + ) diff --git a/setup.py b/setup.py index bf2daae37e..429f2d433d 100644 --- a/setup.py +++ b/setup.py @@ -5,13 +5,31 @@ LICENSE file in the root directory of this source tree. """ -from setuptools import find_packages, setup +from distutils.util import convert_path +from pathlib import Path + +from setuptools import setup + + +def make_ocpmodels_package_dict(): + dirs = [ + convert_path(str(p)) + for p in Path("./ocpmodels/").glob("**") + if (p / "__init__.py").exists() + ] + pkgs = [d.replace("/", ".") for d in dirs] + return {p: d for p, d in zip(pkgs, dirs)} + + +pkg_dict = make_ocpmodels_package_dict() +pkg_dict["ocdata"] = convert_path("ocdata") setup( - name="ocp-models", - version="0.0.3", + name="ocpmodels", + version="0.0.1", description="Machine learning models for use in catalysis as part of the Open Catalyst Project", url="https://github.com/Open-Catalyst-Project/ocp", - packages=find_packages(), + packages=list(pkg_dict.keys()), + package_dir=pkg_dict, include_package_data=True, )