diff --git a/ocpmodels/models/faenet.py b/ocpmodels/models/faenet.py index dc7dd1efd6..279ef0fda9 100644 --- a/ocpmodels/models/faenet.py +++ b/ocpmodels/models/faenet.py @@ -1,7 +1,8 @@ """ Code of the Scalable Frame Averaging (Rotation Invariant) GNN """ -from typing import Dict, Optional + +from typing import Optional import torch import torch.nn.functional as F @@ -240,7 +241,7 @@ def reset_parameters(self): nn.init.xavier_uniform_(self.lin_h.weight) self.lin_h.bias.data.fill_(0) - def forward(self, h, edge_index, e): + def forward(self, h, edge_index, e, batch=None): # Define edge embedding if self.dropout_lin > 0: @@ -264,7 +265,7 @@ def forward(self, h, edge_index, e): h = self.act(self.lin_down(h)) # downscale node rep. h = self.propagate(edge_index, x=h, W=e) # propagate if self.graph_norm: - h = self.act(self.graph_norm(h)) + h = self.act(self.graph_norm(h, batch=batch)) h = F.dropout( h, p=self.dropout_lin, training=self.training or self.deup_inference ) @@ -279,7 +280,7 @@ def forward(self, h, edge_index, e): e = self.lin_geom(e) h = self.propagate(edge_index, x=h, W=e) # propagate if self.graph_norm: - h = self.act(self.graph_norm(h)) + h = self.act(self.graph_norm(h, batch=batch)) h = torch.cat((h, chi), dim=1) h = F.dropout( h, p=self.dropout_lin, training=self.training or self.deup_inference @@ -289,7 +290,7 @@ def forward(self, h, edge_index, e): elif self.mp_type in {"base", "simple"}: h = self.propagate(edge_index, x=h, W=e) # propagate if self.graph_norm: - h = self.act(self.graph_norm(h)) + h = self.act(self.graph_norm(h, batch=batch)) h = F.dropout( h, p=self.dropout_lin, training=self.training or self.deup_inference ) @@ -739,7 +740,7 @@ def energy_forward(self, data, q=None): self.first_trainable_layer.split("_")[1] ): q = h.clone().detach() - h = h + interaction(h, edge_index, e) + h = h + interaction(h, edge_index, e, batch) # Atom skip-co if self.skip_co == "concat_atom":