Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions extension_cpp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
import torch
from pathlib import Path

so_files = list(Path(__file__).parent.glob("_C*.so"))
assert (
len(so_files) == 1
), f"Expected one _C*.so file, found {len(so_files)}"
torch.ops.load_library(so_files[0])

from . import ops
from . import _C, ops
24 changes: 23 additions & 1 deletion extension_cpp/csrc/muladd.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,29 @@
#include <torch/extension.h>
#include <Python.h>
#include <ATen/Operators.h>
#include <torch/all.h>
#include <torch/library.h>

#include <vector>

extern "C" {
/* Creates a dummy empty _C module that can be imported from Python.
The import from Python will load the .so consisting of this file
in this extension, so that the TORCH_LIBRARY static initializers
below are run. */
PyObject* PyInit__C(void)
{
static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
"_C", /* name of module */
NULL, /* module documentation, may be NULL */
-1, /* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
NULL, /* methods */
};
return PyModule_Create(&module_def);
}
}

namespace extension_cpp {

at::Tensor mymuladd_cpu(const at::Tensor& a, const at::Tensor& b, double c) {
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def get_extensions():
"cxx": [
"-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
"-DPy_LIMITED_API=0x03090000",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a comment here about python 3.9

],
"nvcc": [
"-O3" if not debug_mode else "-O0",
Expand Down