Skip to content

Commit 60c6a01

Browse files
committed
chore: clean up dataset changes
1 parent 0eec2f7 commit 60c6a01

File tree

5 files changed

+138
-113
lines changed

5 files changed

+138
-113
lines changed

lambench/metrics/direct_task_weights.yml

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,14 @@ HEA25_S:
1414
energy_std: 0.4030134901622356
1515
force_std: 1.5479359067976695
1616
virial_std: 1.4293255096528095
17-
HEA25_bulk:
17+
MoS2:
1818
domain: Inorganic Materials
1919
energy_weight: 1.0
2020
force_weight: 1.0
2121
virial_weight: 1.0
22-
energy_std: 0.4086027291354181
23-
force_std: 2.075184012071992
24-
virial_std: 2.065014356039771
25-
HEMC_HEMB:
26-
domain: Inorganic Materials
27-
energy_weight: 1.0
28-
force_weight: 1.0
29-
virial_weight: 1.0
30-
energy_std: 0.4750117425061965
31-
force_std: 1.8089415904253994
32-
virial_std: 0.4589409203427954
22+
energy_std: 0.08333066480136275
23+
force_std: 0.9536237886182164
24+
virial_std: 0.42877076652059987
3325
MD22:
3426
domain: Biomolecules & Supramolecules
3527
energy_weight: 1.0
@@ -118,24 +110,41 @@ Si_ZEO22:
118110
energy_std: 0.03534121167926313
119111
force_std: 1.2410267785352673
120112
virial_std: null
121-
WBM_downsampled:
122-
domain: Inorganic Materials
123-
energy_weight: 1.0
124-
force_weight: null
125-
virial_weight: null
126-
energy_std: 0.3743104865117501
127-
force_std: null
128-
virial_std: null
129-
Subalex_9k:
130-
domain: Inorganic Materials
131-
energy_weight: 1.0
132-
force_weight: 1.0
133-
virial_weight: 1.0
134-
energy_std: 0.7749643377228371
135-
force_std: 1.1503770816187873
136-
virial_std: 0.8678699239404154
113+
137114

138115
## DEPRECATED
116+
# WBM_downsampled:
117+
# domain: Inorganic Materials
118+
# energy_weight: 1.0
119+
# force_weight: null
120+
# virial_weight: null
121+
# energy_std: 0.3743104865117501
122+
# force_std: null
123+
# virial_std: null
124+
# Subalex_9k:
125+
# domain: Inorganic Materials
126+
# energy_weight: 1.0
127+
# force_weight: 1.0
128+
# virial_weight: 1.0
129+
# energy_std: 0.7749643377228371
130+
# force_std: 1.1503770816187873
131+
# virial_std: 0.8678699239404154
132+
# HEA25_bulk:
133+
# domain: Inorganic Materials
134+
# energy_weight: 1.0
135+
# force_weight: 1.0
136+
# virial_weight: 1.0
137+
# energy_std: 0.4086027291354181
138+
# force_std: 2.075184012071992
139+
# virial_std: 2.065014356039771
140+
# HEMC_HEMB:
141+
# domain: Inorganic Materials
142+
# energy_weight: 1.0
143+
# force_weight: 1.0
144+
# virial_weight: 1.0
145+
# energy_std: 0.4750117425061965
146+
# force_std: 1.8089415904253994
147+
# virial_std: 0.4589409203427954
139148
# Torsionnet500:
140149
# domain: Small Molecules
141150
# energy_weight: 1.0
Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,39 @@
11
ANI:
2-
test_data: "/bohr/lambench-ood-pbe-m4bg/v1/ANI"
2+
test_data: "/bohr/lambench-ood-zwtr/v2/LAMBench-TestData-v2/ANI"
33
HEA25_S:
4-
test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/HEA25S"
5-
HEA25_bulk:
6-
test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/HEA25"
7-
HEMC_HEMB:
8-
test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/HEMC_HEMB"
4+
test_data: "/bohr/lambench-ood-zwtr/v2/LAMBench-TestData-v2/HEA25S"
5+
MoS2:
6+
test_data: "/bohr/lambench-ood-zwtr/v2/LAMBench-TestData-v2/MoS2"
97
MD22:
10-
test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/MD22"
8+
test_data: "/bohr/lambench-ood-zwtr/v2/LAMBench-TestData-v2/MD22"
119
Collision:
12-
test_data: "/bohr/lambench-ood-pbe-m4bg/v5/Collision"
10+
test_data: "/bohr/lambench-ood-zwtr/v2/LAMBench-TestData-v2/Collision"
1311
H_nature_2022:
14-
test_data: "/bohr/lambench-ood-pbe-m4bg/v5/H_nature"
12+
test_data: "/bohr/lambench-ood-zwtr/v2/LAMBench-TestData-v2/H_nature"
1513
REANN_CO2_Ni100:
16-
test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/REANN_CO2_Ni100"
14+
test_data: "/bohr/lambench-ood-zwtr/v2/LAMBench-TestData-v2/REANN_CO2_Ni100"
1715
NequIP_NC_2022:
18-
test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/NequIP_NC_2022"
16+
test_data: "/bohr/lambench-ood-zwtr/v2/LAMBench-TestData-v2/NequIP_NC_2022"
1917
AIMD-Chig:
20-
test_data: "/bohr/lambench-ood-pbe-m4bg/v3/AIMD_chig"
18+
test_data: "/bohr/lambench-ood-zwtr/v2/LAMBench-TestData-v2/AIMD_chig"
2119
CGM_MLP_NC2023:
22-
test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/CGM_MLP"
20+
test_data: "/bohr/lambench-ood-zwtr/v2/LAMBench-TestData-v2/CGM_MLP"
2321
Cu_MgO_catalysts:
24-
test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/Cu_MgO_CO2"
25-
Subalex_9k:
26-
test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/subalex_downsample_9k"
27-
WBM_downsampled:
28-
test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/WBM_downsampled"
22+
test_data: "/bohr/lambench-ood-zwtr/v2/LAMBench-TestData-v2/Cu_MgO_CO2"
2923
Si_ZEO22:
30-
test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/Si_ZEO22"
24+
test_data: "/bohr/lambench-ood-zwtr/v2/LAMBench-TestData-v2/Si_ZEO22"
3125
HPt_NC_2022:
32-
test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/HPt_NC2022"
26+
test_data: "/bohr/lambench-ood-zwtr/v2/LAMBench-TestData-v2/HPt_NC2022"
3327
Ca_batteries_CM2021:
34-
test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/Ca_batteries"
28+
test_data: "/bohr/lambench-ood-zwtr/v2/LAMBench-TestData-v2/Ca_batteries"
3529
## DEPRECATED
30+
# Subalex_9k:
31+
# test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/subalex_downsample_9k"
32+
# WBM_downsampled:
33+
# test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/WBM_downsampled"
34+
# HEA25_bulk:
35+
# test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/HEA25"
36+
# HEMC_HEMB:
37+
# test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/HEMC_HEMB"
3638
# Torsionnet500:
3739
# test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/raw_torsionnet500"

