From 8890df9a41e6e5cbd54e35125f01bb2b7ba4d666 Mon Sep 17 00:00:00 2001 From: ZQlQZ Date: Mon, 5 Jan 2026 14:17:09 +0800 Subject: [PATCH] fix tool in ark rl --- .../rl/ark/plugins/raw_async_veadk_rollout.py | 35 ++----------------- 1 file changed, 2 insertions(+), 33 deletions(-) diff --git a/veadk/cli/templates/rl/ark/plugins/raw_async_veadk_rollout.py b/veadk/cli/templates/rl/ark/plugins/raw_async_veadk_rollout.py index a3545940..c6fdecd2 100644 --- a/veadk/cli/templates/rl/ark/plugins/raw_async_veadk_rollout.py +++ b/veadk/cli/templates/rl/ark/plugins/raw_async_veadk_rollout.py @@ -15,7 +15,6 @@ import json import asyncio from typing import Optional, Any, cast -from loguru import logger from ark_sdk.resources.pipeline_plugin import rollout from ark_sdk.types.pipeline_plugin import PluginInstance, Runtime from ark_sdk.types.pipeline_plugin.pipeline_plugin import PluginContext @@ -25,7 +24,6 @@ ChatCompletionResponse, RolloutInferenceProxy, RolloutResult, - PluginStatus, ) from veadk.agent import Agent from veadk.memory.short_term_memory import ShortTermMemory @@ -33,9 +31,9 @@ from veadk.tracing.telemetry.opentelemetry_tracer import OpentelemetryTracer from veadk.tracing.telemetry.exporters.cozeloop_exporter import CozeloopExporter from veadk.tracing.telemetry.exporters.cozeloop_exporter import CozeloopExporterConfig +from veadk.tools.demo_tools import get_city_weather from google.adk.models.lite_llm import LiteLLMClient, LiteLlm from litellm import ModelResponse -from cozeloop.decorator import observe # BASE_MODEL 格式 : "{model_provider}/{model_name}" BASE_MODEL = "openai/doubao-seed-1-6-flash-250615" @@ -50,13 +48,6 @@ tracer = OpentelemetryTracer(exporters=cast(Any, exporters)) -@observe() -def get_current_weather(location: str, unit="摄氏度"): - # 实际调用天气查询 API 的逻辑 - # 此处为示例,返回模拟的天气数据 - return f"{location}今天天气晴朗,温度 25 {unit}。" - - class RecordingLiteLlm(LiteLlm): """ 在调用 LiteLlm 的 completion/acompletion 时,拦截并记录原始 ModelResponse。 @@ -171,6 +162,7 @@ async def demo_veadk_rollout( model_provider="openai", model_api_key=proxy.jwt_token, tracers=[tracer], + tools=[get_city_weather], model=model_instance, ) @@ -224,29 +216,6 @@ async def demo_veadk_rollout( if model_response.choices[0].finish_reason != "tool_calls": # 模型最终总结,没有调用工具意愿 break - tool_calls = model_response.choices[0].message.tool_calls - for tool_call in tool_calls or []: - tool_name = tool_call.function.name - if tool_name == "get_current_weather": - try: - args = json.loads(tool_call.function.arguments) - tool_result = get_current_weather(**args) - except Exception as e: - logger.error(f"get_current_weather error: {e}") - return RolloutResult( - status=PluginStatus.SUCCESS, - extra={ - "reward": -1, - }, - ) - # 将工具结果加入消息列表 - messages.append( - { - "role": "tool", - "content": tool_result, - "tool_call_id": tool_call.id, - } - ) # 默认return None则视为rollout成功 return None