Skip to content

Commit d7cf909

Browse files
authored
[Feature] Support pip install (#365)
1 parent d698d66 commit d7cf909

File tree

3 files changed

+54
-1
lines changed

3 files changed

+54
-1
lines changed

docs/README_contribute.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ import graph_net
4949
model = ...
5050

5151
# Extract your own model
52-
model = graph_net.torch.extract(name="model_name", dynamic="True")(model)
52+
model = graph_net.torch.extract(name="model_name", dynamic=True)(model)
5353
```
5454

5555
After running, the extracted graph will be saved to: `$GRAPH_NET_EXTRACT_WORKSPACE/model_name/`.

graph_net/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,21 @@
1+
__all__ = ["torch", "paddle"]
12

3+
from importlib import import_module
4+
from typing import TYPE_CHECKING, Any, List
5+
6+
7+
def __getattr__(name: str) -> Any:
8+
if name in __all__:
9+
module = import_module(f"{__name__}.{name}")
10+
globals()[name] = module
11+
return module
12+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
13+
14+
15+
def __dir__() -> List[str]:
16+
return sorted(list(globals().keys()) + __all__)
17+
18+
19+
if TYPE_CHECKING:
20+
from . import torch as torch # type: ignore
21+
from . import paddle as paddle # type: ignore

pyproject.toml

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
[build-system]
2+
requires = ["hatchling>=1.18", "editables>=0.5"]
3+
build-backend = "hatchling.build"
4+
5+
[project]
6+
name = "graph-net"
7+
version = "0.1.0"
8+
description = "Graph neural network utilities and models."
9+
readme = "README.md"
10+
requires-python = ">=3.8"
11+
license = { file = "LICENSE" }
12+
authors = [
13+
{ name = "GraphNet Authors" }
14+
]
15+
dependencies = []
16+
17+
[project.optional-dependencies]
18+
dev = [
19+
"pre-commit>=3.0",
20+
"astor>=0.8",
21+
]
22+
23+
[tool.hatch.build.targets.wheel]
24+
packages = ["graph_net"]
25+
26+
[tool.hatch.build.targets.sdist]
27+
include = [
28+
"graph_net",
29+
"README.md",
30+
"LICENSE",
31+
]
32+
33+

0 commit comments

Comments
 (0)