diff --git a/tests/functional/codegen/features/test_variable_initialization.py b/tests/functional/codegen/features/test_variable_initialization.py new file mode 100644 index 0000000000..d0e3053196 --- /dev/null +++ b/tests/functional/codegen/features/test_variable_initialization.py @@ -0,0 +1,643 @@ +import pytest + +from tests.utils import decimal_to_int + + +def test_storage_variable_initialization(get_contract): + code = """ +x: uint256 = 42 +y: int128 = -100 +z: bool = True +w: address = 0x0000000000000000000000000000000000000123 + +@external +@view +def get_x() -> uint256: + return self.x + +@external +@view +def get_y() -> int128: + return self.y + +@external +@view +def get_z() -> bool: + return self.z + +@external +@view +def get_w() -> address: + return self.w + """ + + c = get_contract(code) + assert c.get_x() == 42 + assert c.get_y() == -100 + assert c.get_z() is True + assert c.get_w() == "0x0000000000000000000000000000000000000123" + + +def test_storage_variable_initialization_with_constructor_override(get_contract): + code = """ +x: uint256 = 42 +y: int128 = -100 + +@deploy +def __init__(new_x: uint256): + self.x = new_x # Override x + # y keeps its initialized value + +@external +@view +def get_x() -> uint256: + return self.x + +@external +@view +def get_y() -> int128: + return self.y + """ + + c = get_contract(code, 999) + assert c.get_x() == 999 # overridden by constructor + assert c.get_y() == -100 # keeps initialized value + + +def test_immutable_initialization(get_contract): + code = """ +X: immutable(uint256) = 42 +Y: immutable(int128) = -100 + +@deploy +def __init__(override_x: uint256): + X = override_x # Override X + # Y keeps initialized value + +@external +@view +def get_x() -> uint256: + return X + +@external +@view +def get_y() -> int128: + return Y + """ + + c = get_contract(code, 123) + assert c.get_x() == 123 # overridden + assert c.get_y() == -100 # keeps initialized value + + +def test_complex_initialization_expressions(get_contract): + code = """ +# Test various literal types and constant expressions +a: uint256 = 1024 # 2 ** 10 +b: int128 = -100 # -50 * 2 +c: decimal = 3.14159 +d: bool = False +e: bytes32 = 0x1234567890123456789012345678901234567890123456789012345678901234 + +@external +@view +def get_a() -> uint256: + return self.a + +@external +@view +def get_b() -> int128: + return self.b + +@external +@view +def get_c() -> decimal: + return self.c + +@external +@view +def get_d() -> bool: + return self.d + +@external +@view +def get_e() -> bytes32: + return self.e + """ + + c = get_contract(code) + assert c.get_a() == 1024 + assert c.get_b() == -100 + assert c.get_c() == decimal_to_int("3.14159") + assert c.get_d() is False + assert c.get_e() == b"\x124Vx\x90\x124Vx\x90\x124Vx\x90\x124Vx\x90\x124Vx\x90\x124Vx\x90\x124" + + +def test_initialization_order(get_contract): + """Test that initializations happen in declaration order""" + code = """ +a: uint256 = 1 +b: uint256 = 2 +c: uint256 = 3 + +@deploy +def __init__(): + # Check they were initialized in order + assert self.a == 1 + assert self.b == 2 + assert self.c == 3 + + # Now override b + self.b = 20 + +@external +@view +def get_a() -> uint256: + return self.a + +@external +@view +def get_b() -> uint256: + return self.b + +@external +@view +def get_c() -> uint256: + return self.c + """ + + c = get_contract(code) + assert c.get_a() == 1 + assert c.get_b() == 20 # overridden in constructor + assert c.get_c() == 3 + + +def test_no_constructor_with_initialization(get_contract): + """Test that initialization works even without a constructor""" + code = """ +x: uint256 = 100 +y: bool = True + +@external +@view +def get_x() -> uint256: + return self.x + +@external +@view +def get_y() -> bool: + return self.y + """ + + c = get_contract(code) + assert c.get_x() == 100 + assert c.get_y() is True + + +def test_mixed_initialized_and_uninitialized(get_contract): + """Test mixing initialized and uninitialized variables""" + code = """ +a: uint256 = 42 # initialized +b: uint256 # not initialized, should be 0 +c: int128 = -50 # initialized +d: int128 # not initialized, should be 0 + +@external +@view +def get_values() -> (uint256, uint256, int128, int128): + return self.a, self.b, self.c, self.d + """ + + c = get_contract(code) + a, b, c_val, d = c.get_values() + assert a == 42 + assert b == 0 + assert c_val == -50 + assert d == 0 + + +def test_public_variable_initialization(get_contract): + """Test that public variables with initializers work correctly""" + code = """ +x: public(uint256) = 12345 +y: public(bool) = True +z: public(address) = 0x0000000000000000000000000000000000000aBc + """ + + c = get_contract(code) + # public variables automatically get getter functions + assert c.x() == 12345 + assert c.y() is True + assert c.z() == "0x0000000000000000000000000000000000000aBc" + + +@pytest.mark.requires_evm_version("cancun") +def test_transient_storage_initialization(get_contract): + """Test initialization of transient storage variables""" + code = """ +#pragma evm-version cancun + +x: transient(uint256) = 42 +y: transient(bool) = True + +# Storage variables to capture transient values during deployment +stored_x: uint256 +stored_y: bool + +@deploy +def __init__(): + # Capture the initialized transient values + self.stored_x = self.x + self.stored_y = self.y + +@external +@view +def get_stored_x() -> uint256: + return self.stored_x + +@external +@view +def get_stored_y() -> bool: + return self.stored_y + +@external +def get_x() -> uint256: + return self.x + +@external +def get_y() -> bool: + return self.y + """ + + c = get_contract(code) + + # Verify that transient variables were initialized during deployment + assert c.get_stored_x() == 42 + assert c.get_stored_y() is True + + # In test environment, all calls happen in the same transaction, + # so transient storage retains its value from initialization + assert c.get_x() == 42 + assert c.get_y() is True + + +def test_constructor_with_conditional_override(get_contract): + """Test conditional logic in constructor that may override initialized values""" + code = """ +x: uint256 = 100 +y: uint256 = 200 +z: uint256 = 300 + +@deploy +def __init__(override_flag: uint256): + if override_flag == 1: + self.x = 111 + elif override_flag == 2: + self.y = 222 + else: + self.z = 333 + + # nested conditions + if self.x > 100: + if self.y == 200: + self.z = 999 + +@external +@view +def get_values() -> (uint256, uint256, uint256): + return self.x, self.y, self.z + """ + + # Test case 1: override_flag == 1 + c1 = get_contract(code, 1) + x, y, z = c1.get_values() + assert x == 111 # overridden + assert y == 200 # kept initial + assert z == 999 # overridden by nested condition + + # Test case 2: override_flag == 2 + c2 = get_contract(code, 2) + x, y, z = c2.get_values() + assert x == 100 # kept initial + assert y == 222 # overridden + assert z == 300 # kept initial + + # Test case 3: override_flag == other + c3 = get_contract(code, 3) + x, y, z = c3.get_values() + assert x == 100 # kept initial + assert y == 200 # kept initial + assert z == 333 # overridden + + +def test_constructor_with_loop_override(get_contract): + """Test loops in constructor that modify initialized values""" + code = """ +counter: uint256 = 1000 +values: uint256[10] = empty(uint256[10]) + +@deploy +def __init__(iterations: uint256): + # Initialize some array values based on counter + for i: uint256 in range(10): + self.values[i] = self.counter + i + + # Conditionally modify counter in a loop + for i: uint256 in range(10): + if i < iterations: + self.counter += 10 + else: + break + +@external +@view +def get_counter() -> uint256: + return self.counter + +@external +@view +def get_value(idx: uint256) -> uint256: + return self.values[idx] + """ + + # Test with 5 iterations + c = get_contract(code, 5) + assert c.get_counter() == 1050 # 1000 + (5 * 10) + assert c.get_value(0) == 1000 # initial counter value + assert c.get_value(5) == 1005 # counter + 5 + + +def test_early_return_in_constructor(get_contract): + """Test early returns in constructor don't skip initializations""" + code = """ +a: uint256 = 100 +b: uint256 = 200 +c: uint256 = 300 + +@deploy +def __init__(early_exit: bool): + # Variable initializations should have already happened + assert self.a == 100 + assert self.b == 200 + assert self.c == 300 + + if early_exit: + self.a = 111 + return # early return + + # This code only runs if not early_exit + self.b = 222 + self.c = 333 + +@external +@view +def get_values() -> (uint256, uint256, uint256): + return self.a, self.b, self.c + """ + + # Test early exit + c1 = get_contract(code, True) + a, b, c = c1.get_values() + assert a == 111 # modified before return + assert b == 200 # kept initial (after return) + assert c == 300 # kept initial (after return) + + # Test normal flow + c2 = get_contract(code, False) + a, b, c = c2.get_values() + assert a == 100 # kept initial + assert b == 222 # modified + assert c == 333 # modified + + +def test_constructor_with_assert_on_initialized_values(get_contract): + """Test that constructor can make assertions about initialized values""" + code = """ +MIN_VALUE: constant(uint256) = 50 +MAX_VALUE: constant(uint256) = 150 + +x: uint256 = 100 +y: uint256 = 75 +z: uint256 = 125 + +@deploy +def __init__(adjustment: int128): + # Assert initial values are in expected range + assert self.x >= MIN_VALUE and self.x <= MAX_VALUE + assert self.y >= MIN_VALUE and self.y <= MAX_VALUE + assert self.z >= MIN_VALUE and self.z <= MAX_VALUE + + # Adjust values but keep in range + if adjustment > 0: + new_x: uint256 = self.x + convert(adjustment, uint256) + if new_x <= MAX_VALUE: + self.x = new_x + elif adjustment < 0: + sub_amount: uint256 = convert(-adjustment, uint256) + if self.x >= MIN_VALUE + sub_amount: + self.x = self.x - sub_amount + +@external +@view +def get_x() -> uint256: + return self.x + """ + + # Test positive adjustment + c1 = get_contract(code, 25) + assert c1.get_x() == 125 # 100 + 25 + + # Test negative adjustment + c2 = get_contract(code, -40) + assert c2.get_x() == 60 # 100 - 40 + + # Test adjustment that would exceed bounds + c3 = get_contract(code, 60) + assert c3.get_x() == 100 # unchanged because 100 + 60 > 150 + + +def test_msg_sender_initialization(env, get_contract, tx_failed): + """Test that msg.sender can be used in variable initialization""" + code = """ +owner: address = msg.sender +backup_owner: address = msg.sender + +@external +@view +def get_owner() -> address: + return self.owner + +@external +@view +def get_backup_owner() -> address: + return self.backup_owner + +@external +def set_owner(new_owner: address): + assert msg.sender == self.owner, "Only owner can change owner" + self.owner = new_owner + """ + + c = get_contract(code) + + # Check that owner and backup_owner were initialized to deployer + assert c.get_owner() == env.deployer + assert c.get_backup_owner() == env.deployer + + # Test that owner can be changed by the current owner + new_owner = env.accounts[1] + c.set_owner(new_owner) + assert c.get_owner() == new_owner + assert c.get_backup_owner() == env.deployer # unchanged + + # Test that non-owner cannot change owner + with tx_failed(): + env.set_balance(env.accounts[2], 10**18) + c.set_owner(env.accounts[2], sender=env.accounts[2]) + + +def test_msg_sender_with_constructor_override(env, get_contract): + """Test msg.sender initialization with constructor override""" + code = """ +owner: address = msg.sender +admin: address = msg.sender + +@deploy +def __init__(admin_address: address): + # Override admin but keep owner as msg.sender + self.admin = admin_address + +@external +@view +def get_owner() -> address: + return self.owner + +@external +@view +def get_admin() -> address: + return self.admin + """ + + admin_addr = env.accounts[1] + c = get_contract(code, admin_addr) + + # Owner should be the deployer (msg.sender during initialization) + assert c.get_owner() == env.deployer + # Admin should be overridden by constructor + assert c.get_admin() == admin_addr + + +def test_runtime_constants_initialization(env, get_contract): + """Test that runtime constants (block, tx, msg, chain) can be used in initializers""" + code = """ +# All of these are runtime constants and should be allowed +deployer: address = msg.sender +origin: address = tx.origin +deploy_block: uint256 = block.number +deploy_timestamp: uint256 = block.timestamp +chain_id: uint256 = chain.id + +@external +@view +def get_deployer() -> address: + return self.deployer + +@external +@view +def get_origin() -> address: + return self.origin + +@external +@view +def get_deploy_block() -> uint256: + return self.deploy_block + +@external +@view +def get_deploy_timestamp() -> uint256: + return self.deploy_timestamp + +@external +@view +def get_chain_id() -> uint256: + return self.chain_id + """ + + # Record environment values at deployment + c = get_contract(code) + + # Check all values were initialized correctly + assert c.get_deployer() == env.deployer + assert c.get_origin() == env.deployer # In tests, origin == sender + + # Block number should match current environment + assert c.get_deploy_block() == env.block_number + + # Timestamp should match current environment + assert c.get_deploy_timestamp() == env.timestamp + + # Chain ID should be the default (1) + assert c.get_chain_id() == env.DEFAULT_CHAIN_ID + + +def test_self_initialization(get_contract, env): + """Test that self can be used as an initializer""" + code = """ +owner: address = self +backup: address = self + +@external +@view +def get_owner() -> address: + return self.owner + +@external +@view +def get_backup() -> address: + return self.backup + """ + + c = get_contract(code) + + # both should be set to the contract's address + assert c.get_owner() == c.address + assert c.get_backup() == c.address + + +def test_self_initialization_with_override(get_contract, env): + """Test self initialization with constructor override""" + code = """ +owner: address = self + +@deploy +def __init__(): + # override with msg.sender + self.owner = msg.sender + +@external +@view +def get_owner() -> address: + return self.owner + """ + + c = get_contract(code) + + # should be overridden to deployer + assert c.get_owner() == env.deployer + + +def test_immutable_self_initialization(get_contract, env): + """Test that immutables can be initialized with self""" + code = """ +CONTRACT_ADDRESS: immutable(address) = self + +@external +@view +def get_contract_address() -> address: + return CONTRACT_ADDRESS + """ + + c = get_contract(code) + + # immutable should be set to the contract's address + assert c.get_contract_address() == c.address diff --git a/tests/functional/syntax/exceptions/test_instantiation_exception.py b/tests/functional/syntax/exceptions/test_instantiation_exception.py index f693846f81..dc1ea198d5 100644 --- a/tests/functional/syntax/exceptions/test_instantiation_exception.py +++ b/tests/functional/syntax/exceptions/test_instantiation_exception.py @@ -74,6 +74,11 @@ def foo(): def __init__(): b = empty(HashMap[uint256, uint256]) """, + """ +struct S: + x: int128 +s: S = S() + """, ] diff --git a/tests/functional/syntax/exceptions/test_variable_declaration_exception.py b/tests/functional/syntax/exceptions/test_variable_declaration_exception.py index 42c48dbe32..fb26a4f4ef 100644 --- a/tests/functional/syntax/exceptions/test_variable_declaration_exception.py +++ b/tests/functional/syntax/exceptions/test_variable_declaration_exception.py @@ -5,17 +5,6 @@ fail_list = [ """ -q: int128 = 12 -@external -def foo() -> int128: - return self.q - """, - """ -struct S: - x: int128 -s: S = S() - """, - """ foo.a: int128 """, """ @@ -30,3 +19,23 @@ def foo(): def test_variable_declaration_exception(bad_code): with pytest.raises(VariableDeclarationException): compiler.compile_code(bad_code) + + +pass_list = [ + """ +q: int128 = 12 +@external +def foo() -> int128: + return self.q + """, + """ +struct S: + x: int128 +s: S = S(x=5) + """, +] + + +@pytest.mark.parametrize("good_code", pass_list) +def test_variable_initialization_allowed(good_code): + compiler.compile_code(good_code) diff --git a/tests/functional/syntax/test_immutables.py b/tests/functional/syntax/test_immutables.py index 7e5903a6a1..15ee5789d3 100644 --- a/tests/functional/syntax/test_immutables.py +++ b/tests/functional/syntax/test_immutables.py @@ -21,14 +21,6 @@ def __init__(): def get_value() -> uint256: return VALUE """, - # VALUE given an initial value - """ -VALUE: immutable(uint256) = 3 - -@deploy -def __init__(): - pass - """, # setting value outside of constructor """ VALUE: immutable(uint256) @@ -107,7 +99,14 @@ def get_value() -> {typ}: def __init__(_value: uint256): VALUE = _value * 3 x: uint256 = VALUE + 1 + """, """ +VALUE: immutable(uint256) = 3 + +@deploy +def __init__(): + pass + """, ] diff --git a/tests/functional/syntax/test_variable_initialization_errors.py b/tests/functional/syntax/test_variable_initialization_errors.py new file mode 100644 index 0000000000..47e433b034 --- /dev/null +++ b/tests/functional/syntax/test_variable_initialization_errors.py @@ -0,0 +1,148 @@ +import pytest + +from vyper import compile_code +from vyper.exceptions import ( + CallViolation, + ImmutableViolation, + StateAccessViolation, + TypeMismatch, + UndeclaredDefinition, + VariableDeclarationException, +) + + +@pytest.mark.parametrize( + "bad_code,exc", + [ + ( + """ +# Cannot use function calls in initializer +@external +@view +def some_func() -> uint256: + return 42 + +x: uint256 = self.some_func() + """, + CallViolation, + ), + ( + """ +# Cannot use self attributes in initializer +y: uint256 = 10 +x: uint256 = self.y + """, + StateAccessViolation, + ), + ], +) +def test_invalid_initializers(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) + + +@pytest.mark.parametrize( + "bad_code,exc", + [ + ( + """ +# Type mismatch in initialization +x: uint256 = -1 # negative number for unsigned + """, + TypeMismatch, + ), + ( + """ +# Type mismatch with wrong literal type +x: address = 123 + """, + TypeMismatch, + ), + ( + """ +# Boolean type mismatch +x: bool = 1 + """, + TypeMismatch, + ), + ( + """ +# String literal not allowed for numeric type +x: uint256 = "hello" + """, + TypeMismatch, + ), + ], +) +def test_type_mismatch_in_initialization(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) + + +def test_constant_requires_value(): + """Constants must have an initializer""" + bad_code = """ +X: constant(uint256) # Missing initializer + """ + with pytest.raises(VariableDeclarationException): + compile_code(bad_code) + + +def test_immutable_requires_constructor_assignment_without_initializer(): + """Immutables without initializer must be set in constructor""" + bad_code = """ +X: immutable(uint256) # No initializer + +@deploy +def __init__(): + pass # Forgot to set X + """ + with pytest.raises(ImmutableViolation): + compile_code(bad_code) + + +def test_initializer_cannot_reference_other_storage_vars(): + """Initializers cannot reference other storage variables""" + bad_code = """ +a: uint256 = 100 +b: uint256 = self.a + 50 # Cannot reference self.a + """ + with pytest.raises(StateAccessViolation): + compile_code(bad_code) + + +def test_circular_reference_in_constants(): + """Constants cannot have circular references""" + bad_code = """ +A: constant(uint256) = B +B: constant(uint256) = A + """ + # This will raise VyperException with multiple UndeclaredDefinition errors + from vyper.exceptions import VyperException + + with pytest.raises((UndeclaredDefinition, VyperException)): + compile_code(bad_code) + + +def test_initializer_cannot_use_pure_function_calls(): + """Cannot call even pure functions in initializers""" + bad_code = """ +@internal +@pure +def helper() -> uint256: + return 42 + +x: uint256 = self.helper() + """ + with pytest.raises(StateAccessViolation): + compile_code(bad_code) + + +def test_initializer_cannot_reference_other_vars(): + """Cannot reference other storage variables regardless of order""" + bad_code = """ +y: uint256 = 100 +x: uint256 = self.y # Cannot reference self.y even though it's declared first + """ + with pytest.raises(StateAccessViolation): + compile_code(bad_code) diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index 112e0f1549..ca573b9dab 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -8,7 +8,6 @@ module: ( DOCSTRING | import | struct_def | interface_def - | constant_def | variable_def | enum_def // TODO deprecate at some point in favor of flag | flag_def @@ -34,17 +33,11 @@ import: _IMPORT DOT* _import_path [import_alias] | _import_from _IMPORT ( WILDCARD | _import_name [import_alias] ) | _import_from _IMPORT "(" import_list ")" -// Constant definitions +// Variable definitions (including constants) // NOTE: Temporary until decorators used -constant: "constant" "(" type ")" -constant_private: NAME ":" constant -constant_with_getter: NAME ":" "public" "(" constant ")" -constant_def: (constant_private | constant_with_getter) "=" expr - +variable_annotation: ("public" | "reentrant" | "immutable" | "transient" | "constant") "(" (variable_annotation | type) ")" +variable_def: NAME ":" (variable_annotation | type) ["=" expr] variable: NAME ":" type -// NOTE: Temporary until decorators used -variable_annotation: ("public" | "reentrant" | "immutable" | "transient") "(" (variable_annotation | type) ")" -variable_def: NAME ":" (variable_annotation | type) // A decorator "wraps" a method, modifying it's context. // NOTE: One or more can be applied (some combos might conflict) diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index a9630a5b37..a8b6ff6ef9 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -1454,10 +1454,7 @@ def validate(self): "Only public variables can be marked `reentrant`!", self ) - if not self.is_constant and self.value is not None: - raise VariableDeclarationException( - f"{self._pretty_location} variables cannot have an initial value", self.value - ) + # Allow initialization values for all variable types if not isinstance(self.target, Name): raise VariableDeclarationException("Invalid variable declaration", self.target) diff --git a/vyper/codegen/function_definitions/external_function.py b/vyper/codegen/function_definitions/external_function.py index 4c733ee851..185a4b1049 100644 --- a/vyper/codegen/function_definitions/external_function.py +++ b/vyper/codegen/function_definitions/external_function.py @@ -181,6 +181,13 @@ def generate_ir_for_external_function(code, compilation_target): body += nonreentrant_pre + # if this is a constructor, inject storage variable initializations + if func_t.is_constructor: + from vyper.codegen.stmt import generate_variable_initializations + + init_ir = generate_variable_initializations(compilation_target._module, context) + body.append(init_ir) + body += [parse_body(code.body, context, ensure_terminated=True)] # wrap the body in labeled block diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index 56a8da0f79..d0ef4f0860 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -4,16 +4,18 @@ import vyper.ast as vy_ast from vyper.codegen import core, jumptable_utils +from vyper.codegen.context import Constancy, Context from vyper.codegen.core import shr from vyper.codegen.function_definitions import ( generate_ir_for_external_function, generate_ir_for_internal_function, ) from vyper.codegen.ir_node import IRnode +from vyper.codegen.memory_allocator import MemoryAllocator +from vyper.codegen.stmt import generate_variable_initializations from vyper.compiler.settings import _is_debug_mode -from vyper.exceptions import CompilerPanic from vyper.semantics.types.module import ModuleT -from vyper.utils import OrderedSet, method_id_int +from vyper.utils import MemoryPositions, OrderedSet, method_id_int # calculate globally reachable functions to see which @@ -510,9 +512,22 @@ def generate_ir_for_module(module_t: ModuleT) -> tuple[IRnode, IRnode]: deploy_code.extend(ctor_internal_func_irs) else: - if immutables_len != 0: # pragma: nocover - raise CompilerPanic("unreachable") - deploy_code.append(["deploy", 0, runtime, 0]) + # Generate initialization code for variables even without explicit constructor + # Create a minimal constructor context + memory_allocator = MemoryAllocator(MemoryPositions.RESERVED_MEMORY) + context = Context( + vars_=None, + module_ctx=module_t, + memory_allocator=memory_allocator, + constancy=Constancy.Mutable, + is_ctor_context=True, + ) + + init_ir = generate_variable_initializations(module_t._module, context) + deploy_code.append(init_ir) + + init_mem_used = context.memory_allocator.next_mem + deploy_code.append(["deploy", init_mem_used, runtime, immutables_len]) # compile all remaining internal functions so that _ir_info is populated # (whether or not it makes it into the final IR artifact) diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index ffda836373..b2f6ce2ccb 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -9,6 +9,7 @@ IRnode, add_ofst, clamp_le, + data_location_to_address_space, get_dyn_array_count, get_element_ptr, get_type_for_exact_size, @@ -356,7 +357,9 @@ def _is_terminated(code): # codegen a list of statements -def parse_body(code, context, ensure_terminated=False): +def parse_body( + code: list[vy_ast.VyperNode], context: Context, ensure_terminated: bool = False +) -> IRnode: ir_node = ["seq"] for stmt in code: ir = parse_stmt(stmt, context) @@ -369,3 +372,36 @@ def parse_body(code, context, ensure_terminated=False): # force zerovalent, even last statement ir_node.append("pass") # CMC 2022-01-16 is this necessary? return IRnode.from_list(ir_node) + + +def generate_variable_initializations(module_ast: vy_ast.Module, context: Context) -> IRnode: + """ + Generate initialization IR for storage variables with default values. + Returns an IRnode sequence containing all initialization statements. + """ + assert context.is_ctor_context, "Variable initialization must happen in constructor context" + + init_stmts = [] + + for node in module_ast.body: + if isinstance(node, vy_ast.VariableDecl) and node.value is not None: + # skip constants - they are compile-time only + if node.is_constant: + continue + + # generate assignment: self.var = value + varinfo = node.target._metadata["varinfo"] + location = data_location_to_address_space(varinfo.location, context.is_ctor_context) + + lhs = IRnode.from_list( + varinfo.position.position, + typ=varinfo.typ, + location=location, + annotation=f"self.{node.target.id}", + ) + + rhs = Expr(node.value, context).ir_node + init_stmt = make_setter(lhs, rhs) + init_stmts.append(init_stmt) + + return IRnode.from_list(["seq"] + init_stmts) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 70d8cbdd67..19b62d9dcb 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -35,6 +35,7 @@ get_exact_type_from_node, get_expr_info, get_possible_types_from_node, + is_naked_self_reference, uses_state, validate_expected_type, ) @@ -52,7 +53,6 @@ HashMapT, IntegerT, SArrayT, - SelfT, StringT, StructT, TupleT, @@ -184,17 +184,13 @@ def _validate_pure_access(node: vy_ast.Attribute | vy_ast.Name, typ: VyperType) if isinstance(parent_info.typ, AddressT) and node.attr in AddressT._type_members: raise StateAccessViolation("not allowed to query address members in pure functions") + if is_naked_self_reference(node): + raise StateAccessViolation("not allowed to query `self` in pure functions") + if (varinfo := info.var_info) is None: return - # self is magic. we only need to check it if it is not the root of an Attribute - # node. (i.e. it is bare like `self`, not `self.foo`) - is_naked_self = isinstance(varinfo.typ, SelfT) and not isinstance( - node.get_ancestor(), vy_ast.Attribute - ) - if is_naked_self: - raise StateAccessViolation("not allowed to query `self` in pure functions") - if varinfo.is_state_variable() or is_naked_self: + if varinfo.is_state_variable(): raise StateAccessViolation("not allowed to query state variables in pure functions") diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 3fe847dd35..6565ce368a 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -610,7 +610,8 @@ def visit_VariableDecl(self, node): assignments = self.ast.get_descendants( vy_ast.Assign, filters={"target.id": node.target.id} ) - if not assignments: + # immutables with initialization values don't require assignment + if not assignments and node.value is None: # Special error message for common wrong usages via `self.` wrong_self_attribute = self.ast.get_descendants( vy_ast.Attribute, {"value.id": "self", "attr": node.target.id} @@ -688,7 +689,20 @@ def _validate_self_namespace(): return _finalize() - assert node.value is None # checked in VariableDecl.validate() + # allow initialization for storage variables + if node.value is not None: + # validate the initialization expression + ExprVisitor().visit(node.value, type_) # performs validate_expected_type + + # ensure the initialization expression is constant or runtime constant + # (allows literals, constants, msg.sender, self, etc.) + if not check_modifiability(node.value, Modifiability.RUNTIME_CONSTANT): + raise StateAccessViolation( + "Storage variable initializer must be a literal or runtime constant" + " (e.g. msg.sender, self)", + node.value, + ) + if node.is_immutable: _validate_self_namespace() return _finalize() diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 515a97e01f..3eacdf437a 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -24,7 +24,7 @@ from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType from vyper.semantics.types.bytestrings import BytesT, StringT -from vyper.semantics.types.primitives import AddressT, BoolT, BytesM_T, IntegerT +from vyper.semantics.types.primitives import AddressT, BoolT, BytesM_T, IntegerT, SelfT from vyper.semantics.types.subscriptable import DArrayT, SArrayT, TupleT from vyper.utils import OrderedSet, checksum_encode, int_to_fourbytes @@ -656,6 +656,29 @@ def validate_unique_method_ids(functions: List) -> None: seen.add(method_id) +def is_naked_self_reference(node: vy_ast.ExprNode) -> bool: + """ + Check if a node is a reference to naked `self` (not `self.attribute`). + + `self` has dual semantics: as an address (runtime constant) vs as storage + access (modifiable). Naked `self` refers to the address, while `self.x` + accesses storage. This distinction matters for modifiability checks, pure + function validation, and other semantic analysis. + """ + if not isinstance(node, vy_ast.Name): + return False + + info = get_expr_info(node) + if info.var_info is None: + return False + + # self is magic. we only need to check it if it is not the root of an Attribute + # node. (i.e. it is bare like `self`, not `self.foo`) + return isinstance(info.var_info.typ, SelfT) and not isinstance( + node.get_ancestor(), vy_ast.Attribute + ) + + def check_modifiability(node: vy_ast.ExprNode, modifiability: Modifiability) -> bool: """ Check if the given node is not more modifiable than the given modifiability. @@ -682,6 +705,10 @@ def check_modifiability(node: vy_ast.ExprNode, modifiability: Modifiability) -> if hasattr(call_type, "check_modifiability_for_call"): return call_type.check_modifiability_for_call(node, modifiability) + # special case: naked `self` is runtime constant (the address), not modifiable (the storage) + if is_naked_self_reference(node): + return modifiability >= Modifiability.RUNTIME_CONSTANT + info = get_expr_info(node) return info.modifiability <= modifiability