tests/metrics/conftest.py

Lines changed: 76 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -40,34 +40,18 @@
4040
DirectPredictRecord(
4141
id=3,
4242
model_name="test_dp",
43-
task_name="HEA25_bulk",
43+
task_name="MoS2",
4444
create_time=None,
45-
energy_rmse=5.76976,
46-
energy_mae=3.89511,
47-
energy_rmse_natoms=0.134269,
48-
energy_mae_natoms=0.0909164,
49-
force_rmse=0.345338,
50-
force_mae=0.209072,
51-
virial_rmse=167.769,
52-
virial_mae=63.4004,
53-
virial_rmse_natoms=3.87829,
54-
virial_mae_natoms=1.47443,
55-
),
56-
DirectPredictRecord(
57-
id=4,
58-
model_name="test_dp",
59-
task_name="HEMC_HEMB",
60-
create_time=None,
61-
energy_rmse=7.87692,
62-
energy_mae=4.38965,
63-
energy_rmse_natoms=0.154871,
64-
energy_mae_natoms=0.0900861,
65-
force_rmse=0.19703,
66-
force_mae=0.121378,
67-
virial_rmse=6.12989,
68-
virial_mae=2.92621,
69-
virial_rmse_natoms=0.127266,
70-
virial_mae_natoms=0.0619396,
45+
energy_rmse=0.500812,
46+
energy_mae=0.390156,
47+
energy_rmse_natoms=0.0238268,
48+
energy_mae_natoms=0.0186542,
49+
force_rmse=0.157006,
50+
force_mae=0.0976945,
51+
virial_rmse=2.74979,
52+
virial_mae=1.52447,
53+
virial_rmse_natoms=0.109107,
54+
virial_mae_natoms=0.0660086,
7155
),
7256
DirectPredictRecord(
7357
id=5,
@@ -197,38 +181,71 @@
197181
virial_rmse_natoms=None,
198182
virial_mae_natoms=None,
199183
),
200-
DirectPredictRecord(
201-
id=16,
202-
model_name="test_dp",
203-
task_name="WBM_downsampled",
204-
create_time=None,
205-
energy_rmse=0.194829,
206-
energy_mae=0.0604359,
207-
energy_rmse_natoms=0.0318466,
208-
energy_mae_natoms=0.00875549,
209-
force_rmse=None,
210-
force_mae=None,
211-
virial_rmse=None,
212-
virial_mae=None,
213-
virial_rmse_natoms=None,
214-
virial_mae_natoms=None,
215-
),
216-
DirectPredictRecord(
217-
id=17,
218-
model_name="test_dp",
219-
task_name="Subalex_9k",
220-
create_time=None,
221-
energy_rmse=1.90841,
222-
energy_mae=0.268596,
223-
energy_rmse_natoms=0.234027,
224-
energy_mae_natoms=0.0286509,
225-
force_rmse=0.624174,
226-
force_mae=0.0437039,
227-
virial_rmse=4.16581,
228-
virial_mae=0.373998,
229-
virial_rmse_natoms=0.382751,
230-
virial_mae_natoms=0.0371473,
231-
),
184+
## Deprecated
185+
# DirectPredictRecord(
186+
# id=3,
187+
# model_name="test_dp",
188+
# task_name="HEA25_bulk",
189+
# create_time=None,
190+
# energy_rmse=5.76976,
191+
# energy_mae=3.89511,
192+
# energy_rmse_natoms=0.134269,
193+
# energy_mae_natoms=0.0909164,
194+
# force_rmse=0.345338,
195+
# force_mae=0.209072,
196+
# virial_rmse=167.769,
197+
# virial_mae=63.4004,
198+
# virial_rmse_natoms=3.87829,
199+
# virial_mae_natoms=1.47443,
200+
# ),
201+
# DirectPredictRecord(
202+
# id=4,
203+
# model_name="test_dp",
204+
# task_name="HEMC_HEMB",
205+
# create_time=None,
206+
# energy_rmse=7.87692,
207+
# energy_mae=4.38965,
208+
# energy_rmse_natoms=0.154871,
209+
# energy_mae_natoms=0.0900861,
210+
# force_rmse=0.19703,
211+
# force_mae=0.121378,
212+
# virial_rmse=6.12989,
213+
# virial_mae=2.92621,
214+
# virial_rmse_natoms=0.127266,
215+
# virial_mae_natoms=0.0619396,
216+
# ),
217+
# DirectPredictRecord(
218+
# id=16,
219+
# model_name="test_dp",
220+
# task_name="WBM_downsampled",
221+
# create_time=None,
222+
# energy_rmse=0.194829,
223+
# energy_mae=0.0604359,
224+
# energy_rmse_natoms=0.0318466,
225+
# energy_mae_natoms=0.00875549,
226+
# force_rmse=None,
227+
# force_mae=None,
228+
# virial_rmse=None,
229+
# virial_mae=None,
230+
# virial_rmse_natoms=None,
231+
# virial_mae_natoms=None,
232+
# ),
233+
# DirectPredictRecord(
234+
# id=17,
235+
# model_name="test_dp",
236+
# task_name="Subalex_9k",
237+
# create_time=None,
238+
# energy_rmse=1.90841,
239+
# energy_mae=0.268596,
240+
# energy_rmse_natoms=0.234027,
241+
# energy_mae_natoms=0.0286509,
242+
# force_rmse=0.624174,
243+
# force_mae=0.0437039,
244+
# virial_rmse=4.16581,
245+
# virial_mae=0.373998,
246+
# virial_rmse_natoms=0.382751,
247+
# virial_mae_natoms=0.0371473,
248+
# ),
232249
]
233250

