diff --git a/contracts/dex_aggregator/schema/execute_msg.json b/contracts/dex_aggregator/schema/execute_msg.json index 6d8401e..b4036c7 100644 --- a/contracts/dex_aggregator/schema/execute_msg.json +++ b/contracts/dex_aggregator/schema/execute_msg.json @@ -156,6 +156,50 @@ } }, "additionalProperties": false + }, + { + "description": "Registers a new tax token that requires special handling.", + "type": "object", + "required": [ + "register_tax_token" + ], + "properties": { + "register_tax_token": { + "type": "object", + "required": [ + "contract_addr" + ], + "properties": { + "contract_addr": { + "type": "string" + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false + }, + { + "description": "Removes a tax token from the registry.", + "type": "object", + "required": [ + "deregister_tax_token" + ], + "properties": { + "deregister_tax_token": { + "type": "object", + "required": [ + "contract_addr" + ], + "properties": { + "contract_addr": { + "type": "string" + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false } ], "definitions": { diff --git a/contracts/dex_aggregator/src/contract.rs b/contracts/dex_aggregator/src/contract.rs index af695fe..d646004 100644 --- a/contracts/dex_aggregator/src/contract.rs +++ b/contracts/dex_aggregator/src/contract.rs @@ -123,6 +123,12 @@ pub fn execute( ExecuteMsg::EmergencyWithdraw { asset_info } => { crate::execute::emergency_withdraw(deps, env, info, asset_info) } + ExecuteMsg::RegisterTaxToken { contract_addr } => { + crate::execute::register_tax_token(deps, info, contract_addr) + } + ExecuteMsg::DeregisterTaxToken { contract_addr } => { + crate::execute::deregister_tax_token(deps, info, contract_addr) + } } } diff --git a/contracts/dex_aggregator/src/execute.rs b/contracts/dex_aggregator/src/execute.rs index 1942c37..d24bd5e 100644 --- a/contracts/dex_aggregator/src/execute.rs +++ b/contracts/dex_aggregator/src/execute.rs @@ -10,7 +10,9 @@ use std::str::FromStr; use crate::error::ContractError; use crate::msg::{self, amm, orderbook, Operation, Stage}; use crate::reply::proceed_to_next_step; -use crate::state::{Awaiting, ExecutionState, RoutePlan, CONFIG, FEE_MAP, REPLY_ID_COUNTER}; +use crate::state::{ + Awaiting, ExecutionState, RoutePlan, CONFIG, FEE_MAP, REPLY_ID_COUNTER, TAX_TOKEN_REGISTRY, +}; pub fn update_admin( deps: DepsMut, @@ -109,17 +111,33 @@ pub fn create_swap_cosmos_msg( }], }), amm::AssetInfo::Token { contract_addr } => { - let cw20_send_msg = Cw20ExecuteMsg::Send { - contract: amm_op.pool_address.clone(), - amount, - msg: to_json_binary(&amm_swap_msg)?, - }; - - CosmosMsg::Wasm(WasmMsg::Execute { - contract_addr: contract_addr.clone(), - msg: to_json_binary(&cw20_send_msg)?, - funds: vec![], - }) + let token_addr = deps.api.addr_validate(contract_addr)?; + if TAX_TOKEN_REGISTRY.has(deps.storage, &token_addr) { + // It's a tax token. We must use its tax-exempt send function. + CosmosMsg::Wasm(WasmMsg::Execute { + contract_addr: contract_addr.clone(), + msg: to_json_binary( + &crate::msg::reflection::ExecuteMsg::TaxExemptSend { + contract: amm_op.pool_address.clone(), + amount, + msg: to_json_binary(&amm_swap_msg)?, + }, + )?, + funds: vec![], + }) + } else { + // It's a standard token. Use the normal Cw20::Send. + let cw20_send_msg = Cw20ExecuteMsg::Send { + contract: amm_op.pool_address.clone(), + amount, + msg: to_json_binary(&amm_swap_msg)?, + }; + CosmosMsg::Wasm(WasmMsg::Execute { + contract_addr: contract_addr.clone(), + msg: to_json_binary(&cw20_send_msg)?, + funds: vec![], + }) + } } } } @@ -328,3 +346,35 @@ pub fn emergency_withdraw( Ok(response) } + +pub fn register_tax_token( + deps: DepsMut, + info: MessageInfo, + contract_addr: String, +) -> Result, ContractError> { + let config = CONFIG.load(deps.storage)?; + if info.sender != config.admin { + return Err(ContractError::Unauthorized {}); + } + let addr = deps.api.addr_validate(&contract_addr)?; + TAX_TOKEN_REGISTRY.save(deps.storage, &addr, &true)?; + Ok(Response::new() + .add_attribute("action", "register_tax_token") + .add_attribute("token_addr", addr)) +} + +pub fn deregister_tax_token( + deps: DepsMut, + info: MessageInfo, + contract_addr: String, +) -> Result, ContractError> { + let config = CONFIG.load(deps.storage)?; + if info.sender != config.admin { + return Err(ContractError::Unauthorized {}); + } + let addr = deps.api.addr_validate(&contract_addr)?; + TAX_TOKEN_REGISTRY.remove(deps.storage, &addr); + Ok(Response::new() + .add_attribute("action", "deregister_tax_token") + .add_attribute("token_addr", addr)) +} diff --git a/contracts/dex_aggregator/src/msg.rs b/contracts/dex_aggregator/src/msg.rs index 133040b..fdd6976 100644 --- a/contracts/dex_aggregator/src/msg.rs +++ b/contracts/dex_aggregator/src/msg.rs @@ -118,6 +118,24 @@ pub mod orderbook { } } +pub mod reflection { + use super::*; + use cosmwasm_std::Binary; + + #[cw_serde] + pub enum ExecuteMsg { + TaxExemptTransfer { + recipient: String, + amount: Uint128, + }, + TaxExemptSend { + contract: String, + amount: Uint128, + msg: Binary, + }, + } +} + #[cw_serde] pub struct AmmSwapOp { pub pool_address: String, @@ -202,6 +220,14 @@ pub enum ExecuteMsg { EmergencyWithdraw { asset_info: amm::AssetInfo, }, + /// Registers a new tax token that requires special handling. + RegisterTaxToken { + contract_addr: String, + }, + /// Removes a tax token from the registry. + DeregisterTaxToken { + contract_addr: String, + }, } #[cw_serde] diff --git a/contracts/dex_aggregator/src/reply.rs b/contracts/dex_aggregator/src/reply.rs index ba1d161..4f0c5dd 100644 --- a/contracts/dex_aggregator/src/reply.rs +++ b/contracts/dex_aggregator/src/reply.rs @@ -10,7 +10,7 @@ use crate::execute::create_swap_cosmos_msg; use crate::msg::{amm, cw20_adapter, Operation, PlannedSwap, Stage, StagePlan}; use crate::state::{ Awaiting, Config, ExecutionState, PendingPathOp, SubmsgReplyState, ACTIVE_ROUTES, CONFIG, - FEE_MAP, REPLY_ID_COUNTER, SUBMSG_REPLY_STATES, + FEE_MAP, REPLY_ID_COUNTER, SUBMSG_REPLY_STATES, TAX_TOKEN_REGISTRY, }; const DECIMAL_FRACTIONAL: u128 = 1_000_000_000_000_000_000; @@ -142,7 +142,7 @@ fn handle_swap_reply( } } - let received_amount = parse_amount_from_swap_reply(events)?; + let received_amount = parse_amount_from_swap_reply(events, &env)?; let received_asset_info = get_operation_output(replied_op)?; let replied_path = ¤t_stage.splits[split_index].path; @@ -230,7 +230,8 @@ fn handle_swap_reply( if !fee.is_zero() { let config = CONFIG.load(deps.storage)?; - let fee_send_msg = create_send_msg(&config.fee_collector, &received_asset_info, fee)?; + let fee_send_msg = + create_send_msg(&deps, &config.fee_collector, &received_asset_info, fee)?; response = response .add_message(fee_send_msg) .add_attribute("fee_collected", fee.to_string()) @@ -256,6 +257,7 @@ fn apply_fee( // A helper to create the final transfer message. fn create_send_msg( + deps: &DepsMut, recipient: &Addr, asset_info: &amm::AssetInfo, amount: Uint128, @@ -268,14 +270,31 @@ fn create_send_msg( amount, }], })), - amm::AssetInfo::Token { contract_addr } => Ok(CosmosMsg::Wasm(WasmMsg::Execute { - contract_addr: contract_addr.clone(), - msg: to_json_binary(&Cw20ExecuteMsg::Transfer { - recipient: recipient.to_string(), - amount, - })?, - funds: vec![], - })), + amm::AssetInfo::Token { contract_addr } => { + let token_addr = deps.api.addr_validate(contract_addr)?; + // Check if we are dealing with a registered tax token. + if TAX_TOKEN_REGISTRY.has(deps.storage, &token_addr) { + // Use the new tax-exempt message. + Ok(CosmosMsg::Wasm(WasmMsg::Execute { + contract_addr: contract_addr.clone(), + msg: to_json_binary(&crate::msg::reflection::ExecuteMsg::TaxExemptTransfer { + recipient: recipient.to_string(), + amount, + })?, + funds: vec![], + })) + } else { + // Use a standard CW20 Transfer for all other tokens. + Ok(CosmosMsg::Wasm(WasmMsg::Execute { + contract_addr: contract_addr.clone(), + msg: to_json_binary(&Cw20ExecuteMsg::Transfer { + recipient: recipient.to_string(), + amount, + })?, + funds: vec![], + })) + } + } } } @@ -327,6 +346,7 @@ fn handle_final_stage( if !total_final_amount.is_zero() { // Use the sender address from the immutable plan let send_msg = create_send_msg( + deps, &exec_state.plan.sender, &target_asset_info, total_final_amount, @@ -400,6 +420,7 @@ fn handle_final_conversion_reply( let mut response = Response::new(); if !total_final_amount.is_zero() { let send_msg = create_send_msg( + &deps, &exec_state.plan.sender, &final_asset_info, total_final_amount, @@ -489,7 +510,36 @@ fn get_operation_output(op: &Operation) -> Result }) } -fn parse_amount_from_swap_reply(events: &[cosmwasm_std::Event]) -> Result { +fn parse_amount_from_swap_reply( + events: &[cosmwasm_std::Event], + env: &Env, +) -> Result { + // Check for `post_tax_amount` from a tax token's transfer event. + for event in events.iter().rev() { + if event.ty != "wasm" { + continue; + } + let is_recipient_self = event + .attributes + .iter() + .any(|attr| attr.key == "to" && attr.value == env.contract.address.to_string()); + + if is_recipient_self { + if let Some(amount_attr) = event + .attributes + .iter() + .find(|attr| attr.key == "post_tax_amount") + { + return amount_attr.value.parse::().map_err(|_| { + ContractError::MalformedAmountInReply { + value: amount_attr.value.clone(), + } + }); + } + } + } + + // 2. Fallback to original logic for standard, non-taxable tokens. let amount_str_opt = events.iter().find_map(|event| { if !event.ty.starts_with("wasm") { return None; @@ -513,12 +563,11 @@ fn parse_amount_from_swap_reply(events: &[cosmwasm_std::Event]) -> Result() .map_err(|_| ContractError::MalformedAmountInReply { value: amount_str }) } - None => Ok(Uint128::zero()), + None => Ok(Uint128::zero()), // Return zero if no relevant amount is found. } } diff --git a/contracts/dex_aggregator/src/state.rs b/contracts/dex_aggregator/src/state.rs index 16abd8e..a7d556a 100644 --- a/contracts/dex_aggregator/src/state.rs +++ b/contracts/dex_aggregator/src/state.rs @@ -57,3 +57,8 @@ pub struct SubmsgReplyState { pub const ACTIVE_ROUTES: Map = Map::new("execution_states"); pub const SUBMSG_REPLY_STATES: Map = Map::new("submsg_reply_states"); pub const REPLY_ID_COUNTER: Item = Item::new("reply_id_counter"); + +/// A registry of known tax tokens that require special handling. +/// The key is the token's contract address. +/// The value is a simple boolean `true` to indicate it's registered. +pub const TAX_TOKEN_REGISTRY: Map<&Addr, bool> = Map::new("tax_tokens"); diff --git a/tests/integration.rs b/tests/integration.rs index cb29c04..558585f 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -1,5 +1,6 @@ #![cfg(test)] +use std::slice; use std::str::FromStr; use cosmwasm_std::{to_json_binary, Addr, Coin, Decimal, Uint128}; @@ -71,23 +72,23 @@ fn setup() -> TestEnv { // Store codes let aggregator_code_id = wasm - .store_code(&get_wasm_byte_code("dex_aggregator.wasm"), None, &admin) + .store_code(get_wasm_byte_code("dex_aggregator.wasm"), None, &admin) .unwrap() .data .code_id; let mock_swap_code_id = wasm - .store_code(&get_wasm_byte_code("mock_swap.wasm"), None, &admin) + .store_code(get_wasm_byte_code("mock_swap.wasm"), None, &admin) .unwrap() .data .code_id; let _cw20_code_id = wasm - .store_code(&get_wasm_byte_code("cw20_base.wasm"), None, &admin) + .store_code(get_wasm_byte_code("cw20_base.wasm"), None, &admin) .unwrap() .data .code_id; let cw20_adapter_code_id = wasm - .store_code(&get_wasm_byte_code("cw20_adapter.wasm"), None, &admin) + .store_code(get_wasm_byte_code("cw20_adapter.wasm"), None, &admin) .unwrap() .data .code_id; @@ -260,8 +261,8 @@ fn setup() -> TestEnv { TestEnv { app, - admin: admin, - user: user, + admin, + user, fee_collector: fee_collector_account, aggregator_addr, mock_amm_1_addr, @@ -452,7 +453,7 @@ fn test_multi_stage_aggregate_swap_success() { let res = wasm.execute( &env.aggregator_addr, &msg, - &[initial_funds.clone()], + slice::from_ref(&initial_funds), &env.user, ); @@ -647,7 +648,7 @@ fn setup_for_conversion_test() -> ConversionTestSetup { &cw20_adapter::ExecuteMsg::RegisterCw20Contract { addr: Addr::unchecked(shroom_cw20_addr.clone()), }, - &[total_fee.clone()], + slice::from_ref(&total_fee), &admin, ) .unwrap(); @@ -656,7 +657,7 @@ fn setup_for_conversion_test() -> ConversionTestSetup { &cw20_adapter::ExecuteMsg::RegisterCw20Contract { addr: Addr::unchecked(sai_cw20_addr.clone()), }, - &[total_fee.clone()], + slice::from_ref(&total_fee), &admin, ) .unwrap(); @@ -1393,7 +1394,12 @@ fn test_failure_if_minimum_receive_not_met() { }; let funds_to_send = Coin::new(100_000_000_000_000_000_000u128, "inj"); - let res = wasm.execute(&env.aggregator_addr, &msg, &[funds_to_send.clone()], user); + let res = wasm.execute( + &env.aggregator_addr, + &msg, + slice::from_ref(&funds_to_send), + user, + ); assert!( res.is_err(), @@ -1977,7 +1983,7 @@ fn test_native_input_with_initial_cw20_requirement() { amount: amount_to_test, }, &[], - &admin, + admin, ) .unwrap(); wasm.execute( @@ -1989,7 +1995,7 @@ fn test_native_input_with_initial_cw20_requirement() { msg: to_json_binary(&"{}").unwrap(), }, &[], - &admin, + admin, ) .unwrap(); bank.send( @@ -2002,7 +2008,7 @@ fn test_native_input_with_initial_cw20_requirement() { amount: amount_to_test.to_string(), }], }, - &admin, + admin, ) .unwrap(); @@ -2185,7 +2191,12 @@ fn test_stage_with_single_hundred_percent_split() { let funds_to_send = Coin::new(100_000_000_000_000_000_000u128, "inj"); // 100 INJ // Execute the transaction - let res = wasm.execute(&env.aggregator_addr, &msg, &[funds_to_send.clone()], user); + let res = wasm.execute( + &env.aggregator_addr, + &msg, + slice::from_ref(&funds_to_send), + user, + ); assert!( res.is_ok(), "Execution with single-split stage failed: {:?}", @@ -2294,7 +2305,12 @@ fn test_intermediate_swap_failure_reverts_transaction() { }; // Execute the transaction - let res = wasm.execute(&env.aggregator_addr, &msg, &[initial_funds.clone()], user); + let res = wasm.execute( + &env.aggregator_addr, + &msg, + slice::from_ref(&initial_funds), + user, + ); // --- ASSERT FAILURE AND ROLLBACK --- @@ -2849,7 +2865,7 @@ fn test_multi_split_with_mixed_fees() { .find(|a| a.key == "final_received") .unwrap(); - let expected_net_output = Uint128::new(1596_000_000u128); // 396 + 1200 = 1596 USDT + let expected_net_output = Uint128::new(1_596_000_000_u128); // 396 + 1200 = 1596 USDT assert_eq!(final_received_attr.value, expected_net_output.to_string()); // Assertion B: Check the fee collector's final balance @@ -3127,7 +3143,7 @@ fn test_multi_hop_path_with_mid_path_conversion() { let res = wasm.execute( &setup.env.aggregator_addr, &msg, - &[funds_to_send.clone()], + slice::from_ref(&funds_to_send), user, ); @@ -3363,7 +3379,7 @@ fn test_multi_split_to_same_orderbook_contract() { }, ], }], - minimum_receive: Some(Uint128::new(2990_000_000)), // Min 2990 USDT + minimum_receive: Some(Uint128::new(2_990_000_000)), // Min 2990 USDT }; // Get user's initial USDT balance for final assertion. @@ -3480,7 +3496,7 @@ fn test_multi_hop_consecutive_orderbook_swaps() { let res = wasm.execute( &env.aggregator_addr, &msg, - &[funds_to_send.clone()], + slice::from_ref(&funds_to_send), &env.user, ); assert!(res.is_ok(), "Execution failed: {:?}", res.unwrap_err());