Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import tests.hevm
import vyper.evm.opcodes as evm_opcodes
from tests.dsl.code_model import CodeModel
from tests.evm_backends.base_env import BaseEnv, ExecutionReverted
from tests.evm_backends.pyevm_env import PyEvmEnv
from tests.evm_backends.revm_env import RevmEnv
Expand Down Expand Up @@ -255,6 +256,9 @@ def hevm_marker(request):
@pytest.fixture(scope="module")
def get_contract(env, optimize, output_formats, compiler_settings, hevm, request):
def fn(source_code, *args, **kwargs):
# support CodeModel instances
if isinstance(source_code, CodeModel):
source_code = source_code.build()
if "override_opt_level" in kwargs:
kwargs["compiler_settings"] = Settings(
**dict(compiler_settings.__dict__, optimize=kwargs.pop("override_opt_level"))
Expand Down
53 changes: 53 additions & 0 deletions tests/dsl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
DSL for building Vyper contracts in tests.

Example usage:
from tests.dsl import CodeModel

# create a model
model = CodeModel()

# define storage variables
balance = model.storage_var('balance: uint256')
owner = model.storage_var('owner: address')

# build a simple contract
code = (model
.function('__init__()')
.deploy()
.body(f'{owner} = msg.sender')
.done()
.function('deposit()')
.external()
.payable()
.body(f'{balance} += msg.value')
.done()
.function('get_balance() -> uint256')
.external()
.view()
.body(f'return {balance}')
.done()
.build())

# The generated code will be:
# balance: uint256
# owner: address
#
# @deploy
# def __init__():
# self.owner = msg.sender
#
# @external
# @payable
# def deposit():
# self.balance += msg.value
#
# @external
# @view
# def get_balance() -> uint256:
# return self.balance
"""

from tests.dsl.code_model import CodeModel, VarRef

__all__ = [CodeModel, VarRef]
227 changes: 227 additions & 0 deletions tests/dsl/code_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
"""
Code model for building Vyper contracts programmatically.

This module provides a fluent API for constructing Vyper contracts
with proper formatting and structure.
"""

from __future__ import annotations

import textwrap
from typing import Optional


class VarRef:
"""Reference to a variable with type and location information."""

def __init__(self, name: str, typ: str, location: str, visibility: Optional[str] = None):
self.name = name
self.typ = typ
self.location = location
self.visibility = visibility

def __str__(self) -> str:
"""Return the variable name for use in expressions."""
# storage and transient vars need self prefix
if self.location in ("storage", "transient"):
return f"self.{self.name}"
return self.name


class FunctionBuilder:
"""Builder for function definitions."""

def __init__(self, signature: str, parent: CodeModel):
self.signature = signature
self.parent = parent
self.decorators: list[str] = []
self.body_code: Optional[str] = None
self.is_internal = True # functions are internal by default

# parse just the name from the signature
paren_idx = signature.find("(")
if paren_idx == -1:
raise ValueError(f"Invalid function signature: {signature}")
self.name = signature[:paren_idx].strip()

def __str__(self) -> str:
"""Return the function name for use in expressions."""
if self.is_internal:
return f"self.{self.name}"
return self.name

def external(self) -> FunctionBuilder:
"""Add @external decorator."""
self.decorators.append("@external")
self.is_internal = False
return self

def internal(self) -> FunctionBuilder:
"""Add @internal decorator."""
self.decorators.append("@internal")
self.is_internal = True
return self

def deploy(self) -> FunctionBuilder:
"""Add @deploy decorator."""
self.decorators.append("@deploy")
self.is_internal = False # deploy functions are not called with self
return self

def view(self) -> FunctionBuilder:
"""Add @view decorator."""
self.decorators.append("@view")
return self

def pure(self) -> FunctionBuilder:
"""Add @pure decorator."""
self.decorators.append("@pure")
return self

def payable(self) -> FunctionBuilder:
"""Add @payable decorator."""
self.decorators.append("@payable")
return self

def nonreentrant(self) -> FunctionBuilder:
"""Add @nonreentrant decorator."""
self.decorators.append("@nonreentrant")
return self

def body(self, code: str) -> FunctionBuilder:
"""Set the function body."""
# dedent the code to handle multi-line strings nicely
self.body_code = textwrap.dedent(code).strip()
return self

def done(self) -> CodeModel:
"""Finish building the function and return to parent CodeModel."""
return self.parent


class CodeModel:
"""Model for building a Vyper contract."""

def __init__(self):
self._storage_vars: list[str] = []
self._transient_vars: list[str] = []
self._constants: list[str] = []
self._immutables: list[str] = []
self._events: list[str] = []
self._structs: list[str] = []
self._flags: list[str] = []
self._imports: list[str] = []
self._local_vars: dict[str, VarRef] = {}
self._function_builders: list[FunctionBuilder] = []

def storage_var(self, declaration: str) -> VarRef:
"""Add a storage variable."""
name, typ = self._parse_declaration(declaration)
self._storage_vars.append(declaration)
return VarRef(name, typ, "storage", "public")

def transient_var(self, declaration: str) -> VarRef:
"""Add a transient storage variable."""
name, typ = self._parse_declaration(declaration)
self._transient_vars.append(f"{name}: transient({typ})")
return VarRef(name, typ, "transient", "public")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vars will be public by default?


def constant(self, declaration: str) -> VarRef:
"""Add a constant."""
# constants have format: "NAME: constant(type) = value"
parts = declaration.split(":", 1)
name = parts[0].strip()
# extract type from constant(...) = value
type_start = parts[1].find("constant(") + 9
type_end = parts[1].find(")", type_start)
typ = parts[1][type_start:type_end].strip()

self._constants.append(declaration)
return VarRef(name, typ, "constant", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does the visibility for constants differ?


def immutable(self, declaration: str) -> VarRef:
"""Add an immutable variable."""
name, typ = self._parse_declaration(declaration)
self._immutables.append(f"{name}: immutable({typ})")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for immutables we don't have to declare with "immutable" but for constants "constant" is required?

return VarRef(name, typ, "immutable", "public")

def local_var(self, name: str, typ: str) -> VarRef:
"""Register a local variable (used in function bodies)."""
ref = VarRef(name, typ, "memory", None)
self._local_vars[name] = ref
return ref

def event(self, definition: str) -> None:
"""Add an event definition."""
self._events.append(f"event {definition}")

def struct(self, definition: str) -> None:
"""Add a struct definition."""
self._structs.append(f"struct {definition}")

def flag(self, definition: str) -> None:
"""Add a flag (enum) definition."""
self._flags.append(f"flag {definition}")

def function(self, signature: str) -> FunctionBuilder:
"""Start building a function."""
fb = FunctionBuilder(signature, self)
self._function_builders.append(fb)
return fb

def build(self) -> str:
"""Build the complete contract code."""
sections = []

if self._imports:
sections.append("\n".join(self._imports))

if self._events:
sections.append("\n".join(self._events))

if self._structs:
sections.append("\n".join(self._structs))

if self._flags:
sections.append("\n".join(self._flags))

if self._constants:
sections.append("\n".join(self._constants))

if self._storage_vars:
sections.append("\n".join(self._storage_vars))

if self._transient_vars:
sections.append("\n".join(self._transient_vars))

if self._immutables:
sections.append("\n".join(self._immutables))

if self._function_builders:
function_strings = []
for fb in self._function_builders:
lines = []
lines.extend(fb.decorators)
lines.append(f"def {fb.signature}:")

if fb.body_code:
indented_body = "\n".join(f" {line}" for line in fb.body_code.split("\n"))
lines.append(indented_body)
else:
lines.append(" pass")

function_strings.append("\n".join(lines))

sections.append("\n\n".join(function_strings))

return "\n\n".join(sections)

def _parse_declaration(self, declaration: str) -> tuple[str, str]:
"""Parse a variable declaration of form 'name: type' into (name, type)."""
parts = declaration.split(":", 1)
if len(parts) != 2:
raise ValueError(f"Invalid declaration format: {declaration}")

name = parts[0].strip()
typ = parts[1].strip()
return name, typ
30 changes: 13 additions & 17 deletions tests/functional/codegen/features/test_constructor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import contextlib

import pytest

from tests.dsl import CodeModel
from tests.evm_backends.base_env import _compile
from vyper.exceptions import StackTooDeep
from vyper.utils import method_id
Expand Down Expand Up @@ -296,27 +295,24 @@ def __init__():
I_ADDR = CONST_ADDR
I_BYTES32 = CONST_BYTES32
"""
print(code)
c = get_contract(code)
assert c.I_UINT() == CONST_UINT
assert c.I_ADDR() == CONST_ADDR
assert c.I_BYTES32() == bytes.fromhex(CONST_BYTES32.removeprefix("0x"))


@pytest.mark.parametrize("should_fail", [True, False])
def test_constructor_payability(env, get_contract, tx_failed, should_fail):
code = f"""
@deploy
{"" if should_fail else "@payable"}
def __init__():
pass
"""
@pytest.mark.parametrize("is_payable", [False, True])
def test_constructor_payability(env, get_contract, tx_failed, is_payable):
model = CodeModel()
env.set_balance(env.deployer, 10)

if should_fail:
ctx = tx_failed
else:
ctx = contextlib.nullcontext
init = model.function("__init__()").deploy().body("pass")

with ctx():
_ = get_contract(code, value=10)
if is_payable:
# payable constructor should deploy successfully with value
init.payable()
get_contract(model, value=10)
else:
# non-payable constructor should fail when deployed with value
with tx_failed():
get_contract(model, value=10)
Loading