234251
RECORDS_CALCULATOR = [

tests/metrics/test_post_process.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,7 @@ def test_process_results_for_one_model(
3333
)
3434
assert result["generalizability_force_field_results"]["Weighted"] is None
3535
assert result["generalizability_force_field_results"]["ANI"]["energy_rmse"] == 467.7
36-
assert (
37-
result["generalizability_force_field_results"]["WBM_downsampled"]["force_rmse"]
38-
is None
39-
)
36+
4037
# Find differences between the calculator tasks and results
4138
calculator_task_differences = (
4239
CALCULATOR_TASKS.keys() - {"inference_efficiency", "nve_md"}

tests/metrics/test_visualization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def test_aggregate_ood_results_for_one_model(
1515
aggregator = ResultsFetcher()
1616
result = aggregator.aggregate_ood_results_for_one_model(model=model)
1717
np.testing.assert_almost_equal(result["Small Molecules"], 0.203700545, decimal=5)
18-
np.testing.assert_almost_equal(result["Inorganic Materials"], 0.283787, decimal=5)
18+
np.testing.assert_almost_equal(result["Inorganic Materials"], 0.2931686, decimal=5)
1919
assert result["Catalysis"] is None
2020
with caplog.at_level(logging.WARNING):
2121
assert (

0 commit comments

Comments
 (0)