Skip to content

Just adding a basic pyproject.toml #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions calflops/flops_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def calculate_flops(model,
output_precision=2,
output_unit=None,
ignore_modules=None,
is_sparse=False):
is_sparse=False,
return_output=False,
assume_model_on_device=False,
):
"""Returns the total floating-point operations, MACs, and parameters of a model.

Args:
Expand All @@ -55,6 +58,8 @@ def calculate_flops(model,
output_unit (str, optional): The unit used to output the result value, such as T, G, M, and K. Default is None, that is the unit of the output decide on value.
ignore_modules ([type], optional): the list of modules to ignore during profiling. Defaults to None.
is_sparse (bool, optional): Whether to exclude sparse matrix flops. Defaults to False.
return_output (bool, optional): Whether to return the output of the model, mutually exclusive with output_as_string. Defaults to False.
assume_model_on_device (bool, optional): Whether to assume the model is on the device; if False, the model will be moved to the device. Defaults to False.

Example:
.. code-block:: python
Expand Down Expand Up @@ -108,6 +113,7 @@ def calculate_flops(model,

assert isinstance(model, nn.Module), "model must be a PyTorch module"
# assert transformers_tokenizer and auto_generate_transformers_input and "transformers" in str(type(model)), "The model must be a transformers model if args of auto_generate_transformers_input is True and transformers_tokenizer is not None"
assert not (output_as_string and return_output), "output_as_string and return_output are mutually exclusive"
model.eval()

is_transformer = True if "transformers" in str(type(model)) else False
Expand All @@ -119,7 +125,8 @@ def calculate_flops(model,
calculate_flops_pipline.start_flops_calculate(ignore_list=ignore_modules)

device = next(model.parameters()).device
model = model.to(device)
if not assume_model_on_device:
model = model.to(device)

if input_shape is not None:
assert len(args) == 0 and len(
Expand Down Expand Up @@ -162,9 +169,9 @@ def calculate_flops(model,
args[index] = args[index].to(device)

if forward_mode == 'forward':
_ = model(*args, **kwargs)
model_output = model(*args, **kwargs)
elif forward_mode == 'generate':
_ = model.generate(*args, **kwargs)
model_output = model.generate(*args, **kwargs)
else:
raise NotImplementedError("forward_mode should be either forward or generate")

Expand All @@ -187,5 +194,8 @@ def calculate_flops(model,
return flops_to_string(flops, units=output_unit, precision=output_precision), \
macs_to_string(macs, units=output_unit, precision=output_precision), \
params_to_string(params, units=output_unit, precision=output_precision)

return flops, macs, params

if return_output:
return flops, macs, params, model_output
else:
return flops, macs, params
20 changes: 20 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[project]
name = "calflops"
version = "0.3.3"
description = "A tool to compute FLOPs, MACs, and parameters in various neural networks."
readme = "README.md"
requires-python = ">=3.6"
license = {text = "MIT"}
authors = [
{name = "MrYxJ"}
]
dependencies = [
"torch>=1.0.0",
]

[tool.setuptools]
packages = ["calflops"]