From 284211670b3c6a827ab056d69e7aec0335a93f6b Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 11 Jun 2025 00:54:39 +0200 Subject: [PATCH 1/3] feat[test]: add code writing DSL to test suite this makes it easier to programmatically generate vyper source code, which will make it easier to write parametrized tests in the test suite. --- tests/dsl/__init__.py | 53 ++ tests/dsl/code_model.py | 224 +++++++++ tests/functional/codegen/test_dsl_example.py | 488 +++++++++++++++++++ 3 files changed, 765 insertions(+) create mode 100644 tests/dsl/__init__.py create mode 100644 tests/dsl/code_model.py create mode 100644 tests/functional/codegen/test_dsl_example.py diff --git a/tests/dsl/__init__.py b/tests/dsl/__init__.py new file mode 100644 index 0000000000..f9f9589a7a --- /dev/null +++ b/tests/dsl/__init__.py @@ -0,0 +1,53 @@ +""" +DSL for building Vyper contracts in tests. + +Example usage: + from tests.dsl import CodeModel + + # create a model + model = CodeModel() + + # define storage variables + balance = model.storage_var('balance: uint256') + owner = model.storage_var('owner: address') + + # build a simple contract + code = (model + .function('__init__()') + .deploy() + .body(f'{owner} = msg.sender') + .done() + .function('deposit()') + .external() + .payable() + .body(f'{balance} += msg.value') + .done() + .function('get_balance() -> uint256') + .external() + .view() + .body(f'return {balance}') + .done() + .build()) + + # The generated code will be: + # balance: uint256 + # owner: address + # + # @deploy + # def __init__(): + # self.owner = msg.sender + # + # @external + # @payable + # def deposit(): + # self.balance += msg.value + # + # @external + # @view + # def get_balance() -> uint256: + # return self.balance +""" + +from tests.dsl.code_model import CodeModel, VarRef + +__all__ = [CodeModel, VarRef] \ No newline at end of file diff --git a/tests/dsl/code_model.py b/tests/dsl/code_model.py new file mode 100644 index 0000000000..36cffa021a --- /dev/null +++ b/tests/dsl/code_model.py @@ -0,0 +1,224 @@ +""" +Code model for building Vyper contracts programmatically. + +This module provides a fluent API for constructing Vyper contracts +with proper formatting and structure. +""" + +import textwrap +from typing import Optional, List, Dict, Any, Union + +from vyper.ast import parse_to_ast +from vyper.ast.nodes import FunctionDef + + +class VarRef: + """Reference to a variable with type and location information.""" + + def __init__(self, name: str, typ: str, location: str, visibility: Optional[str] = None): + self.name = name + self.typ = typ + self.location = location + self.visibility = visibility + + def __str__(self) -> str: + """Return the variable name for use in expressions.""" + # storage and transient vars need self prefix + if self.location in ("storage", "transient"): + return f"self.{self.name}" + return self.name + + +class FunctionBuilder: + """Builder for function definitions.""" + + def __init__(self, signature: str, parent: "CodeModel"): + self.signature = signature + self.parent = parent + self.decorators: List[str] = [] + self.body_code: Optional[str] = None + self.is_internal = True # functions are internal by default + + # parse just the name from the signature + paren_idx = signature.find('(') + if paren_idx == -1: + raise ValueError(f"Invalid function signature: {signature}") + self.name = signature[:paren_idx].strip() + + def __str__(self) -> str: + """Return the function name for use in expressions.""" + if self.is_internal: + return f"self.{self.name}" + return self.name + + def external(self) -> "FunctionBuilder": + """Add @external decorator.""" + self.decorators.append("@external") + self.is_internal = False + return self + + def internal(self) -> "FunctionBuilder": + """Add @internal decorator.""" + self.decorators.append("@internal") + self.is_internal = True + return self + + def deploy(self) -> "FunctionBuilder": + """Add @deploy decorator.""" + self.decorators.append("@deploy") + self.is_internal = False # deploy functions are not called with self + return self + + def view(self) -> "FunctionBuilder": + """Add @view decorator.""" + self.decorators.append("@view") + return self + + def pure(self) -> "FunctionBuilder": + """Add @pure decorator.""" + self.decorators.append("@pure") + return self + + def payable(self) -> "FunctionBuilder": + """Add @payable decorator.""" + self.decorators.append("@payable") + return self + + def nonreentrant(self) -> "FunctionBuilder": + """Add @nonreentrant decorator.""" + self.decorators.append("@nonreentrant") + return self + + def body(self, code: str) -> "FunctionBuilder": + """Set the function body.""" + # dedent the code to handle multi-line strings nicely + self.body_code = textwrap.dedent(code).strip() + return self + + def done(self) -> "CodeModel": + """Finish building the function and return to parent CodeModel.""" + lines = [] + + lines.extend(self.decorators) + lines.append(f"def {self.signature}:") + + if self.body_code: + indented_body = "\n".join(f" {line}" for line in self.body_code.split("\n")) + lines.append(indented_body) + else: + lines.append(" pass") + + self.parent._functions.append("\n".join(lines)) + return self.parent + + +class CodeModel: + """Model for building a Vyper contract.""" + + def __init__(self): + self._storage_vars: List[str] = [] + self._transient_vars: List[str] = [] + self._constants: List[str] = [] + self._immutables: List[str] = [] + self._events: List[str] = [] + self._structs: List[str] = [] + self._flags: List[str] = [] + self._functions: List[str] = [] + self._imports: List[str] = [] + self._local_vars: Dict[str, VarRef] = {} + + def storage_var(self, declaration: str) -> VarRef: + """Add a storage variable.""" + name, typ = self._parse_declaration(declaration) + self._storage_vars.append(declaration) + return VarRef(name, typ, "storage", "public") + + def transient_var(self, declaration: str) -> VarRef: + """Add a transient storage variable.""" + name, typ = self._parse_declaration(declaration) + self._transient_vars.append(f"{name}: transient({typ})") + return VarRef(name, typ, "transient", "public") + + def constant(self, declaration: str) -> VarRef: + """Add a constant.""" + # constants have format: "NAME: constant(type) = value" + parts = declaration.split(":", 1) + name = parts[0].strip() + # extract type from constant(...) = value + type_start = parts[1].find("constant(") + 9 + type_end = parts[1].find(")", type_start) + typ = parts[1][type_start:type_end].strip() + + self._constants.append(declaration) + return VarRef(name, typ, "constant", None) + + def immutable(self, declaration: str) -> VarRef: + """Add an immutable variable.""" + name, typ = self._parse_declaration(declaration) + self._immutables.append(f"{name}: immutable({typ})") + return VarRef(name, typ, "immutable", "public") + + def local_var(self, name: str, typ: str) -> VarRef: + """Register a local variable (used in function bodies).""" + ref = VarRef(name, typ, "memory", None) + self._local_vars[name] = ref + return ref + + def event(self, definition: str) -> None: + """Add an event definition.""" + self._events.append(f"event {definition}") + + def struct(self, definition: str) -> None: + """Add a struct definition.""" + self._structs.append(f"struct {definition}") + + def flag(self, definition: str) -> None: + """Add a flag (enum) definition.""" + self._flags.append(f"flag {definition}") + + def function(self, signature: str) -> FunctionBuilder: + """Start building a function.""" + return FunctionBuilder(signature, self) + + def build(self) -> str: + """Build the complete contract code.""" + sections = [] + + if self._imports: + sections.append("\n".join(self._imports)) + + if self._events: + sections.append("\n".join(self._events)) + + if self._structs: + sections.append("\n".join(self._structs)) + + if self._flags: + sections.append("\n".join(self._flags)) + + if self._constants: + sections.append("\n".join(self._constants)) + + if self._storage_vars: + sections.append("\n".join(self._storage_vars)) + + if self._transient_vars: + sections.append("\n".join(self._transient_vars)) + + if self._immutables: + sections.append("\n".join(self._immutables)) + + if self._functions: + sections.append("\n\n".join(self._functions)) + + return "\n\n".join(sections) + + def _parse_declaration(self, declaration: str) -> tuple[str, str]: + """Parse a variable declaration of form 'name: type' into (name, type).""" + parts = declaration.split(":", 1) + if len(parts) != 2: + raise ValueError(f"Invalid declaration format: {declaration}") + + name = parts[0].strip() + typ = parts[1].strip() + return name, typ \ No newline at end of file diff --git a/tests/functional/codegen/test_dsl_example.py b/tests/functional/codegen/test_dsl_example.py new file mode 100644 index 0000000000..dc49850a78 --- /dev/null +++ b/tests/functional/codegen/test_dsl_example.py @@ -0,0 +1,488 @@ +""" +Example test showing CodeModel DSL usage. +""" + +import pytest +from eth_utils import to_wei + +from tests.dsl import CodeModel + + +def test_counter_with_init(get_contract): + """Simple counter with initialization.""" + model = CodeModel() + + count = model.storage_var("count: uint256") + owner = model.storage_var("owner: address") + + code = (model + .function("__init__(initial_count: uint256)") + .deploy() + .body(f""" + {count} = initial_count + {owner} = msg.sender + """) + .done() + .function("increment()") + .external() + .body(f"{count} += 1") + .done() + .function("get_count() -> uint256") + .external() + .view() + .body(f"return {count}") + .done() + .build()) + + c = get_contract(code, initial_count=10) + assert c.get_count() == 10 + c.increment() + assert c.get_count() == 11 + + +def test_array_operations_with_internal_helper(get_contract): + """Array operations using internal function.""" + model = CodeModel() + + values = model.storage_var("values: DynArray[uint256, 100]") + + # internal helper to find max value + find_max = model.function("_find_max(arr: DynArray[uint256, 100]) -> uint256").internal().view() + find_max.body(""" + max_val: uint256 = 0 + for val: uint256 in arr: + if val > max_val: + max_val = val + return max_val + """).done() + + code = (model + .function("add(val: uint256)") + .external() + .body(f"{values}.append(val)") + .done() + .function("get_max() -> uint256") + .external() + .view() + .body(f"return {find_max}({values})") + .done() + .build()) + + c = get_contract(code) + + c.add(5) + c.add(10) + c.add(3) + assert c.get_max() == 10 + + +def test_hashmap_with_structs(get_contract, env): + """HashMap with struct values.""" + model = CodeModel() + + model.struct("""User: + balance: uint256 + active: bool + joined_at: uint256""") + + users = model.storage_var("users: HashMap[address, User]") + user_count = model.storage_var("user_count: uint256") + + code = (model + .function("register()") + .external() + .body(f""" + assert {users}[msg.sender].joined_at == 0, "Already registered" + {users}[msg.sender] = User( + balance=0, + active=True, + joined_at=block.timestamp + ) + {user_count} += 1 + """) + .done() + .function("deposit()") + .external() + .payable() + .body(f""" + assert {users}[msg.sender].active, "User not active" + {users}[msg.sender].balance += msg.value + """) + .done() + .function("get_user(addr: address) -> User") + .external() + .view() + .body(f"return {users}[addr]") + .done() + .build()) + + c = get_contract(code) + + # fund the account + env.set_balance(env.deployer, 10**18) + + c.register() + c.deposit(value=100) + user = c.get_user(env.deployer) + assert user[0] == 100 # balance + assert user[1] == True # active + + +def test_constants_and_immutables(get_contract, env): + """Constants and immutables usage.""" + model = CodeModel() + + # constants + max_supply = model.constant("MAX_SUPPLY: constant(uint256) = 10**18") + fee_rate = model.constant("FEE_RATE: constant(uint256) = 250") # 2.5% + fee_divisor = model.constant("FEE_DIVISOR: constant(uint256) = 10000") + + # immutables + owner = model.immutable("OWNER: address") + deployed_at = model.immutable("DEPLOYED_AT: uint256") + + # storage + total_fees = model.storage_var("total_fees: uint256") + + code = (model + .function("__init__()") + .deploy() + .body(f""" + {owner} = msg.sender + {deployed_at} = block.timestamp + """) + .done() + .function("calculate_fee(amount: uint256) -> uint256") + .external() + .pure() + .body(f""" + assert amount <= {max_supply}, "Amount too large" + return amount * {fee_rate} // {fee_divisor} + """) + .done() + .function("collect_fee(amount: uint256) -> uint256") + .external() + .body(f""" + fee: uint256 = amount * {fee_rate} // {fee_divisor} + {total_fees} += fee + return fee + """) + .done() + .function("get_owner() -> address") + .external() + .view() + .body(f"return {owner}") + .done() + .build()) + + c = get_contract(code) + + assert c.calculate_fee(10000) == 250 + assert c.collect_fee(10000) == 250 + assert c.get_owner() == env.deployer + + +def test_events_and_logging(get_contract, get_logs, env): + """Events and logging with get_logs verification.""" + model = CodeModel() + + # events + model.event("""Transfer: + sender: indexed(address) + receiver: indexed(address) + amount: uint256""") + model.event("""Approval: + owner: indexed(address) + spender: indexed(address) + amount: uint256""") + model.event("""Burn: + account: indexed(address) + amount: uint256 + reason: String[100]""") + + balances = model.storage_var("balances: HashMap[address, uint256]") + + code = (model + .function("__init__()") + .deploy() + .body(f"{balances}[msg.sender] = 1000") + .done() + .function("transfer(to: address, amount: uint256)") + .external() + .body(f""" + {balances}[msg.sender] -= amount + {balances}[to] += amount + log Transfer(sender=msg.sender, receiver=to, amount=amount) + """) + .done() + .function("approve(spender: address, amount: uint256)") + .external() + .body("log Approval(owner=msg.sender, spender=spender, amount=amount)") + .done() + .function("burn(amount: uint256, reason: String[100])") + .external() + .body(f""" + {balances}[msg.sender] -= amount + log Burn(account=msg.sender, amount=amount, reason=reason) + """) + .done() + .build()) + + c = get_contract(code) + + # test transfer event + receiver = "0x1234567890123456789012345678901234567890" + c.transfer(receiver, 100) + (log,) = get_logs(c, "Transfer") + assert log.args.sender == env.deployer + assert log.args.receiver == receiver + assert log.args.amount == 100 + + # test approval event + spender = "0x2222222222222222222222222222222222222222" + c.approve(spender, 500) + (log,) = get_logs(c, "Approval") + assert log.args.owner == env.deployer + assert log.args.spender == spender + assert log.args.amount == 500 + + # test burn event with string + c.burn(50, "Testing burn functionality") + (log,) = get_logs(c, "Burn") + assert log.args.account == env.deployer + assert log.args.amount == 50 + assert log.args.reason == "Testing burn functionality" + + +def test_flags_and_enums(get_contract): + """Flags (enums) usage.""" + model = CodeModel() + + model.flag("""OrderStatus: + PENDING + FILLED + CANCELLED + EXPIRED""") + + model.struct("""Order: + amount: uint256 + price: uint256 + status: OrderStatus + created_at: uint256""") + + orders = model.storage_var("orders: HashMap[uint256, Order]") + next_id = model.storage_var("next_order_id: uint256") + + code = (model + .function("create_order(amount: uint256, price: uint256) -> uint256") + .external() + .body(f""" + order_id: uint256 = {next_id} + {orders}[order_id] = Order( + amount=amount, + price=price, + status=OrderStatus.PENDING, + created_at=block.timestamp + ) + {next_id} += 1 + return order_id + """) + .done() + .function("cancel_order(order_id: uint256)") + .external() + .body(f""" + assert {orders}[order_id].status == OrderStatus.PENDING, "Not pending" + {orders}[order_id].status = OrderStatus.CANCELLED + """) + .done() + .function("get_order_status(order_id: uint256) -> OrderStatus") + .external() + .view() + .body(f"return {orders}[order_id].status") + .done() + .build()) + + c = get_contract(code) + + order_id = c.create_order(100, 50) + assert c.get_order_status(order_id) == 1 # PENDING (flags start at 1) + c.cancel_order(order_id) + assert c.get_order_status(order_id) == 4 # CANCELLED + + +def test_payable_and_value_handling(get_contract, env): + """Payable functions and value handling.""" + model = CodeModel() + + deposits = model.storage_var("deposits: HashMap[address, uint256]") + total_deposits = model.storage_var("total_deposits: uint256") + + code = (model + .function("deposit()") + .external() + .payable() + .body(f""" + {deposits}[msg.sender] += msg.value + {total_deposits} += msg.value + """) + .done() + .function("withdraw(amount: uint256)") + .external() + .body(f""" + assert {deposits}[msg.sender] >= amount, "Insufficient balance" + {deposits}[msg.sender] -= amount + {total_deposits} -= amount + send(msg.sender, amount) + """) + .done() + .function("get_balance(addr: address) -> uint256") + .external() + .view() + .body(f"return {deposits}[addr]") + .done() + .build()) + + c = get_contract(code) + + # fund the account + env.set_balance(env.deployer, to_wei(10, "ether")) + + # deposit some ether + c.deposit(value=to_wei(1, "ether")) + assert c.get_balance(env.deployer) == to_wei(1, "ether") + + # withdraw half + c.withdraw(to_wei(0.5, "ether")) + assert c.get_balance(env.deployer) == to_wei(0.5, "ether") + + +def test_nonreentrant_guards(get_contract): + """Nonreentrant modifier usage.""" + model = CodeModel() + + bal = model.storage_var("bal: uint256") + + code = (model + .function("protected_withdraw(amount: uint256)") + .external() + .nonreentrant() + .body(f""" + assert {bal} >= amount + {bal} -= amount + raw_call(msg.sender, b"", value=amount) + """) + .done() + .function("protected_update(new_value: uint256)") + .external() + .nonreentrant() + .body(f"{bal} = new_value") + .done() + .build()) + + c = get_contract(code) + # just check it compiles - actual reentrancy testing would require attack contract + + +def test_complex_internal_function_chain(get_contract): + """Multiple internal functions calling each other.""" + model = CodeModel() + + max_size = model.constant("MAX_DATA_SIZE: constant(uint256) = 100") + data = model.storage_var(f"data: DynArray[uint256, {max_size}]") + + # internal function to validate index + validate_index = model.function("_validate_index(idx: uint256)").internal().view() + validate_index.body(f""" + assert idx < len({data}), "Index out of bounds" + """).done() + + # internal function to swap elements + swap = model.function("_swap(i: uint256, j: uint256)").internal() + swap.body(f""" + {validate_index}(i) + {validate_index}(j) + temp: uint256 = {data}[i] + {data}[i] = {data}[j] + {data}[j] = temp + """).done() + + # internal function to bubble sort + sort = model.function("_bubble_sort()").internal() + sort.body(f""" + n: uint256 = len({data}) + for i: uint256 in range(n, bound={max_size}): + for j: uint256 in range(n - i - 1, bound={max_size}): + if {data}[j] > {data}[j + 1]: + {swap}(j, j + 1) + """).done() + + code = (model + .function("add(val: uint256)") + .external() + .body(f"{data}.append(val)") + .done() + .function("sort_data()") + .external() + .body(f"{sort}()") + .done() + .function("get(idx: uint256) -> uint256") + .external() + .view() + .body(f""" + {validate_index}(idx) + return {data}[idx] + """) + .done() + .build()) + + c = get_contract(code) + + # add unsorted data + c.add(5) + c.add(2) + c.add(8) + c.add(1) + + # sort + c.sort_data() + + # check sorted + assert c.get(0) == 1 + assert c.get(1) == 2 + assert c.get(2) == 5 + assert c.get(3) == 8 + + +@pytest.mark.parametrize("decimals,multiplier", [(6, 10**6), (18, 10**18), (2, 100)]) +def test_parametrized_with_constants(get_contract, decimals, multiplier): + """Parametrized test with constants.""" + model = CodeModel() + + # constant based on parameter + model.constant(f"DECIMALS: constant(uint8) = {decimals}") + model.constant(f"MULTIPLIER: constant(uint256) = {multiplier}") + + bal = model.storage_var("bal: uint256") + + code = (model + .function("deposit(tokens: uint256)") + .external() + .body(f"{bal} += tokens * MULTIPLIER") + .done() + .function("get_balance() -> uint256") + .external() + .view() + .body(f"return {bal}") + .done() + .function("get_decimals() -> uint8") + .external() + .pure() + .body("return DECIMALS") + .done() + .build()) + + c = get_contract(code) + + assert c.get_decimals() == decimals + c.deposit(5) + assert c.get_balance() == 5 * multiplier \ No newline at end of file From 880d622182c0649bb2e8aef96158cce06380d461 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 11 Jun 2025 22:39:58 +0200 Subject: [PATCH 2/3] lint --- tests/dsl/__init__.py | 16 +- tests/dsl/code_model.py | 135 ++++---- tests/functional/codegen/test_dsl_example.py | 332 +++++++++++-------- 3 files changed, 270 insertions(+), 213 deletions(-) diff --git a/tests/dsl/__init__.py b/tests/dsl/__init__.py index f9f9589a7a..f3668f9d29 100644 --- a/tests/dsl/__init__.py +++ b/tests/dsl/__init__.py @@ -3,14 +3,14 @@ Example usage: from tests.dsl import CodeModel - + # create a model model = CodeModel() - + # define storage variables balance = model.storage_var('balance: uint256') owner = model.storage_var('owner: address') - + # build a simple contract code = (model .function('__init__()') @@ -28,20 +28,20 @@ .body(f'return {balance}') .done() .build()) - + # The generated code will be: # balance: uint256 # owner: address - # + # # @deploy # def __init__(): # self.owner = msg.sender - # + # # @external # @payable # def deposit(): # self.balance += msg.value - # + # # @external # @view # def get_balance() -> uint256: @@ -50,4 +50,4 @@ from tests.dsl.code_model import CodeModel, VarRef -__all__ = [CodeModel, VarRef] \ No newline at end of file +__all__ = [CodeModel, VarRef] diff --git a/tests/dsl/code_model.py b/tests/dsl/code_model.py index 36cffa021a..f3f43ada6e 100644 --- a/tests/dsl/code_model.py +++ b/tests/dsl/code_model.py @@ -5,22 +5,21 @@ with proper formatting and structure. """ -import textwrap -from typing import Optional, List, Dict, Any, Union +from __future__ import annotations -from vyper.ast import parse_to_ast -from vyper.ast.nodes import FunctionDef +import textwrap +from typing import Optional class VarRef: """Reference to a variable with type and location information.""" - + def __init__(self, name: str, typ: str, location: str, visibility: Optional[str] = None): self.name = name self.typ = typ self.location = location self.visibility = visibility - + def __str__(self) -> str: """Return the variable name for use in expressions.""" # storage and transient vars need self prefix @@ -31,114 +30,114 @@ def __str__(self) -> str: class FunctionBuilder: """Builder for function definitions.""" - - def __init__(self, signature: str, parent: "CodeModel"): + + def __init__(self, signature: str, parent: CodeModel): self.signature = signature self.parent = parent - self.decorators: List[str] = [] + self.decorators: list[str] = [] self.body_code: Optional[str] = None self.is_internal = True # functions are internal by default - + # parse just the name from the signature - paren_idx = signature.find('(') + paren_idx = signature.find("(") if paren_idx == -1: raise ValueError(f"Invalid function signature: {signature}") self.name = signature[:paren_idx].strip() - + def __str__(self) -> str: """Return the function name for use in expressions.""" if self.is_internal: return f"self.{self.name}" return self.name - - def external(self) -> "FunctionBuilder": + + def external(self) -> FunctionBuilder: """Add @external decorator.""" self.decorators.append("@external") self.is_internal = False return self - - def internal(self) -> "FunctionBuilder": + + def internal(self) -> FunctionBuilder: """Add @internal decorator.""" self.decorators.append("@internal") self.is_internal = True return self - - def deploy(self) -> "FunctionBuilder": + + def deploy(self) -> FunctionBuilder: """Add @deploy decorator.""" self.decorators.append("@deploy") self.is_internal = False # deploy functions are not called with self return self - - def view(self) -> "FunctionBuilder": + + def view(self) -> FunctionBuilder: """Add @view decorator.""" self.decorators.append("@view") return self - - def pure(self) -> "FunctionBuilder": + + def pure(self) -> FunctionBuilder: """Add @pure decorator.""" self.decorators.append("@pure") return self - - def payable(self) -> "FunctionBuilder": + + def payable(self) -> FunctionBuilder: """Add @payable decorator.""" self.decorators.append("@payable") return self - - def nonreentrant(self) -> "FunctionBuilder": + + def nonreentrant(self) -> FunctionBuilder: """Add @nonreentrant decorator.""" self.decorators.append("@nonreentrant") return self - - def body(self, code: str) -> "FunctionBuilder": + + def body(self, code: str) -> FunctionBuilder: """Set the function body.""" # dedent the code to handle multi-line strings nicely self.body_code = textwrap.dedent(code).strip() return self - - def done(self) -> "CodeModel": + + def done(self) -> CodeModel: """Finish building the function and return to parent CodeModel.""" lines = [] - + lines.extend(self.decorators) lines.append(f"def {self.signature}:") - + if self.body_code: indented_body = "\n".join(f" {line}" for line in self.body_code.split("\n")) lines.append(indented_body) else: lines.append(" pass") - + self.parent._functions.append("\n".join(lines)) return self.parent class CodeModel: """Model for building a Vyper contract.""" - + def __init__(self): - self._storage_vars: List[str] = [] - self._transient_vars: List[str] = [] - self._constants: List[str] = [] - self._immutables: List[str] = [] - self._events: List[str] = [] - self._structs: List[str] = [] - self._flags: List[str] = [] - self._functions: List[str] = [] - self._imports: List[str] = [] - self._local_vars: Dict[str, VarRef] = {} - + self._storage_vars: list[str] = [] + self._transient_vars: list[str] = [] + self._constants: list[str] = [] + self._immutables: list[str] = [] + self._events: list[str] = [] + self._structs: list[str] = [] + self._flags: list[str] = [] + self._functions: list[str] = [] + self._imports: list[str] = [] + self._local_vars: dict[str, VarRef] = {} + def storage_var(self, declaration: str) -> VarRef: """Add a storage variable.""" name, typ = self._parse_declaration(declaration) self._storage_vars.append(declaration) return VarRef(name, typ, "storage", "public") - + def transient_var(self, declaration: str) -> VarRef: """Add a transient storage variable.""" name, typ = self._parse_declaration(declaration) self._transient_vars.append(f"{name}: transient({typ})") return VarRef(name, typ, "transient", "public") - + def constant(self, declaration: str) -> VarRef: """Add a constant.""" # constants have format: "NAME: constant(type) = value" @@ -148,77 +147,77 @@ def constant(self, declaration: str) -> VarRef: type_start = parts[1].find("constant(") + 9 type_end = parts[1].find(")", type_start) typ = parts[1][type_start:type_end].strip() - + self._constants.append(declaration) return VarRef(name, typ, "constant", None) - + def immutable(self, declaration: str) -> VarRef: """Add an immutable variable.""" name, typ = self._parse_declaration(declaration) self._immutables.append(f"{name}: immutable({typ})") return VarRef(name, typ, "immutable", "public") - + def local_var(self, name: str, typ: str) -> VarRef: """Register a local variable (used in function bodies).""" ref = VarRef(name, typ, "memory", None) self._local_vars[name] = ref return ref - + def event(self, definition: str) -> None: """Add an event definition.""" self._events.append(f"event {definition}") - + def struct(self, definition: str) -> None: """Add a struct definition.""" self._structs.append(f"struct {definition}") - + def flag(self, definition: str) -> None: """Add a flag (enum) definition.""" self._flags.append(f"flag {definition}") - + def function(self, signature: str) -> FunctionBuilder: """Start building a function.""" return FunctionBuilder(signature, self) - + def build(self) -> str: """Build the complete contract code.""" sections = [] - + if self._imports: sections.append("\n".join(self._imports)) - + if self._events: sections.append("\n".join(self._events)) - + if self._structs: sections.append("\n".join(self._structs)) - + if self._flags: sections.append("\n".join(self._flags)) - + if self._constants: sections.append("\n".join(self._constants)) - + if self._storage_vars: sections.append("\n".join(self._storage_vars)) - + if self._transient_vars: sections.append("\n".join(self._transient_vars)) - + if self._immutables: sections.append("\n".join(self._immutables)) - + if self._functions: sections.append("\n\n".join(self._functions)) - + return "\n\n".join(sections) - + def _parse_declaration(self, declaration: str) -> tuple[str, str]: """Parse a variable declaration of form 'name: type' into (name, type).""" parts = declaration.split(":", 1) if len(parts) != 2: raise ValueError(f"Invalid declaration format: {declaration}") - + name = parts[0].strip() typ = parts[1].strip() - return name, typ \ No newline at end of file + return name, typ diff --git a/tests/functional/codegen/test_dsl_example.py b/tests/functional/codegen/test_dsl_example.py index dc49850a78..64a0a9c13b 100644 --- a/tests/functional/codegen/test_dsl_example.py +++ b/tests/functional/codegen/test_dsl_example.py @@ -11,17 +11,19 @@ def test_counter_with_init(get_contract): """Simple counter with initialization.""" model = CodeModel() - + count = model.storage_var("count: uint256") owner = model.storage_var("owner: address") - - code = (model - .function("__init__(initial_count: uint256)") + + code = ( + model.function("__init__(initial_count: uint256)") .deploy() - .body(f""" + .body( + f""" {count} = initial_count {owner} = msg.sender - """) + """ + ) .done() .function("increment()") .external() @@ -32,8 +34,9 @@ def test_counter_with_init(get_contract): .view() .body(f"return {count}") .done() - .build()) - + .build() + ) + c = get_contract(code, initial_count=10) assert c.get_count() == 10 c.increment() @@ -43,21 +46,23 @@ def test_counter_with_init(get_contract): def test_array_operations_with_internal_helper(get_contract): """Array operations using internal function.""" model = CodeModel() - + values = model.storage_var("values: DynArray[uint256, 100]") - + # internal helper to find max value find_max = model.function("_find_max(arr: DynArray[uint256, 100]) -> uint256").internal().view() - find_max.body(""" + find_max.body( + """ max_val: uint256 = 0 for val: uint256 in arr: if val > max_val: max_val = val return max_val - """).done() - - code = (model - .function("add(val: uint256)") + """ + ).done() + + code = ( + model.function("add(val: uint256)") .external() .body(f"{values}.append(val)") .done() @@ -66,10 +71,11 @@ def test_array_operations_with_internal_helper(get_contract): .view() .body(f"return {find_max}({values})") .done() - .build()) - + .build() + ) + c = get_contract(code) - + c.add(5) c.add(10) c.add(3) @@ -79,19 +85,22 @@ def test_array_operations_with_internal_helper(get_contract): def test_hashmap_with_structs(get_contract, env): """HashMap with struct values.""" model = CodeModel() - - model.struct("""User: + + model.struct( + """User: balance: uint256 active: bool - joined_at: uint256""") - + joined_at: uint256""" + ) + users = model.storage_var("users: HashMap[address, User]") user_count = model.storage_var("user_count: uint256") - - code = (model - .function("register()") + + code = ( + model.function("register()") .external() - .body(f""" + .body( + f""" assert {users}[msg.sender].joined_at == 0, "Already registered" {users}[msg.sender] = User( balance=0, @@ -99,84 +108,95 @@ def test_hashmap_with_structs(get_contract, env): joined_at=block.timestamp ) {user_count} += 1 - """) + """ + ) .done() .function("deposit()") .external() .payable() - .body(f""" + .body( + f""" assert {users}[msg.sender].active, "User not active" {users}[msg.sender].balance += msg.value - """) + """ + ) .done() .function("get_user(addr: address) -> User") .external() .view() .body(f"return {users}[addr]") .done() - .build()) - + .build() + ) + c = get_contract(code) - + # fund the account env.set_balance(env.deployer, 10**18) - + c.register() c.deposit(value=100) user = c.get_user(env.deployer) assert user[0] == 100 # balance - assert user[1] == True # active + assert user[1] is True # active def test_constants_and_immutables(get_contract, env): """Constants and immutables usage.""" model = CodeModel() - + # constants max_supply = model.constant("MAX_SUPPLY: constant(uint256) = 10**18") fee_rate = model.constant("FEE_RATE: constant(uint256) = 250") # 2.5% fee_divisor = model.constant("FEE_DIVISOR: constant(uint256) = 10000") - + # immutables owner = model.immutable("OWNER: address") deployed_at = model.immutable("DEPLOYED_AT: uint256") - + # storage total_fees = model.storage_var("total_fees: uint256") - - code = (model - .function("__init__()") + + code = ( + model.function("__init__()") .deploy() - .body(f""" + .body( + f""" {owner} = msg.sender {deployed_at} = block.timestamp - """) + """ + ) .done() .function("calculate_fee(amount: uint256) -> uint256") .external() .pure() - .body(f""" + .body( + f""" assert amount <= {max_supply}, "Amount too large" return amount * {fee_rate} // {fee_divisor} - """) + """ + ) .done() .function("collect_fee(amount: uint256) -> uint256") .external() - .body(f""" + .body( + f""" fee: uint256 = amount * {fee_rate} // {fee_divisor} {total_fees} += fee return fee - """) + """ + ) .done() .function("get_owner() -> address") .external() .view() .body(f"return {owner}") .done() - .build()) - + .build() + ) + c = get_contract(code) - + assert c.calculate_fee(10000) == 250 assert c.collect_fee(10000) == 250 assert c.get_owner() == env.deployer @@ -185,35 +205,43 @@ def test_constants_and_immutables(get_contract, env): def test_events_and_logging(get_contract, get_logs, env): """Events and logging with get_logs verification.""" model = CodeModel() - + # events - model.event("""Transfer: + model.event( + """Transfer: sender: indexed(address) receiver: indexed(address) - amount: uint256""") - model.event("""Approval: + amount: uint256""" + ) + model.event( + """Approval: owner: indexed(address) spender: indexed(address) - amount: uint256""") - model.event("""Burn: + amount: uint256""" + ) + model.event( + """Burn: account: indexed(address) amount: uint256 - reason: String[100]""") - + reason: String[100]""" + ) + balances = model.storage_var("balances: HashMap[address, uint256]") - - code = (model - .function("__init__()") + + code = ( + model.function("__init__()") .deploy() .body(f"{balances}[msg.sender] = 1000") .done() .function("transfer(to: address, amount: uint256)") .external() - .body(f""" + .body( + f""" {balances}[msg.sender] -= amount {balances}[to] += amount log Transfer(sender=msg.sender, receiver=to, amount=amount) - """) + """ + ) .done() .function("approve(spender: address, amount: uint256)") .external() @@ -221,15 +249,18 @@ def test_events_and_logging(get_contract, get_logs, env): .done() .function("burn(amount: uint256, reason: String[100])") .external() - .body(f""" + .body( + f""" {balances}[msg.sender] -= amount log Burn(account=msg.sender, amount=amount, reason=reason) - """) + """ + ) .done() - .build()) - + .build() + ) + c = get_contract(code) - + # test transfer event receiver = "0x1234567890123456789012345678901234567890" c.transfer(receiver, 100) @@ -237,7 +268,7 @@ def test_events_and_logging(get_contract, get_logs, env): assert log.args.sender == env.deployer assert log.args.receiver == receiver assert log.args.amount == 100 - + # test approval event spender = "0x2222222222222222222222222222222222222222" c.approve(spender, 500) @@ -245,7 +276,7 @@ def test_events_and_logging(get_contract, get_logs, env): assert log.args.owner == env.deployer assert log.args.spender == spender assert log.args.amount == 500 - + # test burn event with string c.burn(50, "Testing burn functionality") (log,) = get_logs(c, "Burn") @@ -257,26 +288,31 @@ def test_events_and_logging(get_contract, get_logs, env): def test_flags_and_enums(get_contract): """Flags (enums) usage.""" model = CodeModel() - - model.flag("""OrderStatus: + + model.flag( + """OrderStatus: PENDING FILLED CANCELLED - EXPIRED""") - - model.struct("""Order: + EXPIRED""" + ) + + model.struct( + """Order: amount: uint256 price: uint256 status: OrderStatus - created_at: uint256""") - + created_at: uint256""" + ) + orders = model.storage_var("orders: HashMap[uint256, Order]") next_id = model.storage_var("next_order_id: uint256") - - code = (model - .function("create_order(amount: uint256, price: uint256) -> uint256") + + code = ( + model.function("create_order(amount: uint256, price: uint256) -> uint256") .external() - .body(f""" + .body( + f""" order_id: uint256 = {next_id} {orders}[order_id] = Order( amount=amount, @@ -286,24 +322,28 @@ def test_flags_and_enums(get_contract): ) {next_id} += 1 return order_id - """) + """ + ) .done() .function("cancel_order(order_id: uint256)") .external() - .body(f""" + .body( + f""" assert {orders}[order_id].status == OrderStatus.PENDING, "Not pending" {orders}[order_id].status = OrderStatus.CANCELLED - """) + """ + ) .done() .function("get_order_status(order_id: uint256) -> OrderStatus") .external() .view() .body(f"return {orders}[order_id].status") .done() - .build()) - + .build() + ) + c = get_contract(code) - + order_id = c.create_order(100, 50) assert c.get_order_status(order_id) == 1 # PENDING (flags start at 1) c.cancel_order(order_id) @@ -313,44 +353,49 @@ def test_flags_and_enums(get_contract): def test_payable_and_value_handling(get_contract, env): """Payable functions and value handling.""" model = CodeModel() - + deposits = model.storage_var("deposits: HashMap[address, uint256]") total_deposits = model.storage_var("total_deposits: uint256") - - code = (model - .function("deposit()") + + code = ( + model.function("deposit()") .external() .payable() - .body(f""" + .body( + f""" {deposits}[msg.sender] += msg.value {total_deposits} += msg.value - """) + """ + ) .done() .function("withdraw(amount: uint256)") .external() - .body(f""" + .body( + f""" assert {deposits}[msg.sender] >= amount, "Insufficient balance" {deposits}[msg.sender] -= amount {total_deposits} -= amount send(msg.sender, amount) - """) + """ + ) .done() .function("get_balance(addr: address) -> uint256") .external() .view() .body(f"return {deposits}[addr]") .done() - .build()) - + .build() + ) + c = get_contract(code) - + # fund the account env.set_balance(env.deployer, to_wei(10, "ether")) - + # deposit some ether c.deposit(value=to_wei(1, "ether")) assert c.get_balance(env.deployer) == to_wei(1, "ether") - + # withdraw half c.withdraw(to_wei(0.5, "ether")) assert c.get_balance(env.deployer) == to_wei(0.5, "ether") @@ -359,65 +404,74 @@ def test_payable_and_value_handling(get_contract, env): def test_nonreentrant_guards(get_contract): """Nonreentrant modifier usage.""" model = CodeModel() - + bal = model.storage_var("bal: uint256") - - code = (model - .function("protected_withdraw(amount: uint256)") + + code = ( + model.function("protected_withdraw(amount: uint256)") .external() .nonreentrant() - .body(f""" + .body( + f""" assert {bal} >= amount {bal} -= amount raw_call(msg.sender, b"", value=amount) - """) + """ + ) .done() .function("protected_update(new_value: uint256)") .external() .nonreentrant() .body(f"{bal} = new_value") .done() - .build()) - - c = get_contract(code) + .build() + ) + + get_contract(code) # just check it compiles - actual reentrancy testing would require attack contract def test_complex_internal_function_chain(get_contract): """Multiple internal functions calling each other.""" model = CodeModel() - + max_size = model.constant("MAX_DATA_SIZE: constant(uint256) = 100") data = model.storage_var(f"data: DynArray[uint256, {max_size}]") - + # internal function to validate index validate_index = model.function("_validate_index(idx: uint256)").internal().view() - validate_index.body(f""" + validate_index.body( + f""" assert idx < len({data}), "Index out of bounds" - """).done() - + """ + ).done() + # internal function to swap elements swap = model.function("_swap(i: uint256, j: uint256)").internal() - swap.body(f""" + swap.body( + f""" {validate_index}(i) {validate_index}(j) temp: uint256 = {data}[i] {data}[i] = {data}[j] {data}[j] = temp - """).done() - + """ + ).done() + # internal function to bubble sort sort = model.function("_bubble_sort()").internal() - sort.body(f""" + sort.body( + f""" n: uint256 = len({data}) for i: uint256 in range(n, bound={max_size}): for j: uint256 in range(n - i - 1, bound={max_size}): if {data}[j] > {data}[j + 1]: {swap}(j, j + 1) - """).done() - - code = (model - .function("add(val: uint256)") + """ + ).done() + + code = ( + model.function("add(val: uint256)") .external() .body(f"{data}.append(val)") .done() @@ -428,24 +482,27 @@ def test_complex_internal_function_chain(get_contract): .function("get(idx: uint256) -> uint256") .external() .view() - .body(f""" + .body( + f""" {validate_index}(idx) return {data}[idx] - """) + """ + ) .done() - .build()) - + .build() + ) + c = get_contract(code) - + # add unsorted data c.add(5) c.add(2) c.add(8) c.add(1) - + # sort c.sort_data() - + # check sorted assert c.get(0) == 1 assert c.get(1) == 2 @@ -457,15 +514,15 @@ def test_complex_internal_function_chain(get_contract): def test_parametrized_with_constants(get_contract, decimals, multiplier): """Parametrized test with constants.""" model = CodeModel() - + # constant based on parameter model.constant(f"DECIMALS: constant(uint8) = {decimals}") model.constant(f"MULTIPLIER: constant(uint256) = {multiplier}") - + bal = model.storage_var("bal: uint256") - - code = (model - .function("deposit(tokens: uint256)") + + code = ( + model.function("deposit(tokens: uint256)") .external() .body(f"{bal} += tokens * MULTIPLIER") .done() @@ -479,10 +536,11 @@ def test_parametrized_with_constants(get_contract, decimals, multiplier): .pure() .body("return DECIMALS") .done() - .build()) - + .build() + ) + c = get_contract(code) - + assert c.get_decimals() == decimals c.deposit(5) - assert c.get_balance() == 5 * multiplier \ No newline at end of file + assert c.get_balance() == 5 * multiplier From fc7d4ef281aebef6c493f7883dc6adee6ce96556 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 27 Jun 2025 14:13:17 +0200 Subject: [PATCH 3/3] refactor a ctor test to use the new model. refactor the builder API --- tests/conftest.py | 4 +++ tests/dsl/code_model.py | 36 ++++++++++--------- .../codegen/features/test_constructor.py | 30 +++++++--------- 3 files changed, 37 insertions(+), 33 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e3520bc547..eba64c8fcf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,7 @@ import tests.hevm import vyper.evm.opcodes as evm_opcodes +from tests.dsl.code_model import CodeModel from tests.evm_backends.base_env import BaseEnv, ExecutionReverted from tests.evm_backends.pyevm_env import PyEvmEnv from tests.evm_backends.revm_env import RevmEnv @@ -255,6 +256,9 @@ def hevm_marker(request): @pytest.fixture(scope="module") def get_contract(env, optimize, output_formats, compiler_settings, hevm, request): def fn(source_code, *args, **kwargs): + # support CodeModel instances + if isinstance(source_code, CodeModel): + source_code = source_code.build() if "override_opt_level" in kwargs: kwargs["compiler_settings"] = Settings( **dict(compiler_settings.__dict__, optimize=kwargs.pop("override_opt_level")) diff --git a/tests/dsl/code_model.py b/tests/dsl/code_model.py index f3f43ada6e..75e1fb3068 100644 --- a/tests/dsl/code_model.py +++ b/tests/dsl/code_model.py @@ -96,18 +96,6 @@ def body(self, code: str) -> FunctionBuilder: def done(self) -> CodeModel: """Finish building the function and return to parent CodeModel.""" - lines = [] - - lines.extend(self.decorators) - lines.append(f"def {self.signature}:") - - if self.body_code: - indented_body = "\n".join(f" {line}" for line in self.body_code.split("\n")) - lines.append(indented_body) - else: - lines.append(" pass") - - self.parent._functions.append("\n".join(lines)) return self.parent @@ -122,9 +110,9 @@ def __init__(self): self._events: list[str] = [] self._structs: list[str] = [] self._flags: list[str] = [] - self._functions: list[str] = [] self._imports: list[str] = [] self._local_vars: dict[str, VarRef] = {} + self._function_builders: list[FunctionBuilder] = [] def storage_var(self, declaration: str) -> VarRef: """Add a storage variable.""" @@ -177,7 +165,9 @@ def flag(self, definition: str) -> None: def function(self, signature: str) -> FunctionBuilder: """Start building a function.""" - return FunctionBuilder(signature, self) + fb = FunctionBuilder(signature, self) + self._function_builders.append(fb) + return fb def build(self) -> str: """Build the complete contract code.""" @@ -207,8 +197,22 @@ def build(self) -> str: if self._immutables: sections.append("\n".join(self._immutables)) - if self._functions: - sections.append("\n\n".join(self._functions)) + if self._function_builders: + function_strings = [] + for fb in self._function_builders: + lines = [] + lines.extend(fb.decorators) + lines.append(f"def {fb.signature}:") + + if fb.body_code: + indented_body = "\n".join(f" {line}" for line in fb.body_code.split("\n")) + lines.append(indented_body) + else: + lines.append(" pass") + + function_strings.append("\n".join(lines)) + + sections.append("\n\n".join(function_strings)) return "\n\n".join(sections) diff --git a/tests/functional/codegen/features/test_constructor.py b/tests/functional/codegen/features/test_constructor.py index 182e2e2ff2..38c74a129f 100644 --- a/tests/functional/codegen/features/test_constructor.py +++ b/tests/functional/codegen/features/test_constructor.py @@ -1,7 +1,6 @@ -import contextlib - import pytest +from tests.dsl import CodeModel from tests.evm_backends.base_env import _compile from vyper.exceptions import StackTooDeep from vyper.utils import method_id @@ -296,27 +295,24 @@ def __init__(): I_ADDR = CONST_ADDR I_BYTES32 = CONST_BYTES32 """ - print(code) c = get_contract(code) assert c.I_UINT() == CONST_UINT assert c.I_ADDR() == CONST_ADDR assert c.I_BYTES32() == bytes.fromhex(CONST_BYTES32.removeprefix("0x")) -@pytest.mark.parametrize("should_fail", [True, False]) -def test_constructor_payability(env, get_contract, tx_failed, should_fail): - code = f""" -@deploy -{"" if should_fail else "@payable"} -def __init__(): - pass -""" +@pytest.mark.parametrize("is_payable", [False, True]) +def test_constructor_payability(env, get_contract, tx_failed, is_payable): + model = CodeModel() env.set_balance(env.deployer, 10) - if should_fail: - ctx = tx_failed - else: - ctx = contextlib.nullcontext + init = model.function("__init__()").deploy().body("pass") - with ctx(): - _ = get_contract(code, value=10) + if is_payable: + # payable constructor should deploy successfully with value + init.payable() + get_contract(model, value=10) + else: + # non-payable constructor should fail when deployed with value + with tx_failed(): + get_contract(model, value=10)