-
Notifications
You must be signed in to change notification settings - Fork 30
Open
Description
Hi author,
Right now I download the pretrain weight from Multiview Contrast. I want to load this weight, and then pass a pdb file to the model. In this case, I can get the latent representation of the protein pdb file from the model.
Now I think I can successfully load the released pretrained weight, however, I am stuck at passing the pdb file to the model. Here is how I pass it:
def extract_representation(model, protein_structure_path):
protein = Protein.from_pdb(protein_structure_path)
_protein = Protein.pack([protein])
input_feature = protein.atom2graph
with torch.no_grad():
representation = model(_protein, input_feature)
# Return the output, which is the protein's representation
return representation
I think GearNet model takes two parameter: 1. a Graph object, 2. input (not sure what this is). However, it seems like either _protein
is wrong or input_feature
is wrong. With this code, I am getting this error:
Traceback (most recent call last):
File "/content/GearNet-main/script/test.py", line 62, in <module>
main()
File "/content/GearNet-main/script/test.py", line 58, in main
representation = extract_representation(model, args.pdb)
File "/content/GearNet-main/script/test.py", line 43, in extract_representation
representation = model(_protein, input_feature)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/content/GearNet-main/gearnet/model.py", line 115, in forward
edge_hidden = self.edge_layers[i](line_graph, edge_hidden)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/content/GearNet-main/gearnet/layer.py", line 132, in forward
update = self.aggregate(graph, message)
File "/content/GearNet-main/gearnet/layer.py", line 118, in aggregate
update = update.view(graph.num_node, self.num_relation * self.input_dim)
RuntimeError: shape '[4944, 472]' is invalid for input of size 711936
Can you please help me on how to pass the pdb file to the model.
Thank you
Metadata
Metadata
Assignees
Labels
No labels