From 23bd12343955a663d864e808178ce380bb394197 Mon Sep 17 00:00:00 2001 From: "caijialin.626" Date: Sat, 11 Oct 2025 17:22:00 +0800 Subject: [PATCH 01/12] [feat][prompt] add prompt list openapi --- .../coze/loop/apis/prompt_open_apiservice.go | 6 + .../router/coze/loop/apis/coze.loop.apis.go | 1 + .../api/router/coze/loop/apis/middleware.go | 5 + .../loop/apis/promptopenapiservice/client.go | 6 + .../promptopenapiservice.go | 36 + .../openapi/coze.loop.prompt.openapi.go | 3160 +++++++++++++++-- .../coze.loop.prompt.openapi_validator.go | 40 + .../openapi/k-coze.loop.prompt.openapi.go | 1782 ++++++++++ .../openapi/promptopenapiservice/client.go | 6 + .../promptopenapiservice.go | 36 + .../prompt/promptopenapiservice/client.go | 6 + .../promptopenapiservice.go | 36 + .../loopenapi/local_promptopenapiservice.go | 21 + .../prompt/application/convertor/openapi.go | 25 + .../application/convertor/openapi_test.go | 179 + backend/modules/prompt/application/openapi.go | 59 + .../prompt/application/openapi_test.go | 507 +++ .../prompt/coze.loop.prompt.openapi.thrift | 40 +- 18 files changed, 5671 insertions(+), 280 deletions(-) diff --git a/backend/api/handler/coze/loop/apis/prompt_open_apiservice.go b/backend/api/handler/coze/loop/apis/prompt_open_apiservice.go index 0cbcda2c7..4f47ac866 100644 --- a/backend/api/handler/coze/loop/apis/prompt_open_apiservice.go +++ b/backend/api/handler/coze/loop/apis/prompt_open_apiservice.go @@ -69,3 +69,9 @@ func ExecuteStreaming(ctx context.Context, c *app.RequestContext) { } } } + +// ListPromptBasic . +// @router /v1/loop/prompts/list [POST] +func ListPromptBasic(ctx context.Context, c *app.RequestContext) { + invokeAndRender(ctx, c, promptOpenAPISvc.ListPromptBasic) +} diff --git a/backend/api/router/coze/loop/apis/coze.loop.apis.go b/backend/api/router/coze/loop/apis/coze.loop.apis.go index 2f1416697..51b706200 100644 --- a/backend/api/router/coze/loop/apis/coze.loop.apis.go +++ b/backend/api/router/coze/loop/apis/coze.loop.apis.go @@ -419,6 +419,7 @@ func Register(r *server.Hertz, handler *apis.APIHandler) { _prompts0 := _loop.Group("/prompts", _prompts0Mw(handler)...) _prompts0.POST("/execute", append(_executeMw(handler), apis.Execute)...) _prompts0.POST("/execute_streaming", append(_executestreamingMw(handler), apis.ExecuteStreaming)...) + _prompts0.POST("/list", append(_listpromptbasicMw(handler), apis.ListPromptBasic)...) _prompts0.POST("/mget", append(_batchgetpromptbypromptkeyMw(handler), apis.BatchGetPromptByPromptKey)...) } { diff --git a/backend/api/router/coze/loop/apis/middleware.go b/backend/api/router/coze/loop/apis/middleware.go index a65268804..79c6ffd0c 100644 --- a/backend/api/router/coze/loop/apis/middleware.go +++ b/backend/api/router/coze/loop/apis/middleware.go @@ -1615,3 +1615,8 @@ func _getdrilldownvaluesMw(handler *apis.APIHandler) []app.HandlerFunc { // your code... return nil } + +func _listpromptbasicMw(handler *apis.APIHandler) []app.HandlerFunc { + // your code... + return nil +} diff --git a/backend/kitex_gen/coze/loop/apis/promptopenapiservice/client.go b/backend/kitex_gen/coze/loop/apis/promptopenapiservice/client.go index 32724c2e7..408451a6f 100644 --- a/backend/kitex_gen/coze/loop/apis/promptopenapiservice/client.go +++ b/backend/kitex_gen/coze/loop/apis/promptopenapiservice/client.go @@ -17,6 +17,7 @@ type Client interface { BatchGetPromptByPromptKey(ctx context.Context, req *openapi.BatchGetPromptByPromptKeyRequest, callOptions ...callopt.Option) (r *openapi.BatchGetPromptByPromptKeyResponse, err error) Execute(ctx context.Context, req *openapi.ExecuteRequest, callOptions ...callopt.Option) (r *openapi.ExecuteResponse, err error) ExecuteStreaming(ctx context.Context, req *openapi.ExecuteRequest, callOptions ...streamcall.Option) (stream PromptOpenAPIService_ExecuteStreamingClient, err error) + ListPromptBasic(ctx context.Context, req *openapi.ListPromptBasicRequest, callOptions ...callopt.Option) (r *openapi.ListPromptBasicResponse, err error) } type PromptOpenAPIService_ExecuteStreamingClient streaming.ServerStreamingClient[openapi.ExecuteStreamingResponse] @@ -66,3 +67,8 @@ func (p *kPromptOpenAPIServiceClient) ExecuteStreaming(ctx context.Context, req ctx = client.NewCtxWithCallOptions(ctx, streamcall.GetCallOptions(callOptions)) return p.kClient.ExecuteStreaming(ctx, req) } + +func (p *kPromptOpenAPIServiceClient) ListPromptBasic(ctx context.Context, req *openapi.ListPromptBasicRequest, callOptions ...callopt.Option) (r *openapi.ListPromptBasicResponse, err error) { + ctx = client.NewCtxWithCallOptions(ctx, callOptions) + return p.kClient.ListPromptBasic(ctx, req) +} diff --git a/backend/kitex_gen/coze/loop/apis/promptopenapiservice/promptopenapiservice.go b/backend/kitex_gen/coze/loop/apis/promptopenapiservice/promptopenapiservice.go index 89ad53ee5..36ff79a1a 100644 --- a/backend/kitex_gen/coze/loop/apis/promptopenapiservice/promptopenapiservice.go +++ b/backend/kitex_gen/coze/loop/apis/promptopenapiservice/promptopenapiservice.go @@ -36,6 +36,13 @@ var serviceMethods = map[string]kitex.MethodInfo{ false, kitex.WithStreamingMode(kitex.StreamingServer), ), + "ListPromptBasic": kitex.NewMethodInfo( + listPromptBasicHandler, + newPromptOpenAPIServiceListPromptBasicArgs, + newPromptOpenAPIServiceListPromptBasicResult, + false, + kitex.WithStreamingMode(kitex.StreamingNone), + ), } var ( @@ -128,6 +135,25 @@ func newPromptOpenAPIServiceExecuteStreamingResult() interface{} { return openapi.NewPromptOpenAPIServiceExecuteStreamingResult() } +func listPromptBasicHandler(ctx context.Context, handler interface{}, arg, result interface{}) error { + realArg := arg.(*openapi.PromptOpenAPIServiceListPromptBasicArgs) + realResult := result.(*openapi.PromptOpenAPIServiceListPromptBasicResult) + success, err := handler.(openapi.PromptOpenAPIService).ListPromptBasic(ctx, realArg.Req) + if err != nil { + return err + } + realResult.Success = success + return nil +} + +func newPromptOpenAPIServiceListPromptBasicArgs() interface{} { + return openapi.NewPromptOpenAPIServiceListPromptBasicArgs() +} + +func newPromptOpenAPIServiceListPromptBasicResult() interface{} { + return openapi.NewPromptOpenAPIServiceListPromptBasicResult() +} + type kClient struct { c client.Client sc client.Streaming @@ -174,3 +200,13 @@ func (p *kClient) ExecuteStreaming(ctx context.Context, req *openapi.ExecuteRequ } return stream, nil } + +func (p *kClient) ListPromptBasic(ctx context.Context, req *openapi.ListPromptBasicRequest) (r *openapi.ListPromptBasicResponse, err error) { + var _args openapi.PromptOpenAPIServiceListPromptBasicArgs + _args.Req = req + var _result openapi.PromptOpenAPIServiceListPromptBasicResult + if err = p.c.Call(ctx, "ListPromptBasic", &_args, &_result); err != nil { + return + } + return _result.GetSuccess(), nil +} diff --git a/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go b/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go index 3abac5b11..b4a924e17 100644 --- a/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go +++ b/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go @@ -8969,246 +8969,2848 @@ func (p *TokenUsage) Field2DeepEqual(src *int32) bool { return true } -type PromptOpenAPIService interface { - BatchGetPromptByPromptKey(ctx context.Context, req *BatchGetPromptByPromptKeyRequest) (r *BatchGetPromptByPromptKeyResponse, err error) - - Execute(ctx context.Context, req *ExecuteRequest) (r *ExecuteResponse, err error) +type ListPromptBasicRequest struct { + WorkspaceID *int64 `thrift:"workspace_id,1,optional" frugal:"1,optional,i64" json:"workspace_id" form:"workspace_id" ` + PageNumber *int32 `thrift:"page_number,2,optional" frugal:"2,optional,i32" form:"page_number" json:"page_number,omitempty"` + PageSize *int32 `thrift:"page_size,3,optional" frugal:"3,optional,i32" form:"page_size" json:"page_size,omitempty"` + // name/key前缀匹配 + KeyWord *string `thrift:"key_word,4,optional" frugal:"4,optional,string" form:"key_word" json:"key_word,omitempty"` + // 创建人 + Creator *string `thrift:"creator,5,optional" frugal:"5,optional,string" form:"creator" json:"creator,omitempty"` + Base *base.Base `thrift:"Base,255,optional" frugal:"255,optional,base.Base" form:"Base" json:"Base,omitempty" query:"Base"` +} - ExecuteStreaming(ctx context.Context, req *ExecuteRequest, stream PromptOpenAPIService_ExecuteStreamingServer) (err error) +func NewListPromptBasicRequest() *ListPromptBasicRequest { + return &ListPromptBasicRequest{} } -type PromptOpenAPIServiceClient struct { - c thrift.TClient +func (p *ListPromptBasicRequest) InitDefault() { } -func NewPromptOpenAPIServiceClientFactory(t thrift.TTransport, f thrift.TProtocolFactory) *PromptOpenAPIServiceClient { - return &PromptOpenAPIServiceClient{ - c: thrift.NewTStandardClient(f.GetProtocol(t), f.GetProtocol(t)), +var ListPromptBasicRequest_WorkspaceID_DEFAULT int64 + +func (p *ListPromptBasicRequest) GetWorkspaceID() (v int64) { + if p == nil { + return + } + if !p.IsSetWorkspaceID() { + return ListPromptBasicRequest_WorkspaceID_DEFAULT } + return *p.WorkspaceID } -func NewPromptOpenAPIServiceClientProtocol(t thrift.TTransport, iprot thrift.TProtocol, oprot thrift.TProtocol) *PromptOpenAPIServiceClient { - return &PromptOpenAPIServiceClient{ - c: thrift.NewTStandardClient(iprot, oprot), +var ListPromptBasicRequest_PageNumber_DEFAULT int32 + +func (p *ListPromptBasicRequest) GetPageNumber() (v int32) { + if p == nil { + return } + if !p.IsSetPageNumber() { + return ListPromptBasicRequest_PageNumber_DEFAULT + } + return *p.PageNumber } -func NewPromptOpenAPIServiceClient(c thrift.TClient) *PromptOpenAPIServiceClient { - return &PromptOpenAPIServiceClient{ - c: c, +var ListPromptBasicRequest_PageSize_DEFAULT int32 + +func (p *ListPromptBasicRequest) GetPageSize() (v int32) { + if p == nil { + return + } + if !p.IsSetPageSize() { + return ListPromptBasicRequest_PageSize_DEFAULT } + return *p.PageSize } -func (p *PromptOpenAPIServiceClient) Client_() thrift.TClient { - return p.c +var ListPromptBasicRequest_KeyWord_DEFAULT string + +func (p *ListPromptBasicRequest) GetKeyWord() (v string) { + if p == nil { + return + } + if !p.IsSetKeyWord() { + return ListPromptBasicRequest_KeyWord_DEFAULT + } + return *p.KeyWord } -func (p *PromptOpenAPIServiceClient) BatchGetPromptByPromptKey(ctx context.Context, req *BatchGetPromptByPromptKeyRequest) (r *BatchGetPromptByPromptKeyResponse, err error) { - var _args PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs - _args.Req = req - var _result PromptOpenAPIServiceBatchGetPromptByPromptKeyResult - if err = p.Client_().Call(ctx, "BatchGetPromptByPromptKey", &_args, &_result); err != nil { +var ListPromptBasicRequest_Creator_DEFAULT string + +func (p *ListPromptBasicRequest) GetCreator() (v string) { + if p == nil { return } - return _result.GetSuccess(), nil + if !p.IsSetCreator() { + return ListPromptBasicRequest_Creator_DEFAULT + } + return *p.Creator } -func (p *PromptOpenAPIServiceClient) Execute(ctx context.Context, req *ExecuteRequest) (r *ExecuteResponse, err error) { - var _args PromptOpenAPIServiceExecuteArgs - _args.Req = req - var _result PromptOpenAPIServiceExecuteResult - if err = p.Client_().Call(ctx, "Execute", &_args, &_result); err != nil { + +var ListPromptBasicRequest_Base_DEFAULT *base.Base + +func (p *ListPromptBasicRequest) GetBase() (v *base.Base) { + if p == nil { return } - return _result.GetSuccess(), nil + if !p.IsSetBase() { + return ListPromptBasicRequest_Base_DEFAULT + } + return p.Base } -func (p *PromptOpenAPIServiceClient) ExecuteStreaming(ctx context.Context, req *ExecuteRequest, stream PromptOpenAPIService_ExecuteStreamingServer) (err error) { - panic("streaming method PromptOpenAPIService.ExecuteStreaming(mode = server) not available, please use Kitex Thrift Streaming Client.") +func (p *ListPromptBasicRequest) SetWorkspaceID(val *int64) { + p.WorkspaceID = val +} +func (p *ListPromptBasicRequest) SetPageNumber(val *int32) { + p.PageNumber = val +} +func (p *ListPromptBasicRequest) SetPageSize(val *int32) { + p.PageSize = val +} +func (p *ListPromptBasicRequest) SetKeyWord(val *string) { + p.KeyWord = val +} +func (p *ListPromptBasicRequest) SetCreator(val *string) { + p.Creator = val +} +func (p *ListPromptBasicRequest) SetBase(val *base.Base) { + p.Base = val } -type PromptOpenAPIService_ExecuteStreamingServer streaming.ServerStreamingServer[ExecuteStreamingResponse] - -type PromptOpenAPIServiceProcessor struct { - processorMap map[string]thrift.TProcessorFunction - handler PromptOpenAPIService +var fieldIDToName_ListPromptBasicRequest = map[int16]string{ + 1: "workspace_id", + 2: "page_number", + 3: "page_size", + 4: "key_word", + 5: "creator", + 255: "Base", } -func (p *PromptOpenAPIServiceProcessor) AddToProcessorMap(key string, processor thrift.TProcessorFunction) { - p.processorMap[key] = processor +func (p *ListPromptBasicRequest) IsSetWorkspaceID() bool { + return p.WorkspaceID != nil } -func (p *PromptOpenAPIServiceProcessor) GetProcessorFunction(key string) (processor thrift.TProcessorFunction, ok bool) { - processor, ok = p.processorMap[key] - return processor, ok +func (p *ListPromptBasicRequest) IsSetPageNumber() bool { + return p.PageNumber != nil } -func (p *PromptOpenAPIServiceProcessor) ProcessorMap() map[string]thrift.TProcessorFunction { - return p.processorMap +func (p *ListPromptBasicRequest) IsSetPageSize() bool { + return p.PageSize != nil } -func NewPromptOpenAPIServiceProcessor(handler PromptOpenAPIService) *PromptOpenAPIServiceProcessor { - self := &PromptOpenAPIServiceProcessor{handler: handler, processorMap: make(map[string]thrift.TProcessorFunction)} - self.AddToProcessorMap("BatchGetPromptByPromptKey", &promptOpenAPIServiceProcessorBatchGetPromptByPromptKey{handler: handler}) - self.AddToProcessorMap("Execute", &promptOpenAPIServiceProcessorExecute{handler: handler}) - self.AddToProcessorMap("ExecuteStreaming", &promptOpenAPIServiceProcessorExecuteStreaming{handler: handler}) - return self +func (p *ListPromptBasicRequest) IsSetKeyWord() bool { + return p.KeyWord != nil } -func (p *PromptOpenAPIServiceProcessor) Process(ctx context.Context, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { - name, _, seqId, err := iprot.ReadMessageBegin() - if err != nil { - return false, err - } - if processor, ok := p.GetProcessorFunction(name); ok { - return processor.Process(ctx, seqId, iprot, oprot) - } - iprot.Skip(thrift.STRUCT) - iprot.ReadMessageEnd() - x := thrift.NewTApplicationException(thrift.UNKNOWN_METHOD, "Unknown function "+name) - oprot.WriteMessageBegin(name, thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return false, x + +func (p *ListPromptBasicRequest) IsSetCreator() bool { + return p.Creator != nil } -type promptOpenAPIServiceProcessorBatchGetPromptByPromptKey struct { - handler PromptOpenAPIService +func (p *ListPromptBasicRequest) IsSetBase() bool { + return p.Base != nil } -func (p *promptOpenAPIServiceProcessorBatchGetPromptByPromptKey) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { - args := PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs{} - if err = args.Read(iprot); err != nil { - iprot.ReadMessageEnd() - x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) - oprot.WriteMessageBegin("BatchGetPromptByPromptKey", thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return false, err - } +func (p *ListPromptBasicRequest) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 - iprot.ReadMessageEnd() - var err2 error - result := PromptOpenAPIServiceBatchGetPromptByPromptKeyResult{} - var retval *BatchGetPromptByPromptKeyResponse - if retval, err2 = p.handler.BatchGetPromptByPromptKey(ctx, args.Req); err2 != nil { - x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing BatchGetPromptByPromptKey: "+err2.Error()) - oprot.WriteMessageBegin("BatchGetPromptByPromptKey", thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return true, err2 - } else { - result.Success = retval - } - if err2 = oprot.WriteMessageBegin("BatchGetPromptByPromptKey", thrift.REPLY, seqId); err2 != nil { - err = err2 - } - if err2 = result.Write(oprot); err == nil && err2 != nil { - err = err2 - } - if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { - err = err2 + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError } - if err2 = oprot.Flush(ctx); err == nil && err2 != nil { - err = err2 + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.I64 { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 2: + if fieldTypeId == thrift.I32 { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 3: + if fieldTypeId == thrift.I32 { + if err = p.ReadField3(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 4: + if fieldTypeId == thrift.STRING { + if err = p.ReadField4(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 5: + if fieldTypeId == thrift.STRING { + if err = p.ReadField5(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 255: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField255(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } } - if err != nil { - return + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError } - return true, err -} -type promptOpenAPIServiceProcessorExecute struct { - handler PromptOpenAPIService + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ListPromptBasicRequest[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } -func (p *promptOpenAPIServiceProcessorExecute) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { - args := PromptOpenAPIServiceExecuteArgs{} - if err = args.Read(iprot); err != nil { - iprot.ReadMessageEnd() - x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) - oprot.WriteMessageBegin("Execute", thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return false, err +func (p *ListPromptBasicRequest) ReadField1(iprot thrift.TProtocol) error { + + var _field *int64 + if v, err := iprot.ReadI64(); err != nil { + return err + } else { + _field = &v } + p.WorkspaceID = _field + return nil +} +func (p *ListPromptBasicRequest) ReadField2(iprot thrift.TProtocol) error { - iprot.ReadMessageEnd() - var err2 error - result := PromptOpenAPIServiceExecuteResult{} - var retval *ExecuteResponse - if retval, err2 = p.handler.Execute(ctx, args.Req); err2 != nil { - x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing Execute: "+err2.Error()) - oprot.WriteMessageBegin("Execute", thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return true, err2 + var _field *int32 + if v, err := iprot.ReadI32(); err != nil { + return err } else { - result.Success = retval + _field = &v } - if err2 = oprot.WriteMessageBegin("Execute", thrift.REPLY, seqId); err2 != nil { - err = err2 + p.PageNumber = _field + return nil +} +func (p *ListPromptBasicRequest) ReadField3(iprot thrift.TProtocol) error { + + var _field *int32 + if v, err := iprot.ReadI32(); err != nil { + return err + } else { + _field = &v } - if err2 = result.Write(oprot); err == nil && err2 != nil { - err = err2 + p.PageSize = _field + return nil +} +func (p *ListPromptBasicRequest) ReadField4(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v } - if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { - err = err2 + p.KeyWord = _field + return nil +} +func (p *ListPromptBasicRequest) ReadField5(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v } - if err2 = oprot.Flush(ctx); err == nil && err2 != nil { - err = err2 + p.Creator = _field + return nil +} +func (p *ListPromptBasicRequest) ReadField255(iprot thrift.TProtocol) error { + _field := base.NewBase() + if err := _field.Read(iprot); err != nil { + return err + } + p.Base = _field + return nil +} + +func (p *ListPromptBasicRequest) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("ListPromptBasicRequest"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } + if err = p.writeField3(oprot); err != nil { + fieldId = 3 + goto WriteFieldError + } + if err = p.writeField4(oprot); err != nil { + fieldId = 4 + goto WriteFieldError + } + if err = p.writeField5(oprot); err != nil { + fieldId = 5 + goto WriteFieldError + } + if err = p.writeField255(oprot); err != nil { + fieldId = 255 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *ListPromptBasicRequest) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetWorkspaceID() { + if err = oprot.WriteFieldBegin("workspace_id", thrift.I64, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI64(*p.WorkspaceID); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} +func (p *ListPromptBasicRequest) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetPageNumber() { + if err = oprot.WriteFieldBegin("page_number", thrift.I32, 2); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI32(*p.PageNumber); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} +func (p *ListPromptBasicRequest) writeField3(oprot thrift.TProtocol) (err error) { + if p.IsSetPageSize() { + if err = oprot.WriteFieldBegin("page_size", thrift.I32, 3); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI32(*p.PageSize); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) +} +func (p *ListPromptBasicRequest) writeField4(oprot thrift.TProtocol) (err error) { + if p.IsSetKeyWord() { + if err = oprot.WriteFieldBegin("key_word", thrift.STRING, 4); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.KeyWord); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 end error: ", p), err) +} +func (p *ListPromptBasicRequest) writeField5(oprot thrift.TProtocol) (err error) { + if p.IsSetCreator() { + if err = oprot.WriteFieldBegin("creator", thrift.STRING, 5); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Creator); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 5 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 5 end error: ", p), err) +} +func (p *ListPromptBasicRequest) writeField255(oprot thrift.TProtocol) (err error) { + if p.IsSetBase() { + if err = oprot.WriteFieldBegin("Base", thrift.STRUCT, 255); err != nil { + goto WriteFieldBeginError + } + if err := p.Base.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 255 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 255 end error: ", p), err) +} + +func (p *ListPromptBasicRequest) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("ListPromptBasicRequest(%+v)", *p) + +} + +func (p *ListPromptBasicRequest) DeepEqual(ano *ListPromptBasicRequest) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.WorkspaceID) { + return false + } + if !p.Field2DeepEqual(ano.PageNumber) { + return false + } + if !p.Field3DeepEqual(ano.PageSize) { + return false + } + if !p.Field4DeepEqual(ano.KeyWord) { + return false + } + if !p.Field5DeepEqual(ano.Creator) { + return false + } + if !p.Field255DeepEqual(ano.Base) { + return false + } + return true +} + +func (p *ListPromptBasicRequest) Field1DeepEqual(src *int64) bool { + + if p.WorkspaceID == src { + return true + } else if p.WorkspaceID == nil || src == nil { + return false + } + if *p.WorkspaceID != *src { + return false + } + return true +} +func (p *ListPromptBasicRequest) Field2DeepEqual(src *int32) bool { + + if p.PageNumber == src { + return true + } else if p.PageNumber == nil || src == nil { + return false + } + if *p.PageNumber != *src { + return false + } + return true +} +func (p *ListPromptBasicRequest) Field3DeepEqual(src *int32) bool { + + if p.PageSize == src { + return true + } else if p.PageSize == nil || src == nil { + return false + } + if *p.PageSize != *src { + return false + } + return true +} +func (p *ListPromptBasicRequest) Field4DeepEqual(src *string) bool { + + if p.KeyWord == src { + return true + } else if p.KeyWord == nil || src == nil { + return false + } + if strings.Compare(*p.KeyWord, *src) != 0 { + return false + } + return true +} +func (p *ListPromptBasicRequest) Field5DeepEqual(src *string) bool { + + if p.Creator == src { + return true + } else if p.Creator == nil || src == nil { + return false + } + if strings.Compare(*p.Creator, *src) != 0 { + return false + } + return true +} +func (p *ListPromptBasicRequest) Field255DeepEqual(src *base.Base) bool { + + if !p.Base.DeepEqual(src) { + return false + } + return true +} + +type ListPromptBasicResponse struct { + Code *int32 `thrift:"code,1,optional" frugal:"1,optional,i32" form:"code" json:"code,omitempty" query:"code"` + Msg *string `thrift:"msg,2,optional" frugal:"2,optional,string" form:"msg" json:"msg,omitempty" query:"msg"` + Data *ListPromptBasicData `thrift:"data,3,optional" frugal:"3,optional,ListPromptBasicData" form:"data" json:"data,omitempty" query:"data"` + BaseResp *base.BaseResp `thrift:"BaseResp,255,optional" frugal:"255,optional,base.BaseResp" form:"BaseResp" json:"BaseResp,omitempty" query:"BaseResp"` +} + +func NewListPromptBasicResponse() *ListPromptBasicResponse { + return &ListPromptBasicResponse{} +} + +func (p *ListPromptBasicResponse) InitDefault() { +} + +var ListPromptBasicResponse_Code_DEFAULT int32 + +func (p *ListPromptBasicResponse) GetCode() (v int32) { + if p == nil { + return + } + if !p.IsSetCode() { + return ListPromptBasicResponse_Code_DEFAULT + } + return *p.Code +} + +var ListPromptBasicResponse_Msg_DEFAULT string + +func (p *ListPromptBasicResponse) GetMsg() (v string) { + if p == nil { + return + } + if !p.IsSetMsg() { + return ListPromptBasicResponse_Msg_DEFAULT + } + return *p.Msg +} + +var ListPromptBasicResponse_Data_DEFAULT *ListPromptBasicData + +func (p *ListPromptBasicResponse) GetData() (v *ListPromptBasicData) { + if p == nil { + return + } + if !p.IsSetData() { + return ListPromptBasicResponse_Data_DEFAULT + } + return p.Data +} + +var ListPromptBasicResponse_BaseResp_DEFAULT *base.BaseResp + +func (p *ListPromptBasicResponse) GetBaseResp() (v *base.BaseResp) { + if p == nil { + return + } + if !p.IsSetBaseResp() { + return ListPromptBasicResponse_BaseResp_DEFAULT + } + return p.BaseResp +} +func (p *ListPromptBasicResponse) SetCode(val *int32) { + p.Code = val +} +func (p *ListPromptBasicResponse) SetMsg(val *string) { + p.Msg = val +} +func (p *ListPromptBasicResponse) SetData(val *ListPromptBasicData) { + p.Data = val +} +func (p *ListPromptBasicResponse) SetBaseResp(val *base.BaseResp) { + p.BaseResp = val +} + +var fieldIDToName_ListPromptBasicResponse = map[int16]string{ + 1: "code", + 2: "msg", + 3: "data", + 255: "BaseResp", +} + +func (p *ListPromptBasicResponse) IsSetCode() bool { + return p.Code != nil +} + +func (p *ListPromptBasicResponse) IsSetMsg() bool { + return p.Msg != nil +} + +func (p *ListPromptBasicResponse) IsSetData() bool { + return p.Data != nil +} + +func (p *ListPromptBasicResponse) IsSetBaseResp() bool { + return p.BaseResp != nil +} + +func (p *ListPromptBasicResponse) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.I32 { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 2: + if fieldTypeId == thrift.STRING { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 3: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField3(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 255: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField255(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ListPromptBasicResponse[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *ListPromptBasicResponse) ReadField1(iprot thrift.TProtocol) error { + + var _field *int32 + if v, err := iprot.ReadI32(); err != nil { + return err + } else { + _field = &v + } + p.Code = _field + return nil +} +func (p *ListPromptBasicResponse) ReadField2(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Msg = _field + return nil +} +func (p *ListPromptBasicResponse) ReadField3(iprot thrift.TProtocol) error { + _field := NewListPromptBasicData() + if err := _field.Read(iprot); err != nil { + return err + } + p.Data = _field + return nil +} +func (p *ListPromptBasicResponse) ReadField255(iprot thrift.TProtocol) error { + _field := base.NewBaseResp() + if err := _field.Read(iprot); err != nil { + return err + } + p.BaseResp = _field + return nil +} + +func (p *ListPromptBasicResponse) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("ListPromptBasicResponse"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } + if err = p.writeField3(oprot); err != nil { + fieldId = 3 + goto WriteFieldError + } + if err = p.writeField255(oprot); err != nil { + fieldId = 255 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *ListPromptBasicResponse) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetCode() { + if err = oprot.WriteFieldBegin("code", thrift.I32, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI32(*p.Code); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} +func (p *ListPromptBasicResponse) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetMsg() { + if err = oprot.WriteFieldBegin("msg", thrift.STRING, 2); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Msg); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} +func (p *ListPromptBasicResponse) writeField3(oprot thrift.TProtocol) (err error) { + if p.IsSetData() { + if err = oprot.WriteFieldBegin("data", thrift.STRUCT, 3); err != nil { + goto WriteFieldBeginError + } + if err := p.Data.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) +} +func (p *ListPromptBasicResponse) writeField255(oprot thrift.TProtocol) (err error) { + if p.IsSetBaseResp() { + if err = oprot.WriteFieldBegin("BaseResp", thrift.STRUCT, 255); err != nil { + goto WriteFieldBeginError + } + if err := p.BaseResp.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 255 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 255 end error: ", p), err) +} + +func (p *ListPromptBasicResponse) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("ListPromptBasicResponse(%+v)", *p) + +} + +func (p *ListPromptBasicResponse) DeepEqual(ano *ListPromptBasicResponse) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.Code) { + return false + } + if !p.Field2DeepEqual(ano.Msg) { + return false + } + if !p.Field3DeepEqual(ano.Data) { + return false + } + if !p.Field255DeepEqual(ano.BaseResp) { + return false + } + return true +} + +func (p *ListPromptBasicResponse) Field1DeepEqual(src *int32) bool { + + if p.Code == src { + return true + } else if p.Code == nil || src == nil { + return false + } + if *p.Code != *src { + return false + } + return true +} +func (p *ListPromptBasicResponse) Field2DeepEqual(src *string) bool { + + if p.Msg == src { + return true + } else if p.Msg == nil || src == nil { + return false + } + if strings.Compare(*p.Msg, *src) != 0 { + return false + } + return true +} +func (p *ListPromptBasicResponse) Field3DeepEqual(src *ListPromptBasicData) bool { + + if !p.Data.DeepEqual(src) { + return false + } + return true +} +func (p *ListPromptBasicResponse) Field255DeepEqual(src *base.BaseResp) bool { + + if !p.BaseResp.DeepEqual(src) { + return false + } + return true +} + +type PromptBasic struct { + // Prompt ID + ID *int64 `thrift:"id,1,optional" frugal:"1,optional,i64" json:"id" form:"id" query:"id"` + // 工作空间ID + WorkspaceID *int64 `thrift:"workspace_id,2,optional" frugal:"2,optional,i64" json:"workspace_id" form:"workspace_id" query:"workspace_id"` + // 唯一标识 + PromptKey *string `thrift:"prompt_key,3,optional" frugal:"3,optional,string" form:"prompt_key" json:"prompt_key,omitempty" query:"prompt_key"` + // Prompt名称 + DisplayName *string `thrift:"display_name,4,optional" frugal:"4,optional,string" form:"display_name" json:"display_name,omitempty" query:"display_name"` + // Prompt描述 + Description *string `thrift:"description,5,optional" frugal:"5,optional,string" form:"description" json:"description,omitempty" query:"description"` + // 最新版本 + LatestVersion *string `thrift:"latest_version,6,optional" frugal:"6,optional,string" form:"latest_version" json:"latest_version,omitempty" query:"latest_version"` + // 创建者 + CreatedBy *string `thrift:"created_by,7,optional" frugal:"7,optional,string" form:"created_by" json:"created_by,omitempty" query:"created_by"` + // 更新者 + UpdatedBy *string `thrift:"updated_by,8,optional" frugal:"8,optional,string" form:"updated_by" json:"updated_by,omitempty" query:"updated_by"` + // 创建时间 + CreatedAt *int64 `thrift:"created_at,9,optional" frugal:"9,optional,i64" json:"created_at" form:"created_at" query:"created_at"` + // 更新时间 + UpdatedAt *int64 `thrift:"updated_at,10,optional" frugal:"10,optional,i64" json:"updated_at" form:"updated_at" query:"updated_at"` + // 最后提交时间 + LatestCommittedAt *int64 `thrift:"latest_committed_at,11,optional" frugal:"11,optional,i64" json:"latest_committed_at" form:"latest_committed_at" query:"latest_committed_at"` +} + +func NewPromptBasic() *PromptBasic { + return &PromptBasic{} +} + +func (p *PromptBasic) InitDefault() { +} + +var PromptBasic_ID_DEFAULT int64 + +func (p *PromptBasic) GetID() (v int64) { + if p == nil { + return + } + if !p.IsSetID() { + return PromptBasic_ID_DEFAULT + } + return *p.ID +} + +var PromptBasic_WorkspaceID_DEFAULT int64 + +func (p *PromptBasic) GetWorkspaceID() (v int64) { + if p == nil { + return + } + if !p.IsSetWorkspaceID() { + return PromptBasic_WorkspaceID_DEFAULT + } + return *p.WorkspaceID +} + +var PromptBasic_PromptKey_DEFAULT string + +func (p *PromptBasic) GetPromptKey() (v string) { + if p == nil { + return + } + if !p.IsSetPromptKey() { + return PromptBasic_PromptKey_DEFAULT + } + return *p.PromptKey +} + +var PromptBasic_DisplayName_DEFAULT string + +func (p *PromptBasic) GetDisplayName() (v string) { + if p == nil { + return + } + if !p.IsSetDisplayName() { + return PromptBasic_DisplayName_DEFAULT + } + return *p.DisplayName +} + +var PromptBasic_Description_DEFAULT string + +func (p *PromptBasic) GetDescription() (v string) { + if p == nil { + return + } + if !p.IsSetDescription() { + return PromptBasic_Description_DEFAULT + } + return *p.Description +} + +var PromptBasic_LatestVersion_DEFAULT string + +func (p *PromptBasic) GetLatestVersion() (v string) { + if p == nil { + return + } + if !p.IsSetLatestVersion() { + return PromptBasic_LatestVersion_DEFAULT + } + return *p.LatestVersion +} + +var PromptBasic_CreatedBy_DEFAULT string + +func (p *PromptBasic) GetCreatedBy() (v string) { + if p == nil { + return + } + if !p.IsSetCreatedBy() { + return PromptBasic_CreatedBy_DEFAULT + } + return *p.CreatedBy +} + +var PromptBasic_UpdatedBy_DEFAULT string + +func (p *PromptBasic) GetUpdatedBy() (v string) { + if p == nil { + return + } + if !p.IsSetUpdatedBy() { + return PromptBasic_UpdatedBy_DEFAULT + } + return *p.UpdatedBy +} + +var PromptBasic_CreatedAt_DEFAULT int64 + +func (p *PromptBasic) GetCreatedAt() (v int64) { + if p == nil { + return + } + if !p.IsSetCreatedAt() { + return PromptBasic_CreatedAt_DEFAULT + } + return *p.CreatedAt +} + +var PromptBasic_UpdatedAt_DEFAULT int64 + +func (p *PromptBasic) GetUpdatedAt() (v int64) { + if p == nil { + return + } + if !p.IsSetUpdatedAt() { + return PromptBasic_UpdatedAt_DEFAULT + } + return *p.UpdatedAt +} + +var PromptBasic_LatestCommittedAt_DEFAULT int64 + +func (p *PromptBasic) GetLatestCommittedAt() (v int64) { + if p == nil { + return + } + if !p.IsSetLatestCommittedAt() { + return PromptBasic_LatestCommittedAt_DEFAULT + } + return *p.LatestCommittedAt +} +func (p *PromptBasic) SetID(val *int64) { + p.ID = val +} +func (p *PromptBasic) SetWorkspaceID(val *int64) { + p.WorkspaceID = val +} +func (p *PromptBasic) SetPromptKey(val *string) { + p.PromptKey = val +} +func (p *PromptBasic) SetDisplayName(val *string) { + p.DisplayName = val +} +func (p *PromptBasic) SetDescription(val *string) { + p.Description = val +} +func (p *PromptBasic) SetLatestVersion(val *string) { + p.LatestVersion = val +} +func (p *PromptBasic) SetCreatedBy(val *string) { + p.CreatedBy = val +} +func (p *PromptBasic) SetUpdatedBy(val *string) { + p.UpdatedBy = val +} +func (p *PromptBasic) SetCreatedAt(val *int64) { + p.CreatedAt = val +} +func (p *PromptBasic) SetUpdatedAt(val *int64) { + p.UpdatedAt = val +} +func (p *PromptBasic) SetLatestCommittedAt(val *int64) { + p.LatestCommittedAt = val +} + +var fieldIDToName_PromptBasic = map[int16]string{ + 1: "id", + 2: "workspace_id", + 3: "prompt_key", + 4: "display_name", + 5: "description", + 6: "latest_version", + 7: "created_by", + 8: "updated_by", + 9: "created_at", + 10: "updated_at", + 11: "latest_committed_at", +} + +func (p *PromptBasic) IsSetID() bool { + return p.ID != nil +} + +func (p *PromptBasic) IsSetWorkspaceID() bool { + return p.WorkspaceID != nil +} + +func (p *PromptBasic) IsSetPromptKey() bool { + return p.PromptKey != nil +} + +func (p *PromptBasic) IsSetDisplayName() bool { + return p.DisplayName != nil +} + +func (p *PromptBasic) IsSetDescription() bool { + return p.Description != nil +} + +func (p *PromptBasic) IsSetLatestVersion() bool { + return p.LatestVersion != nil +} + +func (p *PromptBasic) IsSetCreatedBy() bool { + return p.CreatedBy != nil +} + +func (p *PromptBasic) IsSetUpdatedBy() bool { + return p.UpdatedBy != nil +} + +func (p *PromptBasic) IsSetCreatedAt() bool { + return p.CreatedAt != nil +} + +func (p *PromptBasic) IsSetUpdatedAt() bool { + return p.UpdatedAt != nil +} + +func (p *PromptBasic) IsSetLatestCommittedAt() bool { + return p.LatestCommittedAt != nil +} + +func (p *PromptBasic) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.I64 { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 2: + if fieldTypeId == thrift.I64 { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 3: + if fieldTypeId == thrift.STRING { + if err = p.ReadField3(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 4: + if fieldTypeId == thrift.STRING { + if err = p.ReadField4(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 5: + if fieldTypeId == thrift.STRING { + if err = p.ReadField5(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 6: + if fieldTypeId == thrift.STRING { + if err = p.ReadField6(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 7: + if fieldTypeId == thrift.STRING { + if err = p.ReadField7(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 8: + if fieldTypeId == thrift.STRING { + if err = p.ReadField8(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 9: + if fieldTypeId == thrift.I64 { + if err = p.ReadField9(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 10: + if fieldTypeId == thrift.I64 { + if err = p.ReadField10(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 11: + if fieldTypeId == thrift.I64 { + if err = p.ReadField11(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptBasic[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *PromptBasic) ReadField1(iprot thrift.TProtocol) error { + + var _field *int64 + if v, err := iprot.ReadI64(); err != nil { + return err + } else { + _field = &v + } + p.ID = _field + return nil +} +func (p *PromptBasic) ReadField2(iprot thrift.TProtocol) error { + + var _field *int64 + if v, err := iprot.ReadI64(); err != nil { + return err + } else { + _field = &v + } + p.WorkspaceID = _field + return nil +} +func (p *PromptBasic) ReadField3(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.PromptKey = _field + return nil +} +func (p *PromptBasic) ReadField4(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.DisplayName = _field + return nil +} +func (p *PromptBasic) ReadField5(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Description = _field + return nil +} +func (p *PromptBasic) ReadField6(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.LatestVersion = _field + return nil +} +func (p *PromptBasic) ReadField7(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.CreatedBy = _field + return nil +} +func (p *PromptBasic) ReadField8(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.UpdatedBy = _field + return nil +} +func (p *PromptBasic) ReadField9(iprot thrift.TProtocol) error { + + var _field *int64 + if v, err := iprot.ReadI64(); err != nil { + return err + } else { + _field = &v + } + p.CreatedAt = _field + return nil +} +func (p *PromptBasic) ReadField10(iprot thrift.TProtocol) error { + + var _field *int64 + if v, err := iprot.ReadI64(); err != nil { + return err + } else { + _field = &v + } + p.UpdatedAt = _field + return nil +} +func (p *PromptBasic) ReadField11(iprot thrift.TProtocol) error { + + var _field *int64 + if v, err := iprot.ReadI64(); err != nil { + return err + } else { + _field = &v + } + p.LatestCommittedAt = _field + return nil +} + +func (p *PromptBasic) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("PromptBasic"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } + if err = p.writeField3(oprot); err != nil { + fieldId = 3 + goto WriteFieldError + } + if err = p.writeField4(oprot); err != nil { + fieldId = 4 + goto WriteFieldError + } + if err = p.writeField5(oprot); err != nil { + fieldId = 5 + goto WriteFieldError + } + if err = p.writeField6(oprot); err != nil { + fieldId = 6 + goto WriteFieldError + } + if err = p.writeField7(oprot); err != nil { + fieldId = 7 + goto WriteFieldError + } + if err = p.writeField8(oprot); err != nil { + fieldId = 8 + goto WriteFieldError + } + if err = p.writeField9(oprot); err != nil { + fieldId = 9 + goto WriteFieldError + } + if err = p.writeField10(oprot); err != nil { + fieldId = 10 + goto WriteFieldError + } + if err = p.writeField11(oprot); err != nil { + fieldId = 11 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *PromptBasic) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetID() { + if err = oprot.WriteFieldBegin("id", thrift.I64, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI64(*p.ID); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} +func (p *PromptBasic) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetWorkspaceID() { + if err = oprot.WriteFieldBegin("workspace_id", thrift.I64, 2); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI64(*p.WorkspaceID); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} +func (p *PromptBasic) writeField3(oprot thrift.TProtocol) (err error) { + if p.IsSetPromptKey() { + if err = oprot.WriteFieldBegin("prompt_key", thrift.STRING, 3); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.PromptKey); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) +} +func (p *PromptBasic) writeField4(oprot thrift.TProtocol) (err error) { + if p.IsSetDisplayName() { + if err = oprot.WriteFieldBegin("display_name", thrift.STRING, 4); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.DisplayName); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 end error: ", p), err) +} +func (p *PromptBasic) writeField5(oprot thrift.TProtocol) (err error) { + if p.IsSetDescription() { + if err = oprot.WriteFieldBegin("description", thrift.STRING, 5); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Description); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 5 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 5 end error: ", p), err) +} +func (p *PromptBasic) writeField6(oprot thrift.TProtocol) (err error) { + if p.IsSetLatestVersion() { + if err = oprot.WriteFieldBegin("latest_version", thrift.STRING, 6); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.LatestVersion); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 6 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 6 end error: ", p), err) +} +func (p *PromptBasic) writeField7(oprot thrift.TProtocol) (err error) { + if p.IsSetCreatedBy() { + if err = oprot.WriteFieldBegin("created_by", thrift.STRING, 7); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.CreatedBy); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 7 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 7 end error: ", p), err) +} +func (p *PromptBasic) writeField8(oprot thrift.TProtocol) (err error) { + if p.IsSetUpdatedBy() { + if err = oprot.WriteFieldBegin("updated_by", thrift.STRING, 8); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.UpdatedBy); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 8 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 8 end error: ", p), err) +} +func (p *PromptBasic) writeField9(oprot thrift.TProtocol) (err error) { + if p.IsSetCreatedAt() { + if err = oprot.WriteFieldBegin("created_at", thrift.I64, 9); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI64(*p.CreatedAt); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 9 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 9 end error: ", p), err) +} +func (p *PromptBasic) writeField10(oprot thrift.TProtocol) (err error) { + if p.IsSetUpdatedAt() { + if err = oprot.WriteFieldBegin("updated_at", thrift.I64, 10); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI64(*p.UpdatedAt); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 10 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 10 end error: ", p), err) +} +func (p *PromptBasic) writeField11(oprot thrift.TProtocol) (err error) { + if p.IsSetLatestCommittedAt() { + if err = oprot.WriteFieldBegin("latest_committed_at", thrift.I64, 11); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI64(*p.LatestCommittedAt); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 11 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 11 end error: ", p), err) +} + +func (p *PromptBasic) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("PromptBasic(%+v)", *p) + +} + +func (p *PromptBasic) DeepEqual(ano *PromptBasic) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.ID) { + return false + } + if !p.Field2DeepEqual(ano.WorkspaceID) { + return false + } + if !p.Field3DeepEqual(ano.PromptKey) { + return false + } + if !p.Field4DeepEqual(ano.DisplayName) { + return false + } + if !p.Field5DeepEqual(ano.Description) { + return false + } + if !p.Field6DeepEqual(ano.LatestVersion) { + return false + } + if !p.Field7DeepEqual(ano.CreatedBy) { + return false + } + if !p.Field8DeepEqual(ano.UpdatedBy) { + return false + } + if !p.Field9DeepEqual(ano.CreatedAt) { + return false + } + if !p.Field10DeepEqual(ano.UpdatedAt) { + return false + } + if !p.Field11DeepEqual(ano.LatestCommittedAt) { + return false + } + return true +} + +func (p *PromptBasic) Field1DeepEqual(src *int64) bool { + + if p.ID == src { + return true + } else if p.ID == nil || src == nil { + return false + } + if *p.ID != *src { + return false + } + return true +} +func (p *PromptBasic) Field2DeepEqual(src *int64) bool { + + if p.WorkspaceID == src { + return true + } else if p.WorkspaceID == nil || src == nil { + return false + } + if *p.WorkspaceID != *src { + return false + } + return true +} +func (p *PromptBasic) Field3DeepEqual(src *string) bool { + + if p.PromptKey == src { + return true + } else if p.PromptKey == nil || src == nil { + return false + } + if strings.Compare(*p.PromptKey, *src) != 0 { + return false + } + return true +} +func (p *PromptBasic) Field4DeepEqual(src *string) bool { + + if p.DisplayName == src { + return true + } else if p.DisplayName == nil || src == nil { + return false + } + if strings.Compare(*p.DisplayName, *src) != 0 { + return false + } + return true +} +func (p *PromptBasic) Field5DeepEqual(src *string) bool { + + if p.Description == src { + return true + } else if p.Description == nil || src == nil { + return false + } + if strings.Compare(*p.Description, *src) != 0 { + return false + } + return true +} +func (p *PromptBasic) Field6DeepEqual(src *string) bool { + + if p.LatestVersion == src { + return true + } else if p.LatestVersion == nil || src == nil { + return false + } + if strings.Compare(*p.LatestVersion, *src) != 0 { + return false + } + return true +} +func (p *PromptBasic) Field7DeepEqual(src *string) bool { + + if p.CreatedBy == src { + return true + } else if p.CreatedBy == nil || src == nil { + return false + } + if strings.Compare(*p.CreatedBy, *src) != 0 { + return false + } + return true +} +func (p *PromptBasic) Field8DeepEqual(src *string) bool { + + if p.UpdatedBy == src { + return true + } else if p.UpdatedBy == nil || src == nil { + return false + } + if strings.Compare(*p.UpdatedBy, *src) != 0 { + return false + } + return true +} +func (p *PromptBasic) Field9DeepEqual(src *int64) bool { + + if p.CreatedAt == src { + return true + } else if p.CreatedAt == nil || src == nil { + return false + } + if *p.CreatedAt != *src { + return false + } + return true +} +func (p *PromptBasic) Field10DeepEqual(src *int64) bool { + + if p.UpdatedAt == src { + return true + } else if p.UpdatedAt == nil || src == nil { + return false + } + if *p.UpdatedAt != *src { + return false + } + return true +} +func (p *PromptBasic) Field11DeepEqual(src *int64) bool { + + if p.LatestCommittedAt == src { + return true + } else if p.LatestCommittedAt == nil || src == nil { + return false + } + if *p.LatestCommittedAt != *src { + return false + } + return true +} + +type ListPromptBasicData struct { + // Prompt列表 + Prompts []*PromptBasic `thrift:"prompts,1,optional" frugal:"1,optional,list" form:"prompts" json:"prompts,omitempty" query:"prompts"` + Total *int32 `thrift:"total,2,optional" frugal:"2,optional,i32" form:"total" json:"total,omitempty" query:"total"` +} + +func NewListPromptBasicData() *ListPromptBasicData { + return &ListPromptBasicData{} +} + +func (p *ListPromptBasicData) InitDefault() { +} + +var ListPromptBasicData_Prompts_DEFAULT []*PromptBasic + +func (p *ListPromptBasicData) GetPrompts() (v []*PromptBasic) { + if p == nil { + return + } + if !p.IsSetPrompts() { + return ListPromptBasicData_Prompts_DEFAULT + } + return p.Prompts +} + +var ListPromptBasicData_Total_DEFAULT int32 + +func (p *ListPromptBasicData) GetTotal() (v int32) { + if p == nil { + return + } + if !p.IsSetTotal() { + return ListPromptBasicData_Total_DEFAULT + } + return *p.Total +} +func (p *ListPromptBasicData) SetPrompts(val []*PromptBasic) { + p.Prompts = val +} +func (p *ListPromptBasicData) SetTotal(val *int32) { + p.Total = val +} + +var fieldIDToName_ListPromptBasicData = map[int16]string{ + 1: "prompts", + 2: "total", +} + +func (p *ListPromptBasicData) IsSetPrompts() bool { + return p.Prompts != nil +} + +func (p *ListPromptBasicData) IsSetTotal() bool { + return p.Total != nil +} + +func (p *ListPromptBasicData) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.LIST { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 2: + if fieldTypeId == thrift.I32 { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ListPromptBasicData[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *ListPromptBasicData) ReadField1(iprot thrift.TProtocol) error { + _, size, err := iprot.ReadListBegin() + if err != nil { + return err + } + _field := make([]*PromptBasic, 0, size) + values := make([]PromptBasic, size) + for i := 0; i < size; i++ { + _elem := &values[i] + _elem.InitDefault() + + if err := _elem.Read(iprot); err != nil { + return err + } + + _field = append(_field, _elem) + } + if err := iprot.ReadListEnd(); err != nil { + return err + } + p.Prompts = _field + return nil +} +func (p *ListPromptBasicData) ReadField2(iprot thrift.TProtocol) error { + + var _field *int32 + if v, err := iprot.ReadI32(); err != nil { + return err + } else { + _field = &v + } + p.Total = _field + return nil +} + +func (p *ListPromptBasicData) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("ListPromptBasicData"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *ListPromptBasicData) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetPrompts() { + if err = oprot.WriteFieldBegin("prompts", thrift.LIST, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteListBegin(thrift.STRUCT, len(p.Prompts)); err != nil { + return err + } + for _, v := range p.Prompts { + if err := v.Write(oprot); err != nil { + return err + } + } + if err := oprot.WriteListEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} +func (p *ListPromptBasicData) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetTotal() { + if err = oprot.WriteFieldBegin("total", thrift.I32, 2); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI32(*p.Total); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} + +func (p *ListPromptBasicData) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("ListPromptBasicData(%+v)", *p) + +} + +func (p *ListPromptBasicData) DeepEqual(ano *ListPromptBasicData) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.Prompts) { + return false + } + if !p.Field2DeepEqual(ano.Total) { + return false + } + return true +} + +func (p *ListPromptBasicData) Field1DeepEqual(src []*PromptBasic) bool { + + if len(p.Prompts) != len(src) { + return false + } + for i, v := range p.Prompts { + _src := src[i] + if !v.DeepEqual(_src) { + return false + } + } + return true +} +func (p *ListPromptBasicData) Field2DeepEqual(src *int32) bool { + + if p.Total == src { + return true + } else if p.Total == nil || src == nil { + return false + } + if *p.Total != *src { + return false + } + return true +} + +type PromptOpenAPIService interface { + BatchGetPromptByPromptKey(ctx context.Context, req *BatchGetPromptByPromptKeyRequest) (r *BatchGetPromptByPromptKeyResponse, err error) + + Execute(ctx context.Context, req *ExecuteRequest) (r *ExecuteResponse, err error) + + ExecuteStreaming(ctx context.Context, req *ExecuteRequest, stream PromptOpenAPIService_ExecuteStreamingServer) (err error) + + ListPromptBasic(ctx context.Context, req *ListPromptBasicRequest) (r *ListPromptBasicResponse, err error) +} + +type PromptOpenAPIServiceClient struct { + c thrift.TClient +} + +func NewPromptOpenAPIServiceClientFactory(t thrift.TTransport, f thrift.TProtocolFactory) *PromptOpenAPIServiceClient { + return &PromptOpenAPIServiceClient{ + c: thrift.NewTStandardClient(f.GetProtocol(t), f.GetProtocol(t)), + } +} + +func NewPromptOpenAPIServiceClientProtocol(t thrift.TTransport, iprot thrift.TProtocol, oprot thrift.TProtocol) *PromptOpenAPIServiceClient { + return &PromptOpenAPIServiceClient{ + c: thrift.NewTStandardClient(iprot, oprot), + } +} + +func NewPromptOpenAPIServiceClient(c thrift.TClient) *PromptOpenAPIServiceClient { + return &PromptOpenAPIServiceClient{ + c: c, + } +} + +func (p *PromptOpenAPIServiceClient) Client_() thrift.TClient { + return p.c +} + +func (p *PromptOpenAPIServiceClient) BatchGetPromptByPromptKey(ctx context.Context, req *BatchGetPromptByPromptKeyRequest) (r *BatchGetPromptByPromptKeyResponse, err error) { + var _args PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs + _args.Req = req + var _result PromptOpenAPIServiceBatchGetPromptByPromptKeyResult + if err = p.Client_().Call(ctx, "BatchGetPromptByPromptKey", &_args, &_result); err != nil { + return + } + return _result.GetSuccess(), nil +} +func (p *PromptOpenAPIServiceClient) Execute(ctx context.Context, req *ExecuteRequest) (r *ExecuteResponse, err error) { + var _args PromptOpenAPIServiceExecuteArgs + _args.Req = req + var _result PromptOpenAPIServiceExecuteResult + if err = p.Client_().Call(ctx, "Execute", &_args, &_result); err != nil { + return + } + return _result.GetSuccess(), nil +} +func (p *PromptOpenAPIServiceClient) ExecuteStreaming(ctx context.Context, req *ExecuteRequest, stream PromptOpenAPIService_ExecuteStreamingServer) (err error) { + panic("streaming method PromptOpenAPIService.ExecuteStreaming(mode = server) not available, please use Kitex Thrift Streaming Client.") +} +func (p *PromptOpenAPIServiceClient) ListPromptBasic(ctx context.Context, req *ListPromptBasicRequest) (r *ListPromptBasicResponse, err error) { + var _args PromptOpenAPIServiceListPromptBasicArgs + _args.Req = req + var _result PromptOpenAPIServiceListPromptBasicResult + if err = p.Client_().Call(ctx, "ListPromptBasic", &_args, &_result); err != nil { + return + } + return _result.GetSuccess(), nil +} + +type PromptOpenAPIService_ExecuteStreamingServer streaming.ServerStreamingServer[ExecuteStreamingResponse] + +type PromptOpenAPIServiceProcessor struct { + processorMap map[string]thrift.TProcessorFunction + handler PromptOpenAPIService +} + +func (p *PromptOpenAPIServiceProcessor) AddToProcessorMap(key string, processor thrift.TProcessorFunction) { + p.processorMap[key] = processor +} + +func (p *PromptOpenAPIServiceProcessor) GetProcessorFunction(key string) (processor thrift.TProcessorFunction, ok bool) { + processor, ok = p.processorMap[key] + return processor, ok +} + +func (p *PromptOpenAPIServiceProcessor) ProcessorMap() map[string]thrift.TProcessorFunction { + return p.processorMap +} + +func NewPromptOpenAPIServiceProcessor(handler PromptOpenAPIService) *PromptOpenAPIServiceProcessor { + self := &PromptOpenAPIServiceProcessor{handler: handler, processorMap: make(map[string]thrift.TProcessorFunction)} + self.AddToProcessorMap("BatchGetPromptByPromptKey", &promptOpenAPIServiceProcessorBatchGetPromptByPromptKey{handler: handler}) + self.AddToProcessorMap("Execute", &promptOpenAPIServiceProcessorExecute{handler: handler}) + self.AddToProcessorMap("ExecuteStreaming", &promptOpenAPIServiceProcessorExecuteStreaming{handler: handler}) + self.AddToProcessorMap("ListPromptBasic", &promptOpenAPIServiceProcessorListPromptBasic{handler: handler}) + return self +} +func (p *PromptOpenAPIServiceProcessor) Process(ctx context.Context, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { + name, _, seqId, err := iprot.ReadMessageBegin() + if err != nil { + return false, err + } + if processor, ok := p.GetProcessorFunction(name); ok { + return processor.Process(ctx, seqId, iprot, oprot) + } + iprot.Skip(thrift.STRUCT) + iprot.ReadMessageEnd() + x := thrift.NewTApplicationException(thrift.UNKNOWN_METHOD, "Unknown function "+name) + oprot.WriteMessageBegin(name, thrift.EXCEPTION, seqId) + x.Write(oprot) + oprot.WriteMessageEnd() + oprot.Flush(ctx) + return false, x +} + +type promptOpenAPIServiceProcessorBatchGetPromptByPromptKey struct { + handler PromptOpenAPIService +} + +func (p *promptOpenAPIServiceProcessorBatchGetPromptByPromptKey) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { + args := PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs{} + if err = args.Read(iprot); err != nil { + iprot.ReadMessageEnd() + x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) + oprot.WriteMessageBegin("BatchGetPromptByPromptKey", thrift.EXCEPTION, seqId) + x.Write(oprot) + oprot.WriteMessageEnd() + oprot.Flush(ctx) + return false, err + } + + iprot.ReadMessageEnd() + var err2 error + result := PromptOpenAPIServiceBatchGetPromptByPromptKeyResult{} + var retval *BatchGetPromptByPromptKeyResponse + if retval, err2 = p.handler.BatchGetPromptByPromptKey(ctx, args.Req); err2 != nil { + x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing BatchGetPromptByPromptKey: "+err2.Error()) + oprot.WriteMessageBegin("BatchGetPromptByPromptKey", thrift.EXCEPTION, seqId) + x.Write(oprot) + oprot.WriteMessageEnd() + oprot.Flush(ctx) + return true, err2 + } else { + result.Success = retval + } + if err2 = oprot.WriteMessageBegin("BatchGetPromptByPromptKey", thrift.REPLY, seqId); err2 != nil { + err = err2 + } + if err2 = result.Write(oprot); err == nil && err2 != nil { + err = err2 + } + if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { + err = err2 + } + if err2 = oprot.Flush(ctx); err == nil && err2 != nil { + err = err2 + } + if err != nil { + return + } + return true, err +} + +type promptOpenAPIServiceProcessorExecute struct { + handler PromptOpenAPIService +} + +func (p *promptOpenAPIServiceProcessorExecute) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { + args := PromptOpenAPIServiceExecuteArgs{} + if err = args.Read(iprot); err != nil { + iprot.ReadMessageEnd() + x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) + oprot.WriteMessageBegin("Execute", thrift.EXCEPTION, seqId) + x.Write(oprot) + oprot.WriteMessageEnd() + oprot.Flush(ctx) + return false, err + } + + iprot.ReadMessageEnd() + var err2 error + result := PromptOpenAPIServiceExecuteResult{} + var retval *ExecuteResponse + if retval, err2 = p.handler.Execute(ctx, args.Req); err2 != nil { + x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing Execute: "+err2.Error()) + oprot.WriteMessageBegin("Execute", thrift.EXCEPTION, seqId) + x.Write(oprot) + oprot.WriteMessageEnd() + oprot.Flush(ctx) + return true, err2 + } else { + result.Success = retval + } + if err2 = oprot.WriteMessageBegin("Execute", thrift.REPLY, seqId); err2 != nil { + err = err2 + } + if err2 = result.Write(oprot); err == nil && err2 != nil { + err = err2 + } + if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { + err = err2 + } + if err2 = oprot.Flush(ctx); err == nil && err2 != nil { + err = err2 + } + if err != nil { + return + } + return true, err +} + +type promptOpenAPIServiceProcessorExecuteStreaming struct { + handler PromptOpenAPIService +} + +func (p *promptOpenAPIServiceProcessorExecuteStreaming) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { + panic("streaming method PromptOpenAPIService.ExecuteStreaming(mode = server) not available, please use Kitex Thrift Streaming Client.") +} + +type promptOpenAPIServiceProcessorListPromptBasic struct { + handler PromptOpenAPIService +} + +func (p *promptOpenAPIServiceProcessorListPromptBasic) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { + args := PromptOpenAPIServiceListPromptBasicArgs{} + if err = args.Read(iprot); err != nil { + iprot.ReadMessageEnd() + x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) + oprot.WriteMessageBegin("ListPromptBasic", thrift.EXCEPTION, seqId) + x.Write(oprot) + oprot.WriteMessageEnd() + oprot.Flush(ctx) + return false, err + } + + iprot.ReadMessageEnd() + var err2 error + result := PromptOpenAPIServiceListPromptBasicResult{} + var retval *ListPromptBasicResponse + if retval, err2 = p.handler.ListPromptBasic(ctx, args.Req); err2 != nil { + x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing ListPromptBasic: "+err2.Error()) + oprot.WriteMessageBegin("ListPromptBasic", thrift.EXCEPTION, seqId) + x.Write(oprot) + oprot.WriteMessageEnd() + oprot.Flush(ctx) + return true, err2 + } else { + result.Success = retval + } + if err2 = oprot.WriteMessageBegin("ListPromptBasic", thrift.REPLY, seqId); err2 != nil { + err = err2 + } + if err2 = result.Write(oprot); err == nil && err2 != nil { + err = err2 + } + if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { + err = err2 + } + if err2 = oprot.Flush(ctx); err == nil && err2 != nil { + err = err2 + } + if err != nil { + return + } + return true, err +} + +type PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs struct { + Req *BatchGetPromptByPromptKeyRequest `thrift:"req,1" frugal:"1,default,BatchGetPromptByPromptKeyRequest"` +} + +func NewPromptOpenAPIServiceBatchGetPromptByPromptKeyArgs() *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs { + return &PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs{} +} + +func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) InitDefault() { +} + +var PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs_Req_DEFAULT *BatchGetPromptByPromptKeyRequest + +func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) GetReq() (v *BatchGetPromptByPromptKeyRequest) { + if p == nil { + return + } + if !p.IsSetReq() { + return PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs_Req_DEFAULT + } + return p.Req +} +func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) SetReq(val *BatchGetPromptByPromptKeyRequest) { + p.Req = val +} + +var fieldIDToName_PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs = map[int16]string{ + 1: "req", +} + +func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) IsSetReq() bool { + return p.Req != nil +} + +func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) ReadField1(iprot thrift.TProtocol) error { + _field := NewBatchGetPromptByPromptKeyRequest() + if err := _field.Read(iprot); err != nil { + return err + } + p.Req = _field + return nil +} + +func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("BatchGetPromptByPromptKey_args"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) writeField1(oprot thrift.TProtocol) (err error) { + if err = oprot.WriteFieldBegin("req", thrift.STRUCT, 1); err != nil { + goto WriteFieldBeginError + } + if err := p.Req.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} + +func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs(%+v)", *p) + +} + +func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) DeepEqual(ano *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.Req) { + return false + } + return true +} + +func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) Field1DeepEqual(src *BatchGetPromptByPromptKeyRequest) bool { + + if !p.Req.DeepEqual(src) { + return false + } + return true +} + +type PromptOpenAPIServiceBatchGetPromptByPromptKeyResult struct { + Success *BatchGetPromptByPromptKeyResponse `thrift:"success,0,optional" frugal:"0,optional,BatchGetPromptByPromptKeyResponse"` +} + +func NewPromptOpenAPIServiceBatchGetPromptByPromptKeyResult() *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult { + return &PromptOpenAPIServiceBatchGetPromptByPromptKeyResult{} +} + +func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) InitDefault() { +} + +var PromptOpenAPIServiceBatchGetPromptByPromptKeyResult_Success_DEFAULT *BatchGetPromptByPromptKeyResponse + +func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) GetSuccess() (v *BatchGetPromptByPromptKeyResponse) { + if p == nil { + return + } + if !p.IsSetSuccess() { + return PromptOpenAPIServiceBatchGetPromptByPromptKeyResult_Success_DEFAULT + } + return p.Success +} +func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) SetSuccess(x interface{}) { + p.Success = x.(*BatchGetPromptByPromptKeyResponse) +} + +var fieldIDToName_PromptOpenAPIServiceBatchGetPromptByPromptKeyResult = map[int16]string{ + 0: "success", +} + +func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) IsSetSuccess() bool { + return p.Success != nil +} + +func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 0: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField0(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptOpenAPIServiceBatchGetPromptByPromptKeyResult[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) ReadField0(iprot thrift.TProtocol) error { + _field := NewBatchGetPromptByPromptKeyResponse() + if err := _field.Read(iprot); err != nil { + return err + } + p.Success = _field + return nil +} + +func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("BatchGetPromptByPromptKey_result"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField0(oprot); err != nil { + fieldId = 0 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) writeField0(oprot thrift.TProtocol) (err error) { + if p.IsSetSuccess() { + if err = oprot.WriteFieldBegin("success", thrift.STRUCT, 0); err != nil { + goto WriteFieldBeginError + } + if err := p.Success.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } } - if err != nil { - return + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 0 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 0 end error: ", p), err) +} + +func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) String() string { + if p == nil { + return "" } - return true, err + return fmt.Sprintf("PromptOpenAPIServiceBatchGetPromptByPromptKeyResult(%+v)", *p) + } -type promptOpenAPIServiceProcessorExecuteStreaming struct { - handler PromptOpenAPIService +func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) DeepEqual(ano *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field0DeepEqual(ano.Success) { + return false + } + return true } -func (p *promptOpenAPIServiceProcessorExecuteStreaming) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { - panic("streaming method PromptOpenAPIService.ExecuteStreaming(mode = server) not available, please use Kitex Thrift Streaming Client.") +func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) Field0DeepEqual(src *BatchGetPromptByPromptKeyResponse) bool { + + if !p.Success.DeepEqual(src) { + return false + } + return true } -type PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs struct { - Req *BatchGetPromptByPromptKeyRequest `thrift:"req,1" frugal:"1,default,BatchGetPromptByPromptKeyRequest"` +type PromptOpenAPIServiceExecuteArgs struct { + Req *ExecuteRequest `thrift:"req,1" frugal:"1,default,ExecuteRequest"` } -func NewPromptOpenAPIServiceBatchGetPromptByPromptKeyArgs() *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs { - return &PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs{} +func NewPromptOpenAPIServiceExecuteArgs() *PromptOpenAPIServiceExecuteArgs { + return &PromptOpenAPIServiceExecuteArgs{} } -func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) InitDefault() { +func (p *PromptOpenAPIServiceExecuteArgs) InitDefault() { } -var PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs_Req_DEFAULT *BatchGetPromptByPromptKeyRequest +var PromptOpenAPIServiceExecuteArgs_Req_DEFAULT *ExecuteRequest -func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) GetReq() (v *BatchGetPromptByPromptKeyRequest) { +func (p *PromptOpenAPIServiceExecuteArgs) GetReq() (v *ExecuteRequest) { if p == nil { return } if !p.IsSetReq() { - return PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs_Req_DEFAULT + return PromptOpenAPIServiceExecuteArgs_Req_DEFAULT } return p.Req } -func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) SetReq(val *BatchGetPromptByPromptKeyRequest) { +func (p *PromptOpenAPIServiceExecuteArgs) SetReq(val *ExecuteRequest) { p.Req = val } -var fieldIDToName_PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs = map[int16]string{ +var fieldIDToName_PromptOpenAPIServiceExecuteArgs = map[int16]string{ 1: "req", } -func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) IsSetReq() bool { +func (p *PromptOpenAPIServiceExecuteArgs) IsSetReq() bool { return p.Req != nil } -func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) Read(iprot thrift.TProtocol) (err error) { +func (p *PromptOpenAPIServiceExecuteArgs) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -9253,7 +11855,7 @@ ReadStructBeginError: ReadFieldBeginError: return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs[fieldId]), err) + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptOpenAPIServiceExecuteArgs[fieldId]), err) SkipFieldError: return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) @@ -9263,8 +11865,8 @@ ReadStructEndError: return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } -func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) ReadField1(iprot thrift.TProtocol) error { - _field := NewBatchGetPromptByPromptKeyRequest() +func (p *PromptOpenAPIServiceExecuteArgs) ReadField1(iprot thrift.TProtocol) error { + _field := NewExecuteRequest() if err := _field.Read(iprot); err != nil { return err } @@ -9272,9 +11874,9 @@ func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) ReadField1(iprot thr return nil } -func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) Write(oprot thrift.TProtocol) (err error) { +func (p *PromptOpenAPIServiceExecuteArgs) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 - if err = oprot.WriteStructBegin("BatchGetPromptByPromptKey_args"); err != nil { + if err = oprot.WriteStructBegin("Execute_args"); err != nil { goto WriteStructBeginError } if p != nil { @@ -9300,7 +11902,7 @@ WriteStructEndError: return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } -func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) writeField1(oprot thrift.TProtocol) (err error) { +func (p *PromptOpenAPIServiceExecuteArgs) writeField1(oprot thrift.TProtocol) (err error) { if err = oprot.WriteFieldBegin("req", thrift.STRUCT, 1); err != nil { goto WriteFieldBeginError } @@ -9317,15 +11919,15 @@ WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) } -func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) String() string { +func (p *PromptOpenAPIServiceExecuteArgs) String() string { if p == nil { return "" } - return fmt.Sprintf("PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs(%+v)", *p) + return fmt.Sprintf("PromptOpenAPIServiceExecuteArgs(%+v)", *p) } -func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) DeepEqual(ano *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) bool { +func (p *PromptOpenAPIServiceExecuteArgs) DeepEqual(ano *PromptOpenAPIServiceExecuteArgs) bool { if p == ano { return true } else if p == nil || ano == nil { @@ -9337,7 +11939,7 @@ func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) DeepEqual(ano *Promp return true } -func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) Field1DeepEqual(src *BatchGetPromptByPromptKeyRequest) bool { +func (p *PromptOpenAPIServiceExecuteArgs) Field1DeepEqual(src *ExecuteRequest) bool { if !p.Req.DeepEqual(src) { return false @@ -9345,41 +11947,41 @@ func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) Field1DeepEqual(src return true } -type PromptOpenAPIServiceBatchGetPromptByPromptKeyResult struct { - Success *BatchGetPromptByPromptKeyResponse `thrift:"success,0,optional" frugal:"0,optional,BatchGetPromptByPromptKeyResponse"` +type PromptOpenAPIServiceExecuteResult struct { + Success *ExecuteResponse `thrift:"success,0,optional" frugal:"0,optional,ExecuteResponse"` } -func NewPromptOpenAPIServiceBatchGetPromptByPromptKeyResult() *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult { - return &PromptOpenAPIServiceBatchGetPromptByPromptKeyResult{} +func NewPromptOpenAPIServiceExecuteResult() *PromptOpenAPIServiceExecuteResult { + return &PromptOpenAPIServiceExecuteResult{} } -func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) InitDefault() { +func (p *PromptOpenAPIServiceExecuteResult) InitDefault() { } -var PromptOpenAPIServiceBatchGetPromptByPromptKeyResult_Success_DEFAULT *BatchGetPromptByPromptKeyResponse +var PromptOpenAPIServiceExecuteResult_Success_DEFAULT *ExecuteResponse -func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) GetSuccess() (v *BatchGetPromptByPromptKeyResponse) { +func (p *PromptOpenAPIServiceExecuteResult) GetSuccess() (v *ExecuteResponse) { if p == nil { return } if !p.IsSetSuccess() { - return PromptOpenAPIServiceBatchGetPromptByPromptKeyResult_Success_DEFAULT + return PromptOpenAPIServiceExecuteResult_Success_DEFAULT } return p.Success } -func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) SetSuccess(x interface{}) { - p.Success = x.(*BatchGetPromptByPromptKeyResponse) +func (p *PromptOpenAPIServiceExecuteResult) SetSuccess(x interface{}) { + p.Success = x.(*ExecuteResponse) } -var fieldIDToName_PromptOpenAPIServiceBatchGetPromptByPromptKeyResult = map[int16]string{ +var fieldIDToName_PromptOpenAPIServiceExecuteResult = map[int16]string{ 0: "success", } -func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) IsSetSuccess() bool { +func (p *PromptOpenAPIServiceExecuteResult) IsSetSuccess() bool { return p.Success != nil } -func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) Read(iprot thrift.TProtocol) (err error) { +func (p *PromptOpenAPIServiceExecuteResult) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -9424,7 +12026,7 @@ ReadStructBeginError: ReadFieldBeginError: return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptOpenAPIServiceBatchGetPromptByPromptKeyResult[fieldId]), err) + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptOpenAPIServiceExecuteResult[fieldId]), err) SkipFieldError: return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) @@ -9434,8 +12036,8 @@ ReadStructEndError: return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } -func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) ReadField0(iprot thrift.TProtocol) error { - _field := NewBatchGetPromptByPromptKeyResponse() +func (p *PromptOpenAPIServiceExecuteResult) ReadField0(iprot thrift.TProtocol) error { + _field := NewExecuteResponse() if err := _field.Read(iprot); err != nil { return err } @@ -9443,9 +12045,9 @@ func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) ReadField0(iprot t return nil } -func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) Write(oprot thrift.TProtocol) (err error) { +func (p *PromptOpenAPIServiceExecuteResult) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 - if err = oprot.WriteStructBegin("BatchGetPromptByPromptKey_result"); err != nil { + if err = oprot.WriteStructBegin("Execute_result"); err != nil { goto WriteStructBeginError } if p != nil { @@ -9471,7 +12073,7 @@ WriteStructEndError: return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } -func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) writeField0(oprot thrift.TProtocol) (err error) { +func (p *PromptOpenAPIServiceExecuteResult) writeField0(oprot thrift.TProtocol) (err error) { if p.IsSetSuccess() { if err = oprot.WriteFieldBegin("success", thrift.STRUCT, 0); err != nil { goto WriteFieldBeginError @@ -9490,15 +12092,15 @@ WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 0 end error: ", p), err) } -func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) String() string { +func (p *PromptOpenAPIServiceExecuteResult) String() string { if p == nil { return "" } - return fmt.Sprintf("PromptOpenAPIServiceBatchGetPromptByPromptKeyResult(%+v)", *p) + return fmt.Sprintf("PromptOpenAPIServiceExecuteResult(%+v)", *p) } -func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) DeepEqual(ano *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) bool { +func (p *PromptOpenAPIServiceExecuteResult) DeepEqual(ano *PromptOpenAPIServiceExecuteResult) bool { if p == ano { return true } else if p == nil || ano == nil { @@ -9510,7 +12112,7 @@ func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) DeepEqual(ano *Pro return true } -func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) Field0DeepEqual(src *BatchGetPromptByPromptKeyResponse) bool { +func (p *PromptOpenAPIServiceExecuteResult) Field0DeepEqual(src *ExecuteResponse) bool { if !p.Success.DeepEqual(src) { return false @@ -9518,41 +12120,41 @@ func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyResult) Field0DeepEqual(sr return true } -type PromptOpenAPIServiceExecuteArgs struct { +type PromptOpenAPIServiceExecuteStreamingArgs struct { Req *ExecuteRequest `thrift:"req,1" frugal:"1,default,ExecuteRequest"` } -func NewPromptOpenAPIServiceExecuteArgs() *PromptOpenAPIServiceExecuteArgs { - return &PromptOpenAPIServiceExecuteArgs{} +func NewPromptOpenAPIServiceExecuteStreamingArgs() *PromptOpenAPIServiceExecuteStreamingArgs { + return &PromptOpenAPIServiceExecuteStreamingArgs{} } -func (p *PromptOpenAPIServiceExecuteArgs) InitDefault() { +func (p *PromptOpenAPIServiceExecuteStreamingArgs) InitDefault() { } -var PromptOpenAPIServiceExecuteArgs_Req_DEFAULT *ExecuteRequest +var PromptOpenAPIServiceExecuteStreamingArgs_Req_DEFAULT *ExecuteRequest -func (p *PromptOpenAPIServiceExecuteArgs) GetReq() (v *ExecuteRequest) { +func (p *PromptOpenAPIServiceExecuteStreamingArgs) GetReq() (v *ExecuteRequest) { if p == nil { return } if !p.IsSetReq() { - return PromptOpenAPIServiceExecuteArgs_Req_DEFAULT + return PromptOpenAPIServiceExecuteStreamingArgs_Req_DEFAULT } return p.Req } -func (p *PromptOpenAPIServiceExecuteArgs) SetReq(val *ExecuteRequest) { +func (p *PromptOpenAPIServiceExecuteStreamingArgs) SetReq(val *ExecuteRequest) { p.Req = val } -var fieldIDToName_PromptOpenAPIServiceExecuteArgs = map[int16]string{ +var fieldIDToName_PromptOpenAPIServiceExecuteStreamingArgs = map[int16]string{ 1: "req", } -func (p *PromptOpenAPIServiceExecuteArgs) IsSetReq() bool { +func (p *PromptOpenAPIServiceExecuteStreamingArgs) IsSetReq() bool { return p.Req != nil } -func (p *PromptOpenAPIServiceExecuteArgs) Read(iprot thrift.TProtocol) (err error) { +func (p *PromptOpenAPIServiceExecuteStreamingArgs) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -9597,7 +12199,7 @@ ReadStructBeginError: ReadFieldBeginError: return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptOpenAPIServiceExecuteArgs[fieldId]), err) + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptOpenAPIServiceExecuteStreamingArgs[fieldId]), err) SkipFieldError: return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) @@ -9607,7 +12209,7 @@ ReadStructEndError: return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } -func (p *PromptOpenAPIServiceExecuteArgs) ReadField1(iprot thrift.TProtocol) error { +func (p *PromptOpenAPIServiceExecuteStreamingArgs) ReadField1(iprot thrift.TProtocol) error { _field := NewExecuteRequest() if err := _field.Read(iprot); err != nil { return err @@ -9616,9 +12218,9 @@ func (p *PromptOpenAPIServiceExecuteArgs) ReadField1(iprot thrift.TProtocol) err return nil } -func (p *PromptOpenAPIServiceExecuteArgs) Write(oprot thrift.TProtocol) (err error) { +func (p *PromptOpenAPIServiceExecuteStreamingArgs) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 - if err = oprot.WriteStructBegin("Execute_args"); err != nil { + if err = oprot.WriteStructBegin("ExecuteStreaming_args"); err != nil { goto WriteStructBeginError } if p != nil { @@ -9644,7 +12246,7 @@ WriteStructEndError: return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } -func (p *PromptOpenAPIServiceExecuteArgs) writeField1(oprot thrift.TProtocol) (err error) { +func (p *PromptOpenAPIServiceExecuteStreamingArgs) writeField1(oprot thrift.TProtocol) (err error) { if err = oprot.WriteFieldBegin("req", thrift.STRUCT, 1); err != nil { goto WriteFieldBeginError } @@ -9661,15 +12263,15 @@ WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) } -func (p *PromptOpenAPIServiceExecuteArgs) String() string { +func (p *PromptOpenAPIServiceExecuteStreamingArgs) String() string { if p == nil { return "" } - return fmt.Sprintf("PromptOpenAPIServiceExecuteArgs(%+v)", *p) + return fmt.Sprintf("PromptOpenAPIServiceExecuteStreamingArgs(%+v)", *p) } -func (p *PromptOpenAPIServiceExecuteArgs) DeepEqual(ano *PromptOpenAPIServiceExecuteArgs) bool { +func (p *PromptOpenAPIServiceExecuteStreamingArgs) DeepEqual(ano *PromptOpenAPIServiceExecuteStreamingArgs) bool { if p == ano { return true } else if p == nil || ano == nil { @@ -9681,7 +12283,7 @@ func (p *PromptOpenAPIServiceExecuteArgs) DeepEqual(ano *PromptOpenAPIServiceExe return true } -func (p *PromptOpenAPIServiceExecuteArgs) Field1DeepEqual(src *ExecuteRequest) bool { +func (p *PromptOpenAPIServiceExecuteStreamingArgs) Field1DeepEqual(src *ExecuteRequest) bool { if !p.Req.DeepEqual(src) { return false @@ -9689,41 +12291,41 @@ func (p *PromptOpenAPIServiceExecuteArgs) Field1DeepEqual(src *ExecuteRequest) b return true } -type PromptOpenAPIServiceExecuteResult struct { - Success *ExecuteResponse `thrift:"success,0,optional" frugal:"0,optional,ExecuteResponse"` +type PromptOpenAPIServiceExecuteStreamingResult struct { + Success *ExecuteStreamingResponse `thrift:"success,0,optional" frugal:"0,optional,ExecuteStreamingResponse"` } -func NewPromptOpenAPIServiceExecuteResult() *PromptOpenAPIServiceExecuteResult { - return &PromptOpenAPIServiceExecuteResult{} +func NewPromptOpenAPIServiceExecuteStreamingResult() *PromptOpenAPIServiceExecuteStreamingResult { + return &PromptOpenAPIServiceExecuteStreamingResult{} } -func (p *PromptOpenAPIServiceExecuteResult) InitDefault() { +func (p *PromptOpenAPIServiceExecuteStreamingResult) InitDefault() { } -var PromptOpenAPIServiceExecuteResult_Success_DEFAULT *ExecuteResponse +var PromptOpenAPIServiceExecuteStreamingResult_Success_DEFAULT *ExecuteStreamingResponse -func (p *PromptOpenAPIServiceExecuteResult) GetSuccess() (v *ExecuteResponse) { +func (p *PromptOpenAPIServiceExecuteStreamingResult) GetSuccess() (v *ExecuteStreamingResponse) { if p == nil { return } if !p.IsSetSuccess() { - return PromptOpenAPIServiceExecuteResult_Success_DEFAULT + return PromptOpenAPIServiceExecuteStreamingResult_Success_DEFAULT } return p.Success } -func (p *PromptOpenAPIServiceExecuteResult) SetSuccess(x interface{}) { - p.Success = x.(*ExecuteResponse) +func (p *PromptOpenAPIServiceExecuteStreamingResult) SetSuccess(x interface{}) { + p.Success = x.(*ExecuteStreamingResponse) } -var fieldIDToName_PromptOpenAPIServiceExecuteResult = map[int16]string{ +var fieldIDToName_PromptOpenAPIServiceExecuteStreamingResult = map[int16]string{ 0: "success", } -func (p *PromptOpenAPIServiceExecuteResult) IsSetSuccess() bool { +func (p *PromptOpenAPIServiceExecuteStreamingResult) IsSetSuccess() bool { return p.Success != nil } -func (p *PromptOpenAPIServiceExecuteResult) Read(iprot thrift.TProtocol) (err error) { +func (p *PromptOpenAPIServiceExecuteStreamingResult) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -9768,7 +12370,7 @@ ReadStructBeginError: ReadFieldBeginError: return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptOpenAPIServiceExecuteResult[fieldId]), err) + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptOpenAPIServiceExecuteStreamingResult[fieldId]), err) SkipFieldError: return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) @@ -9778,8 +12380,8 @@ ReadStructEndError: return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } -func (p *PromptOpenAPIServiceExecuteResult) ReadField0(iprot thrift.TProtocol) error { - _field := NewExecuteResponse() +func (p *PromptOpenAPIServiceExecuteStreamingResult) ReadField0(iprot thrift.TProtocol) error { + _field := NewExecuteStreamingResponse() if err := _field.Read(iprot); err != nil { return err } @@ -9787,9 +12389,9 @@ func (p *PromptOpenAPIServiceExecuteResult) ReadField0(iprot thrift.TProtocol) e return nil } -func (p *PromptOpenAPIServiceExecuteResult) Write(oprot thrift.TProtocol) (err error) { +func (p *PromptOpenAPIServiceExecuteStreamingResult) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 - if err = oprot.WriteStructBegin("Execute_result"); err != nil { + if err = oprot.WriteStructBegin("ExecuteStreaming_result"); err != nil { goto WriteStructBeginError } if p != nil { @@ -9815,7 +12417,7 @@ WriteStructEndError: return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } -func (p *PromptOpenAPIServiceExecuteResult) writeField0(oprot thrift.TProtocol) (err error) { +func (p *PromptOpenAPIServiceExecuteStreamingResult) writeField0(oprot thrift.TProtocol) (err error) { if p.IsSetSuccess() { if err = oprot.WriteFieldBegin("success", thrift.STRUCT, 0); err != nil { goto WriteFieldBeginError @@ -9834,15 +12436,15 @@ WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 0 end error: ", p), err) } -func (p *PromptOpenAPIServiceExecuteResult) String() string { +func (p *PromptOpenAPIServiceExecuteStreamingResult) String() string { if p == nil { return "" } - return fmt.Sprintf("PromptOpenAPIServiceExecuteResult(%+v)", *p) + return fmt.Sprintf("PromptOpenAPIServiceExecuteStreamingResult(%+v)", *p) } -func (p *PromptOpenAPIServiceExecuteResult) DeepEqual(ano *PromptOpenAPIServiceExecuteResult) bool { +func (p *PromptOpenAPIServiceExecuteStreamingResult) DeepEqual(ano *PromptOpenAPIServiceExecuteStreamingResult) bool { if p == ano { return true } else if p == nil || ano == nil { @@ -9854,7 +12456,7 @@ func (p *PromptOpenAPIServiceExecuteResult) DeepEqual(ano *PromptOpenAPIServiceE return true } -func (p *PromptOpenAPIServiceExecuteResult) Field0DeepEqual(src *ExecuteResponse) bool { +func (p *PromptOpenAPIServiceExecuteStreamingResult) Field0DeepEqual(src *ExecuteStreamingResponse) bool { if !p.Success.DeepEqual(src) { return false @@ -9862,41 +12464,41 @@ func (p *PromptOpenAPIServiceExecuteResult) Field0DeepEqual(src *ExecuteResponse return true } -type PromptOpenAPIServiceExecuteStreamingArgs struct { - Req *ExecuteRequest `thrift:"req,1" frugal:"1,default,ExecuteRequest"` +type PromptOpenAPIServiceListPromptBasicArgs struct { + Req *ListPromptBasicRequest `thrift:"req,1" frugal:"1,default,ListPromptBasicRequest"` } -func NewPromptOpenAPIServiceExecuteStreamingArgs() *PromptOpenAPIServiceExecuteStreamingArgs { - return &PromptOpenAPIServiceExecuteStreamingArgs{} +func NewPromptOpenAPIServiceListPromptBasicArgs() *PromptOpenAPIServiceListPromptBasicArgs { + return &PromptOpenAPIServiceListPromptBasicArgs{} } -func (p *PromptOpenAPIServiceExecuteStreamingArgs) InitDefault() { +func (p *PromptOpenAPIServiceListPromptBasicArgs) InitDefault() { } -var PromptOpenAPIServiceExecuteStreamingArgs_Req_DEFAULT *ExecuteRequest +var PromptOpenAPIServiceListPromptBasicArgs_Req_DEFAULT *ListPromptBasicRequest -func (p *PromptOpenAPIServiceExecuteStreamingArgs) GetReq() (v *ExecuteRequest) { +func (p *PromptOpenAPIServiceListPromptBasicArgs) GetReq() (v *ListPromptBasicRequest) { if p == nil { return } if !p.IsSetReq() { - return PromptOpenAPIServiceExecuteStreamingArgs_Req_DEFAULT + return PromptOpenAPIServiceListPromptBasicArgs_Req_DEFAULT } return p.Req } -func (p *PromptOpenAPIServiceExecuteStreamingArgs) SetReq(val *ExecuteRequest) { +func (p *PromptOpenAPIServiceListPromptBasicArgs) SetReq(val *ListPromptBasicRequest) { p.Req = val } -var fieldIDToName_PromptOpenAPIServiceExecuteStreamingArgs = map[int16]string{ +var fieldIDToName_PromptOpenAPIServiceListPromptBasicArgs = map[int16]string{ 1: "req", } -func (p *PromptOpenAPIServiceExecuteStreamingArgs) IsSetReq() bool { +func (p *PromptOpenAPIServiceListPromptBasicArgs) IsSetReq() bool { return p.Req != nil } -func (p *PromptOpenAPIServiceExecuteStreamingArgs) Read(iprot thrift.TProtocol) (err error) { +func (p *PromptOpenAPIServiceListPromptBasicArgs) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -9941,7 +12543,7 @@ ReadStructBeginError: ReadFieldBeginError: return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptOpenAPIServiceExecuteStreamingArgs[fieldId]), err) + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptOpenAPIServiceListPromptBasicArgs[fieldId]), err) SkipFieldError: return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) @@ -9951,8 +12553,8 @@ ReadStructEndError: return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } -func (p *PromptOpenAPIServiceExecuteStreamingArgs) ReadField1(iprot thrift.TProtocol) error { - _field := NewExecuteRequest() +func (p *PromptOpenAPIServiceListPromptBasicArgs) ReadField1(iprot thrift.TProtocol) error { + _field := NewListPromptBasicRequest() if err := _field.Read(iprot); err != nil { return err } @@ -9960,9 +12562,9 @@ func (p *PromptOpenAPIServiceExecuteStreamingArgs) ReadField1(iprot thrift.TProt return nil } -func (p *PromptOpenAPIServiceExecuteStreamingArgs) Write(oprot thrift.TProtocol) (err error) { +func (p *PromptOpenAPIServiceListPromptBasicArgs) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 - if err = oprot.WriteStructBegin("ExecuteStreaming_args"); err != nil { + if err = oprot.WriteStructBegin("ListPromptBasic_args"); err != nil { goto WriteStructBeginError } if p != nil { @@ -9988,7 +12590,7 @@ WriteStructEndError: return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } -func (p *PromptOpenAPIServiceExecuteStreamingArgs) writeField1(oprot thrift.TProtocol) (err error) { +func (p *PromptOpenAPIServiceListPromptBasicArgs) writeField1(oprot thrift.TProtocol) (err error) { if err = oprot.WriteFieldBegin("req", thrift.STRUCT, 1); err != nil { goto WriteFieldBeginError } @@ -10005,15 +12607,15 @@ WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) } -func (p *PromptOpenAPIServiceExecuteStreamingArgs) String() string { +func (p *PromptOpenAPIServiceListPromptBasicArgs) String() string { if p == nil { return "" } - return fmt.Sprintf("PromptOpenAPIServiceExecuteStreamingArgs(%+v)", *p) + return fmt.Sprintf("PromptOpenAPIServiceListPromptBasicArgs(%+v)", *p) } -func (p *PromptOpenAPIServiceExecuteStreamingArgs) DeepEqual(ano *PromptOpenAPIServiceExecuteStreamingArgs) bool { +func (p *PromptOpenAPIServiceListPromptBasicArgs) DeepEqual(ano *PromptOpenAPIServiceListPromptBasicArgs) bool { if p == ano { return true } else if p == nil || ano == nil { @@ -10025,7 +12627,7 @@ func (p *PromptOpenAPIServiceExecuteStreamingArgs) DeepEqual(ano *PromptOpenAPIS return true } -func (p *PromptOpenAPIServiceExecuteStreamingArgs) Field1DeepEqual(src *ExecuteRequest) bool { +func (p *PromptOpenAPIServiceListPromptBasicArgs) Field1DeepEqual(src *ListPromptBasicRequest) bool { if !p.Req.DeepEqual(src) { return false @@ -10033,41 +12635,41 @@ func (p *PromptOpenAPIServiceExecuteStreamingArgs) Field1DeepEqual(src *ExecuteR return true } -type PromptOpenAPIServiceExecuteStreamingResult struct { - Success *ExecuteStreamingResponse `thrift:"success,0,optional" frugal:"0,optional,ExecuteStreamingResponse"` +type PromptOpenAPIServiceListPromptBasicResult struct { + Success *ListPromptBasicResponse `thrift:"success,0,optional" frugal:"0,optional,ListPromptBasicResponse"` } -func NewPromptOpenAPIServiceExecuteStreamingResult() *PromptOpenAPIServiceExecuteStreamingResult { - return &PromptOpenAPIServiceExecuteStreamingResult{} +func NewPromptOpenAPIServiceListPromptBasicResult() *PromptOpenAPIServiceListPromptBasicResult { + return &PromptOpenAPIServiceListPromptBasicResult{} } -func (p *PromptOpenAPIServiceExecuteStreamingResult) InitDefault() { +func (p *PromptOpenAPIServiceListPromptBasicResult) InitDefault() { } -var PromptOpenAPIServiceExecuteStreamingResult_Success_DEFAULT *ExecuteStreamingResponse +var PromptOpenAPIServiceListPromptBasicResult_Success_DEFAULT *ListPromptBasicResponse -func (p *PromptOpenAPIServiceExecuteStreamingResult) GetSuccess() (v *ExecuteStreamingResponse) { +func (p *PromptOpenAPIServiceListPromptBasicResult) GetSuccess() (v *ListPromptBasicResponse) { if p == nil { return } if !p.IsSetSuccess() { - return PromptOpenAPIServiceExecuteStreamingResult_Success_DEFAULT + return PromptOpenAPIServiceListPromptBasicResult_Success_DEFAULT } return p.Success } -func (p *PromptOpenAPIServiceExecuteStreamingResult) SetSuccess(x interface{}) { - p.Success = x.(*ExecuteStreamingResponse) +func (p *PromptOpenAPIServiceListPromptBasicResult) SetSuccess(x interface{}) { + p.Success = x.(*ListPromptBasicResponse) } -var fieldIDToName_PromptOpenAPIServiceExecuteStreamingResult = map[int16]string{ +var fieldIDToName_PromptOpenAPIServiceListPromptBasicResult = map[int16]string{ 0: "success", } -func (p *PromptOpenAPIServiceExecuteStreamingResult) IsSetSuccess() bool { +func (p *PromptOpenAPIServiceListPromptBasicResult) IsSetSuccess() bool { return p.Success != nil } -func (p *PromptOpenAPIServiceExecuteStreamingResult) Read(iprot thrift.TProtocol) (err error) { +func (p *PromptOpenAPIServiceListPromptBasicResult) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -10112,7 +12714,7 @@ ReadStructBeginError: ReadFieldBeginError: return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptOpenAPIServiceExecuteStreamingResult[fieldId]), err) + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptOpenAPIServiceListPromptBasicResult[fieldId]), err) SkipFieldError: return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) @@ -10122,8 +12724,8 @@ ReadStructEndError: return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } -func (p *PromptOpenAPIServiceExecuteStreamingResult) ReadField0(iprot thrift.TProtocol) error { - _field := NewExecuteStreamingResponse() +func (p *PromptOpenAPIServiceListPromptBasicResult) ReadField0(iprot thrift.TProtocol) error { + _field := NewListPromptBasicResponse() if err := _field.Read(iprot); err != nil { return err } @@ -10131,9 +12733,9 @@ func (p *PromptOpenAPIServiceExecuteStreamingResult) ReadField0(iprot thrift.TPr return nil } -func (p *PromptOpenAPIServiceExecuteStreamingResult) Write(oprot thrift.TProtocol) (err error) { +func (p *PromptOpenAPIServiceListPromptBasicResult) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 - if err = oprot.WriteStructBegin("ExecuteStreaming_result"); err != nil { + if err = oprot.WriteStructBegin("ListPromptBasic_result"); err != nil { goto WriteStructBeginError } if p != nil { @@ -10159,7 +12761,7 @@ WriteStructEndError: return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } -func (p *PromptOpenAPIServiceExecuteStreamingResult) writeField0(oprot thrift.TProtocol) (err error) { +func (p *PromptOpenAPIServiceListPromptBasicResult) writeField0(oprot thrift.TProtocol) (err error) { if p.IsSetSuccess() { if err = oprot.WriteFieldBegin("success", thrift.STRUCT, 0); err != nil { goto WriteFieldBeginError @@ -10178,15 +12780,15 @@ WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 0 end error: ", p), err) } -func (p *PromptOpenAPIServiceExecuteStreamingResult) String() string { +func (p *PromptOpenAPIServiceListPromptBasicResult) String() string { if p == nil { return "" } - return fmt.Sprintf("PromptOpenAPIServiceExecuteStreamingResult(%+v)", *p) + return fmt.Sprintf("PromptOpenAPIServiceListPromptBasicResult(%+v)", *p) } -func (p *PromptOpenAPIServiceExecuteStreamingResult) DeepEqual(ano *PromptOpenAPIServiceExecuteStreamingResult) bool { +func (p *PromptOpenAPIServiceListPromptBasicResult) DeepEqual(ano *PromptOpenAPIServiceListPromptBasicResult) bool { if p == ano { return true } else if p == nil || ano == nil { @@ -10198,7 +12800,7 @@ func (p *PromptOpenAPIServiceExecuteStreamingResult) DeepEqual(ano *PromptOpenAP return true } -func (p *PromptOpenAPIServiceExecuteStreamingResult) Field0DeepEqual(src *ExecuteStreamingResponse) bool { +func (p *PromptOpenAPIServiceListPromptBasicResult) Field0DeepEqual(src *ListPromptBasicResponse) bool { if !p.Success.DeepEqual(src) { return false diff --git a/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi_validator.go b/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi_validator.go index 92804fba6..c2a982f6d 100644 --- a/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi_validator.go +++ b/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi_validator.go @@ -190,3 +190,43 @@ func (p *VariableVal) IsValid() error { func (p *TokenUsage) IsValid() error { return nil } +func (p *ListPromptBasicRequest) IsValid() error { + if p.PageNumber != nil { + if *p.PageNumber <= int32(0) { + return fmt.Errorf("field PageNumber gt rule failed, current value: %v", *p.PageNumber) + } + } + if p.PageSize != nil { + if *p.PageSize <= int32(0) { + return fmt.Errorf("field PageSize gt rule failed, current value: %v", *p.PageSize) + } + if *p.PageSize > int32(200) { + return fmt.Errorf("field PageSize le rule failed, current value: %v", *p.PageSize) + } + } + if p.Base != nil { + if err := p.Base.IsValid(); err != nil { + return fmt.Errorf("field Base not valid, %w", err) + } + } + return nil +} +func (p *ListPromptBasicResponse) IsValid() error { + if p.Data != nil { + if err := p.Data.IsValid(); err != nil { + return fmt.Errorf("field Data not valid, %w", err) + } + } + if p.BaseResp != nil { + if err := p.BaseResp.IsValid(); err != nil { + return fmt.Errorf("field BaseResp not valid, %w", err) + } + } + return nil +} +func (p *PromptBasic) IsValid() error { + return nil +} +func (p *ListPromptBasicData) IsValid() error { + return nil +} diff --git a/backend/kitex_gen/coze/loop/prompt/openapi/k-coze.loop.prompt.openapi.go b/backend/kitex_gen/coze/loop/prompt/openapi/k-coze.loop.prompt.openapi.go index 878e2b3a9..6e29aaf3e 100644 --- a/backend/kitex_gen/coze/loop/prompt/openapi/k-coze.loop.prompt.openapi.go +++ b/backend/kitex_gen/coze/loop/prompt/openapi/k-coze.loop.prompt.openapi.go @@ -6376,6 +6376,1546 @@ func (p *TokenUsage) DeepCopy(s interface{}) error { return nil } +func (p *ListPromptBasicRequest) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.I64 { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.I32 { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 3: + if fieldTypeId == thrift.I32 { + l, err = p.FastReadField3(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 4: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField4(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 5: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField5(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 255: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField255(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ListPromptBasicRequest[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *ListPromptBasicRequest) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field *int64 + if v, l, err := thrift.Binary.ReadI64(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.WorkspaceID = _field + return offset, nil +} + +func (p *ListPromptBasicRequest) FastReadField2(buf []byte) (int, error) { + offset := 0 + + var _field *int32 + if v, l, err := thrift.Binary.ReadI32(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.PageNumber = _field + return offset, nil +} + +func (p *ListPromptBasicRequest) FastReadField3(buf []byte) (int, error) { + offset := 0 + + var _field *int32 + if v, l, err := thrift.Binary.ReadI32(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.PageSize = _field + return offset, nil +} + +func (p *ListPromptBasicRequest) FastReadField4(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.KeyWord = _field + return offset, nil +} + +func (p *ListPromptBasicRequest) FastReadField5(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Creator = _field + return offset, nil +} + +func (p *ListPromptBasicRequest) FastReadField255(buf []byte) (int, error) { + offset := 0 + _field := base.NewBase() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.Base = _field + return offset, nil +} + +func (p *ListPromptBasicRequest) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *ListPromptBasicRequest) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) + offset += p.fastWriteField3(buf[offset:], w) + offset += p.fastWriteField4(buf[offset:], w) + offset += p.fastWriteField5(buf[offset:], w) + offset += p.fastWriteField255(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *ListPromptBasicRequest) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + l += p.field2Length() + l += p.field3Length() + l += p.field4Length() + l += p.field5Length() + l += p.field255Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *ListPromptBasicRequest) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetWorkspaceID() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I64, 1) + offset += thrift.Binary.WriteI64(buf[offset:], *p.WorkspaceID) + } + return offset +} + +func (p *ListPromptBasicRequest) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetPageNumber() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I32, 2) + offset += thrift.Binary.WriteI32(buf[offset:], *p.PageNumber) + } + return offset +} + +func (p *ListPromptBasicRequest) fastWriteField3(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetPageSize() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I32, 3) + offset += thrift.Binary.WriteI32(buf[offset:], *p.PageSize) + } + return offset +} + +func (p *ListPromptBasicRequest) fastWriteField4(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetKeyWord() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 4) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.KeyWord) + } + return offset +} + +func (p *ListPromptBasicRequest) fastWriteField5(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetCreator() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 5) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Creator) + } + return offset +} + +func (p *ListPromptBasicRequest) fastWriteField255(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetBase() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 255) + offset += p.Base.FastWriteNocopy(buf[offset:], w) + } + return offset +} + +func (p *ListPromptBasicRequest) field1Length() int { + l := 0 + if p.IsSetWorkspaceID() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I64Length() + } + return l +} + +func (p *ListPromptBasicRequest) field2Length() int { + l := 0 + if p.IsSetPageNumber() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I32Length() + } + return l +} + +func (p *ListPromptBasicRequest) field3Length() int { + l := 0 + if p.IsSetPageSize() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I32Length() + } + return l +} + +func (p *ListPromptBasicRequest) field4Length() int { + l := 0 + if p.IsSetKeyWord() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.KeyWord) + } + return l +} + +func (p *ListPromptBasicRequest) field5Length() int { + l := 0 + if p.IsSetCreator() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Creator) + } + return l +} + +func (p *ListPromptBasicRequest) field255Length() int { + l := 0 + if p.IsSetBase() { + l += thrift.Binary.FieldBeginLength() + l += p.Base.BLength() + } + return l +} + +func (p *ListPromptBasicRequest) DeepCopy(s interface{}) error { + src, ok := s.(*ListPromptBasicRequest) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.WorkspaceID != nil { + tmp := *src.WorkspaceID + p.WorkspaceID = &tmp + } + + if src.PageNumber != nil { + tmp := *src.PageNumber + p.PageNumber = &tmp + } + + if src.PageSize != nil { + tmp := *src.PageSize + p.PageSize = &tmp + } + + if src.KeyWord != nil { + var tmp string + if *src.KeyWord != "" { + tmp = kutils.StringDeepCopy(*src.KeyWord) + } + p.KeyWord = &tmp + } + + if src.Creator != nil { + var tmp string + if *src.Creator != "" { + tmp = kutils.StringDeepCopy(*src.Creator) + } + p.Creator = &tmp + } + + var _base *base.Base + if src.Base != nil { + _base = &base.Base{} + if err := _base.DeepCopy(src.Base); err != nil { + return err + } + } + p.Base = _base + + return nil +} + +func (p *ListPromptBasicResponse) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.I32 { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 3: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField3(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 255: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField255(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ListPromptBasicResponse[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *ListPromptBasicResponse) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field *int32 + if v, l, err := thrift.Binary.ReadI32(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Code = _field + return offset, nil +} + +func (p *ListPromptBasicResponse) FastReadField2(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Msg = _field + return offset, nil +} + +func (p *ListPromptBasicResponse) FastReadField3(buf []byte) (int, error) { + offset := 0 + _field := NewListPromptBasicData() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.Data = _field + return offset, nil +} + +func (p *ListPromptBasicResponse) FastReadField255(buf []byte) (int, error) { + offset := 0 + _field := base.NewBaseResp() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.BaseResp = _field + return offset, nil +} + +func (p *ListPromptBasicResponse) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *ListPromptBasicResponse) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) + offset += p.fastWriteField3(buf[offset:], w) + offset += p.fastWriteField255(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *ListPromptBasicResponse) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + l += p.field2Length() + l += p.field3Length() + l += p.field255Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *ListPromptBasicResponse) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetCode() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I32, 1) + offset += thrift.Binary.WriteI32(buf[offset:], *p.Code) + } + return offset +} + +func (p *ListPromptBasicResponse) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetMsg() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 2) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Msg) + } + return offset +} + +func (p *ListPromptBasicResponse) fastWriteField3(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetData() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 3) + offset += p.Data.FastWriteNocopy(buf[offset:], w) + } + return offset +} + +func (p *ListPromptBasicResponse) fastWriteField255(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetBaseResp() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 255) + offset += p.BaseResp.FastWriteNocopy(buf[offset:], w) + } + return offset +} + +func (p *ListPromptBasicResponse) field1Length() int { + l := 0 + if p.IsSetCode() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I32Length() + } + return l +} + +func (p *ListPromptBasicResponse) field2Length() int { + l := 0 + if p.IsSetMsg() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Msg) + } + return l +} + +func (p *ListPromptBasicResponse) field3Length() int { + l := 0 + if p.IsSetData() { + l += thrift.Binary.FieldBeginLength() + l += p.Data.BLength() + } + return l +} + +func (p *ListPromptBasicResponse) field255Length() int { + l := 0 + if p.IsSetBaseResp() { + l += thrift.Binary.FieldBeginLength() + l += p.BaseResp.BLength() + } + return l +} + +func (p *ListPromptBasicResponse) DeepCopy(s interface{}) error { + src, ok := s.(*ListPromptBasicResponse) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.Code != nil { + tmp := *src.Code + p.Code = &tmp + } + + if src.Msg != nil { + var tmp string + if *src.Msg != "" { + tmp = kutils.StringDeepCopy(*src.Msg) + } + p.Msg = &tmp + } + + var _data *ListPromptBasicData + if src.Data != nil { + _data = &ListPromptBasicData{} + if err := _data.DeepCopy(src.Data); err != nil { + return err + } + } + p.Data = _data + + var _baseResp *base.BaseResp + if src.BaseResp != nil { + _baseResp = &base.BaseResp{} + if err := _baseResp.DeepCopy(src.BaseResp); err != nil { + return err + } + } + p.BaseResp = _baseResp + + return nil +} + +func (p *PromptBasic) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.I64 { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.I64 { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 3: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField3(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 4: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField4(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 5: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField5(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 6: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField6(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 7: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField7(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 8: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField8(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 9: + if fieldTypeId == thrift.I64 { + l, err = p.FastReadField9(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 10: + if fieldTypeId == thrift.I64 { + l, err = p.FastReadField10(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 11: + if fieldTypeId == thrift.I64 { + l, err = p.FastReadField11(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptBasic[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *PromptBasic) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field *int64 + if v, l, err := thrift.Binary.ReadI64(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.ID = _field + return offset, nil +} + +func (p *PromptBasic) FastReadField2(buf []byte) (int, error) { + offset := 0 + + var _field *int64 + if v, l, err := thrift.Binary.ReadI64(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.WorkspaceID = _field + return offset, nil +} + +func (p *PromptBasic) FastReadField3(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.PromptKey = _field + return offset, nil +} + +func (p *PromptBasic) FastReadField4(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.DisplayName = _field + return offset, nil +} + +func (p *PromptBasic) FastReadField5(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Description = _field + return offset, nil +} + +func (p *PromptBasic) FastReadField6(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.LatestVersion = _field + return offset, nil +} + +func (p *PromptBasic) FastReadField7(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.CreatedBy = _field + return offset, nil +} + +func (p *PromptBasic) FastReadField8(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.UpdatedBy = _field + return offset, nil +} + +func (p *PromptBasic) FastReadField9(buf []byte) (int, error) { + offset := 0 + + var _field *int64 + if v, l, err := thrift.Binary.ReadI64(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.CreatedAt = _field + return offset, nil +} + +func (p *PromptBasic) FastReadField10(buf []byte) (int, error) { + offset := 0 + + var _field *int64 + if v, l, err := thrift.Binary.ReadI64(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.UpdatedAt = _field + return offset, nil +} + +func (p *PromptBasic) FastReadField11(buf []byte) (int, error) { + offset := 0 + + var _field *int64 + if v, l, err := thrift.Binary.ReadI64(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.LatestCommittedAt = _field + return offset, nil +} + +func (p *PromptBasic) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *PromptBasic) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) + offset += p.fastWriteField9(buf[offset:], w) + offset += p.fastWriteField10(buf[offset:], w) + offset += p.fastWriteField11(buf[offset:], w) + offset += p.fastWriteField3(buf[offset:], w) + offset += p.fastWriteField4(buf[offset:], w) + offset += p.fastWriteField5(buf[offset:], w) + offset += p.fastWriteField6(buf[offset:], w) + offset += p.fastWriteField7(buf[offset:], w) + offset += p.fastWriteField8(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *PromptBasic) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + l += p.field2Length() + l += p.field3Length() + l += p.field4Length() + l += p.field5Length() + l += p.field6Length() + l += p.field7Length() + l += p.field8Length() + l += p.field9Length() + l += p.field10Length() + l += p.field11Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *PromptBasic) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetID() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I64, 1) + offset += thrift.Binary.WriteI64(buf[offset:], *p.ID) + } + return offset +} + +func (p *PromptBasic) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetWorkspaceID() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I64, 2) + offset += thrift.Binary.WriteI64(buf[offset:], *p.WorkspaceID) + } + return offset +} + +func (p *PromptBasic) fastWriteField3(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetPromptKey() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 3) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.PromptKey) + } + return offset +} + +func (p *PromptBasic) fastWriteField4(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetDisplayName() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 4) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.DisplayName) + } + return offset +} + +func (p *PromptBasic) fastWriteField5(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetDescription() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 5) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Description) + } + return offset +} + +func (p *PromptBasic) fastWriteField6(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetLatestVersion() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 6) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.LatestVersion) + } + return offset +} + +func (p *PromptBasic) fastWriteField7(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetCreatedBy() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 7) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.CreatedBy) + } + return offset +} + +func (p *PromptBasic) fastWriteField8(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetUpdatedBy() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 8) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.UpdatedBy) + } + return offset +} + +func (p *PromptBasic) fastWriteField9(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetCreatedAt() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I64, 9) + offset += thrift.Binary.WriteI64(buf[offset:], *p.CreatedAt) + } + return offset +} + +func (p *PromptBasic) fastWriteField10(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetUpdatedAt() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I64, 10) + offset += thrift.Binary.WriteI64(buf[offset:], *p.UpdatedAt) + } + return offset +} + +func (p *PromptBasic) fastWriteField11(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetLatestCommittedAt() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I64, 11) + offset += thrift.Binary.WriteI64(buf[offset:], *p.LatestCommittedAt) + } + return offset +} + +func (p *PromptBasic) field1Length() int { + l := 0 + if p.IsSetID() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I64Length() + } + return l +} + +func (p *PromptBasic) field2Length() int { + l := 0 + if p.IsSetWorkspaceID() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I64Length() + } + return l +} + +func (p *PromptBasic) field3Length() int { + l := 0 + if p.IsSetPromptKey() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.PromptKey) + } + return l +} + +func (p *PromptBasic) field4Length() int { + l := 0 + if p.IsSetDisplayName() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.DisplayName) + } + return l +} + +func (p *PromptBasic) field5Length() int { + l := 0 + if p.IsSetDescription() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Description) + } + return l +} + +func (p *PromptBasic) field6Length() int { + l := 0 + if p.IsSetLatestVersion() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.LatestVersion) + } + return l +} + +func (p *PromptBasic) field7Length() int { + l := 0 + if p.IsSetCreatedBy() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.CreatedBy) + } + return l +} + +func (p *PromptBasic) field8Length() int { + l := 0 + if p.IsSetUpdatedBy() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.UpdatedBy) + } + return l +} + +func (p *PromptBasic) field9Length() int { + l := 0 + if p.IsSetCreatedAt() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I64Length() + } + return l +} + +func (p *PromptBasic) field10Length() int { + l := 0 + if p.IsSetUpdatedAt() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I64Length() + } + return l +} + +func (p *PromptBasic) field11Length() int { + l := 0 + if p.IsSetLatestCommittedAt() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I64Length() + } + return l +} + +func (p *PromptBasic) DeepCopy(s interface{}) error { + src, ok := s.(*PromptBasic) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.ID != nil { + tmp := *src.ID + p.ID = &tmp + } + + if src.WorkspaceID != nil { + tmp := *src.WorkspaceID + p.WorkspaceID = &tmp + } + + if src.PromptKey != nil { + var tmp string + if *src.PromptKey != "" { + tmp = kutils.StringDeepCopy(*src.PromptKey) + } + p.PromptKey = &tmp + } + + if src.DisplayName != nil { + var tmp string + if *src.DisplayName != "" { + tmp = kutils.StringDeepCopy(*src.DisplayName) + } + p.DisplayName = &tmp + } + + if src.Description != nil { + var tmp string + if *src.Description != "" { + tmp = kutils.StringDeepCopy(*src.Description) + } + p.Description = &tmp + } + + if src.LatestVersion != nil { + var tmp string + if *src.LatestVersion != "" { + tmp = kutils.StringDeepCopy(*src.LatestVersion) + } + p.LatestVersion = &tmp + } + + if src.CreatedBy != nil { + var tmp string + if *src.CreatedBy != "" { + tmp = kutils.StringDeepCopy(*src.CreatedBy) + } + p.CreatedBy = &tmp + } + + if src.UpdatedBy != nil { + var tmp string + if *src.UpdatedBy != "" { + tmp = kutils.StringDeepCopy(*src.UpdatedBy) + } + p.UpdatedBy = &tmp + } + + if src.CreatedAt != nil { + tmp := *src.CreatedAt + p.CreatedAt = &tmp + } + + if src.UpdatedAt != nil { + tmp := *src.UpdatedAt + p.UpdatedAt = &tmp + } + + if src.LatestCommittedAt != nil { + tmp := *src.LatestCommittedAt + p.LatestCommittedAt = &tmp + } + + return nil +} + +func (p *ListPromptBasicData) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.LIST { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.I32 { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ListPromptBasicData[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *ListPromptBasicData) FastReadField1(buf []byte) (int, error) { + offset := 0 + + _, size, l, err := thrift.Binary.ReadListBegin(buf[offset:]) + offset += l + if err != nil { + return offset, err + } + _field := make([]*PromptBasic, 0, size) + values := make([]PromptBasic, size) + for i := 0; i < size; i++ { + _elem := &values[i] + _elem.InitDefault() + if l, err := _elem.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + + _field = append(_field, _elem) + } + p.Prompts = _field + return offset, nil +} + +func (p *ListPromptBasicData) FastReadField2(buf []byte) (int, error) { + offset := 0 + + var _field *int32 + if v, l, err := thrift.Binary.ReadI32(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Total = _field + return offset, nil +} + +func (p *ListPromptBasicData) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *ListPromptBasicData) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField2(buf[offset:], w) + offset += p.fastWriteField1(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *ListPromptBasicData) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + l += p.field2Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *ListPromptBasicData) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetPrompts() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.LIST, 1) + listBeginOffset := offset + offset += thrift.Binary.ListBeginLength() + var length int + for _, v := range p.Prompts { + length++ + offset += v.FastWriteNocopy(buf[offset:], w) + } + thrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRUCT, length) + } + return offset +} + +func (p *ListPromptBasicData) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetTotal() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I32, 2) + offset += thrift.Binary.WriteI32(buf[offset:], *p.Total) + } + return offset +} + +func (p *ListPromptBasicData) field1Length() int { + l := 0 + if p.IsSetPrompts() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.ListBeginLength() + for _, v := range p.Prompts { + _ = v + l += v.BLength() + } + } + return l +} + +func (p *ListPromptBasicData) field2Length() int { + l := 0 + if p.IsSetTotal() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I32Length() + } + return l +} + +func (p *ListPromptBasicData) DeepCopy(s interface{}) error { + src, ok := s.(*ListPromptBasicData) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.Prompts != nil { + p.Prompts = make([]*PromptBasic, 0, len(src.Prompts)) + for _, elem := range src.Prompts { + var _elem *PromptBasic + if elem != nil { + _elem = &PromptBasic{} + if err := _elem.DeepCopy(elem); err != nil { + return err + } + } + + p.Prompts = append(p.Prompts, _elem) + } + } + + if src.Total != nil { + tmp := *src.Total + p.Total = &tmp + } + + return nil +} + func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) FastRead(buf []byte) (int, error) { var err error @@ -7078,6 +8618,240 @@ func (p *PromptOpenAPIServiceExecuteStreamingResult) DeepCopy(s interface{}) err return nil } +func (p *PromptOpenAPIServiceListPromptBasicArgs) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptOpenAPIServiceListPromptBasicArgs[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *PromptOpenAPIServiceListPromptBasicArgs) FastReadField1(buf []byte) (int, error) { + offset := 0 + _field := NewListPromptBasicRequest() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.Req = _field + return offset, nil +} + +func (p *PromptOpenAPIServiceListPromptBasicArgs) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *PromptOpenAPIServiceListPromptBasicArgs) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *PromptOpenAPIServiceListPromptBasicArgs) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *PromptOpenAPIServiceListPromptBasicArgs) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 1) + offset += p.Req.FastWriteNocopy(buf[offset:], w) + return offset +} + +func (p *PromptOpenAPIServiceListPromptBasicArgs) field1Length() int { + l := 0 + l += thrift.Binary.FieldBeginLength() + l += p.Req.BLength() + return l +} + +func (p *PromptOpenAPIServiceListPromptBasicArgs) DeepCopy(s interface{}) error { + src, ok := s.(*PromptOpenAPIServiceListPromptBasicArgs) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + var _req *ListPromptBasicRequest + if src.Req != nil { + _req = &ListPromptBasicRequest{} + if err := _req.DeepCopy(src.Req); err != nil { + return err + } + } + p.Req = _req + + return nil +} + +func (p *PromptOpenAPIServiceListPromptBasicResult) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 0: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField0(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptOpenAPIServiceListPromptBasicResult[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *PromptOpenAPIServiceListPromptBasicResult) FastReadField0(buf []byte) (int, error) { + offset := 0 + _field := NewListPromptBasicResponse() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.Success = _field + return offset, nil +} + +func (p *PromptOpenAPIServiceListPromptBasicResult) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *PromptOpenAPIServiceListPromptBasicResult) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField0(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *PromptOpenAPIServiceListPromptBasicResult) BLength() int { + l := 0 + if p != nil { + l += p.field0Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *PromptOpenAPIServiceListPromptBasicResult) fastWriteField0(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetSuccess() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 0) + offset += p.Success.FastWriteNocopy(buf[offset:], w) + } + return offset +} + +func (p *PromptOpenAPIServiceListPromptBasicResult) field0Length() int { + l := 0 + if p.IsSetSuccess() { + l += thrift.Binary.FieldBeginLength() + l += p.Success.BLength() + } + return l +} + +func (p *PromptOpenAPIServiceListPromptBasicResult) DeepCopy(s interface{}) error { + src, ok := s.(*PromptOpenAPIServiceListPromptBasicResult) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + var _success *ListPromptBasicResponse + if src.Success != nil { + _success = &ListPromptBasicResponse{} + if err := _success.DeepCopy(src.Success); err != nil { + return err + } + } + p.Success = _success + + return nil +} + func (p *PromptOpenAPIServiceBatchGetPromptByPromptKeyArgs) GetFirstArgument() interface{} { return p.Req } @@ -7101,3 +8875,11 @@ func (p *PromptOpenAPIServiceExecuteStreamingArgs) GetFirstArgument() interface{ func (p *PromptOpenAPIServiceExecuteStreamingResult) GetResult() interface{} { return p.Success } + +func (p *PromptOpenAPIServiceListPromptBasicArgs) GetFirstArgument() interface{} { + return p.Req +} + +func (p *PromptOpenAPIServiceListPromptBasicResult) GetResult() interface{} { + return p.Success +} diff --git a/backend/kitex_gen/coze/loop/prompt/openapi/promptopenapiservice/client.go b/backend/kitex_gen/coze/loop/prompt/openapi/promptopenapiservice/client.go index 32724c2e7..408451a6f 100644 --- a/backend/kitex_gen/coze/loop/prompt/openapi/promptopenapiservice/client.go +++ b/backend/kitex_gen/coze/loop/prompt/openapi/promptopenapiservice/client.go @@ -17,6 +17,7 @@ type Client interface { BatchGetPromptByPromptKey(ctx context.Context, req *openapi.BatchGetPromptByPromptKeyRequest, callOptions ...callopt.Option) (r *openapi.BatchGetPromptByPromptKeyResponse, err error) Execute(ctx context.Context, req *openapi.ExecuteRequest, callOptions ...callopt.Option) (r *openapi.ExecuteResponse, err error) ExecuteStreaming(ctx context.Context, req *openapi.ExecuteRequest, callOptions ...streamcall.Option) (stream PromptOpenAPIService_ExecuteStreamingClient, err error) + ListPromptBasic(ctx context.Context, req *openapi.ListPromptBasicRequest, callOptions ...callopt.Option) (r *openapi.ListPromptBasicResponse, err error) } type PromptOpenAPIService_ExecuteStreamingClient streaming.ServerStreamingClient[openapi.ExecuteStreamingResponse] @@ -66,3 +67,8 @@ func (p *kPromptOpenAPIServiceClient) ExecuteStreaming(ctx context.Context, req ctx = client.NewCtxWithCallOptions(ctx, streamcall.GetCallOptions(callOptions)) return p.kClient.ExecuteStreaming(ctx, req) } + +func (p *kPromptOpenAPIServiceClient) ListPromptBasic(ctx context.Context, req *openapi.ListPromptBasicRequest, callOptions ...callopt.Option) (r *openapi.ListPromptBasicResponse, err error) { + ctx = client.NewCtxWithCallOptions(ctx, callOptions) + return p.kClient.ListPromptBasic(ctx, req) +} diff --git a/backend/kitex_gen/coze/loop/prompt/openapi/promptopenapiservice/promptopenapiservice.go b/backend/kitex_gen/coze/loop/prompt/openapi/promptopenapiservice/promptopenapiservice.go index 9c47b8a42..cdf924c0a 100644 --- a/backend/kitex_gen/coze/loop/prompt/openapi/promptopenapiservice/promptopenapiservice.go +++ b/backend/kitex_gen/coze/loop/prompt/openapi/promptopenapiservice/promptopenapiservice.go @@ -35,6 +35,13 @@ var serviceMethods = map[string]kitex.MethodInfo{ false, kitex.WithStreamingMode(kitex.StreamingServer), ), + "ListPromptBasic": kitex.NewMethodInfo( + listPromptBasicHandler, + newPromptOpenAPIServiceListPromptBasicArgs, + newPromptOpenAPIServiceListPromptBasicResult, + false, + kitex.WithStreamingMode(kitex.StreamingNone), + ), } var ( @@ -127,6 +134,25 @@ func newPromptOpenAPIServiceExecuteStreamingResult() interface{} { return openapi.NewPromptOpenAPIServiceExecuteStreamingResult() } +func listPromptBasicHandler(ctx context.Context, handler interface{}, arg, result interface{}) error { + realArg := arg.(*openapi.PromptOpenAPIServiceListPromptBasicArgs) + realResult := result.(*openapi.PromptOpenAPIServiceListPromptBasicResult) + success, err := handler.(openapi.PromptOpenAPIService).ListPromptBasic(ctx, realArg.Req) + if err != nil { + return err + } + realResult.Success = success + return nil +} + +func newPromptOpenAPIServiceListPromptBasicArgs() interface{} { + return openapi.NewPromptOpenAPIServiceListPromptBasicArgs() +} + +func newPromptOpenAPIServiceListPromptBasicResult() interface{} { + return openapi.NewPromptOpenAPIServiceListPromptBasicResult() +} + type kClient struct { c client.Client sc client.Streaming @@ -173,3 +199,13 @@ func (p *kClient) ExecuteStreaming(ctx context.Context, req *openapi.ExecuteRequ } return stream, nil } + +func (p *kClient) ListPromptBasic(ctx context.Context, req *openapi.ListPromptBasicRequest) (r *openapi.ListPromptBasicResponse, err error) { + var _args openapi.PromptOpenAPIServiceListPromptBasicArgs + _args.Req = req + var _result openapi.PromptOpenAPIServiceListPromptBasicResult + if err = p.c.Call(ctx, "ListPromptBasic", &_args, &_result); err != nil { + return + } + return _result.GetSuccess(), nil +} diff --git a/backend/kitex_gen/coze/loop/prompt/promptopenapiservice/client.go b/backend/kitex_gen/coze/loop/prompt/promptopenapiservice/client.go index 32724c2e7..408451a6f 100644 --- a/backend/kitex_gen/coze/loop/prompt/promptopenapiservice/client.go +++ b/backend/kitex_gen/coze/loop/prompt/promptopenapiservice/client.go @@ -17,6 +17,7 @@ type Client interface { BatchGetPromptByPromptKey(ctx context.Context, req *openapi.BatchGetPromptByPromptKeyRequest, callOptions ...callopt.Option) (r *openapi.BatchGetPromptByPromptKeyResponse, err error) Execute(ctx context.Context, req *openapi.ExecuteRequest, callOptions ...callopt.Option) (r *openapi.ExecuteResponse, err error) ExecuteStreaming(ctx context.Context, req *openapi.ExecuteRequest, callOptions ...streamcall.Option) (stream PromptOpenAPIService_ExecuteStreamingClient, err error) + ListPromptBasic(ctx context.Context, req *openapi.ListPromptBasicRequest, callOptions ...callopt.Option) (r *openapi.ListPromptBasicResponse, err error) } type PromptOpenAPIService_ExecuteStreamingClient streaming.ServerStreamingClient[openapi.ExecuteStreamingResponse] @@ -66,3 +67,8 @@ func (p *kPromptOpenAPIServiceClient) ExecuteStreaming(ctx context.Context, req ctx = client.NewCtxWithCallOptions(ctx, streamcall.GetCallOptions(callOptions)) return p.kClient.ExecuteStreaming(ctx, req) } + +func (p *kPromptOpenAPIServiceClient) ListPromptBasic(ctx context.Context, req *openapi.ListPromptBasicRequest, callOptions ...callopt.Option) (r *openapi.ListPromptBasicResponse, err error) { + ctx = client.NewCtxWithCallOptions(ctx, callOptions) + return p.kClient.ListPromptBasic(ctx, req) +} diff --git a/backend/kitex_gen/coze/loop/prompt/promptopenapiservice/promptopenapiservice.go b/backend/kitex_gen/coze/loop/prompt/promptopenapiservice/promptopenapiservice.go index 28331e999..a0a872052 100644 --- a/backend/kitex_gen/coze/loop/prompt/promptopenapiservice/promptopenapiservice.go +++ b/backend/kitex_gen/coze/loop/prompt/promptopenapiservice/promptopenapiservice.go @@ -36,6 +36,13 @@ var serviceMethods = map[string]kitex.MethodInfo{ false, kitex.WithStreamingMode(kitex.StreamingServer), ), + "ListPromptBasic": kitex.NewMethodInfo( + listPromptBasicHandler, + newPromptOpenAPIServiceListPromptBasicArgs, + newPromptOpenAPIServiceListPromptBasicResult, + false, + kitex.WithStreamingMode(kitex.StreamingNone), + ), } var ( @@ -128,6 +135,25 @@ func newPromptOpenAPIServiceExecuteStreamingResult() interface{} { return openapi.NewPromptOpenAPIServiceExecuteStreamingResult() } +func listPromptBasicHandler(ctx context.Context, handler interface{}, arg, result interface{}) error { + realArg := arg.(*openapi.PromptOpenAPIServiceListPromptBasicArgs) + realResult := result.(*openapi.PromptOpenAPIServiceListPromptBasicResult) + success, err := handler.(openapi.PromptOpenAPIService).ListPromptBasic(ctx, realArg.Req) + if err != nil { + return err + } + realResult.Success = success + return nil +} + +func newPromptOpenAPIServiceListPromptBasicArgs() interface{} { + return openapi.NewPromptOpenAPIServiceListPromptBasicArgs() +} + +func newPromptOpenAPIServiceListPromptBasicResult() interface{} { + return openapi.NewPromptOpenAPIServiceListPromptBasicResult() +} + type kClient struct { c client.Client sc client.Streaming @@ -174,3 +200,13 @@ func (p *kClient) ExecuteStreaming(ctx context.Context, req *openapi.ExecuteRequ } return stream, nil } + +func (p *kClient) ListPromptBasic(ctx context.Context, req *openapi.ListPromptBasicRequest) (r *openapi.ListPromptBasicResponse, err error) { + var _args openapi.PromptOpenAPIServiceListPromptBasicArgs + _args.Req = req + var _result openapi.PromptOpenAPIServiceListPromptBasicResult + if err = p.c.Call(ctx, "ListPromptBasic", &_args, &_result); err != nil { + return + } + return _result.GetSuccess(), nil +} diff --git a/backend/loop_gen/coze/loop/prompt/loopenapi/local_promptopenapiservice.go b/backend/loop_gen/coze/loop/prompt/loopenapi/local_promptopenapiservice.go index db5f7cf57..ddf3bdb02 100644 --- a/backend/loop_gen/coze/loop/prompt/loopenapi/local_promptopenapiservice.go +++ b/backend/loop_gen/coze/loop/prompt/loopenapi/local_promptopenapiservice.go @@ -90,6 +90,27 @@ func (l *LocalPromptOpenAPIService) ExecuteStreaming(ctx context.Context, req *o return ls, nil } +func (l *LocalPromptOpenAPIService) ListPromptBasic(ctx context.Context, req *openapi.ListPromptBasicRequest, callOptions ...callopt.Option) (*openapi.ListPromptBasicResponse, error) { + chain := l.mds(func(ctx context.Context, in, out interface{}) error { + arg := in.(*openapi.PromptOpenAPIServiceListPromptBasicArgs) + result := out.(*openapi.PromptOpenAPIServiceListPromptBasicResult) + resp, err := l.impl.ListPromptBasic(ctx, arg.Req) + if err != nil { + return err + } + result.SetSuccess(resp) + return nil + }) + + arg := &openapi.PromptOpenAPIServiceListPromptBasicArgs{Req: req} + result := &openapi.PromptOpenAPIServiceListPromptBasicResult{} + ctx = l.injectRPCInfo(ctx, "ListPromptBasic") + if err := chain(ctx, arg, result); err != nil { + return nil, err + } + return result.GetSuccess(), nil +} + func (l *LocalPromptOpenAPIService) injectRPCInfo(ctx context.Context, method string) context.Context { rpcStats := rpcinfo.AsMutableRPCStats(rpcinfo.NewRPCStats()) ri := rpcinfo.NewRPCInfo( diff --git a/backend/modules/prompt/application/convertor/openapi.go b/backend/modules/prompt/application/convertor/openapi.go index f0b2836d4..39481c431 100644 --- a/backend/modules/prompt/application/convertor/openapi.go +++ b/backend/modules/prompt/application/convertor/openapi.go @@ -416,3 +416,28 @@ func OpenAPIFunctionCallDTO2DO(dto *openapi.FunctionCall) *entity.FunctionCall { Arguments: dto.Arguments, } } + +// OpenAPIPromptBasicDO2DTO 将entity Prompt转换为openapi PromptBasic +func OpenAPIPromptBasicDO2DTO(do *entity.Prompt) *openapi.PromptBasic { + if do == nil || do.PromptBasic == nil { + return nil + } + return &openapi.PromptBasic{ + ID: ptr.Of(do.ID), + WorkspaceID: ptr.Of(do.SpaceID), + PromptKey: ptr.Of(do.PromptKey), + DisplayName: ptr.Of(do.PromptBasic.DisplayName), + Description: ptr.Of(do.PromptBasic.Description), + LatestVersion: ptr.Of(do.PromptBasic.LatestVersion), + CreatedBy: ptr.Of(do.PromptBasic.CreatedBy), + UpdatedBy: ptr.Of(do.PromptBasic.UpdatedBy), + CreatedAt: ptr.Of(do.PromptBasic.CreatedAt.UnixMilli()), + UpdatedAt: ptr.Of(do.PromptBasic.UpdatedAt.UnixMilli()), + LatestCommittedAt: func() *int64 { + if do.PromptBasic.LatestCommittedAt == nil { + return nil + } + return ptr.Of(do.PromptBasic.LatestCommittedAt.UnixMilli()) + }(), + } +} diff --git a/backend/modules/prompt/application/convertor/openapi_test.go b/backend/modules/prompt/application/convertor/openapi_test.go index 176bb4956..b6982c2f4 100755 --- a/backend/modules/prompt/application/convertor/openapi_test.go +++ b/backend/modules/prompt/application/convertor/openapi_test.go @@ -5,6 +5,7 @@ package convertor import ( "testing" + "time" "github.com/stretchr/testify/assert" @@ -1712,3 +1713,181 @@ func TestOpenAPIBatchToolCallDTO2DO(t *testing.T) { }) } } + +func TestOpenAPIPromptBasicDO2DTO(t *testing.T) { + createdAt := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) + updatedAt := time.Date(2024, 1, 2, 12, 0, 0, 0, time.UTC) + latestCommittedAt := time.Date(2024, 1, 3, 12, 0, 0, 0, time.UTC) + + tests := []struct { + name string + do *entity.Prompt + want *openapi.PromptBasic + }{ + { + name: "nil input", + do: nil, + want: nil, + }, + { + name: "nil prompt basic", + do: &entity.Prompt{ + ID: 123, + SpaceID: 456, + PromptKey: "test_prompt", + PromptBasic: nil, + }, + want: nil, + }, + { + name: "empty prompt basic", + do: &entity.Prompt{ + ID: 123, + SpaceID: 456, + PromptKey: "test_prompt", + PromptBasic: &entity.PromptBasic{}, + }, + want: &openapi.PromptBasic{ + ID: ptr.Of(int64(123)), + WorkspaceID: ptr.Of(int64(456)), + PromptKey: ptr.Of("test_prompt"), + DisplayName: ptr.Of(""), + Description: ptr.Of(""), + LatestVersion: ptr.Of(""), + CreatedBy: ptr.Of(""), + UpdatedBy: ptr.Of(""), + CreatedAt: ptr.Of(time.Time{}.UnixMilli()), // zero value time + UpdatedAt: ptr.Of(time.Time{}.UnixMilli()), // zero value time + LatestCommittedAt: nil, + }, + }, + { + name: "complete prompt basic without latest committed at", + do: &entity.Prompt{ + ID: 123, + SpaceID: 456, + PromptKey: "test_prompt", + PromptBasic: &entity.PromptBasic{ + DisplayName: "Test Prompt", + Description: "A test prompt for testing", + LatestVersion: "1.0.0", + CreatedBy: "user123", + UpdatedBy: "user456", + CreatedAt: createdAt, + UpdatedAt: updatedAt, + LatestCommittedAt: nil, + }, + }, + want: &openapi.PromptBasic{ + ID: ptr.Of(int64(123)), + WorkspaceID: ptr.Of(int64(456)), + PromptKey: ptr.Of("test_prompt"), + DisplayName: ptr.Of("Test Prompt"), + Description: ptr.Of("A test prompt for testing"), + LatestVersion: ptr.Of("1.0.0"), + CreatedBy: ptr.Of("user123"), + UpdatedBy: ptr.Of("user456"), + CreatedAt: ptr.Of(createdAt.UnixMilli()), + UpdatedAt: ptr.Of(updatedAt.UnixMilli()), + LatestCommittedAt: nil, + }, + }, + { + name: "complete prompt basic with latest committed at", + do: &entity.Prompt{ + ID: 123, + SpaceID: 456, + PromptKey: "test_prompt", + PromptBasic: &entity.PromptBasic{ + DisplayName: "Test Prompt", + Description: "A test prompt for testing", + LatestVersion: "1.0.0", + CreatedBy: "user123", + UpdatedBy: "user456", + CreatedAt: createdAt, + UpdatedAt: updatedAt, + LatestCommittedAt: &latestCommittedAt, + }, + }, + want: &openapi.PromptBasic{ + ID: ptr.Of(int64(123)), + WorkspaceID: ptr.Of(int64(456)), + PromptKey: ptr.Of("test_prompt"), + DisplayName: ptr.Of("Test Prompt"), + Description: ptr.Of("A test prompt for testing"), + LatestVersion: ptr.Of("1.0.0"), + CreatedBy: ptr.Of("user123"), + UpdatedBy: ptr.Of("user456"), + CreatedAt: ptr.Of(createdAt.UnixMilli()), + UpdatedAt: ptr.Of(updatedAt.UnixMilli()), + LatestCommittedAt: ptr.Of(latestCommittedAt.UnixMilli()), + }, + }, + { + name: "prompt basic with zero IDs", + do: &entity.Prompt{ + ID: 0, + SpaceID: 0, + PromptKey: "", + PromptBasic: &entity.PromptBasic{ + DisplayName: "New Prompt", + Description: "A newly created prompt", + LatestVersion: "", + CreatedBy: "user789", + UpdatedBy: "user789", + CreatedAt: createdAt, + UpdatedAt: createdAt, + }, + }, + want: &openapi.PromptBasic{ + ID: ptr.Of(int64(0)), + WorkspaceID: ptr.Of(int64(0)), + PromptKey: ptr.Of(""), + DisplayName: ptr.Of("New Prompt"), + Description: ptr.Of("A newly created prompt"), + LatestVersion: ptr.Of(""), + CreatedBy: ptr.Of("user789"), + UpdatedBy: ptr.Of("user789"), + CreatedAt: ptr.Of(createdAt.UnixMilli()), + UpdatedAt: ptr.Of(createdAt.UnixMilli()), + LatestCommittedAt: nil, + }, + }, + { + name: "prompt basic with special characters in text fields", + do: &entity.Prompt{ + ID: 999, + SpaceID: 888, + PromptKey: "prompt_with_special_chars_@#$", + PromptBasic: &entity.PromptBasic{ + DisplayName: "Prompt with 中文 and émojis 🎉", + Description: "Description with\nnewlines\tand\ttabs", + LatestVersion: "2.3.1-beta", + CreatedBy: "user@example.com", + UpdatedBy: "another.user@example.com", + CreatedAt: createdAt, + UpdatedAt: updatedAt, + }, + }, + want: &openapi.PromptBasic{ + ID: ptr.Of(int64(999)), + WorkspaceID: ptr.Of(int64(888)), + PromptKey: ptr.Of("prompt_with_special_chars_@#$"), + DisplayName: ptr.Of("Prompt with 中文 and émojis 🎉"), + Description: ptr.Of("Description with\nnewlines\tand\ttabs"), + LatestVersion: ptr.Of("2.3.1-beta"), + CreatedBy: ptr.Of("user@example.com"), + UpdatedBy: ptr.Of("another.user@example.com"), + CreatedAt: ptr.Of(createdAt.UnixMilli()), + UpdatedAt: ptr.Of(updatedAt.UnixMilli()), + LatestCommittedAt: nil, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, OpenAPIPromptBasicDO2DTO(tt.do)) + }) + } +} diff --git a/backend/modules/prompt/application/openapi.go b/backend/modules/prompt/application/openapi.go index 3ec9810f7..774349b84 100644 --- a/backend/modules/prompt/application/openapi.go +++ b/backend/modules/prompt/application/openapi.go @@ -67,6 +67,65 @@ type PromptOpenAPIApplicationImpl struct { collector collector.ICollectorProvider } +func (p *PromptOpenAPIApplicationImpl) ListPromptBasic(ctx context.Context, req *openapi.ListPromptBasicRequest) (r *openapi.ListPromptBasicResponse, err error) { + r = openapi.NewListPromptBasicResponse() + if req.GetWorkspaceID() == 0 { + return r, errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtra(map[string]string{"invalid_param": "workspace_id参数为空"})) + } + defer func() { + if err != nil { + logs.CtxError(ctx, "openapi list prompt basic failed, err=%v", err) + } + }() + + // 限流检查 + if !p.promptHubAllowBySpace(ctx, req.GetWorkspaceID()) { + return r, errorx.NewByCode(prompterr.PromptHubQPSLimitCode, errorx.WithExtraMsg("qps limit exceeded")) + } + + // 构建查询参数 + param := repo.ListPromptParam{ + SpaceID: req.GetWorkspaceID(), + KeyWord: req.GetKeyWord(), + CommittedOnly: true, // 只查询已提交的prompts + PageNum: int(req.GetPageNumber()), + PageSize: int(req.GetPageSize()), + } + if req.GetCreator() != "" { + param.CreatedBys = []string{req.GetCreator()} + } + + // 查询prompts + result, err := p.promptManageRepo.ListPrompt(ctx, param) + if err != nil { + return nil, err + } + + // 执行权限检查 + var promptIDs []int64 + for _, prompt := range result.PromptDOs { + promptIDs = append(promptIDs, prompt.ID) + } + if len(promptIDs) > 0 { + if err = p.auth.MCheckPromptPermissionForOpenAPI(ctx, req.GetWorkspaceID(), promptIDs, consts.ActionLoopPromptRead); err != nil { + return nil, err + } + } + + // 构建响应 + r.Data = openapi.NewListPromptBasicData() + r.Data.Total = ptr.Of(int32(result.Total)) + r.Data.Prompts = make([]*openapi.PromptBasic, 0, len(result.PromptDOs)) + for _, promptDO := range result.PromptDOs { + promptBasic := convertor.OpenAPIPromptBasicDO2DTO(promptDO) + if promptBasic != nil { + r.Data.Prompts = append(r.Data.Prompts, promptBasic) + } + } + + return r, nil +} + func (p *PromptOpenAPIApplicationImpl) BatchGetPromptByPromptKey(ctx context.Context, req *openapi.BatchGetPromptByPromptKeyRequest) (r *openapi.BatchGetPromptByPromptKeyResponse, err error) { r = openapi.NewBatchGetPromptByPromptKeyResponse() if req.GetWorkspaceID() == 0 { diff --git a/backend/modules/prompt/application/openapi_test.go b/backend/modules/prompt/application/openapi_test.go index a3c7ad6a1..1fa1e40ac 100644 --- a/backend/modules/prompt/application/openapi_test.go +++ b/backend/modules/prompt/application/openapi_test.go @@ -4076,3 +4076,510 @@ func TestPromptOpenAPIApplicationImpl_ExecuteStreaming(t *testing.T) { }) } } + +func TestPromptOpenAPIApplicationImpl_ListPromptBasic(t *testing.T) { + t.Parallel() + + type fields struct { + promptManageRepo repo.IManageRepo + config conf.IConfigProvider + auth rpc.IAuthProvider + rateLimiter limiter.IRateLimiter + } + type args struct { + ctx context.Context + req *openapi.ListPromptBasicRequest + } + + tests := []struct { + name string + fieldsGetter func(ctrl *gomock.Controller) fields + args args + wantR *openapi.ListPromptBasicResponse + wantErr error + }{ + { + name: "success: list prompts basic info", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockManageRepo := repomocks.NewMockIManageRepo(ctrl) + startTime := time.Now() + mockManageRepo.EXPECT().ListPrompt(gomock.Any(), repo.ListPromptParam{ + SpaceID: 123456, + CommittedOnly: true, + PageNum: 1, + PageSize: 10, + }).Return(&repo.ListPromptResult{ + Total: 2, + PromptDOs: []*entity.Prompt{ + { + ID: 123, + SpaceID: 123456, + PromptKey: "test_prompt1", + PromptBasic: &entity.PromptBasic{ + DisplayName: "Test Prompt 1", + Description: "Test Description 1", + LatestVersion: "1.0.0", + CreatedBy: "test_user", + UpdatedBy: "test_user", + CreatedAt: startTime, + UpdatedAt: startTime, + }, + }, + { + ID: 456, + SpaceID: 123456, + PromptKey: "test_prompt2", + PromptBasic: &entity.PromptBasic{ + DisplayName: "Test Prompt 2", + Description: "Test Description 2", + LatestVersion: "2.0.0", + CreatedBy: "test_user", + UpdatedBy: "test_user", + CreatedAt: startTime, + UpdatedAt: startTime, + }, + }, + }, + }, nil) + + mockConfig := confmocks.NewMockIConfigProvider(ctrl) + mockConfig.EXPECT().GetPromptHubMaxQPSBySpace(gomock.Any(), int64(123456)).Return(100, nil) + + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockAuth.EXPECT().MCheckPromptPermissionForOpenAPI(gomock.Any(), int64(123456), []int64{123, 456}, consts.ActionLoopPromptRead).Return(nil) + + mockRateLimiter := limitermocks.NewMockIRateLimiter(ctrl) + mockRateLimiter.EXPECT().AllowN(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&limiter.Result{ + Allowed: true, + }, nil) + + return fields{ + promptManageRepo: mockManageRepo, + config: mockConfig, + auth: mockAuth, + rateLimiter: mockRateLimiter, + } + }, + args: args{ + ctx: context.Background(), + req: &openapi.ListPromptBasicRequest{ + WorkspaceID: ptr.Of(int64(123456)), + PageNumber: ptr.Of(int32(1)), + PageSize: ptr.Of(int32(10)), + }, + }, + wantR: &openapi.ListPromptBasicResponse{ + Data: &openapi.ListPromptBasicData{ + Total: ptr.Of(int32(2)), + Prompts: []*openapi.PromptBasic{ + { + ID: ptr.Of(int64(123)), + WorkspaceID: ptr.Of(int64(123456)), + PromptKey: ptr.Of("test_prompt1"), + DisplayName: ptr.Of("Test Prompt 1"), + Description: ptr.Of("Test Description 1"), + LatestVersion: ptr.Of("1.0.0"), + CreatedBy: ptr.Of("test_user"), + UpdatedBy: ptr.Of("test_user"), + }, + { + ID: ptr.Of(int64(456)), + WorkspaceID: ptr.Of(int64(123456)), + PromptKey: ptr.Of("test_prompt2"), + DisplayName: ptr.Of("Test Prompt 2"), + Description: ptr.Of("Test Description 2"), + LatestVersion: ptr.Of("2.0.0"), + CreatedBy: ptr.Of("test_user"), + UpdatedBy: ptr.Of("test_user"), + }, + }, + }, + }, + wantErr: nil, + }, + { + name: "success: with keyword filter", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockManageRepo := repomocks.NewMockIManageRepo(ctrl) + startTime := time.Now() + mockManageRepo.EXPECT().ListPrompt(gomock.Any(), repo.ListPromptParam{ + SpaceID: 123456, + KeyWord: "test", + CommittedOnly: true, + PageNum: 1, + PageSize: 10, + }).Return(&repo.ListPromptResult{ + Total: 1, + PromptDOs: []*entity.Prompt{ + { + ID: 123, + SpaceID: 123456, + PromptKey: "test_prompt1", + PromptBasic: &entity.PromptBasic{ + DisplayName: "Test Prompt 1", + Description: "Test Description 1", + LatestVersion: "1.0.0", + CreatedBy: "test_user", + UpdatedBy: "test_user", + CreatedAt: startTime, + UpdatedAt: startTime, + }, + }, + }, + }, nil) + + mockConfig := confmocks.NewMockIConfigProvider(ctrl) + mockConfig.EXPECT().GetPromptHubMaxQPSBySpace(gomock.Any(), int64(123456)).Return(100, nil) + + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockAuth.EXPECT().MCheckPromptPermissionForOpenAPI(gomock.Any(), int64(123456), []int64{123}, consts.ActionLoopPromptRead).Return(nil) + + mockRateLimiter := limitermocks.NewMockIRateLimiter(ctrl) + mockRateLimiter.EXPECT().AllowN(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&limiter.Result{ + Allowed: true, + }, nil) + + return fields{ + promptManageRepo: mockManageRepo, + config: mockConfig, + auth: mockAuth, + rateLimiter: mockRateLimiter, + } + }, + args: args{ + ctx: context.Background(), + req: &openapi.ListPromptBasicRequest{ + WorkspaceID: ptr.Of(int64(123456)), + PageNumber: ptr.Of(int32(1)), + PageSize: ptr.Of(int32(10)), + KeyWord: ptr.Of("test"), + }, + }, + wantR: &openapi.ListPromptBasicResponse{ + Data: &openapi.ListPromptBasicData{ + Total: ptr.Of(int32(1)), + Prompts: []*openapi.PromptBasic{ + { + ID: ptr.Of(int64(123)), + WorkspaceID: ptr.Of(int64(123456)), + PromptKey: ptr.Of("test_prompt1"), + DisplayName: ptr.Of("Test Prompt 1"), + Description: ptr.Of("Test Description 1"), + LatestVersion: ptr.Of("1.0.0"), + CreatedBy: ptr.Of("test_user"), + UpdatedBy: ptr.Of("test_user"), + }, + }, + }, + }, + wantErr: nil, + }, + { + name: "success: with creator filter", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockManageRepo := repomocks.NewMockIManageRepo(ctrl) + startTime := time.Now() + mockManageRepo.EXPECT().ListPrompt(gomock.Any(), repo.ListPromptParam{ + SpaceID: 123456, + CreatedBys: []string{"specific_user"}, + CommittedOnly: true, + PageNum: 1, + PageSize: 10, + }).Return(&repo.ListPromptResult{ + Total: 1, + PromptDOs: []*entity.Prompt{ + { + ID: 123, + SpaceID: 123456, + PromptKey: "user_prompt", + PromptBasic: &entity.PromptBasic{ + DisplayName: "User Prompt", + Description: "User Description", + LatestVersion: "1.0.0", + CreatedBy: "specific_user", + UpdatedBy: "specific_user", + CreatedAt: startTime, + UpdatedAt: startTime, + }, + }, + }, + }, nil) + + mockConfig := confmocks.NewMockIConfigProvider(ctrl) + mockConfig.EXPECT().GetPromptHubMaxQPSBySpace(gomock.Any(), int64(123456)).Return(100, nil) + + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockAuth.EXPECT().MCheckPromptPermissionForOpenAPI(gomock.Any(), int64(123456), []int64{123}, consts.ActionLoopPromptRead).Return(nil) + + mockRateLimiter := limitermocks.NewMockIRateLimiter(ctrl) + mockRateLimiter.EXPECT().AllowN(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&limiter.Result{ + Allowed: true, + }, nil) + + return fields{ + promptManageRepo: mockManageRepo, + config: mockConfig, + auth: mockAuth, + rateLimiter: mockRateLimiter, + } + }, + args: args{ + ctx: context.Background(), + req: &openapi.ListPromptBasicRequest{ + WorkspaceID: ptr.Of(int64(123456)), + PageNumber: ptr.Of(int32(1)), + PageSize: ptr.Of(int32(10)), + Creator: ptr.Of("specific_user"), + }, + }, + wantR: &openapi.ListPromptBasicResponse{ + Data: &openapi.ListPromptBasicData{ + Total: ptr.Of(int32(1)), + Prompts: []*openapi.PromptBasic{ + { + ID: ptr.Of(int64(123)), + WorkspaceID: ptr.Of(int64(123456)), + PromptKey: ptr.Of("user_prompt"), + DisplayName: ptr.Of("User Prompt"), + Description: ptr.Of("User Description"), + LatestVersion: ptr.Of("1.0.0"), + CreatedBy: ptr.Of("specific_user"), + UpdatedBy: ptr.Of("specific_user"), + }, + }, + }, + }, + wantErr: nil, + }, + { + name: "success: empty result", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockManageRepo := repomocks.NewMockIManageRepo(ctrl) + mockManageRepo.EXPECT().ListPrompt(gomock.Any(), repo.ListPromptParam{ + SpaceID: 123456, + CommittedOnly: true, + PageNum: 1, + PageSize: 10, + }).Return(&repo.ListPromptResult{ + Total: 0, + PromptDOs: []*entity.Prompt{}, + }, nil) + + mockConfig := confmocks.NewMockIConfigProvider(ctrl) + mockConfig.EXPECT().GetPromptHubMaxQPSBySpace(gomock.Any(), int64(123456)).Return(100, nil) + + mockRateLimiter := limitermocks.NewMockIRateLimiter(ctrl) + mockRateLimiter.EXPECT().AllowN(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&limiter.Result{ + Allowed: true, + }, nil) + + return fields{ + promptManageRepo: mockManageRepo, + config: mockConfig, + rateLimiter: mockRateLimiter, + } + }, + args: args{ + ctx: context.Background(), + req: &openapi.ListPromptBasicRequest{ + WorkspaceID: ptr.Of(int64(123456)), + PageNumber: ptr.Of(int32(1)), + PageSize: ptr.Of(int32(10)), + }, + }, + wantR: &openapi.ListPromptBasicResponse{ + Data: &openapi.ListPromptBasicData{ + Total: ptr.Of(int32(0)), + Prompts: []*openapi.PromptBasic{}, + }, + }, + wantErr: nil, + }, + { + name: "error: workspace_id is zero", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{} + }, + args: args{ + ctx: context.Background(), + req: &openapi.ListPromptBasicRequest{ + WorkspaceID: ptr.Of(int64(0)), + PageNumber: ptr.Of(int32(1)), + PageSize: ptr.Of(int32(10)), + }, + }, + wantR: openapi.NewListPromptBasicResponse(), + wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtra(map[string]string{"invalid_param": "workspace_id参数为空"})), + }, + { + name: "error: workspace_id is nil", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{} + }, + args: args{ + ctx: context.Background(), + req: &openapi.ListPromptBasicRequest{ + WorkspaceID: nil, + PageNumber: ptr.Of(int32(1)), + PageSize: ptr.Of(int32(10)), + }, + }, + wantR: openapi.NewListPromptBasicResponse(), + wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtra(map[string]string{"invalid_param": "workspace_id参数为空"})), + }, + { + name: "error: rate limit exceeded", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockConfig := confmocks.NewMockIConfigProvider(ctrl) + mockConfig.EXPECT().GetPromptHubMaxQPSBySpace(gomock.Any(), int64(123456)).Return(1, nil) + + mockRateLimiter := limitermocks.NewMockIRateLimiter(ctrl) + mockRateLimiter.EXPECT().AllowN(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&limiter.Result{ + Allowed: false, + }, nil) + + return fields{ + config: mockConfig, + rateLimiter: mockRateLimiter, + } + }, + args: args{ + ctx: context.Background(), + req: &openapi.ListPromptBasicRequest{ + WorkspaceID: ptr.Of(int64(123456)), + PageNumber: ptr.Of(int32(1)), + PageSize: ptr.Of(int32(10)), + }, + }, + wantR: openapi.NewListPromptBasicResponse(), + wantErr: errorx.NewByCode(prompterr.PromptHubQPSLimitCode, errorx.WithExtraMsg("qps limit exceeded")), + }, + { + name: "error: list prompt failed", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockManageRepo := repomocks.NewMockIManageRepo(ctrl) + mockManageRepo.EXPECT().ListPrompt(gomock.Any(), gomock.Any()).Return(nil, errors.New("database error")) + + mockConfig := confmocks.NewMockIConfigProvider(ctrl) + mockConfig.EXPECT().GetPromptHubMaxQPSBySpace(gomock.Any(), int64(123456)).Return(100, nil) + + mockRateLimiter := limitermocks.NewMockIRateLimiter(ctrl) + mockRateLimiter.EXPECT().AllowN(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&limiter.Result{ + Allowed: true, + }, nil) + + return fields{ + promptManageRepo: mockManageRepo, + config: mockConfig, + rateLimiter: mockRateLimiter, + } + }, + args: args{ + ctx: context.Background(), + req: &openapi.ListPromptBasicRequest{ + WorkspaceID: ptr.Of(int64(123456)), + PageNumber: ptr.Of(int32(1)), + PageSize: ptr.Of(int32(10)), + }, + }, + wantR: nil, + wantErr: errors.New("database error"), + }, + { + name: "error: permission check failed", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockManageRepo := repomocks.NewMockIManageRepo(ctrl) + startTime := time.Now() + mockManageRepo.EXPECT().ListPrompt(gomock.Any(), gomock.Any()).Return(&repo.ListPromptResult{ + Total: 1, + PromptDOs: []*entity.Prompt{ + { + ID: 123, + SpaceID: 123456, + PromptKey: "test_prompt1", + PromptBasic: &entity.PromptBasic{ + DisplayName: "Test Prompt 1", + Description: "Test Description 1", + LatestVersion: "1.0.0", + CreatedBy: "test_user", + UpdatedBy: "test_user", + CreatedAt: startTime, + UpdatedAt: startTime, + }, + }, + }, + }, nil) + + mockConfig := confmocks.NewMockIConfigProvider(ctrl) + mockConfig.EXPECT().GetPromptHubMaxQPSBySpace(gomock.Any(), int64(123456)).Return(100, nil) + + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockAuth.EXPECT().MCheckPromptPermissionForOpenAPI(gomock.Any(), int64(123456), []int64{123}, consts.ActionLoopPromptRead).Return(errorx.NewByCode(prompterr.CommonNoPermissionCode)) + + mockRateLimiter := limitermocks.NewMockIRateLimiter(ctrl) + mockRateLimiter.EXPECT().AllowN(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&limiter.Result{ + Allowed: true, + }, nil) + + return fields{ + promptManageRepo: mockManageRepo, + config: mockConfig, + auth: mockAuth, + rateLimiter: mockRateLimiter, + } + }, + args: args{ + ctx: context.Background(), + req: &openapi.ListPromptBasicRequest{ + WorkspaceID: ptr.Of(int64(123456)), + PageNumber: ptr.Of(int32(1)), + PageSize: ptr.Of(int32(10)), + }, + }, + wantR: nil, + wantErr: errorx.NewByCode(prompterr.CommonNoPermissionCode), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 移除 t.Parallel() 以避免数据竞争 + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ttFields := tt.fieldsGetter(ctrl) + p := &PromptOpenAPIApplicationImpl{ + promptManageRepo: ttFields.promptManageRepo, + config: ttFields.config, + auth: ttFields.auth, + rateLimiter: ttFields.rateLimiter, + } + gotR, err := p.ListPromptBasic(tt.args.ctx, tt.args.req) + unittest.AssertErrorEqual(t, tt.wantErr, err) + + // 对于成功的测试用例,需要处理时间戳比较 + if err == nil && tt.wantR != nil && gotR != nil && gotR.Data != nil && tt.wantR.Data != nil { + // 比较除时间戳外的其他字段 + assert.Equal(t, tt.wantR.Data.Total, gotR.Data.Total) + assert.Equal(t, len(tt.wantR.Data.Prompts), len(gotR.Data.Prompts)) + + for i, expected := range tt.wantR.Data.Prompts { + if i < len(gotR.Data.Prompts) { + actual := gotR.Data.Prompts[i] + assert.Equal(t, expected.ID, actual.ID) + assert.Equal(t, expected.WorkspaceID, actual.WorkspaceID) + assert.Equal(t, expected.PromptKey, actual.PromptKey) + assert.Equal(t, expected.DisplayName, actual.DisplayName) + assert.Equal(t, expected.Description, actual.Description) + assert.Equal(t, expected.LatestVersion, actual.LatestVersion) + assert.Equal(t, expected.CreatedBy, actual.CreatedBy) + assert.Equal(t, expected.UpdatedBy, actual.UpdatedBy) + // 时间戳字段只检查是否不为nil + assert.NotNil(t, actual.CreatedAt) + assert.NotNil(t, actual.UpdatedAt) + } + } + } else { + assert.Equal(t, tt.wantR, gotR) + } + }) + } +} diff --git a/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift b/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift index e3dd110b7..44c183b99 100644 --- a/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift +++ b/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift @@ -7,6 +7,7 @@ service PromptOpenAPIService { BatchGetPromptByPromptKeyResponse BatchGetPromptByPromptKey(1: BatchGetPromptByPromptKeyRequest req) (api.tag="openapi", api.post='/v1/loop/prompts/mget') ExecuteResponse Execute(1: ExecuteRequest req) (api.tag="openapi", api.post="/v1/loop/prompts/execute") ExecuteStreamingResponse ExecuteStreaming(1: ExecuteRequest req) (api.tag="openapi", api.post="/v1/loop/prompts/execute_streaming", streaming.mode='server') + ListPromptBasicResponse ListPromptBasic(1: ListPromptBasicRequest req) (api.tag="openapi", api.post='/v1/loop/prompts/list') } struct BatchGetPromptByPromptKeyRequest { @@ -204,4 +205,41 @@ struct VariableVal { struct TokenUsage { 1: optional i32 input_tokens // 输入消耗 2: optional i32 output_tokens // 输出消耗 -} \ No newline at end of file +} + +struct ListPromptBasicRequest { + 1: optional i64 workspace_id (api.body="workspace_id", api.js_conv='true', go.tag='json:"workspace_id"') + 2: optional i32 page_number (api.body="page_number", vt.gt = "0") + 3: optional i32 page_size (api.body="page_size", vt.gt = "0", vt.le = "200") + 4: optional string key_word (api.body="key_word") // name/key前缀匹配 + 5: optional string creator (api.body="creator") // 创建人 + + 255: optional base.Base Base +} + +struct ListPromptBasicResponse { + 1: optional i32 code + 2: optional string msg + 3: optional ListPromptBasicData data + + 255: optional base.BaseResp BaseResp +} + +struct PromptBasic { + 1: optional i64 id (api.js_conv='true', go.tag='json:"id"') // Prompt ID + 2: optional i64 workspace_id (api.js_conv='true', go.tag='json:"workspace_id"') // 工作空间ID + 3: optional string prompt_key // 唯一标识 + 4: optional string display_name // Prompt名称 + 5: optional string description // Prompt描述 + 6: optional string latest_version // 最新版本 + 7: optional string created_by // 创建者 + 8: optional string updated_by // 更新者 + 9: optional i64 created_at (api.js_conv='true', go.tag='json:"created_at"') // 创建时间 + 10: optional i64 updated_at (api.js_conv='true', go.tag='json:"updated_at"') // 更新时间 + 11: optional i64 latest_committed_at (api.js_conv='true', go.tag='json:"latest_committed_at"') // 最后提交时间 +} + +struct ListPromptBasicData { + 1: optional list prompts // Prompt列表 + 2: optional i32 total +} From d06d85c03ed6ea596db23750721492f29c097e51 Mon Sep 17 00:00:00 2001 From: kasarolzzw <39260341+kasarolzzw@users.noreply.github.com> Date: Tue, 14 Oct 2025 20:58:19 +0800 Subject: [PATCH 02/12] [feat][prompt] metadata (#224) * feat: [Coda] add prompt metadata pipelines Change-Id: I142040fd89ef89429df790dc6e4dfc5d04e53609 * sql Change-Id: I44b779e1c2a749217cec05dfbc609ae74474a217 * sql Change-Id: Ie46a446e3eca59cee9c3170d85b751c7c2aa630b --- .../loop/prompt/domain/prompt/k-prompt.go | 194 ++++++++++++++ .../coze/loop/prompt/domain/prompt/prompt.go | 250 ++++++++++++++++-- .../openapi/coze.loop.prompt.openapi.go | 234 +++++++++++++++- .../openapi/k-coze.loop.prompt.openapi.go | 194 ++++++++++++++ .../prompt/application/convertor/openapi.go | 3 + .../application/convertor/openapi_test.go | 59 +++++ .../prompt/application/convertor/prompt.go | 4 + .../application/convertor/prompt_test.go | 46 ++++ .../prompt/domain/entity/prompt_detail.go | 9 +- .../infra/repo/mysql/convertor/manage.go | 17 ++ .../infra/repo/mysql/convertor/manage_test.go | 49 ++++ .../mysql/gorm_gen/model/prompt_commit.gen.go | 1 + .../gorm_gen/model/prompt_user_draft.gen.go | 1 + .../mysql/gorm_gen/query/prompt_commit.gen.go | 6 +- .../gorm_gen/query/prompt_user_draft.gen.go | 6 +- .../infra/repo/mysql/prompt_user_draft.go | 1 + .../prompt/coze.loop.prompt.openapi.thrift | 4 + .../coze/loop/prompt/domain/prompt.thrift | 4 + .../mysql-init/init-sql/prompt_commit.sql | 1 + .../mysql-init/init-sql/prompt_user_draft.sql | 1 + .../patch-sql/prompt_commit_alter.sql | 1 + .../patch-sql/prompt_user_draft_alter.sql | 1 + .../init/mysql/init-sql/prompt_commit.sql | 1 + .../mysql/init-sql/prompt_commit_alter.sql | 1 + .../init/mysql/init-sql/prompt_user_draft.sql | 1 + .../init-sql/prompt_user_draft_alter.sql | 1 + 26 files changed, 1058 insertions(+), 32 deletions(-) diff --git a/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go b/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go index 90905c282..bf6daead7 100644 --- a/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go +++ b/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go @@ -2409,6 +2409,20 @@ func (p *PromptTemplate) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 100: + if fieldTypeId == thrift.MAP { + l, err = p.FastReadField100(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } default: l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l @@ -2491,6 +2505,38 @@ func (p *PromptTemplate) FastReadField3(buf []byte) (int, error) { return offset, nil } +func (p *PromptTemplate) FastReadField100(buf []byte) (int, error) { + offset := 0 + + _, _, size, l, err := thrift.Binary.ReadMapBegin(buf[offset:]) + offset += l + if err != nil { + return offset, err + } + _field := make(map[string]string, size) + for i := 0; i < size; i++ { + var _key string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _key = v + } + + var _val string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _val = v + } + + _field[_key] = _val + } + p.Metadata = _field + return offset, nil +} + func (p *PromptTemplate) FastWrite(buf []byte) int { return p.FastWriteNocopy(buf, nil) } @@ -2501,6 +2547,7 @@ func (p *PromptTemplate) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int offset += p.fastWriteField1(buf[offset:], w) offset += p.fastWriteField2(buf[offset:], w) offset += p.fastWriteField3(buf[offset:], w) + offset += p.fastWriteField100(buf[offset:], w) } offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset @@ -2512,6 +2559,7 @@ func (p *PromptTemplate) BLength() int { l += p.field1Length() l += p.field2Length() l += p.field3Length() + l += p.field100Length() } l += thrift.Binary.FieldStopLength() return l @@ -2558,6 +2606,23 @@ func (p *PromptTemplate) fastWriteField3(buf []byte, w thrift.NocopyWriter) int return offset } +func (p *PromptTemplate) fastWriteField100(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetMetadata() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.MAP, 100) + mapBeginOffset := offset + offset += thrift.Binary.MapBeginLength() + var length int + for k, v := range p.Metadata { + length++ + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, k) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, v) + } + thrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.STRING, thrift.STRING, length) + } + return offset +} + func (p *PromptTemplate) field1Length() int { l := 0 if p.IsSetTemplateType() { @@ -2593,6 +2658,21 @@ func (p *PromptTemplate) field3Length() int { return l } +func (p *PromptTemplate) field100Length() int { + l := 0 + if p.IsSetMetadata() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.MapBeginLength() + for k, v := range p.Metadata { + _, _ = k, v + + l += thrift.Binary.StringLengthNocopy(k) + l += thrift.Binary.StringLengthNocopy(v) + } + } + return l +} + func (p *PromptTemplate) DeepCopy(s interface{}) error { src, ok := s.(*PromptTemplate) if !ok { @@ -2634,6 +2714,23 @@ func (p *PromptTemplate) DeepCopy(s interface{}) error { } } + if src.Metadata != nil { + p.Metadata = make(map[string]string, len(src.Metadata)) + for key, val := range src.Metadata { + var _key string + if key != "" { + _key = kutils.StringDeepCopy(key) + } + + var _val string + if val != "" { + _val = kutils.StringDeepCopy(val) + } + + p.Metadata[_key] = _val + } + } + return nil } @@ -3747,6 +3844,20 @@ func (p *Message) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 100: + if fieldTypeId == thrift.MAP { + l, err = p.FastReadField100(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } default: l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l @@ -3871,6 +3982,38 @@ func (p *Message) FastReadField6(buf []byte) (int, error) { return offset, nil } +func (p *Message) FastReadField100(buf []byte) (int, error) { + offset := 0 + + _, _, size, l, err := thrift.Binary.ReadMapBegin(buf[offset:]) + offset += l + if err != nil { + return offset, err + } + _field := make(map[string]string, size) + for i := 0; i < size; i++ { + var _key string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _key = v + } + + var _val string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _val = v + } + + _field[_key] = _val + } + p.Metadata = _field + return offset, nil +} + func (p *Message) FastWrite(buf []byte) int { return p.FastWriteNocopy(buf, nil) } @@ -3884,6 +4027,7 @@ func (p *Message) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset += p.fastWriteField4(buf[offset:], w) offset += p.fastWriteField5(buf[offset:], w) offset += p.fastWriteField6(buf[offset:], w) + offset += p.fastWriteField100(buf[offset:], w) } offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset @@ -3898,6 +4042,7 @@ func (p *Message) BLength() int { l += p.field4Length() l += p.field5Length() l += p.field6Length() + l += p.field100Length() } l += thrift.Binary.FieldStopLength() return l @@ -3971,6 +4116,23 @@ func (p *Message) fastWriteField6(buf []byte, w thrift.NocopyWriter) int { return offset } +func (p *Message) fastWriteField100(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetMetadata() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.MAP, 100) + mapBeginOffset := offset + offset += thrift.Binary.MapBeginLength() + var length int + for k, v := range p.Metadata { + length++ + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, k) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, v) + } + thrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.STRING, thrift.STRING, length) + } + return offset +} + func (p *Message) field1Length() int { l := 0 if p.IsSetRole() { @@ -4033,6 +4195,21 @@ func (p *Message) field6Length() int { return l } +func (p *Message) field100Length() int { + l := 0 + if p.IsSetMetadata() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.MapBeginLength() + for k, v := range p.Metadata { + _, _ = k, v + + l += thrift.Binary.StringLengthNocopy(k) + l += thrift.Binary.StringLengthNocopy(v) + } + } + return l +} + func (p *Message) DeepCopy(s interface{}) error { src, ok := s.(*Message) if !ok { @@ -4098,6 +4275,23 @@ func (p *Message) DeepCopy(s interface{}) error { } } + if src.Metadata != nil { + p.Metadata = make(map[string]string, len(src.Metadata)) + for key, val := range src.Metadata { + var _key string + if key != "" { + _key = kutils.StringDeepCopy(key) + } + + var _val string + if val != "" { + _val = kutils.StringDeepCopy(val) + } + + p.Metadata[_key] = _val + } + } + return nil } diff --git a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go index ed908b4aa..6476b21cb 100644 --- a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go +++ b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go @@ -3319,9 +3319,10 @@ func (p *PromptDetail) Field255DeepEqual(src map[string]string) bool { } type PromptTemplate struct { - TemplateType *TemplateType `thrift:"template_type,1,optional" frugal:"1,optional,string" form:"template_type" json:"template_type,omitempty" query:"template_type"` - Messages []*Message `thrift:"messages,2,optional" frugal:"2,optional,list" form:"messages" json:"messages,omitempty" query:"messages"` - VariableDefs []*VariableDef `thrift:"variable_defs,3,optional" frugal:"3,optional,list" form:"variable_defs" json:"variable_defs,omitempty" query:"variable_defs"` + TemplateType *TemplateType `thrift:"template_type,1,optional" frugal:"1,optional,string" form:"template_type" json:"template_type,omitempty" query:"template_type"` + Messages []*Message `thrift:"messages,2,optional" frugal:"2,optional,list" form:"messages" json:"messages,omitempty" query:"messages"` + VariableDefs []*VariableDef `thrift:"variable_defs,3,optional" frugal:"3,optional,list" form:"variable_defs" json:"variable_defs,omitempty" query:"variable_defs"` + Metadata map[string]string `thrift:"metadata,100,optional" frugal:"100,optional,map" form:"metadata" json:"metadata,omitempty" query:"metadata"` } func NewPromptTemplate() *PromptTemplate { @@ -3366,6 +3367,18 @@ func (p *PromptTemplate) GetVariableDefs() (v []*VariableDef) { } return p.VariableDefs } + +var PromptTemplate_Metadata_DEFAULT map[string]string + +func (p *PromptTemplate) GetMetadata() (v map[string]string) { + if p == nil { + return + } + if !p.IsSetMetadata() { + return PromptTemplate_Metadata_DEFAULT + } + return p.Metadata +} func (p *PromptTemplate) SetTemplateType(val *TemplateType) { p.TemplateType = val } @@ -3375,11 +3388,15 @@ func (p *PromptTemplate) SetMessages(val []*Message) { func (p *PromptTemplate) SetVariableDefs(val []*VariableDef) { p.VariableDefs = val } +func (p *PromptTemplate) SetMetadata(val map[string]string) { + p.Metadata = val +} var fieldIDToName_PromptTemplate = map[int16]string{ - 1: "template_type", - 2: "messages", - 3: "variable_defs", + 1: "template_type", + 2: "messages", + 3: "variable_defs", + 100: "metadata", } func (p *PromptTemplate) IsSetTemplateType() bool { @@ -3394,6 +3411,10 @@ func (p *PromptTemplate) IsSetVariableDefs() bool { return p.VariableDefs != nil } +func (p *PromptTemplate) IsSetMetadata() bool { + return p.Metadata != nil +} + func (p *PromptTemplate) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -3436,6 +3457,14 @@ func (p *PromptTemplate) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 100: + if fieldTypeId == thrift.MAP { + if err = p.ReadField100(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError @@ -3522,6 +3551,35 @@ func (p *PromptTemplate) ReadField3(iprot thrift.TProtocol) error { p.VariableDefs = _field return nil } +func (p *PromptTemplate) ReadField100(iprot thrift.TProtocol) error { + _, _, size, err := iprot.ReadMapBegin() + if err != nil { + return err + } + _field := make(map[string]string, size) + for i := 0; i < size; i++ { + var _key string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _key = v + } + + var _val string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _val = v + } + + _field[_key] = _val + } + if err := iprot.ReadMapEnd(); err != nil { + return err + } + p.Metadata = _field + return nil +} func (p *PromptTemplate) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 @@ -3541,6 +3599,10 @@ func (p *PromptTemplate) Write(oprot thrift.TProtocol) (err error) { fieldId = 3 goto WriteFieldError } + if err = p.writeField100(oprot); err != nil { + fieldId = 100 + goto WriteFieldError + } } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -3629,6 +3691,35 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) } +func (p *PromptTemplate) writeField100(oprot thrift.TProtocol) (err error) { + if p.IsSetMetadata() { + if err = oprot.WriteFieldBegin("metadata", thrift.MAP, 100); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteMapBegin(thrift.STRING, thrift.STRING, len(p.Metadata)); err != nil { + return err + } + for k, v := range p.Metadata { + if err := oprot.WriteString(k); err != nil { + return err + } + if err := oprot.WriteString(v); err != nil { + return err + } + } + if err := oprot.WriteMapEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 100 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 100 end error: ", p), err) +} func (p *PromptTemplate) String() string { if p == nil { @@ -3653,6 +3744,9 @@ func (p *PromptTemplate) DeepEqual(ano *PromptTemplate) bool { if !p.Field3DeepEqual(ano.VariableDefs) { return false } + if !p.Field100DeepEqual(ano.Metadata) { + return false + } return true } @@ -3694,6 +3788,19 @@ func (p *PromptTemplate) Field3DeepEqual(src []*VariableDef) bool { } return true } +func (p *PromptTemplate) Field100DeepEqual(src map[string]string) bool { + + if len(p.Metadata) != len(src) { + return false + } + for k, v := range p.Metadata { + _src := src[k] + if strings.Compare(v, _src) != 0 { + return false + } + } + return true +} type Tool struct { Type *ToolType `thrift:"type,1,optional" frugal:"1,optional,string" form:"type" json:"type,omitempty" query:"type"` @@ -5182,12 +5289,13 @@ func (p *ModelConfig) Field8DeepEqual(src *bool) bool { } type Message struct { - Role *Role `thrift:"role,1,optional" frugal:"1,optional,string" form:"role" json:"role,omitempty" query:"role"` - ReasoningContent *string `thrift:"reasoning_content,2,optional" frugal:"2,optional,string" form:"reasoning_content" json:"reasoning_content,omitempty" query:"reasoning_content"` - Content *string `thrift:"content,3,optional" frugal:"3,optional,string" form:"content" json:"content,omitempty" query:"content"` - Parts []*ContentPart `thrift:"parts,4,optional" frugal:"4,optional,list" form:"parts" json:"parts,omitempty" query:"parts"` - ToolCallID *string `thrift:"tool_call_id,5,optional" frugal:"5,optional,string" form:"tool_call_id" json:"tool_call_id,omitempty" query:"tool_call_id"` - ToolCalls []*ToolCall `thrift:"tool_calls,6,optional" frugal:"6,optional,list" form:"tool_calls" json:"tool_calls,omitempty" query:"tool_calls"` + Role *Role `thrift:"role,1,optional" frugal:"1,optional,string" form:"role" json:"role,omitempty" query:"role"` + ReasoningContent *string `thrift:"reasoning_content,2,optional" frugal:"2,optional,string" form:"reasoning_content" json:"reasoning_content,omitempty" query:"reasoning_content"` + Content *string `thrift:"content,3,optional" frugal:"3,optional,string" form:"content" json:"content,omitempty" query:"content"` + Parts []*ContentPart `thrift:"parts,4,optional" frugal:"4,optional,list" form:"parts" json:"parts,omitempty" query:"parts"` + ToolCallID *string `thrift:"tool_call_id,5,optional" frugal:"5,optional,string" form:"tool_call_id" json:"tool_call_id,omitempty" query:"tool_call_id"` + ToolCalls []*ToolCall `thrift:"tool_calls,6,optional" frugal:"6,optional,list" form:"tool_calls" json:"tool_calls,omitempty" query:"tool_calls"` + Metadata map[string]string `thrift:"metadata,100,optional" frugal:"100,optional,map" form:"metadata" json:"metadata,omitempty" query:"metadata"` } func NewMessage() *Message { @@ -5268,6 +5376,18 @@ func (p *Message) GetToolCalls() (v []*ToolCall) { } return p.ToolCalls } + +var Message_Metadata_DEFAULT map[string]string + +func (p *Message) GetMetadata() (v map[string]string) { + if p == nil { + return + } + if !p.IsSetMetadata() { + return Message_Metadata_DEFAULT + } + return p.Metadata +} func (p *Message) SetRole(val *Role) { p.Role = val } @@ -5286,14 +5406,18 @@ func (p *Message) SetToolCallID(val *string) { func (p *Message) SetToolCalls(val []*ToolCall) { p.ToolCalls = val } +func (p *Message) SetMetadata(val map[string]string) { + p.Metadata = val +} var fieldIDToName_Message = map[int16]string{ - 1: "role", - 2: "reasoning_content", - 3: "content", - 4: "parts", - 5: "tool_call_id", - 6: "tool_calls", + 1: "role", + 2: "reasoning_content", + 3: "content", + 4: "parts", + 5: "tool_call_id", + 6: "tool_calls", + 100: "metadata", } func (p *Message) IsSetRole() bool { @@ -5320,6 +5444,10 @@ func (p *Message) IsSetToolCalls() bool { return p.ToolCalls != nil } +func (p *Message) IsSetMetadata() bool { + return p.Metadata != nil +} + func (p *Message) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -5386,6 +5514,14 @@ func (p *Message) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 100: + if fieldTypeId == thrift.MAP { + if err = p.ReadField100(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError @@ -5505,6 +5641,35 @@ func (p *Message) ReadField6(iprot thrift.TProtocol) error { p.ToolCalls = _field return nil } +func (p *Message) ReadField100(iprot thrift.TProtocol) error { + _, _, size, err := iprot.ReadMapBegin() + if err != nil { + return err + } + _field := make(map[string]string, size) + for i := 0; i < size; i++ { + var _key string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _key = v + } + + var _val string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _val = v + } + + _field[_key] = _val + } + if err := iprot.ReadMapEnd(); err != nil { + return err + } + p.Metadata = _field + return nil +} func (p *Message) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 @@ -5536,6 +5701,10 @@ func (p *Message) Write(oprot thrift.TProtocol) (err error) { fieldId = 6 goto WriteFieldError } + if err = p.writeField100(oprot); err != nil { + fieldId = 100 + goto WriteFieldError + } } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -5678,6 +5847,35 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 6 end error: ", p), err) } +func (p *Message) writeField100(oprot thrift.TProtocol) (err error) { + if p.IsSetMetadata() { + if err = oprot.WriteFieldBegin("metadata", thrift.MAP, 100); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteMapBegin(thrift.STRING, thrift.STRING, len(p.Metadata)); err != nil { + return err + } + for k, v := range p.Metadata { + if err := oprot.WriteString(k); err != nil { + return err + } + if err := oprot.WriteString(v); err != nil { + return err + } + } + if err := oprot.WriteMapEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 100 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 100 end error: ", p), err) +} func (p *Message) String() string { if p == nil { @@ -5711,6 +5909,9 @@ func (p *Message) DeepEqual(ano *Message) bool { if !p.Field6DeepEqual(ano.ToolCalls) { return false } + if !p.Field100DeepEqual(ano.Metadata) { + return false + } return true } @@ -5788,6 +5989,19 @@ func (p *Message) Field6DeepEqual(src []*ToolCall) bool { } return true } +func (p *Message) Field100DeepEqual(src map[string]string) bool { + + if len(p.Metadata) != len(src) { + return false + } + for k, v := range p.Metadata { + _src := src[k] + if strings.Compare(v, _src) != 0 { + return false + } + } + return true +} type ContentPart struct { Type *ContentType `thrift:"type,1,optional" frugal:"1,optional,string" form:"type" json:"type,omitempty" query:"type"` diff --git a/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go b/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go index b4a924e17..53c8147f1 100644 --- a/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go +++ b/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go @@ -4443,6 +4443,8 @@ type PromptTemplate struct { Messages []*Message `thrift:"messages,2,optional" frugal:"2,optional,list" form:"messages" json:"messages,omitempty" query:"messages"` // 变量定义 VariableDefs []*VariableDef `thrift:"variable_defs,3,optional" frugal:"3,optional,list" form:"variable_defs" json:"variable_defs,omitempty" query:"variable_defs"` + // 模板级元信息 + Metadata map[string]string `thrift:"metadata,100,optional" frugal:"100,optional,map" form:"metadata" json:"metadata,omitempty" query:"metadata"` } func NewPromptTemplate() *PromptTemplate { @@ -4487,6 +4489,18 @@ func (p *PromptTemplate) GetVariableDefs() (v []*VariableDef) { } return p.VariableDefs } + +var PromptTemplate_Metadata_DEFAULT map[string]string + +func (p *PromptTemplate) GetMetadata() (v map[string]string) { + if p == nil { + return + } + if !p.IsSetMetadata() { + return PromptTemplate_Metadata_DEFAULT + } + return p.Metadata +} func (p *PromptTemplate) SetTemplateType(val *TemplateType) { p.TemplateType = val } @@ -4496,11 +4510,15 @@ func (p *PromptTemplate) SetMessages(val []*Message) { func (p *PromptTemplate) SetVariableDefs(val []*VariableDef) { p.VariableDefs = val } +func (p *PromptTemplate) SetMetadata(val map[string]string) { + p.Metadata = val +} var fieldIDToName_PromptTemplate = map[int16]string{ - 1: "template_type", - 2: "messages", - 3: "variable_defs", + 1: "template_type", + 2: "messages", + 3: "variable_defs", + 100: "metadata", } func (p *PromptTemplate) IsSetTemplateType() bool { @@ -4515,6 +4533,10 @@ func (p *PromptTemplate) IsSetVariableDefs() bool { return p.VariableDefs != nil } +func (p *PromptTemplate) IsSetMetadata() bool { + return p.Metadata != nil +} + func (p *PromptTemplate) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -4557,6 +4579,14 @@ func (p *PromptTemplate) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 100: + if fieldTypeId == thrift.MAP { + if err = p.ReadField100(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError @@ -4643,6 +4673,35 @@ func (p *PromptTemplate) ReadField3(iprot thrift.TProtocol) error { p.VariableDefs = _field return nil } +func (p *PromptTemplate) ReadField100(iprot thrift.TProtocol) error { + _, _, size, err := iprot.ReadMapBegin() + if err != nil { + return err + } + _field := make(map[string]string, size) + for i := 0; i < size; i++ { + var _key string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _key = v + } + + var _val string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _val = v + } + + _field[_key] = _val + } + if err := iprot.ReadMapEnd(); err != nil { + return err + } + p.Metadata = _field + return nil +} func (p *PromptTemplate) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 @@ -4662,6 +4721,10 @@ func (p *PromptTemplate) Write(oprot thrift.TProtocol) (err error) { fieldId = 3 goto WriteFieldError } + if err = p.writeField100(oprot); err != nil { + fieldId = 100 + goto WriteFieldError + } } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -4750,6 +4813,35 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) } +func (p *PromptTemplate) writeField100(oprot thrift.TProtocol) (err error) { + if p.IsSetMetadata() { + if err = oprot.WriteFieldBegin("metadata", thrift.MAP, 100); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteMapBegin(thrift.STRING, thrift.STRING, len(p.Metadata)); err != nil { + return err + } + for k, v := range p.Metadata { + if err := oprot.WriteString(k); err != nil { + return err + } + if err := oprot.WriteString(v); err != nil { + return err + } + } + if err := oprot.WriteMapEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 100 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 100 end error: ", p), err) +} func (p *PromptTemplate) String() string { if p == nil { @@ -4774,6 +4866,9 @@ func (p *PromptTemplate) DeepEqual(ano *PromptTemplate) bool { if !p.Field3DeepEqual(ano.VariableDefs) { return false } + if !p.Field100DeepEqual(ano.Metadata) { + return false + } return true } @@ -4815,6 +4910,19 @@ func (p *PromptTemplate) Field3DeepEqual(src []*VariableDef) bool { } return true } +func (p *PromptTemplate) Field100DeepEqual(src map[string]string) bool { + + if len(p.Metadata) != len(src) { + return false + } + for k, v := range p.Metadata { + _src := src[k] + if strings.Compare(v, _src) != 0 { + return false + } + } + return true +} type ToolCallConfig struct { ToolChoice *ToolChoiceType `thrift:"tool_choice,1,optional" frugal:"1,optional,string" form:"tool_choice" json:"tool_choice,omitempty" query:"tool_choice"` @@ -5010,6 +5118,8 @@ type Message struct { ToolCallID *string `thrift:"tool_call_id,5,optional" frugal:"5,optional,string" form:"tool_call_id" json:"tool_call_id,omitempty" query:"tool_call_id"` // tool调用(role为assistant时有效) ToolCalls []*ToolCall `thrift:"tool_calls,6,optional" frugal:"6,optional,list" form:"tool_calls" json:"tool_calls,omitempty" query:"tool_calls"` + // 消息元信息 + Metadata map[string]string `thrift:"metadata,100,optional" frugal:"100,optional,map" form:"metadata" json:"metadata,omitempty" query:"metadata"` } func NewMessage() *Message { @@ -5090,6 +5200,18 @@ func (p *Message) GetToolCalls() (v []*ToolCall) { } return p.ToolCalls } + +var Message_Metadata_DEFAULT map[string]string + +func (p *Message) GetMetadata() (v map[string]string) { + if p == nil { + return + } + if !p.IsSetMetadata() { + return Message_Metadata_DEFAULT + } + return p.Metadata +} func (p *Message) SetRole(val *Role) { p.Role = val } @@ -5108,14 +5230,18 @@ func (p *Message) SetToolCallID(val *string) { func (p *Message) SetToolCalls(val []*ToolCall) { p.ToolCalls = val } +func (p *Message) SetMetadata(val map[string]string) { + p.Metadata = val +} var fieldIDToName_Message = map[int16]string{ - 1: "role", - 2: "content", - 3: "parts", - 4: "reasoning_content", - 5: "tool_call_id", - 6: "tool_calls", + 1: "role", + 2: "content", + 3: "parts", + 4: "reasoning_content", + 5: "tool_call_id", + 6: "tool_calls", + 100: "metadata", } func (p *Message) IsSetRole() bool { @@ -5142,6 +5268,10 @@ func (p *Message) IsSetToolCalls() bool { return p.ToolCalls != nil } +func (p *Message) IsSetMetadata() bool { + return p.Metadata != nil +} + func (p *Message) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -5208,6 +5338,14 @@ func (p *Message) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 100: + if fieldTypeId == thrift.MAP { + if err = p.ReadField100(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError @@ -5327,6 +5465,35 @@ func (p *Message) ReadField6(iprot thrift.TProtocol) error { p.ToolCalls = _field return nil } +func (p *Message) ReadField100(iprot thrift.TProtocol) error { + _, _, size, err := iprot.ReadMapBegin() + if err != nil { + return err + } + _field := make(map[string]string, size) + for i := 0; i < size; i++ { + var _key string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _key = v + } + + var _val string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _val = v + } + + _field[_key] = _val + } + if err := iprot.ReadMapEnd(); err != nil { + return err + } + p.Metadata = _field + return nil +} func (p *Message) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 @@ -5358,6 +5525,10 @@ func (p *Message) Write(oprot thrift.TProtocol) (err error) { fieldId = 6 goto WriteFieldError } + if err = p.writeField100(oprot); err != nil { + fieldId = 100 + goto WriteFieldError + } } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -5500,6 +5671,35 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 6 end error: ", p), err) } +func (p *Message) writeField100(oprot thrift.TProtocol) (err error) { + if p.IsSetMetadata() { + if err = oprot.WriteFieldBegin("metadata", thrift.MAP, 100); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteMapBegin(thrift.STRING, thrift.STRING, len(p.Metadata)); err != nil { + return err + } + for k, v := range p.Metadata { + if err := oprot.WriteString(k); err != nil { + return err + } + if err := oprot.WriteString(v); err != nil { + return err + } + } + if err := oprot.WriteMapEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 100 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 100 end error: ", p), err) +} func (p *Message) String() string { if p == nil { @@ -5533,6 +5733,9 @@ func (p *Message) DeepEqual(ano *Message) bool { if !p.Field6DeepEqual(ano.ToolCalls) { return false } + if !p.Field100DeepEqual(ano.Metadata) { + return false + } return true } @@ -5610,6 +5813,19 @@ func (p *Message) Field6DeepEqual(src []*ToolCall) bool { } return true } +func (p *Message) Field100DeepEqual(src map[string]string) bool { + + if len(p.Metadata) != len(src) { + return false + } + for k, v := range p.Metadata { + _src := src[k] + if strings.Compare(v, _src) != 0 { + return false + } + } + return true +} type ContentPart struct { Type *ContentType `thrift:"type,1,optional" frugal:"1,optional,string" form:"type" json:"type,omitempty" query:"type"` diff --git a/backend/kitex_gen/coze/loop/prompt/openapi/k-coze.loop.prompt.openapi.go b/backend/kitex_gen/coze/loop/prompt/openapi/k-coze.loop.prompt.openapi.go index 6e29aaf3e..c8ed71fa6 100644 --- a/backend/kitex_gen/coze/loop/prompt/openapi/k-coze.loop.prompt.openapi.go +++ b/backend/kitex_gen/coze/loop/prompt/openapi/k-coze.loop.prompt.openapi.go @@ -3250,6 +3250,20 @@ func (p *PromptTemplate) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 100: + if fieldTypeId == thrift.MAP { + l, err = p.FastReadField100(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } default: l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l @@ -3332,6 +3346,38 @@ func (p *PromptTemplate) FastReadField3(buf []byte) (int, error) { return offset, nil } +func (p *PromptTemplate) FastReadField100(buf []byte) (int, error) { + offset := 0 + + _, _, size, l, err := thrift.Binary.ReadMapBegin(buf[offset:]) + offset += l + if err != nil { + return offset, err + } + _field := make(map[string]string, size) + for i := 0; i < size; i++ { + var _key string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _key = v + } + + var _val string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _val = v + } + + _field[_key] = _val + } + p.Metadata = _field + return offset, nil +} + func (p *PromptTemplate) FastWrite(buf []byte) int { return p.FastWriteNocopy(buf, nil) } @@ -3342,6 +3388,7 @@ func (p *PromptTemplate) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int offset += p.fastWriteField1(buf[offset:], w) offset += p.fastWriteField2(buf[offset:], w) offset += p.fastWriteField3(buf[offset:], w) + offset += p.fastWriteField100(buf[offset:], w) } offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset @@ -3353,6 +3400,7 @@ func (p *PromptTemplate) BLength() int { l += p.field1Length() l += p.field2Length() l += p.field3Length() + l += p.field100Length() } l += thrift.Binary.FieldStopLength() return l @@ -3399,6 +3447,23 @@ func (p *PromptTemplate) fastWriteField3(buf []byte, w thrift.NocopyWriter) int return offset } +func (p *PromptTemplate) fastWriteField100(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetMetadata() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.MAP, 100) + mapBeginOffset := offset + offset += thrift.Binary.MapBeginLength() + var length int + for k, v := range p.Metadata { + length++ + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, k) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, v) + } + thrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.STRING, thrift.STRING, length) + } + return offset +} + func (p *PromptTemplate) field1Length() int { l := 0 if p.IsSetTemplateType() { @@ -3434,6 +3499,21 @@ func (p *PromptTemplate) field3Length() int { return l } +func (p *PromptTemplate) field100Length() int { + l := 0 + if p.IsSetMetadata() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.MapBeginLength() + for k, v := range p.Metadata { + _, _ = k, v + + l += thrift.Binary.StringLengthNocopy(k) + l += thrift.Binary.StringLengthNocopy(v) + } + } + return l +} + func (p *PromptTemplate) DeepCopy(s interface{}) error { src, ok := s.(*PromptTemplate) if !ok { @@ -3475,6 +3555,23 @@ func (p *PromptTemplate) DeepCopy(s interface{}) error { } } + if src.Metadata != nil { + p.Metadata = make(map[string]string, len(src.Metadata)) + for key, val := range src.Metadata { + var _key string + if key != "" { + _key = kutils.StringDeepCopy(key) + } + + var _val string + if val != "" { + _val = kutils.StringDeepCopy(val) + } + + p.Metadata[_key] = _val + } + } + return nil } @@ -3696,6 +3793,20 @@ func (p *Message) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 100: + if fieldTypeId == thrift.MAP { + l, err = p.FastReadField100(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } default: l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l @@ -3820,6 +3931,38 @@ func (p *Message) FastReadField6(buf []byte) (int, error) { return offset, nil } +func (p *Message) FastReadField100(buf []byte) (int, error) { + offset := 0 + + _, _, size, l, err := thrift.Binary.ReadMapBegin(buf[offset:]) + offset += l + if err != nil { + return offset, err + } + _field := make(map[string]string, size) + for i := 0; i < size; i++ { + var _key string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _key = v + } + + var _val string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _val = v + } + + _field[_key] = _val + } + p.Metadata = _field + return offset, nil +} + func (p *Message) FastWrite(buf []byte) int { return p.FastWriteNocopy(buf, nil) } @@ -3833,6 +3976,7 @@ func (p *Message) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset += p.fastWriteField4(buf[offset:], w) offset += p.fastWriteField5(buf[offset:], w) offset += p.fastWriteField6(buf[offset:], w) + offset += p.fastWriteField100(buf[offset:], w) } offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset @@ -3847,6 +3991,7 @@ func (p *Message) BLength() int { l += p.field4Length() l += p.field5Length() l += p.field6Length() + l += p.field100Length() } l += thrift.Binary.FieldStopLength() return l @@ -3920,6 +4065,23 @@ func (p *Message) fastWriteField6(buf []byte, w thrift.NocopyWriter) int { return offset } +func (p *Message) fastWriteField100(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetMetadata() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.MAP, 100) + mapBeginOffset := offset + offset += thrift.Binary.MapBeginLength() + var length int + for k, v := range p.Metadata { + length++ + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, k) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, v) + } + thrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.STRING, thrift.STRING, length) + } + return offset +} + func (p *Message) field1Length() int { l := 0 if p.IsSetRole() { @@ -3982,6 +4144,21 @@ func (p *Message) field6Length() int { return l } +func (p *Message) field100Length() int { + l := 0 + if p.IsSetMetadata() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.MapBeginLength() + for k, v := range p.Metadata { + _, _ = k, v + + l += thrift.Binary.StringLengthNocopy(k) + l += thrift.Binary.StringLengthNocopy(v) + } + } + return l +} + func (p *Message) DeepCopy(s interface{}) error { src, ok := s.(*Message) if !ok { @@ -4047,6 +4224,23 @@ func (p *Message) DeepCopy(s interface{}) error { } } + if src.Metadata != nil { + p.Metadata = make(map[string]string, len(src.Metadata)) + for key, val := range src.Metadata { + var _key string + if key != "" { + _key = kutils.StringDeepCopy(key) + } + + var _val string + if val != "" { + _val = kutils.StringDeepCopy(val) + } + + p.Metadata[_key] = _val + } + } + return nil } diff --git a/backend/modules/prompt/application/convertor/openapi.go b/backend/modules/prompt/application/convertor/openapi.go index 39481c431..ca5af15df 100644 --- a/backend/modules/prompt/application/convertor/openapi.go +++ b/backend/modules/prompt/application/convertor/openapi.go @@ -43,6 +43,7 @@ func OpenAPIPromptTemplateDO2DTO(do *entity.PromptTemplate) *openapi.PromptTempl TemplateType: ptr.Of(prompt.TemplateType(do.TemplateType)), Messages: OpenAPIBatchMessageDO2DTO(do.Messages), VariableDefs: OpenAPIBatchVariableDefDO2DTO(do.VariableDefs), + Metadata: do.Metadata, } } @@ -71,6 +72,7 @@ func OpenAPIMessageDO2DTO(do *entity.Message) *openapi.Message { Parts: OpenAPIBatchContentPartDO2DTO(do.Parts), ToolCallID: do.ToolCallID, ToolCalls: OpenAPIBatchToolCallDO2DTO(do.ToolCalls), + Metadata: do.Metadata, } } @@ -227,6 +229,7 @@ func OpenAPIMessageDTO2DO(dto *openapi.Message) *entity.Message { Parts: OpenAPIBatchContentPartDTO2DO(dto.Parts), ToolCallID: dto.ToolCallID, ToolCalls: OpenAPIBatchToolCallDTO2DO(dto.ToolCalls), + Metadata: dto.Metadata, } } diff --git a/backend/modules/prompt/application/convertor/openapi_test.go b/backend/modules/prompt/application/convertor/openapi_test.go index b6982c2f4..9d05c72b7 100755 --- a/backend/modules/prompt/application/convertor/openapi_test.go +++ b/backend/modules/prompt/application/convertor/openapi_test.go @@ -350,6 +350,30 @@ func mockOpenAPIPromptCases() []openAPIPromptTestCase { Version: ptr.Of("1.0.0"), }, }, + { + name: "prompt template metadata", + do: &entity.Prompt{ + ID: 123, + SpaceID: 456, + PromptKey: "test_prompt", + PromptCommit: &entity.PromptCommit{ + CommitInfo: &entity.CommitInfo{Version: "1.0.0"}, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + Metadata: map[string]string{"commit": "meta"}, + }, + }, + }, + }, + want: &openapi.Prompt{ + WorkspaceID: ptr.Of(int64(456)), + PromptKey: ptr.Of("test_prompt"), + Version: ptr.Of("1.0.0"), + PromptTemplate: &openapi.PromptTemplate{ + Metadata: map[string]string{"commit": "meta"}, + }, + }, + }, } } @@ -410,6 +434,15 @@ func TestOpenAPIPromptTemplateDO2DTO(t *testing.T) { }, }, }, + { + name: "template with metadata", + do: &entity.PromptTemplate{ + Metadata: map[string]string{"k": "v"}, + }, + want: &openapi.PromptTemplate{ + Metadata: map[string]string{"k": "v"}, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -824,6 +857,17 @@ func TestOpenAPIMessageDO2DTO_NewFields(t *testing.T) { }, }, }, + { + name: "message with metadata", + do: &entity.Message{ + Role: entity.RoleAssistant, + Metadata: map[string]string{"meta": "value"}, + }, + want: &openapi.Message{ + Role: ptr.Of(prompt.RoleAssistant), + Metadata: map[string]string{"meta": "value"}, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -1105,6 +1149,21 @@ func TestOpenAPIBatchMessageDTO2DO(t *testing.T) { }, }, }, + { + name: "messages with metadata", + dtos: []*openapi.Message{ + { + Role: ptr.Of(prompt.RoleAssistant), + Metadata: map[string]string{"meta": "value"}, + }, + }, + want: []*entity.Message{ + { + Role: entity.RoleAssistant, + Metadata: map[string]string{"meta": "value"}, + }, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/backend/modules/prompt/application/convertor/prompt.go b/backend/modules/prompt/application/convertor/prompt.go index 716d33777..5d87c2843 100644 --- a/backend/modules/prompt/application/convertor/prompt.go +++ b/backend/modules/prompt/application/convertor/prompt.go @@ -109,6 +109,7 @@ func PromptTemplateDTO2DO(dto *prompt.PromptTemplate) *entity.PromptTemplate { TemplateType: TemplateTypeDTO2DO(dto.GetTemplateType()), Messages: BatchMessageDTO2DO(dto.Messages), VariableDefs: BatchVariableDefDTO2DO(dto.VariableDefs), + Metadata: dto.Metadata, } } @@ -149,6 +150,7 @@ func MessageDTO2DO(dto *prompt.Message) *entity.Message { Parts: BatchContentPartDTO2DO(dto.Parts), ToolCallID: dto.ToolCallID, ToolCalls: BatchToolCallDTO2DO(dto.ToolCalls), + Metadata: dto.Metadata, } } @@ -628,6 +630,7 @@ func MessageDO2DTO(do *entity.Message) *prompt.Message { Parts: BatchContentPartDO2DTO(do.Parts), ToolCallID: do.ToolCallID, ToolCalls: BatchToolCallDO2DTO(do.ToolCalls), + Metadata: do.Metadata, } } @@ -828,6 +831,7 @@ func PromptTemplateDO2DTO(do *entity.PromptTemplate) *prompt.PromptTemplate { TemplateType: ptr.Of(prompt.TemplateType(do.TemplateType)), Messages: BatchMessageDO2DTO(do.Messages), VariableDefs: BatchVariableDefDO2DTO(do.VariableDefs), + Metadata: do.Metadata, } } diff --git a/backend/modules/prompt/application/convertor/prompt_test.go b/backend/modules/prompt/application/convertor/prompt_test.go index c2ce98bc4..bbe5d8e8b 100644 --- a/backend/modules/prompt/application/convertor/prompt_test.go +++ b/backend/modules/prompt/application/convertor/prompt_test.go @@ -342,6 +342,41 @@ func mockPromptCases() []promptTestCase { }, }, }, + { + name: "prompt template metadata", + dto: &prompt.Prompt{ + PromptCommit: &prompt.PromptCommit{ + Detail: &prompt.PromptDetail{ + PromptTemplate: &prompt.PromptTemplate{ + Metadata: map[string]string{"commit-meta": "value"}, + }, + }, + }, + PromptDraft: &prompt.PromptDraft{ + Detail: &prompt.PromptDetail{ + PromptTemplate: &prompt.PromptTemplate{ + Metadata: map[string]string{"draft-meta": "value"}, + }, + }, + }, + }, + do: &entity.Prompt{ + PromptCommit: &entity.PromptCommit{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + Metadata: map[string]string{"commit-meta": "value"}, + }, + }, + }, + PromptDraft: &entity.PromptDraft{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + Metadata: map[string]string{"draft-meta": "value"}, + }, + }, + }, + }, + }, } } @@ -521,6 +556,17 @@ func mockMessageCases() []messageTestCase { ReasoningContent: ptr.Of("This is my reasoning process..."), }, }, + { + name: "message with metadata", + dto: &prompt.Message{ + Role: ptr.Of(prompt.RoleAssistant), + Metadata: map[string]string{"key": "value"}, + }, + do: &entity.Message{ + Role: entity.RoleAssistant, + Metadata: map[string]string{"key": "value"}, + }, + }, } } diff --git a/backend/modules/prompt/domain/entity/prompt_detail.go b/backend/modules/prompt/domain/entity/prompt_detail.go index bfd4b34f0..e17ab4c5e 100644 --- a/backend/modules/prompt/domain/entity/prompt_detail.go +++ b/backend/modules/prompt/domain/entity/prompt_detail.go @@ -34,9 +34,10 @@ type PromptDetail struct { } type PromptTemplate struct { - TemplateType TemplateType `json:"template_type"` - Messages []*Message `json:"messages,omitempty"` - VariableDefs []*VariableDef `json:"variable_defs,omitempty"` + TemplateType TemplateType `json:"template_type"` + Messages []*Message `json:"messages,omitempty"` + VariableDefs []*VariableDef `json:"variable_defs,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` } type TemplateType string @@ -53,6 +54,8 @@ type Message struct { Parts []*ContentPart `json:"parts,omitempty"` ToolCallID *string `json:"tool_call_id,omitempty"` ToolCalls []*ToolCall `json:"tool_calls,omitempty"` + + Metadata map[string]string `json:"metadata,omitempty"` } type Role string diff --git a/backend/modules/prompt/infra/repo/mysql/convertor/manage.go b/backend/modules/prompt/infra/repo/mysql/convertor/manage.go index a8d41dca4..6aea313fa 100644 --- a/backend/modules/prompt/infra/repo/mysql/convertor/manage.go +++ b/backend/modules/prompt/infra/repo/mysql/convertor/manage.go @@ -188,6 +188,9 @@ func PromptDO2CommitPO(do *entity.Prompt) *model.PromptCommit { if do.PromptCommit.PromptDetail.PromptTemplate.VariableDefs != nil { po.VariableDefs = ptr.Of(json.Jsonify(do.PromptCommit.PromptDetail.PromptTemplate.VariableDefs)) } + if do.PromptCommit.PromptDetail.PromptTemplate.Metadata != nil { + po.Metadata = ptr.Of(json.Jsonify(do.PromptCommit.PromptDetail.PromptTemplate.Metadata)) + } } // 序列化ExtInfos到ExtInfo字段 if do.PromptCommit.PromptDetail.ExtInfos != nil { @@ -219,6 +222,9 @@ func PromptDO2DraftPO(promptDO *entity.Prompt) *model.PromptUserDraft { if detailDO.PromptTemplate.VariableDefs != nil { po.VariableDefs = ptr.Of(json.Jsonify(detailDO.PromptTemplate.VariableDefs)) } + if detailDO.PromptTemplate.Metadata != nil { + po.Metadata = ptr.Of(json.Jsonify(detailDO.PromptTemplate.Metadata)) + } } if detailDO.ModelConfig != nil { po.ModelConfig = ptr.Of(json.Jsonify(detailDO.ModelConfig)) @@ -269,6 +275,7 @@ func PromptUserDraftPO2PromptDetailDO(draftPO *model.PromptUserDraft) *entity.Pr Messages: UnmarshalMessageDOs(draftPO.Messages), VariableDefs: UnmarshalVariableDefDOs(draftPO.VariableDefs), TemplateType: UnmarshalTemplateType(draftPO.TemplateType), + Metadata: UnmarshalMetadata(draftPO.Metadata), }, Tools: UnmarshalToolDOs(draftPO.Tools), ToolCallConfig: UnmarshalToolCallConfig(draftPO.ToolCallConfig), @@ -286,6 +293,7 @@ func PromptCommitPO2PromptDetailDO(commitPO *model.PromptCommit) *entity.PromptD Messages: UnmarshalMessageDOs(commitPO.Messages), VariableDefs: UnmarshalVariableDefDOs(commitPO.VariableDefs), TemplateType: UnmarshalTemplateType(commitPO.TemplateType), + Metadata: UnmarshalMetadata(commitPO.Metadata), }, Tools: UnmarshalToolDOs(commitPO.Tools), ToolCallConfig: UnmarshalToolCallConfig(commitPO.ToolCallConfig), @@ -346,6 +354,15 @@ func UnmarshalToolDOs(text *string) []*entity.Tool { return tools } +func UnmarshalMetadata(text *string) map[string]string { + if text == nil { + return nil + } + metadata := make(map[string]string) + _ = json.Unmarshal([]byte(*text), &metadata) + return metadata +} + func UnmarshalExtInfos(text *string) map[string]string { if text == nil { return nil diff --git a/backend/modules/prompt/infra/repo/mysql/convertor/manage_test.go b/backend/modules/prompt/infra/repo/mysql/convertor/manage_test.go index 0a3d487ec..003b072bd 100644 --- a/backend/modules/prompt/infra/repo/mysql/convertor/manage_test.go +++ b/backend/modules/prompt/infra/repo/mysql/convertor/manage_test.go @@ -712,3 +712,52 @@ func TestPromptDO2DraftPO(t *testing.T) { }) } } + +func TestPromptTemplateMetadataRoundTrip(t *testing.T) { + t.Parallel() + + commitMetadata := map[string]string{"commit": "meta"} + draftMetadata := map[string]string{"draft": "meta"} + prompt := &entity.Prompt{ + ID: 1, + SpaceID: 2, + PromptKey: "test_key", + PromptCommit: &entity.PromptCommit{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + Metadata: commitMetadata, + }, + }, + }, + PromptDraft: &entity.PromptDraft{ + DraftInfo: &entity.DraftInfo{UserID: "user"}, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + Metadata: draftMetadata, + }, + }, + }, + } + + commitPO := PromptDO2CommitPO(prompt) + if assert.NotNil(t, commitPO.Metadata) { + commitDO := CommitPO2DO(&model.PromptCommit{Metadata: commitPO.Metadata}) + assert.NotNil(t, commitDO) + assert.NotNil(t, commitDO.PromptDetail) + assert.NotNil(t, commitDO.PromptDetail.PromptTemplate) + assert.Equal(t, commitMetadata, commitDO.PromptDetail.PromptTemplate.Metadata) + } + + draftPO := PromptDO2DraftPO(prompt) + if assert.NotNil(t, draftPO.Metadata) { + draftModel := &model.PromptUserDraft{ + UserID: "user", + Metadata: draftPO.Metadata, + } + draftDO := DraftPO2DO(draftModel) + assert.NotNil(t, draftDO) + assert.NotNil(t, draftDO.PromptDetail) + assert.NotNil(t, draftDO.PromptDetail.PromptTemplate) + assert.Equal(t, draftMetadata, draftDO.PromptDetail.PromptTemplate.Metadata) + } +} diff --git a/backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_commit.gen.go b/backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_commit.gen.go index 3ae54896c..ab30250ed 100644 --- a/backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_commit.gen.go +++ b/backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_commit.gen.go @@ -22,6 +22,7 @@ type PromptCommit struct { VariableDefs *string `gorm:"column:variable_defs;type:text;comment:变量定义" json:"variable_defs"` // 变量定义 Tools *string `gorm:"column:tools;type:longtext;comment:tools" json:"tools"` // tools ToolCallConfig *string `gorm:"column:tool_call_config;type:text;comment:tool调用配置" json:"tool_call_config"` // tool调用配置 + Metadata *string `gorm:"column:metadata;type:text;comment:模板元信息" json:"metadata"` // 模板元信息 Version string `gorm:"column:version;type:varchar(128);not null;uniqueIndex:uniq_prompt_id_version,priority:2;index:idx_prompt_key_version,priority:2;comment:版本" json:"version"` // 版本 BaseVersion string `gorm:"column:base_version;type:varchar(128);not null;comment:来源版本" json:"base_version"` // 来源版本 CommittedBy string `gorm:"column:committed_by;type:varchar(128);not null;comment:提交人" json:"committed_by"` // 提交人 diff --git a/backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_user_draft.gen.go b/backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_user_draft.gen.go index ca606d16b..14d1faeb8 100644 --- a/backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_user_draft.gen.go +++ b/backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_user_draft.gen.go @@ -24,6 +24,7 @@ type PromptUserDraft struct { VariableDefs *string `gorm:"column:variable_defs;type:text;comment:变量定义" json:"variable_defs"` // 变量定义 Tools *string `gorm:"column:tools;type:longtext;comment:tools" json:"tools"` // tools ToolCallConfig *string `gorm:"column:tool_call_config;type:text;comment:tool调用配置" json:"tool_call_config"` // tool调用配置 + Metadata *string `gorm:"column:metadata;type:text;comment:模板元信息" json:"metadata"` // 模板元信息 BaseVersion string `gorm:"column:base_version;type:varchar(128);not null;comment:草稿关联版本" json:"base_version"` // 草稿关联版本 IsDraftEdited int32 `gorm:"column:is_draft_edited;type:tinyint(4);not null;comment:草稿内容是否基于BaseVersion有变更" json:"is_draft_edited"` // 草稿内容是否基于BaseVersion有变更 ExtInfo *string `gorm:"column:ext_info;type:text;comment:扩展字段" json:"ext_info"` // 扩展字段 diff --git a/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_commit.gen.go b/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_commit.gen.go index 3595f3f23..eb6bdb327 100644 --- a/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_commit.gen.go +++ b/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_commit.gen.go @@ -37,6 +37,7 @@ func newPromptCommit(db *gorm.DB, opts ...gen.DOOption) promptCommit { _promptCommit.VariableDefs = field.NewString(tableName, "variable_defs") _promptCommit.Tools = field.NewString(tableName, "tools") _promptCommit.ToolCallConfig = field.NewString(tableName, "tool_call_config") + _promptCommit.Metadata = field.NewString(tableName, "metadata") _promptCommit.Version = field.NewString(tableName, "version") _promptCommit.BaseVersion = field.NewString(tableName, "base_version") _promptCommit.CommittedBy = field.NewString(tableName, "committed_by") @@ -65,6 +66,7 @@ type promptCommit struct { VariableDefs field.String // 变量定义 Tools field.String // tools ToolCallConfig field.String // tool调用配置 + Metadata field.String // 模板元信息 Version field.String // 版本 BaseVersion field.String // 来源版本 CommittedBy field.String // 提交人 @@ -98,6 +100,7 @@ func (p *promptCommit) updateTableName(table string) *promptCommit { p.VariableDefs = field.NewString(table, "variable_defs") p.Tools = field.NewString(table, "tools") p.ToolCallConfig = field.NewString(table, "tool_call_config") + p.Metadata = field.NewString(table, "metadata") p.Version = field.NewString(table, "version") p.BaseVersion = field.NewString(table, "base_version") p.CommittedBy = field.NewString(table, "committed_by") @@ -133,7 +136,7 @@ func (p *promptCommit) GetFieldByName(fieldName string) (field.OrderExpr, bool) } func (p *promptCommit) fillFieldMap() { - p.fieldMap = make(map[string]field.Expr, 17) + p.fieldMap = make(map[string]field.Expr, 18) p.fieldMap["id"] = p.ID p.fieldMap["space_id"] = p.SpaceID p.fieldMap["prompt_id"] = p.PromptID @@ -144,6 +147,7 @@ func (p *promptCommit) fillFieldMap() { p.fieldMap["variable_defs"] = p.VariableDefs p.fieldMap["tools"] = p.Tools p.fieldMap["tool_call_config"] = p.ToolCallConfig + p.fieldMap["metadata"] = p.Metadata p.fieldMap["version"] = p.Version p.fieldMap["base_version"] = p.BaseVersion p.fieldMap["committed_by"] = p.CommittedBy diff --git a/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_user_draft.gen.go b/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_user_draft.gen.go index 6f5f60d21..b625d760f 100644 --- a/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_user_draft.gen.go +++ b/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_user_draft.gen.go @@ -37,6 +37,7 @@ func newPromptUserDraft(db *gorm.DB, opts ...gen.DOOption) promptUserDraft { _promptUserDraft.VariableDefs = field.NewString(tableName, "variable_defs") _promptUserDraft.Tools = field.NewString(tableName, "tools") _promptUserDraft.ToolCallConfig = field.NewString(tableName, "tool_call_config") + _promptUserDraft.Metadata = field.NewString(tableName, "metadata") _promptUserDraft.BaseVersion = field.NewString(tableName, "base_version") _promptUserDraft.IsDraftEdited = field.NewInt32(tableName, "is_draft_edited") _promptUserDraft.ExtInfo = field.NewString(tableName, "ext_info") @@ -64,6 +65,7 @@ type promptUserDraft struct { VariableDefs field.String // 变量定义 Tools field.String // tools ToolCallConfig field.String // tool调用配置 + Metadata field.String // 模板元信息 BaseVersion field.String // 草稿关联版本 IsDraftEdited field.Int32 // 草稿内容是否基于BaseVersion有变更 ExtInfo field.String // 扩展字段 @@ -96,6 +98,7 @@ func (p *promptUserDraft) updateTableName(table string) *promptUserDraft { p.VariableDefs = field.NewString(table, "variable_defs") p.Tools = field.NewString(table, "tools") p.ToolCallConfig = field.NewString(table, "tool_call_config") + p.Metadata = field.NewString(table, "metadata") p.BaseVersion = field.NewString(table, "base_version") p.IsDraftEdited = field.NewInt32(table, "is_draft_edited") p.ExtInfo = field.NewString(table, "ext_info") @@ -130,7 +133,7 @@ func (p *promptUserDraft) GetFieldByName(fieldName string) (field.OrderExpr, boo } func (p *promptUserDraft) fillFieldMap() { - p.fieldMap = make(map[string]field.Expr, 16) + p.fieldMap = make(map[string]field.Expr, 17) p.fieldMap["id"] = p.ID p.fieldMap["space_id"] = p.SpaceID p.fieldMap["prompt_id"] = p.PromptID @@ -141,6 +144,7 @@ func (p *promptUserDraft) fillFieldMap() { p.fieldMap["variable_defs"] = p.VariableDefs p.fieldMap["tools"] = p.Tools p.fieldMap["tool_call_config"] = p.ToolCallConfig + p.fieldMap["metadata"] = p.Metadata p.fieldMap["base_version"] = p.BaseVersion p.fieldMap["is_draft_edited"] = p.IsDraftEdited p.fieldMap["ext_info"] = p.ExtInfo diff --git a/backend/modules/prompt/infra/repo/mysql/prompt_user_draft.go b/backend/modules/prompt/infra/repo/mysql/prompt_user_draft.go index 1b9f64056..ff0df1e1e 100644 --- a/backend/modules/prompt/infra/repo/mysql/prompt_user_draft.go +++ b/backend/modules/prompt/infra/repo/mysql/prompt_user_draft.go @@ -151,6 +151,7 @@ func (d *PromptUserDraftDAOImpl) Update(ctx context.Context, promptDraftPO *mode q.PromptUserDraft.ToolCallConfig.ColumnName().String(): promptDraftPO.ToolCallConfig, q.PromptUserDraft.TemplateType.ColumnName().String(): promptDraftPO.TemplateType, q.PromptUserDraft.VariableDefs.ColumnName().String(): promptDraftPO.VariableDefs, + q.PromptUserDraft.Metadata.ColumnName().String(): promptDraftPO.Metadata, q.PromptUserDraft.IsDraftEdited.ColumnName().String(): promptDraftPO.IsDraftEdited, }) if err != nil { diff --git a/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift b/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift index 44c183b99..0917d6018 100644 --- a/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift +++ b/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift @@ -95,6 +95,8 @@ struct PromptTemplate { 1: optional TemplateType template_type // 模板类型 2: optional list messages // 只支持message list形式托管 3: optional list variable_defs // 变量定义 + + 100: optional map metadata // 模板级元信息 } typedef string TemplateType @@ -117,6 +119,8 @@ struct Message { 4: optional string reasoning_content // 推理思考内容 5: optional string tool_call_id // tool调用ID(role为tool时有效) 6: optional list tool_calls // tool调用(role为assistant时有效) + + 100: optional map metadata // 消息元信息 } struct ContentPart { diff --git a/idl/thrift/coze/loop/prompt/domain/prompt.thrift b/idl/thrift/coze/loop/prompt/domain/prompt.thrift index f70477467..c78f0a95c 100644 --- a/idl/thrift/coze/loop/prompt/domain/prompt.thrift +++ b/idl/thrift/coze/loop/prompt/domain/prompt.thrift @@ -61,6 +61,8 @@ struct PromptTemplate { 1: optional TemplateType template_type 2: optional list messages 3: optional list variable_defs + + 100: optional map metadata } typedef string TemplateType (ts.enum="true") @@ -107,6 +109,8 @@ struct Message { 4: optional list parts 5: optional string tool_call_id 6: optional list tool_calls + + 100: optional map metadata } typedef string Role (ts.enum="true") diff --git a/release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_commit.sql b/release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_commit.sql index e871fc6b4..a8b4623c8 100644 --- a/release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_commit.sql +++ b/release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_commit.sql @@ -10,6 +10,7 @@ CREATE TABLE IF NOT EXISTS `prompt_commit` `variable_defs` text COLLATE utf8mb4_general_ci COMMENT '变量定义', `tools` longtext COLLATE utf8mb4_general_ci COMMENT 'tools', `tool_call_config` text COLLATE utf8mb4_general_ci COMMENT 'tool调用配置', + `metadata` text COLLATE utf8mb4_general_ci COMMENT '模板元信息', `version` varchar(128) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '版本', `base_version` varchar(128) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '来源版本', `committed_by` varchar(128) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '提交人', diff --git a/release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_user_draft.sql b/release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_user_draft.sql index 787bffdf8..804bf143e 100644 --- a/release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_user_draft.sql +++ b/release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_user_draft.sql @@ -10,6 +10,7 @@ CREATE TABLE IF NOT EXISTS `prompt_user_draft` `variable_defs` text COLLATE utf8mb4_general_ci COMMENT '变量定义', `tools` longtext COLLATE utf8mb4_general_ci COMMENT 'tools', `tool_call_config` text COLLATE utf8mb4_general_ci COMMENT 'tool调用配置', + `metadata` text COLLATE utf8mb4_general_ci COMMENT '模板元信息', `base_version` varchar(128) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '草稿关联版本', `is_draft_edited` tinyint NOT NULL DEFAULT '0' COMMENT '草稿内容是否基于BaseVersion有变更', `ext_info` text COLLATE utf8mb4_general_ci COMMENT '扩展字段', diff --git a/release/deployment/docker-compose/bootstrap/mysql-init/patch-sql/prompt_commit_alter.sql b/release/deployment/docker-compose/bootstrap/mysql-init/patch-sql/prompt_commit_alter.sql index 5cdac00f2..4b622b352 100644 --- a/release/deployment/docker-compose/bootstrap/mysql-init/patch-sql/prompt_commit_alter.sql +++ b/release/deployment/docker-compose/bootstrap/mysql-init/patch-sql/prompt_commit_alter.sql @@ -1 +1,2 @@ ALTER TABLE `prompt_commit` ADD COLUMN `ext_info` text COLLATE utf8mb4_general_ci COMMENT 'Extended information field'; +ALTER TABLE `prompt_commit` ADD COLUMN `metadata` text COLLATE utf8mb4_general_ci COMMENT 'Template metadata field'; diff --git a/release/deployment/docker-compose/bootstrap/mysql-init/patch-sql/prompt_user_draft_alter.sql b/release/deployment/docker-compose/bootstrap/mysql-init/patch-sql/prompt_user_draft_alter.sql index 479517c91..54323e1e0 100644 --- a/release/deployment/docker-compose/bootstrap/mysql-init/patch-sql/prompt_user_draft_alter.sql +++ b/release/deployment/docker-compose/bootstrap/mysql-init/patch-sql/prompt_user_draft_alter.sql @@ -1 +1,2 @@ ALTER TABLE `prompt_user_draft` ADD COLUMN `ext_info` text COLLATE utf8mb4_general_ci COMMENT 'Extended information field'; +ALTER TABLE `prompt_user_draft` ADD COLUMN `metadata` text COLLATE utf8mb4_general_ci COMMENT 'Template metadata field'; diff --git a/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_commit.sql b/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_commit.sql index e871fc6b4..a8b4623c8 100644 --- a/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_commit.sql +++ b/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_commit.sql @@ -10,6 +10,7 @@ CREATE TABLE IF NOT EXISTS `prompt_commit` `variable_defs` text COLLATE utf8mb4_general_ci COMMENT '变量定义', `tools` longtext COLLATE utf8mb4_general_ci COMMENT 'tools', `tool_call_config` text COLLATE utf8mb4_general_ci COMMENT 'tool调用配置', + `metadata` text COLLATE utf8mb4_general_ci COMMENT '模板元信息', `version` varchar(128) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '版本', `base_version` varchar(128) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '来源版本', `committed_by` varchar(128) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '提交人', diff --git a/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_commit_alter.sql b/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_commit_alter.sql index 5cdac00f2..4b622b352 100644 --- a/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_commit_alter.sql +++ b/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_commit_alter.sql @@ -1 +1,2 @@ ALTER TABLE `prompt_commit` ADD COLUMN `ext_info` text COLLATE utf8mb4_general_ci COMMENT 'Extended information field'; +ALTER TABLE `prompt_commit` ADD COLUMN `metadata` text COLLATE utf8mb4_general_ci COMMENT 'Template metadata field'; diff --git a/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_user_draft.sql b/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_user_draft.sql index 787bffdf8..804bf143e 100644 --- a/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_user_draft.sql +++ b/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_user_draft.sql @@ -10,6 +10,7 @@ CREATE TABLE IF NOT EXISTS `prompt_user_draft` `variable_defs` text COLLATE utf8mb4_general_ci COMMENT '变量定义', `tools` longtext COLLATE utf8mb4_general_ci COMMENT 'tools', `tool_call_config` text COLLATE utf8mb4_general_ci COMMENT 'tool调用配置', + `metadata` text COLLATE utf8mb4_general_ci COMMENT '模板元信息', `base_version` varchar(128) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '草稿关联版本', `is_draft_edited` tinyint NOT NULL DEFAULT '0' COMMENT '草稿内容是否基于BaseVersion有变更', `ext_info` text COLLATE utf8mb4_general_ci COMMENT '扩展字段', diff --git a/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_user_draft_alter.sql b/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_user_draft_alter.sql index 479517c91..54323e1e0 100644 --- a/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_user_draft_alter.sql +++ b/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_user_draft_alter.sql @@ -1 +1,2 @@ ALTER TABLE `prompt_user_draft` ADD COLUMN `ext_info` text COLLATE utf8mb4_general_ci COMMENT 'Extended information field'; +ALTER TABLE `prompt_user_draft` ADD COLUMN `metadata` text COLLATE utf8mb4_general_ci COMMENT 'Template metadata field'; From fdb03c50e6a63cbb88f1447a2005fc6e0b25b77c Mon Sep 17 00:00:00 2001 From: zhongzhiwei Date: Wed, 22 Oct 2025 14:36:46 +0800 Subject: [PATCH 03/12] git ignore Change-Id: I59dd9b1517ded4a40e3fb3efcae608164adf9678 --- .gitignore | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 0f5fa2d27..0a9b64c6e 100644 --- a/.gitignore +++ b/.gitignore @@ -50,7 +50,12 @@ release/deployment/helm-chart/umbrella/charts/ release/deployment/helm-chart/umbrella/Chart.lock **/kitex_remote_config.json -.coda/ backend/script/errorx/.env +.coda/ +.claude/ +.coco/ .cursor/ -AGENTS.md \ No newline at end of file +CLAUDE.md +AGENTS.md +.specify/ +specs/ \ No newline at end of file From bcdde7d73d1023df54fd781ccc4397af384b9563 Mon Sep 17 00:00:00 2001 From: zhongzhiwei Date: Wed, 22 Oct 2025 14:50:27 +0800 Subject: [PATCH 04/12] ci Change-Id: I75a355cb6ec1700a54108f7e16a4717563cb7046 --- .github/workflows/backend-ci.yaml | 4 ++-- .github/workflows/license-check.yaml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/backend-ci.yaml b/.github/workflows/backend-ci.yaml index 0b3a485d5..b8cef43f2 100644 --- a/.github/workflows/backend-ci.yaml +++ b/.github/workflows/backend-ci.yaml @@ -2,12 +2,12 @@ name: CI@backend on: push: - branches: ["main", "release/**"] + branches: ["main", "release/**", "feat/prompt_sync"] paths: - 'backend/**' - '.github/workflows/backend-ci.yaml' pull_request: - branches: ["main", "release/**"] + branches: ["main", "release/**", "feat/prompt_sync"] paths: - 'backend/**' - '.github/workflows/backend-ci.yaml' diff --git a/.github/workflows/license-check.yaml b/.github/workflows/license-check.yaml index 01d27033b..d811639a3 100644 --- a/.github/workflows/license-check.yaml +++ b/.github/workflows/license-check.yaml @@ -2,9 +2,9 @@ name: License Check on: push: - branches: ['main', "release/**"] + branches: ['main', "release/**", "feat/prompt_sync"] pull_request: - branches: ['main', "release/**"] + branches: ['main', "release/**", "feat/prompt_sync"] workflow_dispatch: permissions: From 02085834475df1648de851030a89a485e484cc57 Mon Sep 17 00:00:00 2001 From: caijialin0626 <61818131+caijialin0626@users.noreply.github.com> Date: Fri, 24 Oct 2025 14:34:06 +0800 Subject: [PATCH 05/12] [feat][prompt] prompt img video file upload (#244) --- .gitignore | 6 +- .../api/handler/coze/loop/apis/wire_gen.go | 1 + .../loop/apis/foundationfileservice/client.go | 6 + .../foundationfileservice.go | 36 + .../file/coze.loop.foundation.file.go | 1821 +++++++++++++++-- .../coze.loop.foundation.file_validator.go | 29 + .../foundation/file/fileservice/client.go | 6 + .../file/fileservice/fileservice.go | 36 + .../file/k-coze.loop.foundation.file.go | 1079 ++++++++++ .../loop/foundation/fileservice/client.go | 6 + .../foundation/fileservice/fileservice.go | 36 + .../coze/loop/llm/domain/runtime/k-runtime.go | 403 ++++ .../coze/loop/llm/domain/runtime/runtime.go | 588 +++++- .../llm/domain/runtime/runtime_validator.go | 24 + .../loop/prompt/domain/prompt/k-prompt.go | 403 ++++ .../coze/loop/prompt/domain/prompt/prompt.go | 585 +++++- .../prompt/domain/prompt/prompt_validator.go | 24 + .../openapi/coze.loop.prompt.openapi.go | 329 +++ .../coze.loop.prompt.openapi_validator.go | 16 + .../openapi/k-coze.loop.prompt.openapi.go | 228 +++ .../foundation/lofile/local_fileservice.go | 21 + .../infra/rpc/foundation/file_test.go | 4 + .../foundation/mocks/fileservice_client.go | 20 + .../modules/foundation/application/file.go | 31 + .../foundation/application/file_test.go | 189 ++ .../domain/file/service/mocks/file_service.go | 15 + .../foundation/domain/file/service/server.go | 65 + .../domain/file/service/server_test.go | 151 ++ .../prompt/application/convertor/openapi.go | 48 +- .../application/convertor/openapi_test.go | 140 +- .../prompt/application/convertor/prompt.go | 68 +- .../application/convertor/prompt_test.go | 48 +- backend/modules/prompt/application/debug.go | 7 + .../modules/prompt/application/debug_test.go | 70 + backend/modules/prompt/application/execute.go | 6 + .../prompt/application/execute_test.go | 31 + backend/modules/prompt/application/openapi.go | 15 + .../prompt/application/openapi_test.go | 273 ++- .../prompt/domain/component/rpc/file.go | 1 + .../component/rpc/mocks/file_provider.go | 15 + .../prompt/domain/entity/prompt_detail.go | 22 +- .../modules/prompt/domain/service/manage.go | 139 +- .../prompt/domain/service/manage_test.go | 438 +++- .../domain/service/mocks/prompt_service.go | 28 + .../modules/prompt/domain/service/service.go | 2 + .../prompt/infra/rpc/convertor/chat.go | 142 +- .../prompt/infra/rpc/convertor/chat_test.go | 92 + backend/modules/prompt/infra/rpc/file.go | 17 + backend/modules/prompt/infra/rpc/file_test.go | 201 ++ .../infra/rpc/mocks/fileservice_mock.go | 123 ++ .../coze.loop.foundation.file.thrift | 23 + .../coze/loop/llm/domain/runtime.thrift | 17 +- .../prompt/coze.loop.prompt.openapi.thrift | 7 + .../coze/loop/prompt/domain/prompt.thrift | 12 + 54 files changed, 7919 insertions(+), 224 deletions(-) create mode 100644 backend/modules/foundation/application/file_test.go create mode 100644 backend/modules/prompt/infra/rpc/file_test.go create mode 100644 backend/modules/prompt/infra/rpc/mocks/fileservice_mock.go diff --git a/.gitignore b/.gitignore index 0a9b64c6e..c8d5f94cd 100644 --- a/.gitignore +++ b/.gitignore @@ -50,12 +50,12 @@ release/deployment/helm-chart/umbrella/charts/ release/deployment/helm-chart/umbrella/Chart.lock **/kitex_remote_config.json -backend/script/errorx/.env .coda/ .claude/ .coco/ -.cursor/ CLAUDE.md AGENTS.md +backend/script/errorx/.env +.cursor/ .specify/ -specs/ \ No newline at end of file +specs/ diff --git a/backend/api/handler/coze/loop/apis/wire_gen.go b/backend/api/handler/coze/loop/apis/wire_gen.go index 7b5ae44ac..f090ab564 100644 --- a/backend/api/handler/coze/loop/apis/wire_gen.go +++ b/backend/api/handler/coze/loop/apis/wire_gen.go @@ -8,6 +8,7 @@ package apis import ( "context" + "github.com/cloudwego/kitex/pkg/endpoint" "github.com/coze-dev/coze-loop/backend/infra/ck" "github.com/coze-dev/coze-loop/backend/infra/db" diff --git a/backend/kitex_gen/coze/loop/apis/foundationfileservice/client.go b/backend/kitex_gen/coze/loop/apis/foundationfileservice/client.go index f62249864..3477f5c4d 100644 --- a/backend/kitex_gen/coze/loop/apis/foundationfileservice/client.go +++ b/backend/kitex_gen/coze/loop/apis/foundationfileservice/client.go @@ -12,6 +12,7 @@ import ( // Client is designed to provide IDL-compatible methods with call-option parameter for kitex framework. type Client interface { UploadLoopFileInner(ctx context.Context, req *file.UploadLoopFileInnerRequest, callOptions ...callopt.Option) (r *file.UploadLoopFileInnerResponse, err error) + UploadFileForServer(ctx context.Context, req *file.UploadFileForServerRequest, callOptions ...callopt.Option) (r *file.UploadFileForServerResponse, err error) SignUploadFile(ctx context.Context, req *file.SignUploadFileRequest, callOptions ...callopt.Option) (r *file.SignUploadFileResponse, err error) SignDownloadFile(ctx context.Context, req *file.SignDownloadFileRequest, callOptions ...callopt.Option) (r *file.SignDownloadFileResponse, err error) } @@ -50,6 +51,11 @@ func (p *kFoundationFileServiceClient) UploadLoopFileInner(ctx context.Context, return p.kClient.UploadLoopFileInner(ctx, req) } +func (p *kFoundationFileServiceClient) UploadFileForServer(ctx context.Context, req *file.UploadFileForServerRequest, callOptions ...callopt.Option) (r *file.UploadFileForServerResponse, err error) { + ctx = client.NewCtxWithCallOptions(ctx, callOptions) + return p.kClient.UploadFileForServer(ctx, req) +} + func (p *kFoundationFileServiceClient) SignUploadFile(ctx context.Context, req *file.SignUploadFileRequest, callOptions ...callopt.Option) (r *file.SignUploadFileResponse, err error) { ctx = client.NewCtxWithCallOptions(ctx, callOptions) return p.kClient.SignUploadFile(ctx, req) diff --git a/backend/kitex_gen/coze/loop/apis/foundationfileservice/foundationfileservice.go b/backend/kitex_gen/coze/loop/apis/foundationfileservice/foundationfileservice.go index aafa01b86..ba3608ca2 100644 --- a/backend/kitex_gen/coze/loop/apis/foundationfileservice/foundationfileservice.go +++ b/backend/kitex_gen/coze/loop/apis/foundationfileservice/foundationfileservice.go @@ -21,6 +21,13 @@ var serviceMethods = map[string]kitex.MethodInfo{ false, kitex.WithStreamingMode(kitex.StreamingNone), ), + "UploadFileForServer": kitex.NewMethodInfo( + uploadFileForServerHandler, + newFileServiceUploadFileForServerArgs, + newFileServiceUploadFileForServerResult, + false, + kitex.WithStreamingMode(kitex.StreamingNone), + ), "SignUploadFile": kitex.NewMethodInfo( signUploadFileHandler, newFileServiceSignUploadFileArgs, @@ -87,6 +94,25 @@ func newFileServiceUploadLoopFileInnerResult() interface{} { return file.NewFileServiceUploadLoopFileInnerResult() } +func uploadFileForServerHandler(ctx context.Context, handler interface{}, arg, result interface{}) error { + realArg := arg.(*file.FileServiceUploadFileForServerArgs) + realResult := result.(*file.FileServiceUploadFileForServerResult) + success, err := handler.(file.FileService).UploadFileForServer(ctx, realArg.Req) + if err != nil { + return err + } + realResult.Success = success + return nil +} + +func newFileServiceUploadFileForServerArgs() interface{} { + return file.NewFileServiceUploadFileForServerArgs() +} + +func newFileServiceUploadFileForServerResult() interface{} { + return file.NewFileServiceUploadFileForServerResult() +} + func signUploadFileHandler(ctx context.Context, handler interface{}, arg, result interface{}) error { realArg := arg.(*file.FileServiceSignUploadFileArgs) realResult := result.(*file.FileServiceSignUploadFileResult) @@ -147,6 +173,16 @@ func (p *kClient) UploadLoopFileInner(ctx context.Context, req *file.UploadLoopF return _result.GetSuccess(), nil } +func (p *kClient) UploadFileForServer(ctx context.Context, req *file.UploadFileForServerRequest) (r *file.UploadFileForServerResponse, err error) { + var _args file.FileServiceUploadFileForServerArgs + _args.Req = req + var _result file.FileServiceUploadFileForServerResult + if err = p.c.Call(ctx, "UploadFileForServer", &_args, &_result); err != nil { + return + } + return _result.GetSuccess(), nil +} + func (p *kClient) SignUploadFile(ctx context.Context, req *file.SignUploadFileRequest) (r *file.SignUploadFileResponse, err error) { var _args file.FileServiceSignUploadFileArgs _args.Req = req diff --git a/backend/kitex_gen/coze/loop/foundation/file/coze.loop.foundation.file.go b/backend/kitex_gen/coze/loop/foundation/file/coze.loop.foundation.file.go index 4343090bd..bfd664b2f 100644 --- a/backend/kitex_gen/coze/loop/foundation/file/coze.loop.foundation.file.go +++ b/backend/kitex_gen/coze/loop/foundation/file/coze.loop.foundation.file.go @@ -3989,9 +3989,1146 @@ func (p *SignDownloadFileResponse) Field255DeepEqual(src *base.BaseResp) bool { return true } +type UploadFileOption struct { + // file name + FileName *string `thrift:"file_name,1,optional" frugal:"1,optional,string" form:"file_name" json:"file_name,omitempty" query:"file_name"` + // custom mimetype -> ext mapping + MimeTypeExtMapping map[string]string `thrift:"mime_type_ext_mapping,2,optional" frugal:"2,optional,map" form:"mime_type_ext_mapping" json:"mime_type_ext_mapping,omitempty" query:"mime_type_ext_mapping"` +} + +func NewUploadFileOption() *UploadFileOption { + return &UploadFileOption{} +} + +func (p *UploadFileOption) InitDefault() { +} + +var UploadFileOption_FileName_DEFAULT string + +func (p *UploadFileOption) GetFileName() (v string) { + if p == nil { + return + } + if !p.IsSetFileName() { + return UploadFileOption_FileName_DEFAULT + } + return *p.FileName +} + +var UploadFileOption_MimeTypeExtMapping_DEFAULT map[string]string + +func (p *UploadFileOption) GetMimeTypeExtMapping() (v map[string]string) { + if p == nil { + return + } + if !p.IsSetMimeTypeExtMapping() { + return UploadFileOption_MimeTypeExtMapping_DEFAULT + } + return p.MimeTypeExtMapping +} +func (p *UploadFileOption) SetFileName(val *string) { + p.FileName = val +} +func (p *UploadFileOption) SetMimeTypeExtMapping(val map[string]string) { + p.MimeTypeExtMapping = val +} + +var fieldIDToName_UploadFileOption = map[int16]string{ + 1: "file_name", + 2: "mime_type_ext_mapping", +} + +func (p *UploadFileOption) IsSetFileName() bool { + return p.FileName != nil +} + +func (p *UploadFileOption) IsSetMimeTypeExtMapping() bool { + return p.MimeTypeExtMapping != nil +} + +func (p *UploadFileOption) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 2: + if fieldTypeId == thrift.MAP { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_UploadFileOption[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *UploadFileOption) ReadField1(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.FileName = _field + return nil +} +func (p *UploadFileOption) ReadField2(iprot thrift.TProtocol) error { + _, _, size, err := iprot.ReadMapBegin() + if err != nil { + return err + } + _field := make(map[string]string, size) + for i := 0; i < size; i++ { + var _key string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _key = v + } + + var _val string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _val = v + } + + _field[_key] = _val + } + if err := iprot.ReadMapEnd(); err != nil { + return err + } + p.MimeTypeExtMapping = _field + return nil +} + +func (p *UploadFileOption) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("UploadFileOption"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *UploadFileOption) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetFileName() { + if err = oprot.WriteFieldBegin("file_name", thrift.STRING, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.FileName); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} +func (p *UploadFileOption) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetMimeTypeExtMapping() { + if err = oprot.WriteFieldBegin("mime_type_ext_mapping", thrift.MAP, 2); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteMapBegin(thrift.STRING, thrift.STRING, len(p.MimeTypeExtMapping)); err != nil { + return err + } + for k, v := range p.MimeTypeExtMapping { + if err := oprot.WriteString(k); err != nil { + return err + } + if err := oprot.WriteString(v); err != nil { + return err + } + } + if err := oprot.WriteMapEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} + +func (p *UploadFileOption) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("UploadFileOption(%+v)", *p) + +} + +func (p *UploadFileOption) DeepEqual(ano *UploadFileOption) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.FileName) { + return false + } + if !p.Field2DeepEqual(ano.MimeTypeExtMapping) { + return false + } + return true +} + +func (p *UploadFileOption) Field1DeepEqual(src *string) bool { + + if p.FileName == src { + return true + } else if p.FileName == nil || src == nil { + return false + } + if strings.Compare(*p.FileName, *src) != 0 { + return false + } + return true +} +func (p *UploadFileOption) Field2DeepEqual(src map[string]string) bool { + + if len(p.MimeTypeExtMapping) != len(src) { + return false + } + for k, v := range p.MimeTypeExtMapping { + _src := src[k] + if strings.Compare(v, _src) != 0 { + return false + } + } + return true +} + +type UploadFileForServerRequest struct { + // file mime type + MimeType string `thrift:"mime_type,1,required" frugal:"1,required,string" form:"mime_type,required" json:"mime_type,required" query:"mime_type,required"` + // file binary data + Body []byte `thrift:"body,2,required" frugal:"2,required,binary" form:"body,required" json:"body,required" query:"body,required"` + // workspace id + WorkspaceID int64 `thrift:"workspace_id,3,required" frugal:"3,required,i64" json:"workspace_id" form:"workspace_id,required" query:"workspace_id,required"` + // upload options + Option *UploadFileOption `thrift:"option,4,optional" frugal:"4,optional,UploadFileOption" form:"option" json:"option,omitempty" query:"option"` + Base *base.Base `thrift:"Base,255,optional" frugal:"255,optional,base.Base" form:"Base" json:"Base,omitempty" query:"Base"` +} + +func NewUploadFileForServerRequest() *UploadFileForServerRequest { + return &UploadFileForServerRequest{} +} + +func (p *UploadFileForServerRequest) InitDefault() { +} + +func (p *UploadFileForServerRequest) GetMimeType() (v string) { + if p != nil { + return p.MimeType + } + return +} + +func (p *UploadFileForServerRequest) GetBody() (v []byte) { + if p != nil { + return p.Body + } + return +} + +func (p *UploadFileForServerRequest) GetWorkspaceID() (v int64) { + if p != nil { + return p.WorkspaceID + } + return +} + +var UploadFileForServerRequest_Option_DEFAULT *UploadFileOption + +func (p *UploadFileForServerRequest) GetOption() (v *UploadFileOption) { + if p == nil { + return + } + if !p.IsSetOption() { + return UploadFileForServerRequest_Option_DEFAULT + } + return p.Option +} + +var UploadFileForServerRequest_Base_DEFAULT *base.Base + +func (p *UploadFileForServerRequest) GetBase() (v *base.Base) { + if p == nil { + return + } + if !p.IsSetBase() { + return UploadFileForServerRequest_Base_DEFAULT + } + return p.Base +} +func (p *UploadFileForServerRequest) SetMimeType(val string) { + p.MimeType = val +} +func (p *UploadFileForServerRequest) SetBody(val []byte) { + p.Body = val +} +func (p *UploadFileForServerRequest) SetWorkspaceID(val int64) { + p.WorkspaceID = val +} +func (p *UploadFileForServerRequest) SetOption(val *UploadFileOption) { + p.Option = val +} +func (p *UploadFileForServerRequest) SetBase(val *base.Base) { + p.Base = val +} + +var fieldIDToName_UploadFileForServerRequest = map[int16]string{ + 1: "mime_type", + 2: "body", + 3: "workspace_id", + 4: "option", + 255: "Base", +} + +func (p *UploadFileForServerRequest) IsSetOption() bool { + return p.Option != nil +} + +func (p *UploadFileForServerRequest) IsSetBase() bool { + return p.Base != nil +} + +func (p *UploadFileForServerRequest) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + var issetMimeType bool = false + var issetBody bool = false + var issetWorkspaceID bool = false + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + issetMimeType = true + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 2: + if fieldTypeId == thrift.STRING { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + issetBody = true + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 3: + if fieldTypeId == thrift.I64 { + if err = p.ReadField3(iprot); err != nil { + goto ReadFieldError + } + issetWorkspaceID = true + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 4: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField4(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 255: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField255(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + if !issetMimeType { + fieldId = 1 + goto RequiredFieldNotSetError + } + + if !issetBody { + fieldId = 2 + goto RequiredFieldNotSetError + } + + if !issetWorkspaceID { + fieldId = 3 + goto RequiredFieldNotSetError + } + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_UploadFileForServerRequest[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +RequiredFieldNotSetError: + return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("required field %s is not set", fieldIDToName_UploadFileForServerRequest[fieldId])) +} + +func (p *UploadFileForServerRequest) ReadField1(iprot thrift.TProtocol) error { + + var _field string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = v + } + p.MimeType = _field + return nil +} +func (p *UploadFileForServerRequest) ReadField2(iprot thrift.TProtocol) error { + + var _field []byte + if v, err := iprot.ReadBinary(); err != nil { + return err + } else { + _field = []byte(v) + } + p.Body = _field + return nil +} +func (p *UploadFileForServerRequest) ReadField3(iprot thrift.TProtocol) error { + + var _field int64 + if v, err := iprot.ReadI64(); err != nil { + return err + } else { + _field = v + } + p.WorkspaceID = _field + return nil +} +func (p *UploadFileForServerRequest) ReadField4(iprot thrift.TProtocol) error { + _field := NewUploadFileOption() + if err := _field.Read(iprot); err != nil { + return err + } + p.Option = _field + return nil +} +func (p *UploadFileForServerRequest) ReadField255(iprot thrift.TProtocol) error { + _field := base.NewBase() + if err := _field.Read(iprot); err != nil { + return err + } + p.Base = _field + return nil +} + +func (p *UploadFileForServerRequest) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("UploadFileForServerRequest"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } + if err = p.writeField3(oprot); err != nil { + fieldId = 3 + goto WriteFieldError + } + if err = p.writeField4(oprot); err != nil { + fieldId = 4 + goto WriteFieldError + } + if err = p.writeField255(oprot); err != nil { + fieldId = 255 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *UploadFileForServerRequest) writeField1(oprot thrift.TProtocol) (err error) { + if err = oprot.WriteFieldBegin("mime_type", thrift.STRING, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(p.MimeType); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} +func (p *UploadFileForServerRequest) writeField2(oprot thrift.TProtocol) (err error) { + if err = oprot.WriteFieldBegin("body", thrift.STRING, 2); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteBinary([]byte(p.Body)); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} +func (p *UploadFileForServerRequest) writeField3(oprot thrift.TProtocol) (err error) { + if err = oprot.WriteFieldBegin("workspace_id", thrift.I64, 3); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI64(p.WorkspaceID); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) +} +func (p *UploadFileForServerRequest) writeField4(oprot thrift.TProtocol) (err error) { + if p.IsSetOption() { + if err = oprot.WriteFieldBegin("option", thrift.STRUCT, 4); err != nil { + goto WriteFieldBeginError + } + if err := p.Option.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 end error: ", p), err) +} +func (p *UploadFileForServerRequest) writeField255(oprot thrift.TProtocol) (err error) { + if p.IsSetBase() { + if err = oprot.WriteFieldBegin("Base", thrift.STRUCT, 255); err != nil { + goto WriteFieldBeginError + } + if err := p.Base.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 255 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 255 end error: ", p), err) +} + +func (p *UploadFileForServerRequest) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("UploadFileForServerRequest(%+v)", *p) + +} + +func (p *UploadFileForServerRequest) DeepEqual(ano *UploadFileForServerRequest) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.MimeType) { + return false + } + if !p.Field2DeepEqual(ano.Body) { + return false + } + if !p.Field3DeepEqual(ano.WorkspaceID) { + return false + } + if !p.Field4DeepEqual(ano.Option) { + return false + } + if !p.Field255DeepEqual(ano.Base) { + return false + } + return true +} + +func (p *UploadFileForServerRequest) Field1DeepEqual(src string) bool { + + if strings.Compare(p.MimeType, src) != 0 { + return false + } + return true +} +func (p *UploadFileForServerRequest) Field2DeepEqual(src []byte) bool { + + if bytes.Compare(p.Body, src) != 0 { + return false + } + return true +} +func (p *UploadFileForServerRequest) Field3DeepEqual(src int64) bool { + + if p.WorkspaceID != src { + return false + } + return true +} +func (p *UploadFileForServerRequest) Field4DeepEqual(src *UploadFileOption) bool { + + if !p.Option.DeepEqual(src) { + return false + } + return true +} +func (p *UploadFileForServerRequest) Field255DeepEqual(src *base.Base) bool { + + if !p.Base.DeepEqual(src) { + return false + } + return true +} + +type UploadFileForServerResponse struct { + Code *int32 `thrift:"code,1,optional" frugal:"1,optional,i32" form:"code" json:"code,omitempty" query:"code"` + Msg *string `thrift:"msg,2,optional" frugal:"2,optional,string" form:"msg" json:"msg,omitempty" query:"msg"` + Data *FileData `thrift:"data,3,optional" frugal:"3,optional,FileData" form:"data" json:"data,omitempty" query:"data"` + BaseResp *base.BaseResp `thrift:"BaseResp,255" frugal:"255,default,base.BaseResp" form:"BaseResp" json:"BaseResp" query:"BaseResp"` +} + +func NewUploadFileForServerResponse() *UploadFileForServerResponse { + return &UploadFileForServerResponse{} +} + +func (p *UploadFileForServerResponse) InitDefault() { +} + +var UploadFileForServerResponse_Code_DEFAULT int32 + +func (p *UploadFileForServerResponse) GetCode() (v int32) { + if p == nil { + return + } + if !p.IsSetCode() { + return UploadFileForServerResponse_Code_DEFAULT + } + return *p.Code +} + +var UploadFileForServerResponse_Msg_DEFAULT string + +func (p *UploadFileForServerResponse) GetMsg() (v string) { + if p == nil { + return + } + if !p.IsSetMsg() { + return UploadFileForServerResponse_Msg_DEFAULT + } + return *p.Msg +} + +var UploadFileForServerResponse_Data_DEFAULT *FileData + +func (p *UploadFileForServerResponse) GetData() (v *FileData) { + if p == nil { + return + } + if !p.IsSetData() { + return UploadFileForServerResponse_Data_DEFAULT + } + return p.Data +} + +var UploadFileForServerResponse_BaseResp_DEFAULT *base.BaseResp + +func (p *UploadFileForServerResponse) GetBaseResp() (v *base.BaseResp) { + if p == nil { + return + } + if !p.IsSetBaseResp() { + return UploadFileForServerResponse_BaseResp_DEFAULT + } + return p.BaseResp +} +func (p *UploadFileForServerResponse) SetCode(val *int32) { + p.Code = val +} +func (p *UploadFileForServerResponse) SetMsg(val *string) { + p.Msg = val +} +func (p *UploadFileForServerResponse) SetData(val *FileData) { + p.Data = val +} +func (p *UploadFileForServerResponse) SetBaseResp(val *base.BaseResp) { + p.BaseResp = val +} + +var fieldIDToName_UploadFileForServerResponse = map[int16]string{ + 1: "code", + 2: "msg", + 3: "data", + 255: "BaseResp", +} + +func (p *UploadFileForServerResponse) IsSetCode() bool { + return p.Code != nil +} + +func (p *UploadFileForServerResponse) IsSetMsg() bool { + return p.Msg != nil +} + +func (p *UploadFileForServerResponse) IsSetData() bool { + return p.Data != nil +} + +func (p *UploadFileForServerResponse) IsSetBaseResp() bool { + return p.BaseResp != nil +} + +func (p *UploadFileForServerResponse) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.I32 { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 2: + if fieldTypeId == thrift.STRING { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 3: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField3(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 255: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField255(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_UploadFileForServerResponse[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *UploadFileForServerResponse) ReadField1(iprot thrift.TProtocol) error { + + var _field *int32 + if v, err := iprot.ReadI32(); err != nil { + return err + } else { + _field = &v + } + p.Code = _field + return nil +} +func (p *UploadFileForServerResponse) ReadField2(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Msg = _field + return nil +} +func (p *UploadFileForServerResponse) ReadField3(iprot thrift.TProtocol) error { + _field := NewFileData() + if err := _field.Read(iprot); err != nil { + return err + } + p.Data = _field + return nil +} +func (p *UploadFileForServerResponse) ReadField255(iprot thrift.TProtocol) error { + _field := base.NewBaseResp() + if err := _field.Read(iprot); err != nil { + return err + } + p.BaseResp = _field + return nil +} + +func (p *UploadFileForServerResponse) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("UploadFileForServerResponse"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } + if err = p.writeField3(oprot); err != nil { + fieldId = 3 + goto WriteFieldError + } + if err = p.writeField255(oprot); err != nil { + fieldId = 255 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *UploadFileForServerResponse) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetCode() { + if err = oprot.WriteFieldBegin("code", thrift.I32, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI32(*p.Code); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} +func (p *UploadFileForServerResponse) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetMsg() { + if err = oprot.WriteFieldBegin("msg", thrift.STRING, 2); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Msg); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} +func (p *UploadFileForServerResponse) writeField3(oprot thrift.TProtocol) (err error) { + if p.IsSetData() { + if err = oprot.WriteFieldBegin("data", thrift.STRUCT, 3); err != nil { + goto WriteFieldBeginError + } + if err := p.Data.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) +} +func (p *UploadFileForServerResponse) writeField255(oprot thrift.TProtocol) (err error) { + if err = oprot.WriteFieldBegin("BaseResp", thrift.STRUCT, 255); err != nil { + goto WriteFieldBeginError + } + if err := p.BaseResp.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 255 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 255 end error: ", p), err) +} + +func (p *UploadFileForServerResponse) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("UploadFileForServerResponse(%+v)", *p) + +} + +func (p *UploadFileForServerResponse) DeepEqual(ano *UploadFileForServerResponse) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.Code) { + return false + } + if !p.Field2DeepEqual(ano.Msg) { + return false + } + if !p.Field3DeepEqual(ano.Data) { + return false + } + if !p.Field255DeepEqual(ano.BaseResp) { + return false + } + return true +} + +func (p *UploadFileForServerResponse) Field1DeepEqual(src *int32) bool { + + if p.Code == src { + return true + } else if p.Code == nil || src == nil { + return false + } + if *p.Code != *src { + return false + } + return true +} +func (p *UploadFileForServerResponse) Field2DeepEqual(src *string) bool { + + if p.Msg == src { + return true + } else if p.Msg == nil || src == nil { + return false + } + if strings.Compare(*p.Msg, *src) != 0 { + return false + } + return true +} +func (p *UploadFileForServerResponse) Field3DeepEqual(src *FileData) bool { + + if !p.Data.DeepEqual(src) { + return false + } + return true +} +func (p *UploadFileForServerResponse) Field255DeepEqual(src *base.BaseResp) bool { + + if !p.BaseResp.DeepEqual(src) { + return false + } + return true +} + type FileService interface { UploadLoopFileInner(ctx context.Context, req *UploadLoopFileInnerRequest) (r *UploadLoopFileInnerResponse, err error) + UploadFileForServer(ctx context.Context, req *UploadFileForServerRequest) (r *UploadFileForServerResponse, err error) + SignUploadFile(ctx context.Context, req *SignUploadFileRequest) (r *SignUploadFileResponse, err error) SignDownloadFile(ctx context.Context, req *SignDownloadFileRequest) (r *SignDownloadFileResponse, err error) @@ -4032,6 +5169,15 @@ func (p *FileServiceClient) UploadLoopFileInner(ctx context.Context, req *Upload } return _result.GetSuccess(), nil } +func (p *FileServiceClient) UploadFileForServer(ctx context.Context, req *UploadFileForServerRequest) (r *UploadFileForServerResponse, err error) { + var _args FileServiceUploadFileForServerArgs + _args.Req = req + var _result FileServiceUploadFileForServerResult + if err = p.Client_().Call(ctx, "UploadFileForServer", &_args, &_result); err != nil { + return + } + return _result.GetSuccess(), nil +} func (p *FileServiceClient) SignUploadFile(ctx context.Context, req *SignUploadFileRequest) (r *SignUploadFileResponse, err error) { var _args FileServiceSignUploadFileArgs _args.Req = req @@ -4072,6 +5218,7 @@ func (p *FileServiceProcessor) ProcessorMap() map[string]thrift.TProcessorFuncti func NewFileServiceProcessor(handler FileService) *FileServiceProcessor { self := &FileServiceProcessor{handler: handler, processorMap: make(map[string]thrift.TProcessorFunction)} self.AddToProcessorMap("UploadLoopFileInner", &fileServiceProcessorUploadLoopFileInner{handler: handler}) + self.AddToProcessorMap("UploadFileForServer", &fileServiceProcessorUploadFileForServer{handler: handler}) self.AddToProcessorMap("SignUploadFile", &fileServiceProcessorSignUploadFile{handler: handler}) self.AddToProcessorMap("SignDownloadFile", &fileServiceProcessorSignDownloadFile{handler: handler}) return self @@ -4081,29 +5228,173 @@ func (p *FileServiceProcessor) Process(ctx context.Context, iprot, oprot thrift. if err != nil { return false, err } - if processor, ok := p.GetProcessorFunction(name); ok { - return processor.Process(ctx, seqId, iprot, oprot) - } - iprot.Skip(thrift.STRUCT) + if processor, ok := p.GetProcessorFunction(name); ok { + return processor.Process(ctx, seqId, iprot, oprot) + } + iprot.Skip(thrift.STRUCT) + iprot.ReadMessageEnd() + x := thrift.NewTApplicationException(thrift.UNKNOWN_METHOD, "Unknown function "+name) + oprot.WriteMessageBegin(name, thrift.EXCEPTION, seqId) + x.Write(oprot) + oprot.WriteMessageEnd() + oprot.Flush(ctx) + return false, x +} + +type fileServiceProcessorUploadLoopFileInner struct { + handler FileService +} + +func (p *fileServiceProcessorUploadLoopFileInner) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { + args := FileServiceUploadLoopFileInnerArgs{} + if err = args.Read(iprot); err != nil { + iprot.ReadMessageEnd() + x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) + oprot.WriteMessageBegin("UploadLoopFileInner", thrift.EXCEPTION, seqId) + x.Write(oprot) + oprot.WriteMessageEnd() + oprot.Flush(ctx) + return false, err + } + + iprot.ReadMessageEnd() + var err2 error + result := FileServiceUploadLoopFileInnerResult{} + var retval *UploadLoopFileInnerResponse + if retval, err2 = p.handler.UploadLoopFileInner(ctx, args.Req); err2 != nil { + x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing UploadLoopFileInner: "+err2.Error()) + oprot.WriteMessageBegin("UploadLoopFileInner", thrift.EXCEPTION, seqId) + x.Write(oprot) + oprot.WriteMessageEnd() + oprot.Flush(ctx) + return true, err2 + } else { + result.Success = retval + } + if err2 = oprot.WriteMessageBegin("UploadLoopFileInner", thrift.REPLY, seqId); err2 != nil { + err = err2 + } + if err2 = result.Write(oprot); err == nil && err2 != nil { + err = err2 + } + if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { + err = err2 + } + if err2 = oprot.Flush(ctx); err == nil && err2 != nil { + err = err2 + } + if err != nil { + return + } + return true, err +} + +type fileServiceProcessorUploadFileForServer struct { + handler FileService +} + +func (p *fileServiceProcessorUploadFileForServer) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { + args := FileServiceUploadFileForServerArgs{} + if err = args.Read(iprot); err != nil { + iprot.ReadMessageEnd() + x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) + oprot.WriteMessageBegin("UploadFileForServer", thrift.EXCEPTION, seqId) + x.Write(oprot) + oprot.WriteMessageEnd() + oprot.Flush(ctx) + return false, err + } + + iprot.ReadMessageEnd() + var err2 error + result := FileServiceUploadFileForServerResult{} + var retval *UploadFileForServerResponse + if retval, err2 = p.handler.UploadFileForServer(ctx, args.Req); err2 != nil { + x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing UploadFileForServer: "+err2.Error()) + oprot.WriteMessageBegin("UploadFileForServer", thrift.EXCEPTION, seqId) + x.Write(oprot) + oprot.WriteMessageEnd() + oprot.Flush(ctx) + return true, err2 + } else { + result.Success = retval + } + if err2 = oprot.WriteMessageBegin("UploadFileForServer", thrift.REPLY, seqId); err2 != nil { + err = err2 + } + if err2 = result.Write(oprot); err == nil && err2 != nil { + err = err2 + } + if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { + err = err2 + } + if err2 = oprot.Flush(ctx); err == nil && err2 != nil { + err = err2 + } + if err != nil { + return + } + return true, err +} + +type fileServiceProcessorSignUploadFile struct { + handler FileService +} + +func (p *fileServiceProcessorSignUploadFile) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { + args := FileServiceSignUploadFileArgs{} + if err = args.Read(iprot); err != nil { + iprot.ReadMessageEnd() + x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) + oprot.WriteMessageBegin("SignUploadFile", thrift.EXCEPTION, seqId) + x.Write(oprot) + oprot.WriteMessageEnd() + oprot.Flush(ctx) + return false, err + } + iprot.ReadMessageEnd() - x := thrift.NewTApplicationException(thrift.UNKNOWN_METHOD, "Unknown function "+name) - oprot.WriteMessageBegin(name, thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return false, x + var err2 error + result := FileServiceSignUploadFileResult{} + var retval *SignUploadFileResponse + if retval, err2 = p.handler.SignUploadFile(ctx, args.Req); err2 != nil { + x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing SignUploadFile: "+err2.Error()) + oprot.WriteMessageBegin("SignUploadFile", thrift.EXCEPTION, seqId) + x.Write(oprot) + oprot.WriteMessageEnd() + oprot.Flush(ctx) + return true, err2 + } else { + result.Success = retval + } + if err2 = oprot.WriteMessageBegin("SignUploadFile", thrift.REPLY, seqId); err2 != nil { + err = err2 + } + if err2 = result.Write(oprot); err == nil && err2 != nil { + err = err2 + } + if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { + err = err2 + } + if err2 = oprot.Flush(ctx); err == nil && err2 != nil { + err = err2 + } + if err != nil { + return + } + return true, err } -type fileServiceProcessorUploadLoopFileInner struct { +type fileServiceProcessorSignDownloadFile struct { handler FileService } -func (p *fileServiceProcessorUploadLoopFileInner) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { - args := FileServiceUploadLoopFileInnerArgs{} +func (p *fileServiceProcessorSignDownloadFile) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { + args := FileServiceSignDownloadFileArgs{} if err = args.Read(iprot); err != nil { iprot.ReadMessageEnd() x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) - oprot.WriteMessageBegin("UploadLoopFileInner", thrift.EXCEPTION, seqId) + oprot.WriteMessageBegin("SignDownloadFile", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush(ctx) @@ -4112,11 +5403,11 @@ func (p *fileServiceProcessorUploadLoopFileInner) Process(ctx context.Context, s iprot.ReadMessageEnd() var err2 error - result := FileServiceUploadLoopFileInnerResult{} - var retval *UploadLoopFileInnerResponse - if retval, err2 = p.handler.UploadLoopFileInner(ctx, args.Req); err2 != nil { - x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing UploadLoopFileInner: "+err2.Error()) - oprot.WriteMessageBegin("UploadLoopFileInner", thrift.EXCEPTION, seqId) + result := FileServiceSignDownloadFileResult{} + var retval *SignDownloadFileResponse + if retval, err2 = p.handler.SignDownloadFile(ctx, args.Req); err2 != nil { + x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing SignDownloadFile: "+err2.Error()) + oprot.WriteMessageBegin("SignDownloadFile", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush(ctx) @@ -4124,7 +5415,7 @@ func (p *fileServiceProcessorUploadLoopFileInner) Process(ctx context.Context, s } else { result.Success = retval } - if err2 = oprot.WriteMessageBegin("UploadLoopFileInner", thrift.REPLY, seqId); err2 != nil { + if err2 = oprot.WriteMessageBegin("SignDownloadFile", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { @@ -4139,140 +5430,388 @@ func (p *fileServiceProcessorUploadLoopFileInner) Process(ctx context.Context, s if err != nil { return } - return true, err + return true, err +} + +type FileServiceUploadLoopFileInnerArgs struct { + Req *UploadLoopFileInnerRequest `thrift:"req,1" frugal:"1,default,UploadLoopFileInnerRequest"` +} + +func NewFileServiceUploadLoopFileInnerArgs() *FileServiceUploadLoopFileInnerArgs { + return &FileServiceUploadLoopFileInnerArgs{} +} + +func (p *FileServiceUploadLoopFileInnerArgs) InitDefault() { +} + +var FileServiceUploadLoopFileInnerArgs_Req_DEFAULT *UploadLoopFileInnerRequest + +func (p *FileServiceUploadLoopFileInnerArgs) GetReq() (v *UploadLoopFileInnerRequest) { + if p == nil { + return + } + if !p.IsSetReq() { + return FileServiceUploadLoopFileInnerArgs_Req_DEFAULT + } + return p.Req +} +func (p *FileServiceUploadLoopFileInnerArgs) SetReq(val *UploadLoopFileInnerRequest) { + p.Req = val +} + +var fieldIDToName_FileServiceUploadLoopFileInnerArgs = map[int16]string{ + 1: "req", +} + +func (p *FileServiceUploadLoopFileInnerArgs) IsSetReq() bool { + return p.Req != nil +} + +func (p *FileServiceUploadLoopFileInnerArgs) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_FileServiceUploadLoopFileInnerArgs[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *FileServiceUploadLoopFileInnerArgs) ReadField1(iprot thrift.TProtocol) error { + _field := NewUploadLoopFileInnerRequest() + if err := _field.Read(iprot); err != nil { + return err + } + p.Req = _field + return nil +} + +func (p *FileServiceUploadLoopFileInnerArgs) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("UploadLoopFileInner_args"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *FileServiceUploadLoopFileInnerArgs) writeField1(oprot thrift.TProtocol) (err error) { + if err = oprot.WriteFieldBegin("req", thrift.STRUCT, 1); err != nil { + goto WriteFieldBeginError + } + if err := p.Req.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} + +func (p *FileServiceUploadLoopFileInnerArgs) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("FileServiceUploadLoopFileInnerArgs(%+v)", *p) + +} + +func (p *FileServiceUploadLoopFileInnerArgs) DeepEqual(ano *FileServiceUploadLoopFileInnerArgs) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.Req) { + return false + } + return true +} + +func (p *FileServiceUploadLoopFileInnerArgs) Field1DeepEqual(src *UploadLoopFileInnerRequest) bool { + + if !p.Req.DeepEqual(src) { + return false + } + return true +} + +type FileServiceUploadLoopFileInnerResult struct { + Success *UploadLoopFileInnerResponse `thrift:"success,0,optional" frugal:"0,optional,UploadLoopFileInnerResponse"` +} + +func NewFileServiceUploadLoopFileInnerResult() *FileServiceUploadLoopFileInnerResult { + return &FileServiceUploadLoopFileInnerResult{} +} + +func (p *FileServiceUploadLoopFileInnerResult) InitDefault() { +} + +var FileServiceUploadLoopFileInnerResult_Success_DEFAULT *UploadLoopFileInnerResponse + +func (p *FileServiceUploadLoopFileInnerResult) GetSuccess() (v *UploadLoopFileInnerResponse) { + if p == nil { + return + } + if !p.IsSetSuccess() { + return FileServiceUploadLoopFileInnerResult_Success_DEFAULT + } + return p.Success +} +func (p *FileServiceUploadLoopFileInnerResult) SetSuccess(x interface{}) { + p.Success = x.(*UploadLoopFileInnerResponse) } -type fileServiceProcessorSignUploadFile struct { - handler FileService +var fieldIDToName_FileServiceUploadLoopFileInnerResult = map[int16]string{ + 0: "success", } -func (p *fileServiceProcessorSignUploadFile) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { - args := FileServiceSignUploadFileArgs{} - if err = args.Read(iprot); err != nil { - iprot.ReadMessageEnd() - x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) - oprot.WriteMessageBegin("SignUploadFile", thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return false, err +func (p *FileServiceUploadLoopFileInnerResult) IsSetSuccess() bool { + return p.Success != nil +} + +func (p *FileServiceUploadLoopFileInnerResult) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError } - iprot.ReadMessageEnd() - var err2 error - result := FileServiceSignUploadFileResult{} - var retval *SignUploadFileResponse - if retval, err2 = p.handler.SignUploadFile(ctx, args.Req); err2 != nil { - x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing SignUploadFile: "+err2.Error()) - oprot.WriteMessageBegin("SignUploadFile", thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return true, err2 - } else { - result.Success = retval + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 0: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField0(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } } - if err2 = oprot.WriteMessageBegin("SignUploadFile", thrift.REPLY, seqId); err2 != nil { - err = err2 + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError } - if err2 = result.Write(oprot); err == nil && err2 != nil { - err = err2 + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_FileServiceUploadLoopFileInnerResult[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *FileServiceUploadLoopFileInnerResult) ReadField0(iprot thrift.TProtocol) error { + _field := NewUploadLoopFileInnerResponse() + if err := _field.Read(iprot); err != nil { + return err } - if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { - err = err2 + p.Success = _field + return nil +} + +func (p *FileServiceUploadLoopFileInnerResult) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("UploadLoopFileInner_result"); err != nil { + goto WriteStructBeginError } - if err2 = oprot.Flush(ctx); err == nil && err2 != nil { - err = err2 + if p != nil { + if err = p.writeField0(oprot); err != nil { + fieldId = 0 + goto WriteFieldError + } } - if err != nil { - return + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError } - return true, err + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } -type fileServiceProcessorSignDownloadFile struct { - handler FileService +func (p *FileServiceUploadLoopFileInnerResult) writeField0(oprot thrift.TProtocol) (err error) { + if p.IsSetSuccess() { + if err = oprot.WriteFieldBegin("success", thrift.STRUCT, 0); err != nil { + goto WriteFieldBeginError + } + if err := p.Success.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 0 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 0 end error: ", p), err) } -func (p *fileServiceProcessorSignDownloadFile) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { - args := FileServiceSignDownloadFileArgs{} - if err = args.Read(iprot); err != nil { - iprot.ReadMessageEnd() - x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) - oprot.WriteMessageBegin("SignDownloadFile", thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return false, err +func (p *FileServiceUploadLoopFileInnerResult) String() string { + if p == nil { + return "" } + return fmt.Sprintf("FileServiceUploadLoopFileInnerResult(%+v)", *p) - iprot.ReadMessageEnd() - var err2 error - result := FileServiceSignDownloadFileResult{} - var retval *SignDownloadFileResponse - if retval, err2 = p.handler.SignDownloadFile(ctx, args.Req); err2 != nil { - x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing SignDownloadFile: "+err2.Error()) - oprot.WriteMessageBegin("SignDownloadFile", thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return true, err2 - } else { - result.Success = retval - } - if err2 = oprot.WriteMessageBegin("SignDownloadFile", thrift.REPLY, seqId); err2 != nil { - err = err2 - } - if err2 = result.Write(oprot); err == nil && err2 != nil { - err = err2 - } - if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { - err = err2 +} + +func (p *FileServiceUploadLoopFileInnerResult) DeepEqual(ano *FileServiceUploadLoopFileInnerResult) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false } - if err2 = oprot.Flush(ctx); err == nil && err2 != nil { - err = err2 + if !p.Field0DeepEqual(ano.Success) { + return false } - if err != nil { - return + return true +} + +func (p *FileServiceUploadLoopFileInnerResult) Field0DeepEqual(src *UploadLoopFileInnerResponse) bool { + + if !p.Success.DeepEqual(src) { + return false } - return true, err + return true } -type FileServiceUploadLoopFileInnerArgs struct { - Req *UploadLoopFileInnerRequest `thrift:"req,1" frugal:"1,default,UploadLoopFileInnerRequest"` +type FileServiceUploadFileForServerArgs struct { + Req *UploadFileForServerRequest `thrift:"req,1" frugal:"1,default,UploadFileForServerRequest"` } -func NewFileServiceUploadLoopFileInnerArgs() *FileServiceUploadLoopFileInnerArgs { - return &FileServiceUploadLoopFileInnerArgs{} +func NewFileServiceUploadFileForServerArgs() *FileServiceUploadFileForServerArgs { + return &FileServiceUploadFileForServerArgs{} } -func (p *FileServiceUploadLoopFileInnerArgs) InitDefault() { +func (p *FileServiceUploadFileForServerArgs) InitDefault() { } -var FileServiceUploadLoopFileInnerArgs_Req_DEFAULT *UploadLoopFileInnerRequest +var FileServiceUploadFileForServerArgs_Req_DEFAULT *UploadFileForServerRequest -func (p *FileServiceUploadLoopFileInnerArgs) GetReq() (v *UploadLoopFileInnerRequest) { +func (p *FileServiceUploadFileForServerArgs) GetReq() (v *UploadFileForServerRequest) { if p == nil { return } if !p.IsSetReq() { - return FileServiceUploadLoopFileInnerArgs_Req_DEFAULT + return FileServiceUploadFileForServerArgs_Req_DEFAULT } return p.Req } -func (p *FileServiceUploadLoopFileInnerArgs) SetReq(val *UploadLoopFileInnerRequest) { +func (p *FileServiceUploadFileForServerArgs) SetReq(val *UploadFileForServerRequest) { p.Req = val } -var fieldIDToName_FileServiceUploadLoopFileInnerArgs = map[int16]string{ +var fieldIDToName_FileServiceUploadFileForServerArgs = map[int16]string{ 1: "req", } -func (p *FileServiceUploadLoopFileInnerArgs) IsSetReq() bool { +func (p *FileServiceUploadFileForServerArgs) IsSetReq() bool { return p.Req != nil } -func (p *FileServiceUploadLoopFileInnerArgs) Read(iprot thrift.TProtocol) (err error) { +func (p *FileServiceUploadFileForServerArgs) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -4317,7 +5856,7 @@ ReadStructBeginError: ReadFieldBeginError: return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_FileServiceUploadLoopFileInnerArgs[fieldId]), err) + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_FileServiceUploadFileForServerArgs[fieldId]), err) SkipFieldError: return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) @@ -4327,8 +5866,8 @@ ReadStructEndError: return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } -func (p *FileServiceUploadLoopFileInnerArgs) ReadField1(iprot thrift.TProtocol) error { - _field := NewUploadLoopFileInnerRequest() +func (p *FileServiceUploadFileForServerArgs) ReadField1(iprot thrift.TProtocol) error { + _field := NewUploadFileForServerRequest() if err := _field.Read(iprot); err != nil { return err } @@ -4336,9 +5875,9 @@ func (p *FileServiceUploadLoopFileInnerArgs) ReadField1(iprot thrift.TProtocol) return nil } -func (p *FileServiceUploadLoopFileInnerArgs) Write(oprot thrift.TProtocol) (err error) { +func (p *FileServiceUploadFileForServerArgs) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 - if err = oprot.WriteStructBegin("UploadLoopFileInner_args"); err != nil { + if err = oprot.WriteStructBegin("UploadFileForServer_args"); err != nil { goto WriteStructBeginError } if p != nil { @@ -4364,7 +5903,7 @@ WriteStructEndError: return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } -func (p *FileServiceUploadLoopFileInnerArgs) writeField1(oprot thrift.TProtocol) (err error) { +func (p *FileServiceUploadFileForServerArgs) writeField1(oprot thrift.TProtocol) (err error) { if err = oprot.WriteFieldBegin("req", thrift.STRUCT, 1); err != nil { goto WriteFieldBeginError } @@ -4381,15 +5920,15 @@ WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) } -func (p *FileServiceUploadLoopFileInnerArgs) String() string { +func (p *FileServiceUploadFileForServerArgs) String() string { if p == nil { return "" } - return fmt.Sprintf("FileServiceUploadLoopFileInnerArgs(%+v)", *p) + return fmt.Sprintf("FileServiceUploadFileForServerArgs(%+v)", *p) } -func (p *FileServiceUploadLoopFileInnerArgs) DeepEqual(ano *FileServiceUploadLoopFileInnerArgs) bool { +func (p *FileServiceUploadFileForServerArgs) DeepEqual(ano *FileServiceUploadFileForServerArgs) bool { if p == ano { return true } else if p == nil || ano == nil { @@ -4401,7 +5940,7 @@ func (p *FileServiceUploadLoopFileInnerArgs) DeepEqual(ano *FileServiceUploadLoo return true } -func (p *FileServiceUploadLoopFileInnerArgs) Field1DeepEqual(src *UploadLoopFileInnerRequest) bool { +func (p *FileServiceUploadFileForServerArgs) Field1DeepEqual(src *UploadFileForServerRequest) bool { if !p.Req.DeepEqual(src) { return false @@ -4409,41 +5948,41 @@ func (p *FileServiceUploadLoopFileInnerArgs) Field1DeepEqual(src *UploadLoopFile return true } -type FileServiceUploadLoopFileInnerResult struct { - Success *UploadLoopFileInnerResponse `thrift:"success,0,optional" frugal:"0,optional,UploadLoopFileInnerResponse"` +type FileServiceUploadFileForServerResult struct { + Success *UploadFileForServerResponse `thrift:"success,0,optional" frugal:"0,optional,UploadFileForServerResponse"` } -func NewFileServiceUploadLoopFileInnerResult() *FileServiceUploadLoopFileInnerResult { - return &FileServiceUploadLoopFileInnerResult{} +func NewFileServiceUploadFileForServerResult() *FileServiceUploadFileForServerResult { + return &FileServiceUploadFileForServerResult{} } -func (p *FileServiceUploadLoopFileInnerResult) InitDefault() { +func (p *FileServiceUploadFileForServerResult) InitDefault() { } -var FileServiceUploadLoopFileInnerResult_Success_DEFAULT *UploadLoopFileInnerResponse +var FileServiceUploadFileForServerResult_Success_DEFAULT *UploadFileForServerResponse -func (p *FileServiceUploadLoopFileInnerResult) GetSuccess() (v *UploadLoopFileInnerResponse) { +func (p *FileServiceUploadFileForServerResult) GetSuccess() (v *UploadFileForServerResponse) { if p == nil { return } if !p.IsSetSuccess() { - return FileServiceUploadLoopFileInnerResult_Success_DEFAULT + return FileServiceUploadFileForServerResult_Success_DEFAULT } return p.Success } -func (p *FileServiceUploadLoopFileInnerResult) SetSuccess(x interface{}) { - p.Success = x.(*UploadLoopFileInnerResponse) +func (p *FileServiceUploadFileForServerResult) SetSuccess(x interface{}) { + p.Success = x.(*UploadFileForServerResponse) } -var fieldIDToName_FileServiceUploadLoopFileInnerResult = map[int16]string{ +var fieldIDToName_FileServiceUploadFileForServerResult = map[int16]string{ 0: "success", } -func (p *FileServiceUploadLoopFileInnerResult) IsSetSuccess() bool { +func (p *FileServiceUploadFileForServerResult) IsSetSuccess() bool { return p.Success != nil } -func (p *FileServiceUploadLoopFileInnerResult) Read(iprot thrift.TProtocol) (err error) { +func (p *FileServiceUploadFileForServerResult) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -4488,7 +6027,7 @@ ReadStructBeginError: ReadFieldBeginError: return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_FileServiceUploadLoopFileInnerResult[fieldId]), err) + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_FileServiceUploadFileForServerResult[fieldId]), err) SkipFieldError: return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) @@ -4498,8 +6037,8 @@ ReadStructEndError: return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } -func (p *FileServiceUploadLoopFileInnerResult) ReadField0(iprot thrift.TProtocol) error { - _field := NewUploadLoopFileInnerResponse() +func (p *FileServiceUploadFileForServerResult) ReadField0(iprot thrift.TProtocol) error { + _field := NewUploadFileForServerResponse() if err := _field.Read(iprot); err != nil { return err } @@ -4507,9 +6046,9 @@ func (p *FileServiceUploadLoopFileInnerResult) ReadField0(iprot thrift.TProtocol return nil } -func (p *FileServiceUploadLoopFileInnerResult) Write(oprot thrift.TProtocol) (err error) { +func (p *FileServiceUploadFileForServerResult) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 - if err = oprot.WriteStructBegin("UploadLoopFileInner_result"); err != nil { + if err = oprot.WriteStructBegin("UploadFileForServer_result"); err != nil { goto WriteStructBeginError } if p != nil { @@ -4535,7 +6074,7 @@ WriteStructEndError: return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } -func (p *FileServiceUploadLoopFileInnerResult) writeField0(oprot thrift.TProtocol) (err error) { +func (p *FileServiceUploadFileForServerResult) writeField0(oprot thrift.TProtocol) (err error) { if p.IsSetSuccess() { if err = oprot.WriteFieldBegin("success", thrift.STRUCT, 0); err != nil { goto WriteFieldBeginError @@ -4554,15 +6093,15 @@ WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 0 end error: ", p), err) } -func (p *FileServiceUploadLoopFileInnerResult) String() string { +func (p *FileServiceUploadFileForServerResult) String() string { if p == nil { return "" } - return fmt.Sprintf("FileServiceUploadLoopFileInnerResult(%+v)", *p) + return fmt.Sprintf("FileServiceUploadFileForServerResult(%+v)", *p) } -func (p *FileServiceUploadLoopFileInnerResult) DeepEqual(ano *FileServiceUploadLoopFileInnerResult) bool { +func (p *FileServiceUploadFileForServerResult) DeepEqual(ano *FileServiceUploadFileForServerResult) bool { if p == ano { return true } else if p == nil || ano == nil { @@ -4574,7 +6113,7 @@ func (p *FileServiceUploadLoopFileInnerResult) DeepEqual(ano *FileServiceUploadL return true } -func (p *FileServiceUploadLoopFileInnerResult) Field0DeepEqual(src *UploadLoopFileInnerResponse) bool { +func (p *FileServiceUploadFileForServerResult) Field0DeepEqual(src *UploadFileForServerResponse) bool { if !p.Success.DeepEqual(src) { return false diff --git a/backend/kitex_gen/coze/loop/foundation/file/coze.loop.foundation.file_validator.go b/backend/kitex_gen/coze/loop/foundation/file/coze.loop.foundation.file_validator.go index 9788fed88..04fa376c2 100644 --- a/backend/kitex_gen/coze/loop/foundation/file/coze.loop.foundation.file_validator.go +++ b/backend/kitex_gen/coze/loop/foundation/file/coze.loop.foundation.file_validator.go @@ -114,3 +114,32 @@ func (p *SignDownloadFileResponse) IsValid() error { } return nil } +func (p *UploadFileOption) IsValid() error { + return nil +} +func (p *UploadFileForServerRequest) IsValid() error { + if p.Option != nil { + if err := p.Option.IsValid(); err != nil { + return fmt.Errorf("field Option not valid, %w", err) + } + } + if p.Base != nil { + if err := p.Base.IsValid(); err != nil { + return fmt.Errorf("field Base not valid, %w", err) + } + } + return nil +} +func (p *UploadFileForServerResponse) IsValid() error { + if p.Data != nil { + if err := p.Data.IsValid(); err != nil { + return fmt.Errorf("field Data not valid, %w", err) + } + } + if p.BaseResp != nil { + if err := p.BaseResp.IsValid(); err != nil { + return fmt.Errorf("field BaseResp not valid, %w", err) + } + } + return nil +} diff --git a/backend/kitex_gen/coze/loop/foundation/file/fileservice/client.go b/backend/kitex_gen/coze/loop/foundation/file/fileservice/client.go index 23e9a4e3f..ceda72c8b 100644 --- a/backend/kitex_gen/coze/loop/foundation/file/fileservice/client.go +++ b/backend/kitex_gen/coze/loop/foundation/file/fileservice/client.go @@ -12,6 +12,7 @@ import ( // Client is designed to provide IDL-compatible methods with call-option parameter for kitex framework. type Client interface { UploadLoopFileInner(ctx context.Context, req *file.UploadLoopFileInnerRequest, callOptions ...callopt.Option) (r *file.UploadLoopFileInnerResponse, err error) + UploadFileForServer(ctx context.Context, req *file.UploadFileForServerRequest, callOptions ...callopt.Option) (r *file.UploadFileForServerResponse, err error) SignUploadFile(ctx context.Context, req *file.SignUploadFileRequest, callOptions ...callopt.Option) (r *file.SignUploadFileResponse, err error) SignDownloadFile(ctx context.Context, req *file.SignDownloadFileRequest, callOptions ...callopt.Option) (r *file.SignDownloadFileResponse, err error) } @@ -50,6 +51,11 @@ func (p *kFileServiceClient) UploadLoopFileInner(ctx context.Context, req *file. return p.kClient.UploadLoopFileInner(ctx, req) } +func (p *kFileServiceClient) UploadFileForServer(ctx context.Context, req *file.UploadFileForServerRequest, callOptions ...callopt.Option) (r *file.UploadFileForServerResponse, err error) { + ctx = client.NewCtxWithCallOptions(ctx, callOptions) + return p.kClient.UploadFileForServer(ctx, req) +} + func (p *kFileServiceClient) SignUploadFile(ctx context.Context, req *file.SignUploadFileRequest, callOptions ...callopt.Option) (r *file.SignUploadFileResponse, err error) { ctx = client.NewCtxWithCallOptions(ctx, callOptions) return p.kClient.SignUploadFile(ctx, req) diff --git a/backend/kitex_gen/coze/loop/foundation/file/fileservice/fileservice.go b/backend/kitex_gen/coze/loop/foundation/file/fileservice/fileservice.go index 58c27823e..ddb329e89 100644 --- a/backend/kitex_gen/coze/loop/foundation/file/fileservice/fileservice.go +++ b/backend/kitex_gen/coze/loop/foundation/file/fileservice/fileservice.go @@ -20,6 +20,13 @@ var serviceMethods = map[string]kitex.MethodInfo{ false, kitex.WithStreamingMode(kitex.StreamingNone), ), + "UploadFileForServer": kitex.NewMethodInfo( + uploadFileForServerHandler, + newFileServiceUploadFileForServerArgs, + newFileServiceUploadFileForServerResult, + false, + kitex.WithStreamingMode(kitex.StreamingNone), + ), "SignUploadFile": kitex.NewMethodInfo( signUploadFileHandler, newFileServiceSignUploadFileArgs, @@ -86,6 +93,25 @@ func newFileServiceUploadLoopFileInnerResult() interface{} { return file.NewFileServiceUploadLoopFileInnerResult() } +func uploadFileForServerHandler(ctx context.Context, handler interface{}, arg, result interface{}) error { + realArg := arg.(*file.FileServiceUploadFileForServerArgs) + realResult := result.(*file.FileServiceUploadFileForServerResult) + success, err := handler.(file.FileService).UploadFileForServer(ctx, realArg.Req) + if err != nil { + return err + } + realResult.Success = success + return nil +} + +func newFileServiceUploadFileForServerArgs() interface{} { + return file.NewFileServiceUploadFileForServerArgs() +} + +func newFileServiceUploadFileForServerResult() interface{} { + return file.NewFileServiceUploadFileForServerResult() +} + func signUploadFileHandler(ctx context.Context, handler interface{}, arg, result interface{}) error { realArg := arg.(*file.FileServiceSignUploadFileArgs) realResult := result.(*file.FileServiceSignUploadFileResult) @@ -146,6 +172,16 @@ func (p *kClient) UploadLoopFileInner(ctx context.Context, req *file.UploadLoopF return _result.GetSuccess(), nil } +func (p *kClient) UploadFileForServer(ctx context.Context, req *file.UploadFileForServerRequest) (r *file.UploadFileForServerResponse, err error) { + var _args file.FileServiceUploadFileForServerArgs + _args.Req = req + var _result file.FileServiceUploadFileForServerResult + if err = p.c.Call(ctx, "UploadFileForServer", &_args, &_result); err != nil { + return + } + return _result.GetSuccess(), nil +} + func (p *kClient) SignUploadFile(ctx context.Context, req *file.SignUploadFileRequest) (r *file.SignUploadFileResponse, err error) { var _args file.FileServiceSignUploadFileArgs _args.Req = req diff --git a/backend/kitex_gen/coze/loop/foundation/file/k-coze.loop.foundation.file.go b/backend/kitex_gen/coze/loop/foundation/file/k-coze.loop.foundation.file.go index d24e0c817..d596c405c 100644 --- a/backend/kitex_gen/coze/loop/foundation/file/k-coze.loop.foundation.file.go +++ b/backend/kitex_gen/coze/loop/foundation/file/k-coze.loop.foundation.file.go @@ -2890,6 +2890,843 @@ func (p *SignDownloadFileResponse) DeepCopy(s interface{}) error { return nil } +func (p *UploadFileOption) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.MAP { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_UploadFileOption[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *UploadFileOption) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.FileName = _field + return offset, nil +} + +func (p *UploadFileOption) FastReadField2(buf []byte) (int, error) { + offset := 0 + + _, _, size, l, err := thrift.Binary.ReadMapBegin(buf[offset:]) + offset += l + if err != nil { + return offset, err + } + _field := make(map[string]string, size) + for i := 0; i < size; i++ { + var _key string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _key = v + } + + var _val string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _val = v + } + + _field[_key] = _val + } + p.MimeTypeExtMapping = _field + return offset, nil +} + +func (p *UploadFileOption) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *UploadFileOption) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *UploadFileOption) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + l += p.field2Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *UploadFileOption) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetFileName() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 1) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.FileName) + } + return offset +} + +func (p *UploadFileOption) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetMimeTypeExtMapping() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.MAP, 2) + mapBeginOffset := offset + offset += thrift.Binary.MapBeginLength() + var length int + for k, v := range p.MimeTypeExtMapping { + length++ + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, k) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, v) + } + thrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.STRING, thrift.STRING, length) + } + return offset +} + +func (p *UploadFileOption) field1Length() int { + l := 0 + if p.IsSetFileName() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.FileName) + } + return l +} + +func (p *UploadFileOption) field2Length() int { + l := 0 + if p.IsSetMimeTypeExtMapping() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.MapBeginLength() + for k, v := range p.MimeTypeExtMapping { + _, _ = k, v + + l += thrift.Binary.StringLengthNocopy(k) + l += thrift.Binary.StringLengthNocopy(v) + } + } + return l +} + +func (p *UploadFileOption) DeepCopy(s interface{}) error { + src, ok := s.(*UploadFileOption) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.FileName != nil { + var tmp string + if *src.FileName != "" { + tmp = kutils.StringDeepCopy(*src.FileName) + } + p.FileName = &tmp + } + + if src.MimeTypeExtMapping != nil { + p.MimeTypeExtMapping = make(map[string]string, len(src.MimeTypeExtMapping)) + for key, val := range src.MimeTypeExtMapping { + var _key string + if key != "" { + _key = kutils.StringDeepCopy(key) + } + + var _val string + if val != "" { + _val = kutils.StringDeepCopy(val) + } + + p.MimeTypeExtMapping[_key] = _val + } + } + + return nil +} + +func (p *UploadFileForServerRequest) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + var issetMimeType bool = false + var issetBody bool = false + var issetWorkspaceID bool = false + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + issetMimeType = true + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + issetBody = true + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 3: + if fieldTypeId == thrift.I64 { + l, err = p.FastReadField3(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + issetWorkspaceID = true + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 4: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField4(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 255: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField255(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + if !issetMimeType { + fieldId = 1 + goto RequiredFieldNotSetError + } + + if !issetBody { + fieldId = 2 + goto RequiredFieldNotSetError + } + + if !issetWorkspaceID { + fieldId = 3 + goto RequiredFieldNotSetError + } + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_UploadFileForServerRequest[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +RequiredFieldNotSetError: + return offset, thrift.NewProtocolException(thrift.INVALID_DATA, fmt.Sprintf("required field %s is not set", fieldIDToName_UploadFileForServerRequest[fieldId])) +} + +func (p *UploadFileForServerRequest) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = v + } + p.MimeType = _field + return offset, nil +} + +func (p *UploadFileForServerRequest) FastReadField2(buf []byte) (int, error) { + offset := 0 + + var _field []byte + if v, l, err := thrift.Binary.ReadBinary(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + + _field = []byte(v) + } + p.Body = _field + return offset, nil +} + +func (p *UploadFileForServerRequest) FastReadField3(buf []byte) (int, error) { + offset := 0 + + var _field int64 + if v, l, err := thrift.Binary.ReadI64(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = v + } + p.WorkspaceID = _field + return offset, nil +} + +func (p *UploadFileForServerRequest) FastReadField4(buf []byte) (int, error) { + offset := 0 + _field := NewUploadFileOption() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.Option = _field + return offset, nil +} + +func (p *UploadFileForServerRequest) FastReadField255(buf []byte) (int, error) { + offset := 0 + _field := base.NewBase() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.Base = _field + return offset, nil +} + +func (p *UploadFileForServerRequest) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *UploadFileForServerRequest) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField3(buf[offset:], w) + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) + offset += p.fastWriteField4(buf[offset:], w) + offset += p.fastWriteField255(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *UploadFileForServerRequest) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + l += p.field2Length() + l += p.field3Length() + l += p.field4Length() + l += p.field255Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *UploadFileForServerRequest) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 1) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, p.MimeType) + return offset +} + +func (p *UploadFileForServerRequest) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 2) + offset += thrift.Binary.WriteBinaryNocopy(buf[offset:], w, []byte(p.Body)) + return offset +} + +func (p *UploadFileForServerRequest) fastWriteField3(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I64, 3) + offset += thrift.Binary.WriteI64(buf[offset:], p.WorkspaceID) + return offset +} + +func (p *UploadFileForServerRequest) fastWriteField4(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetOption() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 4) + offset += p.Option.FastWriteNocopy(buf[offset:], w) + } + return offset +} + +func (p *UploadFileForServerRequest) fastWriteField255(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetBase() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 255) + offset += p.Base.FastWriteNocopy(buf[offset:], w) + } + return offset +} + +func (p *UploadFileForServerRequest) field1Length() int { + l := 0 + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(p.MimeType) + return l +} + +func (p *UploadFileForServerRequest) field2Length() int { + l := 0 + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.BinaryLengthNocopy([]byte(p.Body)) + return l +} + +func (p *UploadFileForServerRequest) field3Length() int { + l := 0 + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I64Length() + return l +} + +func (p *UploadFileForServerRequest) field4Length() int { + l := 0 + if p.IsSetOption() { + l += thrift.Binary.FieldBeginLength() + l += p.Option.BLength() + } + return l +} + +func (p *UploadFileForServerRequest) field255Length() int { + l := 0 + if p.IsSetBase() { + l += thrift.Binary.FieldBeginLength() + l += p.Base.BLength() + } + return l +} + +func (p *UploadFileForServerRequest) DeepCopy(s interface{}) error { + src, ok := s.(*UploadFileForServerRequest) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.MimeType != "" { + p.MimeType = kutils.StringDeepCopy(src.MimeType) + } + + if len(src.Body) != 0 { + tmp := make([]byte, len(src.Body)) + copy(tmp, src.Body) + p.Body = tmp + } + + p.WorkspaceID = src.WorkspaceID + + var _option *UploadFileOption + if src.Option != nil { + _option = &UploadFileOption{} + if err := _option.DeepCopy(src.Option); err != nil { + return err + } + } + p.Option = _option + + var _base *base.Base + if src.Base != nil { + _base = &base.Base{} + if err := _base.DeepCopy(src.Base); err != nil { + return err + } + } + p.Base = _base + + return nil +} + +func (p *UploadFileForServerResponse) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.I32 { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 3: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField3(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 255: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField255(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_UploadFileForServerResponse[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *UploadFileForServerResponse) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field *int32 + if v, l, err := thrift.Binary.ReadI32(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Code = _field + return offset, nil +} + +func (p *UploadFileForServerResponse) FastReadField2(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Msg = _field + return offset, nil +} + +func (p *UploadFileForServerResponse) FastReadField3(buf []byte) (int, error) { + offset := 0 + _field := NewFileData() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.Data = _field + return offset, nil +} + +func (p *UploadFileForServerResponse) FastReadField255(buf []byte) (int, error) { + offset := 0 + _field := base.NewBaseResp() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.BaseResp = _field + return offset, nil +} + +func (p *UploadFileForServerResponse) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *UploadFileForServerResponse) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) + offset += p.fastWriteField3(buf[offset:], w) + offset += p.fastWriteField255(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *UploadFileForServerResponse) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + l += p.field2Length() + l += p.field3Length() + l += p.field255Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *UploadFileForServerResponse) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetCode() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I32, 1) + offset += thrift.Binary.WriteI32(buf[offset:], *p.Code) + } + return offset +} + +func (p *UploadFileForServerResponse) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetMsg() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 2) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Msg) + } + return offset +} + +func (p *UploadFileForServerResponse) fastWriteField3(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetData() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 3) + offset += p.Data.FastWriteNocopy(buf[offset:], w) + } + return offset +} + +func (p *UploadFileForServerResponse) fastWriteField255(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 255) + offset += p.BaseResp.FastWriteNocopy(buf[offset:], w) + return offset +} + +func (p *UploadFileForServerResponse) field1Length() int { + l := 0 + if p.IsSetCode() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I32Length() + } + return l +} + +func (p *UploadFileForServerResponse) field2Length() int { + l := 0 + if p.IsSetMsg() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Msg) + } + return l +} + +func (p *UploadFileForServerResponse) field3Length() int { + l := 0 + if p.IsSetData() { + l += thrift.Binary.FieldBeginLength() + l += p.Data.BLength() + } + return l +} + +func (p *UploadFileForServerResponse) field255Length() int { + l := 0 + l += thrift.Binary.FieldBeginLength() + l += p.BaseResp.BLength() + return l +} + +func (p *UploadFileForServerResponse) DeepCopy(s interface{}) error { + src, ok := s.(*UploadFileForServerResponse) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.Code != nil { + tmp := *src.Code + p.Code = &tmp + } + + if src.Msg != nil { + var tmp string + if *src.Msg != "" { + tmp = kutils.StringDeepCopy(*src.Msg) + } + p.Msg = &tmp + } + + var _data *FileData + if src.Data != nil { + _data = &FileData{} + if err := _data.DeepCopy(src.Data); err != nil { + return err + } + } + p.Data = _data + + var _baseResp *base.BaseResp + if src.BaseResp != nil { + _baseResp = &base.BaseResp{} + if err := _baseResp.DeepCopy(src.BaseResp); err != nil { + return err + } + } + p.BaseResp = _baseResp + + return nil +} + func (p *FileServiceUploadLoopFileInnerArgs) FastRead(buf []byte) (int, error) { var err error @@ -3124,6 +3961,240 @@ func (p *FileServiceUploadLoopFileInnerResult) DeepCopy(s interface{}) error { return nil } +func (p *FileServiceUploadFileForServerArgs) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_FileServiceUploadFileForServerArgs[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *FileServiceUploadFileForServerArgs) FastReadField1(buf []byte) (int, error) { + offset := 0 + _field := NewUploadFileForServerRequest() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.Req = _field + return offset, nil +} + +func (p *FileServiceUploadFileForServerArgs) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *FileServiceUploadFileForServerArgs) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *FileServiceUploadFileForServerArgs) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *FileServiceUploadFileForServerArgs) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 1) + offset += p.Req.FastWriteNocopy(buf[offset:], w) + return offset +} + +func (p *FileServiceUploadFileForServerArgs) field1Length() int { + l := 0 + l += thrift.Binary.FieldBeginLength() + l += p.Req.BLength() + return l +} + +func (p *FileServiceUploadFileForServerArgs) DeepCopy(s interface{}) error { + src, ok := s.(*FileServiceUploadFileForServerArgs) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + var _req *UploadFileForServerRequest + if src.Req != nil { + _req = &UploadFileForServerRequest{} + if err := _req.DeepCopy(src.Req); err != nil { + return err + } + } + p.Req = _req + + return nil +} + +func (p *FileServiceUploadFileForServerResult) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 0: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField0(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_FileServiceUploadFileForServerResult[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *FileServiceUploadFileForServerResult) FastReadField0(buf []byte) (int, error) { + offset := 0 + _field := NewUploadFileForServerResponse() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.Success = _field + return offset, nil +} + +func (p *FileServiceUploadFileForServerResult) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *FileServiceUploadFileForServerResult) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField0(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *FileServiceUploadFileForServerResult) BLength() int { + l := 0 + if p != nil { + l += p.field0Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *FileServiceUploadFileForServerResult) fastWriteField0(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetSuccess() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 0) + offset += p.Success.FastWriteNocopy(buf[offset:], w) + } + return offset +} + +func (p *FileServiceUploadFileForServerResult) field0Length() int { + l := 0 + if p.IsSetSuccess() { + l += thrift.Binary.FieldBeginLength() + l += p.Success.BLength() + } + return l +} + +func (p *FileServiceUploadFileForServerResult) DeepCopy(s interface{}) error { + src, ok := s.(*FileServiceUploadFileForServerResult) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + var _success *UploadFileForServerResponse + if src.Success != nil { + _success = &UploadFileForServerResponse{} + if err := _success.DeepCopy(src.Success); err != nil { + return err + } + } + p.Success = _success + + return nil +} + func (p *FileServiceSignUploadFileArgs) FastRead(buf []byte) (int, error) { var err error @@ -3600,6 +4671,14 @@ func (p *FileServiceUploadLoopFileInnerResult) GetResult() interface{} { return p.Success } +func (p *FileServiceUploadFileForServerArgs) GetFirstArgument() interface{} { + return p.Req +} + +func (p *FileServiceUploadFileForServerResult) GetResult() interface{} { + return p.Success +} + func (p *FileServiceSignUploadFileArgs) GetFirstArgument() interface{} { return p.Req } diff --git a/backend/kitex_gen/coze/loop/foundation/fileservice/client.go b/backend/kitex_gen/coze/loop/foundation/fileservice/client.go index 23e9a4e3f..ceda72c8b 100644 --- a/backend/kitex_gen/coze/loop/foundation/fileservice/client.go +++ b/backend/kitex_gen/coze/loop/foundation/fileservice/client.go @@ -12,6 +12,7 @@ import ( // Client is designed to provide IDL-compatible methods with call-option parameter for kitex framework. type Client interface { UploadLoopFileInner(ctx context.Context, req *file.UploadLoopFileInnerRequest, callOptions ...callopt.Option) (r *file.UploadLoopFileInnerResponse, err error) + UploadFileForServer(ctx context.Context, req *file.UploadFileForServerRequest, callOptions ...callopt.Option) (r *file.UploadFileForServerResponse, err error) SignUploadFile(ctx context.Context, req *file.SignUploadFileRequest, callOptions ...callopt.Option) (r *file.SignUploadFileResponse, err error) SignDownloadFile(ctx context.Context, req *file.SignDownloadFileRequest, callOptions ...callopt.Option) (r *file.SignDownloadFileResponse, err error) } @@ -50,6 +51,11 @@ func (p *kFileServiceClient) UploadLoopFileInner(ctx context.Context, req *file. return p.kClient.UploadLoopFileInner(ctx, req) } +func (p *kFileServiceClient) UploadFileForServer(ctx context.Context, req *file.UploadFileForServerRequest, callOptions ...callopt.Option) (r *file.UploadFileForServerResponse, err error) { + ctx = client.NewCtxWithCallOptions(ctx, callOptions) + return p.kClient.UploadFileForServer(ctx, req) +} + func (p *kFileServiceClient) SignUploadFile(ctx context.Context, req *file.SignUploadFileRequest, callOptions ...callopt.Option) (r *file.SignUploadFileResponse, err error) { ctx = client.NewCtxWithCallOptions(ctx, callOptions) return p.kClient.SignUploadFile(ctx, req) diff --git a/backend/kitex_gen/coze/loop/foundation/fileservice/fileservice.go b/backend/kitex_gen/coze/loop/foundation/fileservice/fileservice.go index c692e829b..f51db04e9 100644 --- a/backend/kitex_gen/coze/loop/foundation/fileservice/fileservice.go +++ b/backend/kitex_gen/coze/loop/foundation/fileservice/fileservice.go @@ -21,6 +21,13 @@ var serviceMethods = map[string]kitex.MethodInfo{ false, kitex.WithStreamingMode(kitex.StreamingNone), ), + "UploadFileForServer": kitex.NewMethodInfo( + uploadFileForServerHandler, + newFileServiceUploadFileForServerArgs, + newFileServiceUploadFileForServerResult, + false, + kitex.WithStreamingMode(kitex.StreamingNone), + ), "SignUploadFile": kitex.NewMethodInfo( signUploadFileHandler, newFileServiceSignUploadFileArgs, @@ -87,6 +94,25 @@ func newFileServiceUploadLoopFileInnerResult() interface{} { return file.NewFileServiceUploadLoopFileInnerResult() } +func uploadFileForServerHandler(ctx context.Context, handler interface{}, arg, result interface{}) error { + realArg := arg.(*file.FileServiceUploadFileForServerArgs) + realResult := result.(*file.FileServiceUploadFileForServerResult) + success, err := handler.(file.FileService).UploadFileForServer(ctx, realArg.Req) + if err != nil { + return err + } + realResult.Success = success + return nil +} + +func newFileServiceUploadFileForServerArgs() interface{} { + return file.NewFileServiceUploadFileForServerArgs() +} + +func newFileServiceUploadFileForServerResult() interface{} { + return file.NewFileServiceUploadFileForServerResult() +} + func signUploadFileHandler(ctx context.Context, handler interface{}, arg, result interface{}) error { realArg := arg.(*file.FileServiceSignUploadFileArgs) realResult := result.(*file.FileServiceSignUploadFileResult) @@ -147,6 +173,16 @@ func (p *kClient) UploadLoopFileInner(ctx context.Context, req *file.UploadLoopF return _result.GetSuccess(), nil } +func (p *kClient) UploadFileForServer(ctx context.Context, req *file.UploadFileForServerRequest) (r *file.UploadFileForServerResponse, err error) { + var _args file.FileServiceUploadFileForServerArgs + _args.Req = req + var _result file.FileServiceUploadFileForServerResult + if err = p.c.Call(ctx, "UploadFileForServer", &_args, &_result); err != nil { + return + } + return _result.GetSuccess(), nil +} + func (p *kClient) SignUploadFile(ctx context.Context, req *file.SignUploadFileRequest) (r *file.SignUploadFileResponse, err error) { var _args file.FileServiceSignUploadFileArgs _args.Req = req diff --git a/backend/kitex_gen/coze/loop/llm/domain/runtime/k-runtime.go b/backend/kitex_gen/coze/loop/llm/domain/runtime/k-runtime.go index a977efdd3..0fea0fa18 100644 --- a/backend/kitex_gen/coze/loop/llm/domain/runtime/k-runtime.go +++ b/backend/kitex_gen/coze/loop/llm/domain/runtime/k-runtime.go @@ -1221,6 +1221,20 @@ func (p *ChatMessagePart) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 5: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField5(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } default: l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l @@ -1279,6 +1293,18 @@ func (p *ChatMessagePart) FastReadField3(buf []byte) (int, error) { return offset, nil } +func (p *ChatMessagePart) FastReadField5(buf []byte) (int, error) { + offset := 0 + _field := NewChatMessageVideoURL() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.VideoURL = _field + return offset, nil +} + func (p *ChatMessagePart) FastWrite(buf []byte) int { return p.FastWriteNocopy(buf, nil) } @@ -1289,6 +1315,7 @@ func (p *ChatMessagePart) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int offset += p.fastWriteField1(buf[offset:], w) offset += p.fastWriteField2(buf[offset:], w) offset += p.fastWriteField3(buf[offset:], w) + offset += p.fastWriteField5(buf[offset:], w) } offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset @@ -1300,6 +1327,7 @@ func (p *ChatMessagePart) BLength() int { l += p.field1Length() l += p.field2Length() l += p.field3Length() + l += p.field5Length() } l += thrift.Binary.FieldStopLength() return l @@ -1332,6 +1360,15 @@ func (p *ChatMessagePart) fastWriteField3(buf []byte, w thrift.NocopyWriter) int return offset } +func (p *ChatMessagePart) fastWriteField5(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetVideoURL() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 5) + offset += p.VideoURL.FastWriteNocopy(buf[offset:], w) + } + return offset +} + func (p *ChatMessagePart) field1Length() int { l := 0 if p.IsSetType() { @@ -1359,6 +1396,15 @@ func (p *ChatMessagePart) field3Length() int { return l } +func (p *ChatMessagePart) field5Length() int { + l := 0 + if p.IsSetVideoURL() { + l += thrift.Binary.FieldBeginLength() + l += p.VideoURL.BLength() + } + return l +} + func (p *ChatMessagePart) DeepCopy(s interface{}) error { src, ok := s.(*ChatMessagePart) if !ok { @@ -1387,6 +1433,363 @@ func (p *ChatMessagePart) DeepCopy(s interface{}) error { } p.ImageURL = _imageURL + var _videoURL *ChatMessageVideoURL + if src.VideoURL != nil { + _videoURL = &ChatMessageVideoURL{} + if err := _videoURL.DeepCopy(src.VideoURL); err != nil { + return err + } + } + p.VideoURL = _videoURL + + return nil +} + +func (p *ChatMessageVideoURL) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 3: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField3(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ChatMessageVideoURL[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *ChatMessageVideoURL) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.URL = _field + return offset, nil +} + +func (p *ChatMessageVideoURL) FastReadField2(buf []byte) (int, error) { + offset := 0 + _field := NewVideoURLDetail() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.Detail = _field + return offset, nil +} + +func (p *ChatMessageVideoURL) FastReadField3(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.MimeType = _field + return offset, nil +} + +func (p *ChatMessageVideoURL) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *ChatMessageVideoURL) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) + offset += p.fastWriteField3(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *ChatMessageVideoURL) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + l += p.field2Length() + l += p.field3Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *ChatMessageVideoURL) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetURL() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 1) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.URL) + } + return offset +} + +func (p *ChatMessageVideoURL) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetDetail() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 2) + offset += p.Detail.FastWriteNocopy(buf[offset:], w) + } + return offset +} + +func (p *ChatMessageVideoURL) fastWriteField3(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetMimeType() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 3) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.MimeType) + } + return offset +} + +func (p *ChatMessageVideoURL) field1Length() int { + l := 0 + if p.IsSetURL() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.URL) + } + return l +} + +func (p *ChatMessageVideoURL) field2Length() int { + l := 0 + if p.IsSetDetail() { + l += thrift.Binary.FieldBeginLength() + l += p.Detail.BLength() + } + return l +} + +func (p *ChatMessageVideoURL) field3Length() int { + l := 0 + if p.IsSetMimeType() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.MimeType) + } + return l +} + +func (p *ChatMessageVideoURL) DeepCopy(s interface{}) error { + src, ok := s.(*ChatMessageVideoURL) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.URL != nil { + var tmp string + if *src.URL != "" { + tmp = kutils.StringDeepCopy(*src.URL) + } + p.URL = &tmp + } + + var _detail *VideoURLDetail + if src.Detail != nil { + _detail = &VideoURLDetail{} + if err := _detail.DeepCopy(src.Detail); err != nil { + return err + } + } + p.Detail = _detail + + if src.MimeType != nil { + var tmp string + if *src.MimeType != "" { + tmp = kutils.StringDeepCopy(*src.MimeType) + } + p.MimeType = &tmp + } + + return nil +} + +func (p *VideoURLDetail) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.DOUBLE { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_VideoURLDetail[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *VideoURLDetail) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field *float64 + if v, l, err := thrift.Binary.ReadDouble(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Fps = _field + return offset, nil +} + +func (p *VideoURLDetail) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *VideoURLDetail) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *VideoURLDetail) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *VideoURLDetail) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetFps() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.DOUBLE, 1) + offset += thrift.Binary.WriteDouble(buf[offset:], *p.Fps) + } + return offset +} + +func (p *VideoURLDetail) field1Length() int { + l := 0 + if p.IsSetFps() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.DoubleLength() + } + return l +} + +func (p *VideoURLDetail) DeepCopy(s interface{}) error { + src, ok := s.(*VideoURLDetail) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.Fps != nil { + tmp := *src.Fps + p.Fps = &tmp + } + return nil } diff --git a/backend/kitex_gen/coze/loop/llm/domain/runtime/runtime.go b/backend/kitex_gen/coze/loop/llm/domain/runtime/runtime.go index e4129248d..84213022a 100644 --- a/backend/kitex_gen/coze/loop/llm/domain/runtime/runtime.go +++ b/backend/kitex_gen/coze/loop/llm/domain/runtime/runtime.go @@ -35,12 +35,18 @@ const ( ChatMessagePartTypeText = "text" ChatMessagePartTypeImageURL = "image_url" + // const ChatMessagePartType chat_message_part_type_audio_url = "audio_url" + ChatMessagePartTypeVideoURL = "video_url" ImageURLDetailAuto = "auto" ImageURLDetailLow = "low" ImageURLDetailHigh = "high" + + MimePrefixImage = "image/" + + MimePrefixVideo = "video/" ) type ResponseFormatType = string @@ -55,11 +61,11 @@ type ToolType = string type ChatMessagePartType = string -// const ChatMessagePartType chat_message_part_type_audio_url = "audio_url" -// const ChatMessagePartType chat_message_part_type_video_url = "video_url" // const ChatMessagePartType chat_message_part_type_file_url = "file_url" type ImageURLDetail = string +type MimeTypePrefix = string + type ModelConfig struct { // 模型id ModelID int64 `thrift:"model_id,1,required" frugal:"1,required,i64" json:"model_id" form:"model_id,required" query:"model_id,required"` @@ -1618,6 +1624,8 @@ type ChatMessagePart struct { Type *ChatMessagePartType `thrift:"type,1,optional" frugal:"1,optional,string" form:"type" json:"type,omitempty" query:"type"` Text *string `thrift:"text,2,optional" frugal:"2,optional,string" form:"text" json:"text,omitempty" query:"text"` ImageURL *ChatMessageImageURL `thrift:"image_url,3,optional" frugal:"3,optional,ChatMessageImageURL" form:"image_url" json:"image_url,omitempty" query:"image_url"` + // 4: optional ChatMessageAudioURL audio_url 占位,暂不支持 + VideoURL *ChatMessageVideoURL `thrift:"video_url,5,optional" frugal:"5,optional,ChatMessageVideoURL" form:"video_url" json:"video_url,omitempty" query:"video_url"` } func NewChatMessagePart() *ChatMessagePart { @@ -1662,6 +1670,18 @@ func (p *ChatMessagePart) GetImageURL() (v *ChatMessageImageURL) { } return p.ImageURL } + +var ChatMessagePart_VideoURL_DEFAULT *ChatMessageVideoURL + +func (p *ChatMessagePart) GetVideoURL() (v *ChatMessageVideoURL) { + if p == nil { + return + } + if !p.IsSetVideoURL() { + return ChatMessagePart_VideoURL_DEFAULT + } + return p.VideoURL +} func (p *ChatMessagePart) SetType(val *ChatMessagePartType) { p.Type = val } @@ -1671,11 +1691,15 @@ func (p *ChatMessagePart) SetText(val *string) { func (p *ChatMessagePart) SetImageURL(val *ChatMessageImageURL) { p.ImageURL = val } +func (p *ChatMessagePart) SetVideoURL(val *ChatMessageVideoURL) { + p.VideoURL = val +} var fieldIDToName_ChatMessagePart = map[int16]string{ 1: "type", 2: "text", 3: "image_url", + 5: "video_url", } func (p *ChatMessagePart) IsSetType() bool { @@ -1690,6 +1714,10 @@ func (p *ChatMessagePart) IsSetImageURL() bool { return p.ImageURL != nil } +func (p *ChatMessagePart) IsSetVideoURL() bool { + return p.VideoURL != nil +} + func (p *ChatMessagePart) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -1732,6 +1760,14 @@ func (p *ChatMessagePart) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 5: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField5(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError @@ -1791,6 +1827,14 @@ func (p *ChatMessagePart) ReadField3(iprot thrift.TProtocol) error { p.ImageURL = _field return nil } +func (p *ChatMessagePart) ReadField5(iprot thrift.TProtocol) error { + _field := NewChatMessageVideoURL() + if err := _field.Read(iprot); err != nil { + return err + } + p.VideoURL = _field + return nil +} func (p *ChatMessagePart) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 @@ -1810,6 +1854,10 @@ func (p *ChatMessagePart) Write(oprot thrift.TProtocol) (err error) { fieldId = 3 goto WriteFieldError } + if err = p.writeField5(oprot); err != nil { + fieldId = 5 + goto WriteFieldError + } } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -1882,6 +1930,24 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) } +func (p *ChatMessagePart) writeField5(oprot thrift.TProtocol) (err error) { + if p.IsSetVideoURL() { + if err = oprot.WriteFieldBegin("video_url", thrift.STRUCT, 5); err != nil { + goto WriteFieldBeginError + } + if err := p.VideoURL.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 5 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 5 end error: ", p), err) +} func (p *ChatMessagePart) String() string { if p == nil { @@ -1906,6 +1972,9 @@ func (p *ChatMessagePart) DeepEqual(ano *ChatMessagePart) bool { if !p.Field3DeepEqual(ano.ImageURL) { return false } + if !p.Field5DeepEqual(ano.VideoURL) { + return false + } return true } @@ -1940,6 +2009,521 @@ func (p *ChatMessagePart) Field3DeepEqual(src *ChatMessageImageURL) bool { } return true } +func (p *ChatMessagePart) Field5DeepEqual(src *ChatMessageVideoURL) bool { + + if !p.VideoURL.DeepEqual(src) { + return false + } + return true +} + +type ChatMessageVideoURL struct { + URL *string `thrift:"url,1,optional" frugal:"1,optional,string" form:"url" json:"url,omitempty" query:"url"` + Detail *VideoURLDetail `thrift:"detail,2,optional" frugal:"2,optional,VideoURLDetail" form:"detail" json:"detail,omitempty" query:"detail"` + MimeType *string `thrift:"mime_type,3,optional" frugal:"3,optional,string" form:"mime_type" json:"mime_type,omitempty" query:"mime_type"` +} + +func NewChatMessageVideoURL() *ChatMessageVideoURL { + return &ChatMessageVideoURL{} +} + +func (p *ChatMessageVideoURL) InitDefault() { +} + +var ChatMessageVideoURL_URL_DEFAULT string + +func (p *ChatMessageVideoURL) GetURL() (v string) { + if p == nil { + return + } + if !p.IsSetURL() { + return ChatMessageVideoURL_URL_DEFAULT + } + return *p.URL +} + +var ChatMessageVideoURL_Detail_DEFAULT *VideoURLDetail + +func (p *ChatMessageVideoURL) GetDetail() (v *VideoURLDetail) { + if p == nil { + return + } + if !p.IsSetDetail() { + return ChatMessageVideoURL_Detail_DEFAULT + } + return p.Detail +} + +var ChatMessageVideoURL_MimeType_DEFAULT string + +func (p *ChatMessageVideoURL) GetMimeType() (v string) { + if p == nil { + return + } + if !p.IsSetMimeType() { + return ChatMessageVideoURL_MimeType_DEFAULT + } + return *p.MimeType +} +func (p *ChatMessageVideoURL) SetURL(val *string) { + p.URL = val +} +func (p *ChatMessageVideoURL) SetDetail(val *VideoURLDetail) { + p.Detail = val +} +func (p *ChatMessageVideoURL) SetMimeType(val *string) { + p.MimeType = val +} + +var fieldIDToName_ChatMessageVideoURL = map[int16]string{ + 1: "url", + 2: "detail", + 3: "mime_type", +} + +func (p *ChatMessageVideoURL) IsSetURL() bool { + return p.URL != nil +} + +func (p *ChatMessageVideoURL) IsSetDetail() bool { + return p.Detail != nil +} + +func (p *ChatMessageVideoURL) IsSetMimeType() bool { + return p.MimeType != nil +} + +func (p *ChatMessageVideoURL) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 2: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 3: + if fieldTypeId == thrift.STRING { + if err = p.ReadField3(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ChatMessageVideoURL[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *ChatMessageVideoURL) ReadField1(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.URL = _field + return nil +} +func (p *ChatMessageVideoURL) ReadField2(iprot thrift.TProtocol) error { + _field := NewVideoURLDetail() + if err := _field.Read(iprot); err != nil { + return err + } + p.Detail = _field + return nil +} +func (p *ChatMessageVideoURL) ReadField3(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.MimeType = _field + return nil +} + +func (p *ChatMessageVideoURL) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("ChatMessageVideoURL"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } + if err = p.writeField3(oprot); err != nil { + fieldId = 3 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *ChatMessageVideoURL) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetURL() { + if err = oprot.WriteFieldBegin("url", thrift.STRING, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.URL); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} +func (p *ChatMessageVideoURL) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetDetail() { + if err = oprot.WriteFieldBegin("detail", thrift.STRUCT, 2); err != nil { + goto WriteFieldBeginError + } + if err := p.Detail.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} +func (p *ChatMessageVideoURL) writeField3(oprot thrift.TProtocol) (err error) { + if p.IsSetMimeType() { + if err = oprot.WriteFieldBegin("mime_type", thrift.STRING, 3); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.MimeType); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) +} + +func (p *ChatMessageVideoURL) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("ChatMessageVideoURL(%+v)", *p) + +} + +func (p *ChatMessageVideoURL) DeepEqual(ano *ChatMessageVideoURL) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.URL) { + return false + } + if !p.Field2DeepEqual(ano.Detail) { + return false + } + if !p.Field3DeepEqual(ano.MimeType) { + return false + } + return true +} + +func (p *ChatMessageVideoURL) Field1DeepEqual(src *string) bool { + + if p.URL == src { + return true + } else if p.URL == nil || src == nil { + return false + } + if strings.Compare(*p.URL, *src) != 0 { + return false + } + return true +} +func (p *ChatMessageVideoURL) Field2DeepEqual(src *VideoURLDetail) bool { + + if !p.Detail.DeepEqual(src) { + return false + } + return true +} +func (p *ChatMessageVideoURL) Field3DeepEqual(src *string) bool { + + if p.MimeType == src { + return true + } else if p.MimeType == nil || src == nil { + return false + } + if strings.Compare(*p.MimeType, *src) != 0 { + return false + } + return true +} + +type VideoURLDetail struct { + Fps *float64 `thrift:"fps,1,optional" frugal:"1,optional,double" form:"fps" json:"fps,omitempty" query:"fps"` +} + +func NewVideoURLDetail() *VideoURLDetail { + return &VideoURLDetail{} +} + +func (p *VideoURLDetail) InitDefault() { +} + +var VideoURLDetail_Fps_DEFAULT float64 + +func (p *VideoURLDetail) GetFps() (v float64) { + if p == nil { + return + } + if !p.IsSetFps() { + return VideoURLDetail_Fps_DEFAULT + } + return *p.Fps +} +func (p *VideoURLDetail) SetFps(val *float64) { + p.Fps = val +} + +var fieldIDToName_VideoURLDetail = map[int16]string{ + 1: "fps", +} + +func (p *VideoURLDetail) IsSetFps() bool { + return p.Fps != nil +} + +func (p *VideoURLDetail) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.DOUBLE { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_VideoURLDetail[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *VideoURLDetail) ReadField1(iprot thrift.TProtocol) error { + + var _field *float64 + if v, err := iprot.ReadDouble(); err != nil { + return err + } else { + _field = &v + } + p.Fps = _field + return nil +} + +func (p *VideoURLDetail) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("VideoURLDetail"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *VideoURLDetail) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetFps() { + if err = oprot.WriteFieldBegin("fps", thrift.DOUBLE, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteDouble(*p.Fps); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} + +func (p *VideoURLDetail) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("VideoURLDetail(%+v)", *p) + +} + +func (p *VideoURLDetail) DeepEqual(ano *VideoURLDetail) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.Fps) { + return false + } + return true +} + +func (p *VideoURLDetail) Field1DeepEqual(src *float64) bool { + + if p.Fps == src { + return true + } else if p.Fps == nil || src == nil { + return false + } + if *p.Fps != *src { + return false + } + return true +} type ChatMessageImageURL struct { URL *string `thrift:"url,1,optional" frugal:"1,optional,string" form:"url" json:"url,omitempty" query:"url"` diff --git a/backend/kitex_gen/coze/loop/llm/domain/runtime/runtime_validator.go b/backend/kitex_gen/coze/loop/llm/domain/runtime/runtime_validator.go index 9be9670f5..549088e8d 100644 --- a/backend/kitex_gen/coze/loop/llm/domain/runtime/runtime_validator.go +++ b/backend/kitex_gen/coze/loop/llm/domain/runtime/runtime_validator.go @@ -43,6 +43,30 @@ func (p *ChatMessagePart) IsValid() error { return fmt.Errorf("field ImageURL not valid, %w", err) } } + if p.VideoURL != nil { + if err := p.VideoURL.IsValid(); err != nil { + return fmt.Errorf("field VideoURL not valid, %w", err) + } + } + return nil +} +func (p *ChatMessageVideoURL) IsValid() error { + if p.Detail != nil { + if err := p.Detail.IsValid(); err != nil { + return fmt.Errorf("field Detail not valid, %w", err) + } + } + return nil +} +func (p *VideoURLDetail) IsValid() error { + if p.Fps != nil { + if *p.Fps < float64(0.2) { + return fmt.Errorf("field Fps ge rule failed, current value: %v", *p.Fps) + } + if *p.Fps > float64(5) { + return fmt.Errorf("field Fps le rule failed, current value: %v", *p.Fps) + } + } return nil } func (p *ChatMessageImageURL) IsValid() error { diff --git a/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go b/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go index bf6daead7..874e621b7 100644 --- a/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go +++ b/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go @@ -4354,6 +4354,34 @@ func (p *ContentPart) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 4: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField4(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 5: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField5(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } default: l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l @@ -4412,6 +4440,30 @@ func (p *ContentPart) FastReadField3(buf []byte) (int, error) { return offset, nil } +func (p *ContentPart) FastReadField4(buf []byte) (int, error) { + offset := 0 + _field := NewVideoURL() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.VideoURL = _field + return offset, nil +} + +func (p *ContentPart) FastReadField5(buf []byte) (int, error) { + offset := 0 + _field := NewMediaConfig() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.MediaConfig = _field + return offset, nil +} + func (p *ContentPart) FastWrite(buf []byte) int { return p.FastWriteNocopy(buf, nil) } @@ -4422,6 +4474,8 @@ func (p *ContentPart) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset += p.fastWriteField1(buf[offset:], w) offset += p.fastWriteField2(buf[offset:], w) offset += p.fastWriteField3(buf[offset:], w) + offset += p.fastWriteField4(buf[offset:], w) + offset += p.fastWriteField5(buf[offset:], w) } offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset @@ -4433,6 +4487,8 @@ func (p *ContentPart) BLength() int { l += p.field1Length() l += p.field2Length() l += p.field3Length() + l += p.field4Length() + l += p.field5Length() } l += thrift.Binary.FieldStopLength() return l @@ -4465,6 +4521,24 @@ func (p *ContentPart) fastWriteField3(buf []byte, w thrift.NocopyWriter) int { return offset } +func (p *ContentPart) fastWriteField4(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetVideoURL() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 4) + offset += p.VideoURL.FastWriteNocopy(buf[offset:], w) + } + return offset +} + +func (p *ContentPart) fastWriteField5(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetMediaConfig() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 5) + offset += p.MediaConfig.FastWriteNocopy(buf[offset:], w) + } + return offset +} + func (p *ContentPart) field1Length() int { l := 0 if p.IsSetType() { @@ -4492,6 +4566,24 @@ func (p *ContentPart) field3Length() int { return l } +func (p *ContentPart) field4Length() int { + l := 0 + if p.IsSetVideoURL() { + l += thrift.Binary.FieldBeginLength() + l += p.VideoURL.BLength() + } + return l +} + +func (p *ContentPart) field5Length() int { + l := 0 + if p.IsSetMediaConfig() { + l += thrift.Binary.FieldBeginLength() + l += p.MediaConfig.BLength() + } + return l +} + func (p *ContentPart) DeepCopy(s interface{}) error { src, ok := s.(*ContentPart) if !ok { @@ -4520,6 +4612,24 @@ func (p *ContentPart) DeepCopy(s interface{}) error { } p.ImageURL = _imageURL + var _videoURL *VideoURL + if src.VideoURL != nil { + _videoURL = &VideoURL{} + if err := _videoURL.DeepCopy(src.VideoURL); err != nil { + return err + } + } + p.VideoURL = _videoURL + + var _mediaConfig *MediaConfig + if src.MediaConfig != nil { + _mediaConfig = &MediaConfig{} + if err := _mediaConfig.DeepCopy(src.MediaConfig); err != nil { + return err + } + } + p.MediaConfig = _mediaConfig + return nil } @@ -4699,6 +4809,299 @@ func (p *ImageURL) DeepCopy(s interface{}) error { return nil } +func (p *VideoURL) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_VideoURL[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *VideoURL) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.URL = _field + return offset, nil +} + +func (p *VideoURL) FastReadField2(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.URI = _field + return offset, nil +} + +func (p *VideoURL) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *VideoURL) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *VideoURL) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + l += p.field2Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *VideoURL) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetURL() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 1) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.URL) + } + return offset +} + +func (p *VideoURL) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetURI() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 2) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.URI) + } + return offset +} + +func (p *VideoURL) field1Length() int { + l := 0 + if p.IsSetURL() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.URL) + } + return l +} + +func (p *VideoURL) field2Length() int { + l := 0 + if p.IsSetURI() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.URI) + } + return l +} + +func (p *VideoURL) DeepCopy(s interface{}) error { + src, ok := s.(*VideoURL) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.URL != nil { + var tmp string + if *src.URL != "" { + tmp = kutils.StringDeepCopy(*src.URL) + } + p.URL = &tmp + } + + if src.URI != nil { + var tmp string + if *src.URI != "" { + tmp = kutils.StringDeepCopy(*src.URI) + } + p.URI = &tmp + } + + return nil +} + +func (p *MediaConfig) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.DOUBLE { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MediaConfig[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *MediaConfig) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field *float64 + if v, l, err := thrift.Binary.ReadDouble(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Fps = _field + return offset, nil +} + +func (p *MediaConfig) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *MediaConfig) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *MediaConfig) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *MediaConfig) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetFps() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.DOUBLE, 1) + offset += thrift.Binary.WriteDouble(buf[offset:], *p.Fps) + } + return offset +} + +func (p *MediaConfig) field1Length() int { + l := 0 + if p.IsSetFps() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.DoubleLength() + } + return l +} + +func (p *MediaConfig) DeepCopy(s interface{}) error { + src, ok := s.(*MediaConfig) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.Fps != nil { + tmp := *src.Fps + p.Fps = &tmp + } + + return nil +} + func (p *ToolCall) FastRead(buf []byte) (int, error) { var err error diff --git a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go index 6476b21cb..cb2c99b73 100644 --- a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go +++ b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go @@ -33,6 +33,8 @@ const ( ContentTypeImageURL = "image_url" + ContentTypeVideoURL = "video_url" + ContentTypeMultiPartVariable = "multi_part_variable" VariableTypeString = "string" @@ -6004,9 +6006,11 @@ func (p *Message) Field100DeepEqual(src map[string]string) bool { } type ContentPart struct { - Type *ContentType `thrift:"type,1,optional" frugal:"1,optional,string" form:"type" json:"type,omitempty" query:"type"` - Text *string `thrift:"text,2,optional" frugal:"2,optional,string" form:"text" json:"text,omitempty" query:"text"` - ImageURL *ImageURL `thrift:"image_url,3,optional" frugal:"3,optional,ImageURL" form:"image_url" json:"image_url,omitempty" query:"image_url"` + Type *ContentType `thrift:"type,1,optional" frugal:"1,optional,string" form:"type" json:"type,omitempty" query:"type"` + Text *string `thrift:"text,2,optional" frugal:"2,optional,string" form:"text" json:"text,omitempty" query:"text"` + ImageURL *ImageURL `thrift:"image_url,3,optional" frugal:"3,optional,ImageURL" form:"image_url" json:"image_url,omitempty" query:"image_url"` + VideoURL *VideoURL `thrift:"video_url,4,optional" frugal:"4,optional,VideoURL" form:"video_url" json:"video_url,omitempty" query:"video_url"` + MediaConfig *MediaConfig `thrift:"media_config,5,optional" frugal:"5,optional,MediaConfig" form:"media_config" json:"media_config,omitempty" query:"media_config"` } func NewContentPart() *ContentPart { @@ -6051,6 +6055,30 @@ func (p *ContentPart) GetImageURL() (v *ImageURL) { } return p.ImageURL } + +var ContentPart_VideoURL_DEFAULT *VideoURL + +func (p *ContentPart) GetVideoURL() (v *VideoURL) { + if p == nil { + return + } + if !p.IsSetVideoURL() { + return ContentPart_VideoURL_DEFAULT + } + return p.VideoURL +} + +var ContentPart_MediaConfig_DEFAULT *MediaConfig + +func (p *ContentPart) GetMediaConfig() (v *MediaConfig) { + if p == nil { + return + } + if !p.IsSetMediaConfig() { + return ContentPart_MediaConfig_DEFAULT + } + return p.MediaConfig +} func (p *ContentPart) SetType(val *ContentType) { p.Type = val } @@ -6060,11 +6088,19 @@ func (p *ContentPart) SetText(val *string) { func (p *ContentPart) SetImageURL(val *ImageURL) { p.ImageURL = val } +func (p *ContentPart) SetVideoURL(val *VideoURL) { + p.VideoURL = val +} +func (p *ContentPart) SetMediaConfig(val *MediaConfig) { + p.MediaConfig = val +} var fieldIDToName_ContentPart = map[int16]string{ 1: "type", 2: "text", 3: "image_url", + 4: "video_url", + 5: "media_config", } func (p *ContentPart) IsSetType() bool { @@ -6079,6 +6115,14 @@ func (p *ContentPart) IsSetImageURL() bool { return p.ImageURL != nil } +func (p *ContentPart) IsSetVideoURL() bool { + return p.VideoURL != nil +} + +func (p *ContentPart) IsSetMediaConfig() bool { + return p.MediaConfig != nil +} + func (p *ContentPart) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -6121,6 +6165,22 @@ func (p *ContentPart) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 4: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField4(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 5: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField5(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError @@ -6180,6 +6240,22 @@ func (p *ContentPart) ReadField3(iprot thrift.TProtocol) error { p.ImageURL = _field return nil } +func (p *ContentPart) ReadField4(iprot thrift.TProtocol) error { + _field := NewVideoURL() + if err := _field.Read(iprot); err != nil { + return err + } + p.VideoURL = _field + return nil +} +func (p *ContentPart) ReadField5(iprot thrift.TProtocol) error { + _field := NewMediaConfig() + if err := _field.Read(iprot); err != nil { + return err + } + p.MediaConfig = _field + return nil +} func (p *ContentPart) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 @@ -6199,6 +6275,14 @@ func (p *ContentPart) Write(oprot thrift.TProtocol) (err error) { fieldId = 3 goto WriteFieldError } + if err = p.writeField4(oprot); err != nil { + fieldId = 4 + goto WriteFieldError + } + if err = p.writeField5(oprot); err != nil { + fieldId = 5 + goto WriteFieldError + } } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -6271,6 +6355,42 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) } +func (p *ContentPart) writeField4(oprot thrift.TProtocol) (err error) { + if p.IsSetVideoURL() { + if err = oprot.WriteFieldBegin("video_url", thrift.STRUCT, 4); err != nil { + goto WriteFieldBeginError + } + if err := p.VideoURL.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 end error: ", p), err) +} +func (p *ContentPart) writeField5(oprot thrift.TProtocol) (err error) { + if p.IsSetMediaConfig() { + if err = oprot.WriteFieldBegin("media_config", thrift.STRUCT, 5); err != nil { + goto WriteFieldBeginError + } + if err := p.MediaConfig.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 5 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 5 end error: ", p), err) +} func (p *ContentPart) String() string { if p == nil { @@ -6295,6 +6415,12 @@ func (p *ContentPart) DeepEqual(ano *ContentPart) bool { if !p.Field3DeepEqual(ano.ImageURL) { return false } + if !p.Field4DeepEqual(ano.VideoURL) { + return false + } + if !p.Field5DeepEqual(ano.MediaConfig) { + return false + } return true } @@ -6329,6 +6455,20 @@ func (p *ContentPart) Field3DeepEqual(src *ImageURL) bool { } return true } +func (p *ContentPart) Field4DeepEqual(src *VideoURL) bool { + + if !p.VideoURL.DeepEqual(src) { + return false + } + return true +} +func (p *ContentPart) Field5DeepEqual(src *MediaConfig) bool { + + if !p.MediaConfig.DeepEqual(src) { + return false + } + return true +} type ImageURL struct { URI *string `thrift:"uri,1,optional" frugal:"1,optional,string" form:"uri" json:"uri,omitempty" query:"uri"` @@ -6588,6 +6728,445 @@ func (p *ImageURL) Field2DeepEqual(src *string) bool { return true } +type VideoURL struct { + URL *string `thrift:"url,1,optional" frugal:"1,optional,string" form:"url" json:"url,omitempty" query:"url"` + URI *string `thrift:"uri,2,optional" frugal:"2,optional,string" form:"uri" json:"uri,omitempty" query:"uri"` +} + +func NewVideoURL() *VideoURL { + return &VideoURL{} +} + +func (p *VideoURL) InitDefault() { +} + +var VideoURL_URL_DEFAULT string + +func (p *VideoURL) GetURL() (v string) { + if p == nil { + return + } + if !p.IsSetURL() { + return VideoURL_URL_DEFAULT + } + return *p.URL +} + +var VideoURL_URI_DEFAULT string + +func (p *VideoURL) GetURI() (v string) { + if p == nil { + return + } + if !p.IsSetURI() { + return VideoURL_URI_DEFAULT + } + return *p.URI +} +func (p *VideoURL) SetURL(val *string) { + p.URL = val +} +func (p *VideoURL) SetURI(val *string) { + p.URI = val +} + +var fieldIDToName_VideoURL = map[int16]string{ + 1: "url", + 2: "uri", +} + +func (p *VideoURL) IsSetURL() bool { + return p.URL != nil +} + +func (p *VideoURL) IsSetURI() bool { + return p.URI != nil +} + +func (p *VideoURL) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 2: + if fieldTypeId == thrift.STRING { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_VideoURL[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *VideoURL) ReadField1(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.URL = _field + return nil +} +func (p *VideoURL) ReadField2(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.URI = _field + return nil +} + +func (p *VideoURL) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("VideoURL"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *VideoURL) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetURL() { + if err = oprot.WriteFieldBegin("url", thrift.STRING, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.URL); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} +func (p *VideoURL) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetURI() { + if err = oprot.WriteFieldBegin("uri", thrift.STRING, 2); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.URI); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} + +func (p *VideoURL) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("VideoURL(%+v)", *p) + +} + +func (p *VideoURL) DeepEqual(ano *VideoURL) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.URL) { + return false + } + if !p.Field2DeepEqual(ano.URI) { + return false + } + return true +} + +func (p *VideoURL) Field1DeepEqual(src *string) bool { + + if p.URL == src { + return true + } else if p.URL == nil || src == nil { + return false + } + if strings.Compare(*p.URL, *src) != 0 { + return false + } + return true +} +func (p *VideoURL) Field2DeepEqual(src *string) bool { + + if p.URI == src { + return true + } else if p.URI == nil || src == nil { + return false + } + if strings.Compare(*p.URI, *src) != 0 { + return false + } + return true +} + +type MediaConfig struct { + Fps *float64 `thrift:"fps,1,optional" frugal:"1,optional,double" form:"fps" json:"fps,omitempty" query:"fps"` +} + +func NewMediaConfig() *MediaConfig { + return &MediaConfig{} +} + +func (p *MediaConfig) InitDefault() { +} + +var MediaConfig_Fps_DEFAULT float64 + +func (p *MediaConfig) GetFps() (v float64) { + if p == nil { + return + } + if !p.IsSetFps() { + return MediaConfig_Fps_DEFAULT + } + return *p.Fps +} +func (p *MediaConfig) SetFps(val *float64) { + p.Fps = val +} + +var fieldIDToName_MediaConfig = map[int16]string{ + 1: "fps", +} + +func (p *MediaConfig) IsSetFps() bool { + return p.Fps != nil +} + +func (p *MediaConfig) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.DOUBLE { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MediaConfig[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *MediaConfig) ReadField1(iprot thrift.TProtocol) error { + + var _field *float64 + if v, err := iprot.ReadDouble(); err != nil { + return err + } else { + _field = &v + } + p.Fps = _field + return nil +} + +func (p *MediaConfig) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("MediaConfig"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *MediaConfig) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetFps() { + if err = oprot.WriteFieldBegin("fps", thrift.DOUBLE, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteDouble(*p.Fps); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} + +func (p *MediaConfig) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("MediaConfig(%+v)", *p) + +} + +func (p *MediaConfig) DeepEqual(ano *MediaConfig) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.Fps) { + return false + } + return true +} + +func (p *MediaConfig) Field1DeepEqual(src *float64) bool { + + if p.Fps == src { + return true + } else if p.Fps == nil || src == nil { + return false + } + if *p.Fps != *src { + return false + } + return true +} + type ToolCall struct { Index *int64 `thrift:"index,1,optional" frugal:"1,optional,i64" json:"index" form:"index" query:"index"` ID *string `thrift:"id,2,optional" frugal:"2,optional,string" form:"id" json:"id,omitempty" query:"id"` diff --git a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt_validator.go b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt_validator.go index f6639cac5..658007ba5 100644 --- a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt_validator.go +++ b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt_validator.go @@ -121,11 +121,35 @@ func (p *ContentPart) IsValid() error { return fmt.Errorf("field ImageURL not valid, %w", err) } } + if p.VideoURL != nil { + if err := p.VideoURL.IsValid(); err != nil { + return fmt.Errorf("field VideoURL not valid, %w", err) + } + } + if p.MediaConfig != nil { + if err := p.MediaConfig.IsValid(); err != nil { + return fmt.Errorf("field MediaConfig not valid, %w", err) + } + } return nil } func (p *ImageURL) IsValid() error { return nil } +func (p *VideoURL) IsValid() error { + return nil +} +func (p *MediaConfig) IsValid() error { + if p.Fps != nil { + if *p.Fps < float64(0.2) { + return fmt.Errorf("field Fps ge rule failed, current value: %v", *p.Fps) + } + if *p.Fps > float64(5) { + return fmt.Errorf("field Fps le rule failed, current value: %v", *p.Fps) + } + } + return nil +} func (p *ToolCall) IsValid() error { if p.FunctionCall != nil { if err := p.FunctionCall.IsValid(); err != nil { diff --git a/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go b/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go index 53c8147f1..18c4cf28d 100644 --- a/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go +++ b/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go @@ -24,6 +24,8 @@ const ( ContentTypeImageURL = "image_url" + ContentTypeVideoURL = "video_url" + ContentTypeBase64Data = "base64_data" ContentTypeMultiPartVariable = "multi_part_variable" @@ -5832,6 +5834,8 @@ type ContentPart struct { Text *string `thrift:"text,2,optional" frugal:"2,optional,string" form:"text" json:"text,omitempty" query:"text"` ImageURL *string `thrift:"image_url,3,optional" frugal:"3,optional,string" form:"image_url" json:"image_url,omitempty" query:"image_url"` Base64Data *string `thrift:"base64_data,4,optional" frugal:"4,optional,string" form:"base64_data" json:"base64_data,omitempty" query:"base64_data"` + VideoURL *string `thrift:"video_url,5,optional" frugal:"5,optional,string" form:"video_url" json:"video_url,omitempty" query:"video_url"` + Config *MediaConfig `thrift:"config,6,optional" frugal:"6,optional,MediaConfig" form:"config" json:"config,omitempty" query:"config"` } func NewContentPart() *ContentPart { @@ -5888,6 +5892,30 @@ func (p *ContentPart) GetBase64Data() (v string) { } return *p.Base64Data } + +var ContentPart_VideoURL_DEFAULT string + +func (p *ContentPart) GetVideoURL() (v string) { + if p == nil { + return + } + if !p.IsSetVideoURL() { + return ContentPart_VideoURL_DEFAULT + } + return *p.VideoURL +} + +var ContentPart_Config_DEFAULT *MediaConfig + +func (p *ContentPart) GetConfig() (v *MediaConfig) { + if p == nil { + return + } + if !p.IsSetConfig() { + return ContentPart_Config_DEFAULT + } + return p.Config +} func (p *ContentPart) SetType(val *ContentType) { p.Type = val } @@ -5900,12 +5928,20 @@ func (p *ContentPart) SetImageURL(val *string) { func (p *ContentPart) SetBase64Data(val *string) { p.Base64Data = val } +func (p *ContentPart) SetVideoURL(val *string) { + p.VideoURL = val +} +func (p *ContentPart) SetConfig(val *MediaConfig) { + p.Config = val +} var fieldIDToName_ContentPart = map[int16]string{ 1: "type", 2: "text", 3: "image_url", 4: "base64_data", + 5: "video_url", + 6: "config", } func (p *ContentPart) IsSetType() bool { @@ -5924,6 +5960,14 @@ func (p *ContentPart) IsSetBase64Data() bool { return p.Base64Data != nil } +func (p *ContentPart) IsSetVideoURL() bool { + return p.VideoURL != nil +} + +func (p *ContentPart) IsSetConfig() bool { + return p.Config != nil +} + func (p *ContentPart) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -5974,6 +6018,22 @@ func (p *ContentPart) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 5: + if fieldTypeId == thrift.STRING { + if err = p.ReadField5(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 6: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField6(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError @@ -6047,6 +6107,25 @@ func (p *ContentPart) ReadField4(iprot thrift.TProtocol) error { p.Base64Data = _field return nil } +func (p *ContentPart) ReadField5(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.VideoURL = _field + return nil +} +func (p *ContentPart) ReadField6(iprot thrift.TProtocol) error { + _field := NewMediaConfig() + if err := _field.Read(iprot); err != nil { + return err + } + p.Config = _field + return nil +} func (p *ContentPart) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 @@ -6070,6 +6149,14 @@ func (p *ContentPart) Write(oprot thrift.TProtocol) (err error) { fieldId = 4 goto WriteFieldError } + if err = p.writeField5(oprot); err != nil { + fieldId = 5 + goto WriteFieldError + } + if err = p.writeField6(oprot); err != nil { + fieldId = 6 + goto WriteFieldError + } } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -6160,6 +6247,42 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 4 end error: ", p), err) } +func (p *ContentPart) writeField5(oprot thrift.TProtocol) (err error) { + if p.IsSetVideoURL() { + if err = oprot.WriteFieldBegin("video_url", thrift.STRING, 5); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.VideoURL); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 5 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 5 end error: ", p), err) +} +func (p *ContentPart) writeField6(oprot thrift.TProtocol) (err error) { + if p.IsSetConfig() { + if err = oprot.WriteFieldBegin("config", thrift.STRUCT, 6); err != nil { + goto WriteFieldBeginError + } + if err := p.Config.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 6 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 6 end error: ", p), err) +} func (p *ContentPart) String() string { if p == nil { @@ -6187,6 +6310,12 @@ func (p *ContentPart) DeepEqual(ano *ContentPart) bool { if !p.Field4DeepEqual(ano.Base64Data) { return false } + if !p.Field5DeepEqual(ano.VideoURL) { + return false + } + if !p.Field6DeepEqual(ano.Config) { + return false + } return true } @@ -6238,6 +6367,206 @@ func (p *ContentPart) Field4DeepEqual(src *string) bool { } return true } +func (p *ContentPart) Field5DeepEqual(src *string) bool { + + if p.VideoURL == src { + return true + } else if p.VideoURL == nil || src == nil { + return false + } + if strings.Compare(*p.VideoURL, *src) != 0 { + return false + } + return true +} +func (p *ContentPart) Field6DeepEqual(src *MediaConfig) bool { + + if !p.Config.DeepEqual(src) { + return false + } + return true +} + +type MediaConfig struct { + Fps *float64 `thrift:"fps,1,optional" frugal:"1,optional,double" form:"fps" json:"fps,omitempty" query:"fps"` +} + +func NewMediaConfig() *MediaConfig { + return &MediaConfig{} +} + +func (p *MediaConfig) InitDefault() { +} + +var MediaConfig_Fps_DEFAULT float64 + +func (p *MediaConfig) GetFps() (v float64) { + if p == nil { + return + } + if !p.IsSetFps() { + return MediaConfig_Fps_DEFAULT + } + return *p.Fps +} +func (p *MediaConfig) SetFps(val *float64) { + p.Fps = val +} + +var fieldIDToName_MediaConfig = map[int16]string{ + 1: "fps", +} + +func (p *MediaConfig) IsSetFps() bool { + return p.Fps != nil +} + +func (p *MediaConfig) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.DOUBLE { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MediaConfig[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *MediaConfig) ReadField1(iprot thrift.TProtocol) error { + + var _field *float64 + if v, err := iprot.ReadDouble(); err != nil { + return err + } else { + _field = &v + } + p.Fps = _field + return nil +} + +func (p *MediaConfig) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("MediaConfig"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *MediaConfig) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetFps() { + if err = oprot.WriteFieldBegin("fps", thrift.DOUBLE, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteDouble(*p.Fps); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} + +func (p *MediaConfig) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("MediaConfig(%+v)", *p) + +} + +func (p *MediaConfig) DeepEqual(ano *MediaConfig) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.Fps) { + return false + } + return true +} + +func (p *MediaConfig) Field1DeepEqual(src *float64) bool { + + if p.Fps == src { + return true + } else if p.Fps == nil || src == nil { + return false + } + if *p.Fps != *src { + return false + } + return true +} type VariableDef struct { // 变量名字 diff --git a/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi_validator.go b/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi_validator.go index c2a982f6d..86bbe7350 100644 --- a/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi_validator.go +++ b/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi_validator.go @@ -154,6 +154,22 @@ func (p *Message) IsValid() error { return nil } func (p *ContentPart) IsValid() error { + if p.Config != nil { + if err := p.Config.IsValid(); err != nil { + return fmt.Errorf("field Config not valid, %w", err) + } + } + return nil +} +func (p *MediaConfig) IsValid() error { + if p.Fps != nil { + if *p.Fps < float64(0.2) { + return fmt.Errorf("field Fps ge rule failed, current value: %v", *p.Fps) + } + if *p.Fps > float64(5) { + return fmt.Errorf("field Fps le rule failed, current value: %v", *p.Fps) + } + } return nil } func (p *VariableDef) IsValid() error { diff --git a/backend/kitex_gen/coze/loop/prompt/openapi/k-coze.loop.prompt.openapi.go b/backend/kitex_gen/coze/loop/prompt/openapi/k-coze.loop.prompt.openapi.go index c8ed71fa6..8d7ce402c 100644 --- a/backend/kitex_gen/coze/loop/prompt/openapi/k-coze.loop.prompt.openapi.go +++ b/backend/kitex_gen/coze/loop/prompt/openapi/k-coze.loop.prompt.openapi.go @@ -4317,6 +4317,34 @@ func (p *ContentPart) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 5: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField5(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 6: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField6(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } default: l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l @@ -4391,6 +4419,32 @@ func (p *ContentPart) FastReadField4(buf []byte) (int, error) { return offset, nil } +func (p *ContentPart) FastReadField5(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.VideoURL = _field + return offset, nil +} + +func (p *ContentPart) FastReadField6(buf []byte) (int, error) { + offset := 0 + _field := NewMediaConfig() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.Config = _field + return offset, nil +} + func (p *ContentPart) FastWrite(buf []byte) int { return p.FastWriteNocopy(buf, nil) } @@ -4402,6 +4456,8 @@ func (p *ContentPart) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset += p.fastWriteField2(buf[offset:], w) offset += p.fastWriteField3(buf[offset:], w) offset += p.fastWriteField4(buf[offset:], w) + offset += p.fastWriteField5(buf[offset:], w) + offset += p.fastWriteField6(buf[offset:], w) } offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset @@ -4414,6 +4470,8 @@ func (p *ContentPart) BLength() int { l += p.field2Length() l += p.field3Length() l += p.field4Length() + l += p.field5Length() + l += p.field6Length() } l += thrift.Binary.FieldStopLength() return l @@ -4455,6 +4513,24 @@ func (p *ContentPart) fastWriteField4(buf []byte, w thrift.NocopyWriter) int { return offset } +func (p *ContentPart) fastWriteField5(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetVideoURL() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 5) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.VideoURL) + } + return offset +} + +func (p *ContentPart) fastWriteField6(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetConfig() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 6) + offset += p.Config.FastWriteNocopy(buf[offset:], w) + } + return offset +} + func (p *ContentPart) field1Length() int { l := 0 if p.IsSetType() { @@ -4491,6 +4567,24 @@ func (p *ContentPart) field4Length() int { return l } +func (p *ContentPart) field5Length() int { + l := 0 + if p.IsSetVideoURL() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.VideoURL) + } + return l +} + +func (p *ContentPart) field6Length() int { + l := 0 + if p.IsSetConfig() { + l += thrift.Binary.FieldBeginLength() + l += p.Config.BLength() + } + return l +} + func (p *ContentPart) DeepCopy(s interface{}) error { src, ok := s.(*ContentPart) if !ok { @@ -4526,6 +4620,140 @@ func (p *ContentPart) DeepCopy(s interface{}) error { p.Base64Data = &tmp } + if src.VideoURL != nil { + var tmp string + if *src.VideoURL != "" { + tmp = kutils.StringDeepCopy(*src.VideoURL) + } + p.VideoURL = &tmp + } + + var _config *MediaConfig + if src.Config != nil { + _config = &MediaConfig{} + if err := _config.DeepCopy(src.Config); err != nil { + return err + } + } + p.Config = _config + + return nil +} + +func (p *MediaConfig) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.DOUBLE { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MediaConfig[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *MediaConfig) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field *float64 + if v, l, err := thrift.Binary.ReadDouble(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Fps = _field + return offset, nil +} + +func (p *MediaConfig) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *MediaConfig) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *MediaConfig) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *MediaConfig) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetFps() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.DOUBLE, 1) + offset += thrift.Binary.WriteDouble(buf[offset:], *p.Fps) + } + return offset +} + +func (p *MediaConfig) field1Length() int { + l := 0 + if p.IsSetFps() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.DoubleLength() + } + return l +} + +func (p *MediaConfig) DeepCopy(s interface{}) error { + src, ok := s.(*MediaConfig) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.Fps != nil { + tmp := *src.Fps + p.Fps = &tmp + } + return nil } diff --git a/backend/loop_gen/coze/loop/foundation/lofile/local_fileservice.go b/backend/loop_gen/coze/loop/foundation/lofile/local_fileservice.go index e451ab6cc..4404593ae 100644 --- a/backend/loop_gen/coze/loop/foundation/lofile/local_fileservice.go +++ b/backend/loop_gen/coze/loop/foundation/lofile/local_fileservice.go @@ -43,6 +43,27 @@ func (l *LocalFileService) UploadLoopFileInner(ctx context.Context, req *file.Up return result.GetSuccess(), nil } +func (l *LocalFileService) UploadFileForServer(ctx context.Context, req *file.UploadFileForServerRequest, callOptions ...callopt.Option) (*file.UploadFileForServerResponse, error) { + chain := l.mds(func(ctx context.Context, in, out interface{}) error { + arg := in.(*file.FileServiceUploadFileForServerArgs) + result := out.(*file.FileServiceUploadFileForServerResult) + resp, err := l.impl.UploadFileForServer(ctx, arg.Req) + if err != nil { + return err + } + result.SetSuccess(resp) + return nil + }) + + arg := &file.FileServiceUploadFileForServerArgs{Req: req} + result := &file.FileServiceUploadFileForServerResult{} + ctx = l.injectRPCInfo(ctx, "UploadFileForServer") + if err := chain(ctx, arg, result); err != nil { + return nil, err + } + return result.GetSuccess(), nil +} + func (l *LocalFileService) SignUploadFile(ctx context.Context, req *file.SignUploadFileRequest, callOptions ...callopt.Option) (*file.SignUploadFileResponse, error) { chain := l.mds(func(ctx context.Context, in, out interface{}) error { arg := in.(*file.FileServiceSignUploadFileArgs) diff --git a/backend/modules/evaluation/infra/rpc/foundation/file_test.go b/backend/modules/evaluation/infra/rpc/foundation/file_test.go index 11943c9de..6d9658dfa 100755 --- a/backend/modules/evaluation/infra/rpc/foundation/file_test.go +++ b/backend/modules/evaluation/infra/rpc/foundation/file_test.go @@ -188,6 +188,10 @@ type mockFileServiceClient struct { lastRequest *file.SignDownloadFileRequest } +func (m *mockFileServiceClient) UploadFileForServer(ctx context.Context, req *file.UploadFileForServerRequest, callOptions ...callopt.Option) (r *file.UploadFileForServerResponse, err error) { + return nil, errors.New("not implemented") +} + func (m *mockFileServiceClient) UploadLoopFileInner(ctx context.Context, req *file.UploadLoopFileInnerRequest, callOptions ...callopt.Option) (r *file.UploadLoopFileInnerResponse, err error) { return nil, errors.New("not implemented") } diff --git a/backend/modules/evaluation/infra/rpc/foundation/mocks/fileservice_client.go b/backend/modules/evaluation/infra/rpc/foundation/mocks/fileservice_client.go index fbe1552aa..cb342e6fd 100644 --- a/backend/modules/evaluation/infra/rpc/foundation/mocks/fileservice_client.go +++ b/backend/modules/evaluation/infra/rpc/foundation/mocks/fileservice_client.go @@ -82,6 +82,26 @@ func (mr *MockClientMockRecorder) SignUploadFile(ctx, req any, callOptions ...an return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignUploadFile", reflect.TypeOf((*MockClient)(nil).SignUploadFile), varargs...) } +// UploadFileForServer mocks base method. +func (m *MockClient) UploadFileForServer(ctx context.Context, req *file.UploadFileForServerRequest, callOptions ...callopt.Option) (*file.UploadFileForServerResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, req} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "UploadFileForServer", varargs...) + ret0, _ := ret[0].(*file.UploadFileForServerResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UploadFileForServer indicates an expected call of UploadFileForServer. +func (mr *MockClientMockRecorder) UploadFileForServer(ctx, req any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, req}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UploadFileForServer", reflect.TypeOf((*MockClient)(nil).UploadFileForServer), varargs...) +} + // UploadLoopFileInner mocks base method. func (m *MockClient) UploadLoopFileInner(ctx context.Context, req *file.UploadLoopFileInnerRequest, callOptions ...callopt.Option) (*file.UploadLoopFileInnerResponse, error) { m.ctrl.T.Helper() diff --git a/backend/modules/foundation/application/file.go b/backend/modules/foundation/application/file.go index 4c621e972..30716971e 100644 --- a/backend/modules/foundation/application/file.go +++ b/backend/modules/foundation/application/file.go @@ -30,6 +30,37 @@ func NewFileApplication(objectStorage fileserver.BatchObjectStorage, auth rpc.IA } } +func (p *FileApplicationImpl) UploadFileForServer(ctx context.Context, req *file.UploadFileForServerRequest) (r *file.UploadFileForServerResponse, err error) { + if req == nil || req.MimeType == "" || len(req.Body) == 0 || req.WorkspaceID == 0 { + return nil, errorx.NewByCode(errno.CommonInvalidParamCode) + } + + spaceID := strconv.FormatInt(req.WorkspaceID, 10) + + // Extract custom mime type mappings and file name from option + var customMimeTypeExtMap map[string]string + var fileName string + if req.Option != nil { + customMimeTypeExtMap = req.Option.MimeTypeExtMapping + if req.Option.FileName != nil { + fileName = *req.Option.FileName + } + } + + key, err := p.fileService.UploadFileForServer(ctx, req.MimeType, req.Body, spaceID, customMimeTypeExtMap, fileName) + if err != nil { + return nil, err + } + + return &file.UploadFileForServerResponse{ + Data: &file.FileData{ + Bytes: lo.ToPtr(int64(len(req.Body))), + FileName: lo.ToPtr(key), + }, + BaseResp: base.NewBaseResp(), + }, nil +} + func (p *FileApplicationImpl) UploadLoopFileInner(ctx context.Context, req *file.UploadLoopFileInnerRequest) (r *file.UploadLoopFileInnerResponse, err error) { if req == nil || req.ContentType == "" || len(req.Body) == 0 { return nil, errorx.NewByCode(errno.CommonInvalidParamCode) diff --git a/backend/modules/foundation/application/file_test.go b/backend/modules/foundation/application/file_test.go new file mode 100644 index 000000000..79d41e94d --- /dev/null +++ b/backend/modules/foundation/application/file_test.go @@ -0,0 +1,189 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package application + +import ( + "context" + "testing" + + "github.com/samber/lo" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/foundation/file" + service "github.com/coze-dev/coze-loop/backend/modules/foundation/domain/file/service" + servicemocks "github.com/coze-dev/coze-loop/backend/modules/foundation/domain/file/service/mocks" + "github.com/coze-dev/coze-loop/backend/modules/foundation/pkg/errno" + "github.com/coze-dev/coze-loop/backend/pkg/errorx" + "github.com/coze-dev/coze-loop/backend/pkg/unittest" +) + +type dummyAuth struct{} + +func (a dummyAuth) CheckWorkspacePermission(context.Context, string, string) error { + return nil +} + +func TestFileApplicationImpl_UploadFileForServer(t *testing.T) { + type fields struct { + fileService service.FileService + } + type args struct { + ctx context.Context + req *file.UploadFileForServerRequest + } + + tests := []struct { + name string + fieldsGetter func(ctrl *gomock.Controller) fields + args args + wantErr error + wantKey string + }{ + { + name: "nil request returns error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{} + }, + args: args{ + ctx: context.Background(), + req: nil, + }, + wantErr: errorx.NewByCode(errno.CommonInvalidParamCode), + }, + { + name: "missing mime type returns error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{} + }, + args: args{ + ctx: context.Background(), + req: &file.UploadFileForServerRequest{ + Body: []byte("data"), + WorkspaceID: 1, + }, + }, + wantErr: errorx.NewByCode(errno.CommonInvalidParamCode), + }, + { + name: "missing body returns error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{} + }, + args: args{ + ctx: context.Background(), + req: &file.UploadFileForServerRequest{ + MimeType: "text/plain", + WorkspaceID: 1, + }, + }, + wantErr: errorx.NewByCode(errno.CommonInvalidParamCode), + }, + { + name: "missing workspace id returns error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{} + }, + args: args{ + ctx: context.Background(), + req: &file.UploadFileForServerRequest{ + MimeType: "text/plain", + Body: []byte("data"), + }, + }, + wantErr: errorx.NewByCode(errno.CommonInvalidParamCode), + }, + { + name: "success without option", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockService := servicemocks.NewMockFileService(ctrl) + mockService.EXPECT(). + UploadFileForServer(gomock.Any(), "image/png", []byte("img"), "123", gomock.AssignableToTypeOf(map[string]string(nil)), ""). + Return("123/generated.png", nil) + return fields{fileService: mockService} + }, + args: args{ + ctx: context.Background(), + req: &file.UploadFileForServerRequest{ + MimeType: "image/png", + Body: []byte("img"), + WorkspaceID: 123, + }, + }, + wantErr: nil, + wantKey: "123/generated.png", + }, + { + name: "success with option and file name", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockService := servicemocks.NewMockFileService(ctrl) + mockService.EXPECT(). + UploadFileForServer(gomock.Any(), "application/json", []byte("{\"a\":1}"), "45", gomock.Eq(map[string]string{"application/custom": ".cus"}), "name.json"). + Return("45/name.json", nil) + return fields{fileService: mockService} + }, + args: args{ + ctx: context.Background(), + req: &file.UploadFileForServerRequest{ + MimeType: "application/json", + Body: []byte("{\"a\":1}"), + WorkspaceID: 45, + Option: &file.UploadFileOption{ + FileName: lo.ToPtr("name.json"), + MimeTypeExtMapping: map[string]string{"application/custom": ".cus"}, + }, + }, + }, + wantErr: nil, + wantKey: "45/name.json", + }, + { + name: "service returns error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockService := servicemocks.NewMockFileService(ctrl) + mockService.EXPECT(). + UploadFileForServer(gomock.Any(), "text/plain", []byte("fail"), "9", gomock.AssignableToTypeOf(map[string]string(nil)), ""). + Return("", assert.AnError) + return fields{fileService: mockService} + }, + args: args{ + ctx: context.Background(), + req: &file.UploadFileForServerRequest{ + MimeType: "text/plain", + Body: []byte("fail"), + WorkspaceID: 9, + }, + }, + wantErr: assert.AnError, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ttFields := tt.fieldsGetter(ctrl) + app := &FileApplicationImpl{ + auth: dummyAuth{}, + fileService: ttFields.fileService, + } + + got, err := app.UploadFileForServer(tt.args.ctx, tt.args.req) + + unittest.AssertErrorEqual(t, tt.wantErr, err) + if tt.wantErr != nil { + assert.Nil(t, got) + return + } + + assert.NotNil(t, got) + assert.NotNil(t, got.Data) + assert.Equal(t, tt.wantKey, got.GetData().GetFileName()) + assert.Equal(t, int64(len(tt.args.req.Body)), got.GetData().GetBytes()) + assert.NotNil(t, got.BaseResp) + }) + } +} diff --git a/backend/modules/foundation/domain/file/service/mocks/file_service.go b/backend/modules/foundation/domain/file/service/mocks/file_service.go index 1f5d43595..cb3e2ee66 100644 --- a/backend/modules/foundation/domain/file/service/mocks/file_service.go +++ b/backend/modules/foundation/domain/file/service/mocks/file_service.go @@ -73,6 +73,21 @@ func (mr *MockFileServiceMockRecorder) SignUploadFile(ctx, req any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignUploadFile", reflect.TypeOf((*MockFileService)(nil).SignUploadFile), ctx, req) } +// UploadFileForServer mocks base method. +func (m *MockFileService) UploadFileForServer(ctx context.Context, mimeType string, body []byte, spaceID string, customMimeTypeExtMap map[string]string, fileName string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UploadFileForServer", ctx, mimeType, body, spaceID, customMimeTypeExtMap, fileName) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UploadFileForServer indicates an expected call of UploadFileForServer. +func (mr *MockFileServiceMockRecorder) UploadFileForServer(ctx, mimeType, body, spaceID, customMimeTypeExtMap, fileName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UploadFileForServer", reflect.TypeOf((*MockFileService)(nil).UploadFileForServer), ctx, mimeType, body, spaceID, customMimeTypeExtMap, fileName) +} + // UploadLoopFile mocks base method. func (m *MockFileService) UploadLoopFile(ctx context.Context, fileHeader *multipart.FileHeader, spaceID string) (string, error) { m.ctrl.T.Helper() diff --git a/backend/modules/foundation/domain/file/service/server.go b/backend/modules/foundation/domain/file/service/server.go index df761df5a..67bee34e1 100644 --- a/backend/modules/foundation/domain/file/service/server.go +++ b/backend/modules/foundation/domain/file/service/server.go @@ -4,16 +4,19 @@ package service import ( + "bytes" "context" "errors" "fmt" "io" + "mime" "mime/multipart" "net/http" "net/url" "path/filepath" "time" + "github.com/google/uuid" errors2 "github.com/pkg/errors" "github.com/coze-dev/coze-loop/backend/infra/fileserver" @@ -27,6 +30,7 @@ import ( //go:generate mockgen -destination=mocks/file_service.go -package=mocks . FileService type FileService interface { UploadLoopFile(ctx context.Context, fileHeader *multipart.FileHeader, spaceID string) (key string, err error) + UploadFileForServer(ctx context.Context, mimeType string, body []byte, spaceID string, customMimeTypeExtMap map[string]string, fileName string) (key string, err error) SignUploadFile(ctx context.Context, req *file.SignUploadFileRequest) (uris []string, heads []*file.SignHead, err error) SignDownLoadFile(ctx context.Context, req *file.SignDownloadFileRequest) (uris []string, err error) } @@ -86,6 +90,67 @@ func (fs fileService) UploadLoopFile(ctx context.Context, fileHeader *multipart. return fileName, nil } +func (fs fileService) UploadFileForServer(ctx context.Context, mimeType string, body []byte, spaceID string, customMimeTypeExtMap map[string]string, fileName string) (key string, err error) { + if len(body) == 0 { + return "", errorx.NewByCode(errno.CommonInvalidParamCode) + } + + // If user provided file name, use it directly (may already have extension) + if fileName == "" { + // Add custom mime type extension mappings if provided + if len(customMimeTypeExtMap) > 0 { + for mType, ext := range customMimeTypeExtMap { + if mType != "" && ext != "" { + // Ensure extension starts with a dot + if ext[0] != '.' { + ext = "." + ext + } + if err := mime.AddExtensionType(ext, mType); err != nil { + logs.CtxError(ctx, "add extension type failed, mimeType: %s, ext: %s, err: %v", mType, ext, err) + } + } + } + } + + // Get file extension from mime type + ext := "" + if mimeType != "" { + exts, err := mime.ExtensionsByType(mimeType) + if err == nil && len(exts) > 0 { + ext = exts[0] + } + } + + // Generate random file name + fileName = uuid.New().String() + + // Append extension if we have one + if ext != "" { + fileName = fileName + ext + } + } + + // Build full path with workspace ID + fullPath := filepath.Join(spaceID, "/", fileName) + + // Detect content type from file data + fileContentType := http.DetectContentType(body) + if mimeType != "" { + // Use provided mime type if available + fileContentType = mimeType + } + + // Upload file + reader := bytes.NewReader(body) + logs.CtxDebug(ctx, "start upload for server, fileName: %s, mimeType: %s", fullPath, fileContentType) + if err = fs.client.Upload(ctx, fullPath, reader, fileserver.UploadWithContentType(fileContentType)); err != nil { + logs.CtxError(ctx, "upload file for server failed, err: %v", err) + return "", err + } + + return fullPath, nil +} + func (fs fileService) SignUploadFile(ctx context.Context, req *file.SignUploadFileRequest) (uris []string, heads []*file.SignHead, err error) { signOpt := make([]fileserver.SignOpt, 0) if req.Option != nil { diff --git a/backend/modules/foundation/domain/file/service/server_test.go b/backend/modules/foundation/domain/file/service/server_test.go index 0e7875964..3600399d8 100644 --- a/backend/modules/foundation/domain/file/service/server_test.go +++ b/backend/modules/foundation/domain/file/service/server_test.go @@ -11,6 +11,7 @@ import ( "net/http" "net/http/httptest" "os" + "strings" "testing" "github.com/pkg/errors" @@ -183,6 +184,156 @@ func TestFileServiceImpl_UploadLoopFile(t *testing.T) { } } +func TestFileServiceImpl_UploadFileForServer(t *testing.T) { + type args struct { + ctx context.Context + mimeType string + body []byte + spaceID string + customMimeTypeExtMap map[string]string + fileName string + } + + tests := []struct { + name string + args args + expectedErr error + expectedKey string + expectedKeyPrefix string + expectedKeySuffix string + expectedContentType string + expectUpload bool + uploadErr error + }{ + { + name: "empty body returns invalid param error", + args: args{ + ctx: context.Background(), + mimeType: "text/plain", + body: []byte{}, + spaceID: "42", + }, + expectedErr: errorx.NewByCode(errno.CommonInvalidParamCode), + }, + { + name: "success with provided file name", + args: args{ + ctx: context.Background(), + mimeType: "text/plain", + body: []byte("hello world"), + spaceID: "workspace", + fileName: "custom.txt", + }, + expectUpload: true, + expectedErr: nil, + expectedKey: "workspace/custom.txt", + expectedContentType: "text/plain", + }, + { + name: "generate name when file name empty uses mime extension", + args: args{ + ctx: context.Background(), + mimeType: "image/png", + body: []byte{0x89, 0x50, 0x4e, 0x47, 0x00, 0x00}, + spaceID: "space", + }, + expectUpload: true, + expectedErr: nil, + expectedKeyPrefix: "space/", + expectedKeySuffix: ".png", + expectedContentType: "image/png", + }, + { + name: "custom mime mapping applies extension", + args: args{ + ctx: context.Background(), + mimeType: "application/x-coze", + body: []byte("coze-data"), + spaceID: "space", + customMimeTypeExtMap: map[string]string{"application/x-coze": "coze"}, + }, + expectUpload: true, + expectedErr: nil, + expectedKeyPrefix: "space/", + expectedKeySuffix: ".coze", + // Files without explicit mimeType should fall back to detection; since we pass mimeType, expect to reuse it. + expectedContentType: "application/x-coze", + }, + { + name: "upload failure bubbles up error", + args: args{ + ctx: context.Background(), + mimeType: "text/plain", + body: []byte("fail case"), + spaceID: "space", + }, + expectUpload: true, + uploadErr: errors.New("upload failed"), + expectedErr: errors.New("upload failed"), + expectedKeyPrefix: "space/", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + objectStorage := fsmocks.NewMockBatchObjectStorage(ctrl) + var uploadedKey string + var uploadedBody []byte + var uploadedContentTypes []string + + if tt.expectUpload { + objectStorage.EXPECT(). + Upload(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, key string, reader io.Reader, opts ...fileserver.UploadOpt) error { + uploadedKey = key + body, err := io.ReadAll(reader) + if err != nil { + return err + } + uploadedBody = body + option := fileserver.NewUploadOption(opts...) + uploadedContentTypes = option.ContentTypes + return tt.uploadErr + }) + } + + f := &fileService{ + client: objectStorage, + } + + got, err := f.UploadFileForServer(tt.args.ctx, tt.args.mimeType, tt.args.body, tt.args.spaceID, tt.args.customMimeTypeExtMap, tt.args.fileName) + + unittest.AssertErrorEqual(t, tt.expectedErr, err) + + if !tt.expectUpload { + assert.Equal(t, "", got) + return + } + + assert.Equal(t, tt.args.body, uploadedBody) + if tt.expectedContentType != "" { + assert.Equal(t, []string{tt.expectedContentType}, uploadedContentTypes) + } + if tt.expectedKey != "" { + assert.Equal(t, tt.expectedKey, uploadedKey) + } else { + assert.True(t, strings.HasPrefix(uploadedKey, tt.expectedKeyPrefix)) + assert.True(t, strings.HasSuffix(uploadedKey, tt.expectedKeySuffix)) + } + + if tt.expectedErr == nil { + assert.Equal(t, uploadedKey, got) + } else { + assert.Equal(t, "", got) + } + }) + } +} + func TestFileServiceImpl_SignUploadFile(t *testing.T) { _ = os.Setenv("COZE_LOOP_OSS_PROTOCOL", "http") _ = os.Setenv("COZE_LOOP_OSS_DOMAIN", "cozeloop-minio") diff --git a/backend/modules/prompt/application/convertor/openapi.go b/backend/modules/prompt/application/convertor/openapi.go index ca5af15df..d29f264e4 100644 --- a/backend/modules/prompt/application/convertor/openapi.go +++ b/backend/modules/prompt/application/convertor/openapi.go @@ -43,7 +43,7 @@ func OpenAPIPromptTemplateDO2DTO(do *entity.PromptTemplate) *openapi.PromptTempl TemplateType: ptr.Of(prompt.TemplateType(do.TemplateType)), Messages: OpenAPIBatchMessageDO2DTO(do.Messages), VariableDefs: OpenAPIBatchVariableDefDO2DTO(do.VariableDefs), - Metadata: do.Metadata, + Metadata: do.Metadata, } } @@ -72,7 +72,7 @@ func OpenAPIMessageDO2DTO(do *entity.Message) *openapi.Message { Parts: OpenAPIBatchContentPartDO2DTO(do.Parts), ToolCallID: do.ToolCallID, ToolCalls: OpenAPIBatchToolCallDO2DTO(do.ToolCalls), - Metadata: do.Metadata, + Metadata: do.Metadata, } } @@ -179,11 +179,26 @@ func OpenAPIContentPartDO2DTO(do *entity.ContentPart) *openapi.ContentPart { if do.ImageURL != nil { imageURL = ptr.Of(do.ImageURL.URL) } + var videoURL *string + var config *openapi.MediaConfig + if do.VideoURL != nil { + if do.VideoURL.URL != "" { + videoURL = ptr.Of(do.VideoURL.URL) + } + } + // Set Config with fps if available + if do.MediaConfig != nil && do.MediaConfig.Fps != nil { + config = &openapi.MediaConfig{ + Fps: do.MediaConfig.Fps, + } + } return &openapi.ContentPart{ Type: ptr.Of(OpenAPIContentTypeDO2DTO(do.Type)), Text: do.Text, ImageURL: imageURL, + VideoURL: videoURL, Base64Data: do.Base64Data, + Config: config, } } @@ -193,6 +208,8 @@ func OpenAPIContentTypeDO2DTO(do entity.ContentType) openapi.ContentType { return openapi.ContentTypeText case entity.ContentTypeImageURL: return openapi.ContentTypeImageURL + case entity.ContentTypeVideoURL: + return openapi.ContentTypeVideoURL case entity.ContentTypeBase64Data: return openapi.ContentTypeBase64Data case entity.ContentTypeMultiPartVariable: @@ -229,7 +246,7 @@ func OpenAPIMessageDTO2DO(dto *openapi.Message) *entity.Message { Parts: OpenAPIBatchContentPartDTO2DO(dto.Parts), ToolCallID: dto.ToolCallID, ToolCalls: OpenAPIBatchToolCallDTO2DO(dto.ToolCalls), - Metadata: dto.Metadata, + Metadata: dto.Metadata, } } @@ -259,11 +276,26 @@ func OpenAPIContentPartDTO2DO(dto *openapi.ContentPart) *entity.ContentPart { URL: *dto.ImageURL, } } + var videoURL *entity.VideoURL + if dto.VideoURL != nil && *dto.VideoURL != "" { + videoURL = &entity.VideoURL{ + URL: *dto.VideoURL, + } + } + var mediaConfig *entity.MediaConfig + // Set MediaConfig from Config if available + if dto.Config != nil && dto.Config.Fps != nil { + mediaConfig = &entity.MediaConfig{ + Fps: dto.Config.Fps, + } + } return &entity.ContentPart{ - Type: OpenAPIContentTypeDTO2DO(dto.GetType()), - Text: dto.Text, - ImageURL: imageURL, - Base64Data: dto.Base64Data, + Type: OpenAPIContentTypeDTO2DO(dto.GetType()), + Text: dto.Text, + ImageURL: imageURL, + VideoURL: videoURL, + Base64Data: dto.Base64Data, + MediaConfig: mediaConfig, } } @@ -274,6 +306,8 @@ func OpenAPIContentTypeDTO2DO(dto openapi.ContentType) entity.ContentType { return entity.ContentTypeText case openapi.ContentTypeImageURL: return entity.ContentTypeImageURL + case openapi.ContentTypeVideoURL: + return entity.ContentTypeVideoURL case openapi.ContentTypeBase64Data: return entity.ContentTypeBase64Data case openapi.ContentTypeMultiPartVariable: diff --git a/backend/modules/prompt/application/convertor/openapi_test.go b/backend/modules/prompt/application/convertor/openapi_test.go index 9d05c72b7..8a2210f06 100755 --- a/backend/modules/prompt/application/convertor/openapi_test.go +++ b/backend/modules/prompt/application/convertor/openapi_test.go @@ -365,12 +365,14 @@ func mockOpenAPIPromptCases() []openAPIPromptTestCase { }, }, }, - want: &openapi.Prompt{ + dto: &openapi.Prompt{ WorkspaceID: ptr.Of(int64(456)), PromptKey: ptr.Of("test_prompt"), Version: ptr.Of("1.0.0"), PromptTemplate: &openapi.PromptTemplate{ - Metadata: map[string]string{"commit": "meta"}, + TemplateType: ptr.Of(prompt.TemplateType("")), + VariableDefs: []*openapi.VariableDef{}, + Metadata: map[string]string{"commit": "meta"}, }, }, }, @@ -440,7 +442,9 @@ func TestOpenAPIPromptTemplateDO2DTO(t *testing.T) { Metadata: map[string]string{"k": "v"}, }, want: &openapi.PromptTemplate{ - Metadata: map[string]string{"k": "v"}, + TemplateType: ptr.Of(prompt.TemplateType("")), + VariableDefs: []*openapi.VariableDef{}, + Metadata: map[string]string{"k": "v"}, }, }, } @@ -544,6 +548,11 @@ func TestOpenAPIContentTypeDO2DTO(t *testing.T) { do: entity.ContentTypeImageURL, want: openapi.ContentTypeImageURL, }, + { + name: "video url content type", + do: entity.ContentTypeVideoURL, + want: openapi.ContentTypeVideoURL, + }, { name: "unknown content type - should default to text", do: entity.ContentType("unknown"), @@ -629,6 +638,38 @@ func TestOpenAPIContentPartDO2DTO(t *testing.T) { Text: ptr.Of(""), }, }, + { + name: "video url content part with fps", + do: &entity.ContentPart{ + Type: entity.ContentTypeVideoURL, + VideoURL: &entity.VideoURL{ + URL: "https://example.com/video.mp4", + }, + MediaConfig: &entity.MediaConfig{ + Fps: ptr.Of(2.0), + }, + }, + want: &openapi.ContentPart{ + Type: ptr.Of(openapi.ContentTypeVideoURL), + VideoURL: ptr.Of("https://example.com/video.mp4"), + Config: &openapi.MediaConfig{ + Fps: ptr.Of(2.0), + }, + }, + }, + { + name: "video url content part without fps", + do: &entity.ContentPart{ + Type: entity.ContentTypeVideoURL, + VideoURL: &entity.VideoURL{ + URL: "https://example.com/video.mp4", + }, + }, + want: &openapi.ContentPart{ + Type: ptr.Of(openapi.ContentTypeVideoURL), + VideoURL: ptr.Of("https://example.com/video.mp4"), + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -739,6 +780,50 @@ func TestOpenAPIBatchContentPartDO2DTO(t *testing.T) { }, want: []*openapi.ContentPart{}, }, + { + name: "array with video url part", + do: []*entity.ContentPart{ + { + Type: entity.ContentTypeVideoURL, + VideoURL: &entity.VideoURL{ + URL: "https://example.com/video.mp4", + }, + MediaConfig: &entity.MediaConfig{ + Fps: ptr.Of(1.5), + }, + }, + }, + want: []*openapi.ContentPart{ + { + Type: ptr.Of(openapi.ContentTypeVideoURL), + VideoURL: ptr.Of("https://example.com/video.mp4"), + Config: &openapi.MediaConfig{ + Fps: ptr.Of(1.5), + }, + }, + }, + }, + { + name: "base64 content part carries fps", + do: []*entity.ContentPart{ + { + Type: entity.ContentTypeBase64Data, + Base64Data: ptr.Of("data:video/mp4;base64,QUJDRA=="), + MediaConfig: &entity.MediaConfig{ + Fps: ptr.Of(2.4), + }, + }, + }, + want: []*openapi.ContentPart{ + { + Type: ptr.Of(openapi.ContentTypeBase64Data), + Base64Data: ptr.Of("data:video/mp4;base64,QUJDRA=="), + Config: &openapi.MediaConfig{ + Fps: ptr.Of(2.4), + }, + }, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -975,6 +1060,11 @@ func TestOpenAPIContentTypeDO2DTO_NewTypes(t *testing.T) { do: entity.ContentTypeImageURL, want: openapi.ContentTypeImageURL, }, + { + name: "video url content type", + do: entity.ContentTypeVideoURL, + want: openapi.ContentTypeVideoURL, + }, { name: "base64 data content type", do: entity.ContentTypeBase64Data, @@ -1289,6 +1379,50 @@ func TestOpenAPIBatchContentPartDTO2DO(t *testing.T) { }, }, }, + { + name: "video url handling with fps config", + dtos: []*openapi.ContentPart{ + { + Type: ptr.Of(openapi.ContentTypeVideoURL), + VideoURL: ptr.Of("https://example.com/video.mp4"), + Config: &openapi.MediaConfig{ + Fps: ptr.Of(1.8), + }, + }, + }, + want: []*entity.ContentPart{ + { + Type: entity.ContentTypeVideoURL, + VideoURL: &entity.VideoURL{ + URL: "https://example.com/video.mp4", + }, + MediaConfig: &entity.MediaConfig{ + Fps: ptr.Of(1.8), + }, + }, + }, + }, + { + name: "base64 video carries fps without video url", + dtos: []*openapi.ContentPart{ + { + Type: ptr.Of(openapi.ContentTypeBase64Data), + Base64Data: ptr.Of("data:video/mp4;base64,QUJDRA=="), + Config: &openapi.MediaConfig{ + Fps: ptr.Of(2.2), + }, + }, + }, + want: []*entity.ContentPart{ + { + Type: entity.ContentTypeBase64Data, + Base64Data: ptr.Of("data:video/mp4;base64,QUJDRA=="), + MediaConfig: &entity.MediaConfig{ + Fps: ptr.Of(2.2), + }, + }, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/backend/modules/prompt/application/convertor/prompt.go b/backend/modules/prompt/application/convertor/prompt.go index 5d87c2843..1c30c2413 100644 --- a/backend/modules/prompt/application/convertor/prompt.go +++ b/backend/modules/prompt/application/convertor/prompt.go @@ -109,7 +109,7 @@ func PromptTemplateDTO2DO(dto *prompt.PromptTemplate) *entity.PromptTemplate { TemplateType: TemplateTypeDTO2DO(dto.GetTemplateType()), Messages: BatchMessageDTO2DO(dto.Messages), VariableDefs: BatchVariableDefDTO2DO(dto.VariableDefs), - Metadata: dto.Metadata, + Metadata: dto.Metadata, } } @@ -150,7 +150,7 @@ func MessageDTO2DO(dto *prompt.Message) *entity.Message { Parts: BatchContentPartDTO2DO(dto.Parts), ToolCallID: dto.ToolCallID, ToolCalls: BatchToolCallDTO2DO(dto.ToolCalls), - Metadata: dto.Metadata, + Metadata: dto.Metadata, } } @@ -191,9 +191,11 @@ func ContentPartDTO2DO(dto *prompt.ContentPart) *entity.ContentPart { } return &entity.ContentPart{ - Type: ContentTypeDTO2DO(dto.GetType()), - Text: dto.Text, - ImageURL: ImageURLDTO2DO(dto.ImageURL), + Type: ContentTypeDTO2DO(dto.GetType()), + Text: dto.Text, + ImageURL: ImageURLDTO2DO(dto.ImageURL), + VideoURL: VideoURLDTO2DO(dto.VideoURL), + MediaConfig: MediaConfigDTO2DO(dto.MediaConfig), } } @@ -203,6 +205,8 @@ func ContentTypeDTO2DO(dto prompt.ContentType) entity.ContentType { return entity.ContentTypeText case prompt.ContentTypeImageURL: return entity.ContentTypeImageURL + case prompt.ContentTypeVideoURL: + return entity.ContentTypeVideoURL case prompt.ContentTypeMultiPartVariable: return entity.ContentTypeMultiPartVariable default: @@ -221,6 +225,27 @@ func ImageURLDTO2DO(dto *prompt.ImageURL) *entity.ImageURL { } } +func VideoURLDTO2DO(dto *prompt.VideoURL) *entity.VideoURL { + if dto == nil { + return nil + } + + return &entity.VideoURL{ + URI: dto.GetURI(), + URL: dto.GetURL(), + } +} + +func MediaConfigDTO2DO(dto *prompt.MediaConfig) *entity.MediaConfig { + if dto == nil { + return nil + } + + return &entity.MediaConfig{ + Fps: dto.Fps, + } +} + func BatchVariableDefDTO2DO(dtos []*prompt.VariableDef) []*entity.VariableDef { if dtos == nil { return nil @@ -525,9 +550,11 @@ func ContentPartDO2DTO(do *entity.ContentPart) *prompt.ContentPart { return nil } return &prompt.ContentPart{ - Type: ptr.Of(ContentTypeDO2DTO(do.Type)), - Text: do.Text, - ImageURL: ImageURLDO2DTO(do.ImageURL), + Type: ptr.Of(ContentTypeDO2DTO(do.Type)), + Text: do.Text, + ImageURL: ImageURLDO2DTO(do.ImageURL), + VideoURL: VideoURLDO2DTO(do.VideoURL), + MediaConfig: MediaConfigDO2DTO(do.MediaConfig), } } @@ -537,6 +564,8 @@ func ContentTypeDO2DTO(do entity.ContentType) prompt.ContentType { return prompt.ContentTypeText case entity.ContentTypeImageURL: return prompt.ContentTypeImageURL + case entity.ContentTypeVideoURL: + return prompt.ContentType("video_url") case entity.ContentTypeMultiPartVariable: return prompt.ContentTypeMultiPartVariable default: @@ -554,6 +583,25 @@ func ImageURLDO2DTO(do *entity.ImageURL) *prompt.ImageURL { } } +func VideoURLDO2DTO(do *entity.VideoURL) *prompt.VideoURL { + if do == nil { + return nil + } + return &prompt.VideoURL{ + URI: ptr.Of(do.URI), + URL: ptr.Of(do.URL), + } +} + +func MediaConfigDO2DTO(do *entity.MediaConfig) *prompt.MediaConfig { + if do == nil { + return nil + } + return &prompt.MediaConfig{ + Fps: do.Fps, + } +} + func BatchDebugToolCallDO2DTO(dos []*entity.DebugToolCall) []*prompt.DebugToolCall { if dos == nil { return nil @@ -630,7 +678,7 @@ func MessageDO2DTO(do *entity.Message) *prompt.Message { Parts: BatchContentPartDO2DTO(do.Parts), ToolCallID: do.ToolCallID, ToolCalls: BatchToolCallDO2DTO(do.ToolCalls), - Metadata: do.Metadata, + Metadata: do.Metadata, } } @@ -831,7 +879,7 @@ func PromptTemplateDO2DTO(do *entity.PromptTemplate) *prompt.PromptTemplate { TemplateType: ptr.Of(prompt.TemplateType(do.TemplateType)), Messages: BatchMessageDO2DTO(do.Messages), VariableDefs: BatchVariableDefDO2DTO(do.VariableDefs), - Metadata: do.Metadata, + Metadata: do.Metadata, } } diff --git a/backend/modules/prompt/application/convertor/prompt_test.go b/backend/modules/prompt/application/convertor/prompt_test.go index bbe5d8e8b..bd3ac6d15 100644 --- a/backend/modules/prompt/application/convertor/prompt_test.go +++ b/backend/modules/prompt/application/convertor/prompt_test.go @@ -345,17 +345,22 @@ func mockPromptCases() []promptTestCase { { name: "prompt template metadata", dto: &prompt.Prompt{ + ID: ptr.Of(int64(0)), + WorkspaceID: ptr.Of(int64(0)), + PromptKey: ptr.Of(""), PromptCommit: &prompt.PromptCommit{ Detail: &prompt.PromptDetail{ PromptTemplate: &prompt.PromptTemplate{ - Metadata: map[string]string{"commit-meta": "value"}, + TemplateType: ptr.Of(prompt.TemplateTypeNormal), + Metadata: map[string]string{"commit-meta": "value"}, }, }, }, PromptDraft: &prompt.PromptDraft{ Detail: &prompt.PromptDetail{ PromptTemplate: &prompt.PromptTemplate{ - Metadata: map[string]string{"draft-meta": "value"}, + TemplateType: ptr.Of(prompt.TemplateTypeNormal), + Metadata: map[string]string{"draft-meta": "value"}, }, }, }, @@ -364,14 +369,16 @@ func mockPromptCases() []promptTestCase { PromptCommit: &entity.PromptCommit{ PromptDetail: &entity.PromptDetail{ PromptTemplate: &entity.PromptTemplate{ - Metadata: map[string]string{"commit-meta": "value"}, + TemplateType: entity.TemplateTypeNormal, + Metadata: map[string]string{"commit-meta": "value"}, }, }, }, PromptDraft: &entity.PromptDraft{ PromptDetail: &entity.PromptDetail{ PromptTemplate: &entity.PromptTemplate{ - Metadata: map[string]string{"draft-meta": "value"}, + TemplateType: entity.TemplateTypeNormal, + Metadata: map[string]string{"draft-meta": "value"}, }, }, }, @@ -512,6 +519,39 @@ func mockMessageCases() []messageTestCase { }, }, }, + { + name: "user message with video content", + dto: &prompt.Message{ + Role: ptr.Of(prompt.RoleUser), + Parts: []*prompt.ContentPart{ + { + Type: ptr.Of(prompt.ContentTypeVideoURL), + VideoURL: &prompt.VideoURL{ + URL: ptr.Of("https://example.com/video.mp4"), + URI: ptr.Of("video-uri"), + }, + MediaConfig: &prompt.MediaConfig{ + Fps: ptr.Of(2.5), + }, + }, + }, + }, + do: &entity.Message{ + Role: entity.RoleUser, + Parts: []*entity.ContentPart{ + { + Type: entity.ContentTypeVideoURL, + VideoURL: &entity.VideoURL{ + URL: "https://example.com/video.mp4", + URI: "video-uri", + }, + MediaConfig: &entity.MediaConfig{ + Fps: ptr.Of(2.5), + }, + }, + }, + }, + }, { name: "assistant message with tool calls", dto: &prompt.Message{ diff --git a/backend/modules/prompt/application/debug.go b/backend/modules/prompt/application/debug.go index b84ac096d..af07f685f 100644 --- a/backend/modules/prompt/application/debug.go +++ b/backend/modules/prompt/application/debug.go @@ -285,6 +285,13 @@ func (p *PromptDebugApplicationImpl) doDebugStreaming(ctx context.Context, req * if reply == nil || reply.Item == nil { continue } + // Convert base64 files to download URLs + if reply.Item.Message != nil { + if err := p.promptService.MConvertBase64DataURLToFileURL(ctx, []*entity.Message{reply.Item.Message}, req.Prompt.GetWorkspaceID()); err != nil { + logs.CtxError(ctx, "failed to convert base64 to file URLs: %v", err) + return nil, err + } + } chunk := &debug.DebugStreamingResponse{ Delta: convertor.MessageDO2DTO(reply.Item.Message), FinishReason: ptr.Of(reply.Item.FinishReason), diff --git a/backend/modules/prompt/application/debug_test.go b/backend/modules/prompt/application/debug_test.go index 3553f28f3..2c0783e40 100644 --- a/backend/modules/prompt/application/debug_test.go +++ b/backend/modules/prompt/application/debug_test.go @@ -65,6 +65,7 @@ func TestPromptDebugApplicationImpl_DebugStreaming(t *testing.T) { mockDebugLogRepo.EXPECT().SaveDebugLog(gomock.Any(), gomock.Any()).Return(nil) mockPromptSvc := servicemocks.NewMockIPromptService(ctrl) mockPromptSvc.EXPECT().MCompleteMultiModalFileURL(gomock.Any(), gomock.Any(), nil).Return(nil) + mockPromptSvc.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptSvc.EXPECT().ExecuteStreaming(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, param service.ExecuteStreamingParam) (*entity.Reply, error) { for _, v := range mockContent { param.ResultStream <- &entity.Reply{ @@ -126,6 +127,7 @@ func TestPromptDebugApplicationImpl_DebugStreaming(t *testing.T) { mockDebugLogRepo.EXPECT().SaveDebugLog(gomock.Any(), gomock.Any()).Return(nil) mockPromptSvc := servicemocks.NewMockIPromptService(ctrl) mockPromptSvc.EXPECT().MCompleteMultiModalFileURL(gomock.Any(), gomock.Any(), nil).Return(nil) + mockPromptSvc.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptSvc.EXPECT().ExecuteStreaming(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, param service.ExecuteStreamingParam) (*entity.Reply, error) { for _, v := range mockContent { param.ResultStream <- &entity.Reply{ @@ -216,6 +218,73 @@ func TestPromptDebugApplicationImpl_DebugStreaming(t *testing.T) { }, wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode), }, + { + name: "base64 convert error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockDebugLogRepo := repomocks.NewMockIDebugLogRepo(ctrl) + mockDebugLogRepo.EXPECT().SaveDebugLog(gomock.Any(), gomock.Any()).Return(nil) + mockPromptSvc := servicemocks.NewMockIPromptService(ctrl) + mockPromptSvc.EXPECT().MCompleteMultiModalFileURL(gomock.Any(), gomock.Any(), nil).Return(nil) + mockPromptSvc.EXPECT().ExecuteStreaming(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, param service.ExecuteStreamingParam) (*entity.Reply, error) { + param.ResultStream <- &entity.Reply{ + Item: &entity.ReplyItem{ + Message: &entity.Message{ + Role: entity.RoleAssistant, + Parts: []*entity.ContentPart{ + { + Type: entity.ContentTypeImageURL, + ImageURL: &entity.ImageURL{ + URL: "", + }, + }, + }, + }, + }, + } + return &entity.Reply{ + Item: &entity.ReplyItem{ + Message: &entity.Message{ + Role: entity.RoleAssistant, + }, + }, + }, nil + }) + convertErr := errors.New("convert error") + mockPromptSvc.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(convertErr) + mockBenefitSvc := benefitmocks.NewMockIBenefitService(ctrl) + mockBenefitSvc.EXPECT().CheckPromptBenefit(gomock.Any(), gomock.Any()).Return(&benefit.CheckPromptBenefitResult{}, nil) + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + return fields{ + debugLogRepo: mockDebugLogRepo, + debugContextRepo: nil, + promptService: mockPromptSvc, + benefitService: mockBenefitSvc, + auth: mockAuth, + file: nil, + } + }, + args: args{ + ctx: session.WithCtxUser(context.Background(), mockUser), + req: &debug.DebugStreamingRequest{ + Prompt: &prompt.Prompt{ + ID: ptr.Of(int64(123456)), + WorkspaceID: ptr.Of(int64(123456)), + PromptDraft: &prompt.PromptDraft{ + Detail: &prompt.PromptDetail{ + PromptTemplate: &prompt.PromptTemplate{ + TemplateType: ptr.Of(prompt.TemplateTypeNormal), + }, + ModelConfig: &prompt.ModelConfig{}, + }, + }, + }, + SingleStepDebug: ptr.Of(true), + }, + stream: localstream.NewInMemStream(context.Background(), make(chan *debug.DebugStreamingResponse), make(chan error)), + }, + wantErr: errorx.WrapByCode(errors.New("convert error"), prompterr.CommonInternalErrorCode), + }, { name: "goroutine panic", fieldsGetter: func(ctrl *gomock.Controller) fields { @@ -223,6 +292,7 @@ func TestPromptDebugApplicationImpl_DebugStreaming(t *testing.T) { mockDebugLogRepo.EXPECT().SaveDebugLog(gomock.Any(), gomock.Any()).Return(nil) mockPromptSvc := servicemocks.NewMockIPromptService(ctrl) mockPromptSvc.EXPECT().MCompleteMultiModalFileURL(gomock.Any(), gomock.Any(), nil).Return(nil) + mockPromptSvc.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptSvc.EXPECT().ExecuteStreaming(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, param service.ExecuteStreamingParam) (*entity.Reply, error) { panic("mock panic") }) diff --git a/backend/modules/prompt/application/execute.go b/backend/modules/prompt/application/execute.go index df00e7e43..10e3c981c 100644 --- a/backend/modules/prompt/application/execute.go +++ b/backend/modules/prompt/application/execute.go @@ -79,6 +79,12 @@ func (p *PromptExecuteApplicationImpl) ExecuteInternal(ctx context.Context, req return r, err } if reply != nil && reply.Item != nil { + // Convert base64 files to download URLs + if reply.Item.Message != nil { + if err := p.promptService.MConvertBase64DataURLToFileURL(ctx, []*entity.Message{reply.Item.Message}, req.GetWorkspaceID()); err != nil { + return r, err + } + } r.Message = convertor.MessageDO2DTO(reply.Item.Message) r.FinishReason = ptr.Of(reply.Item.FinishReason) r.Usage = convertor.TokenUsageDO2DTO(reply.Item.TokenUsage) diff --git a/backend/modules/prompt/application/execute_test.go b/backend/modules/prompt/application/execute_test.go index aff21fe66..db9a57628 100755 --- a/backend/modules/prompt/application/execute_test.go +++ b/backend/modules/prompt/application/execute_test.go @@ -110,6 +110,7 @@ func TestPromptExecuteApplicationImpl_ExecuteInternal(t *testing.T) { mockPromptService := servicemocks.NewMockIPromptService(ctrl) mockPromptService.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(mockReply, nil) + mockPromptService.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) return fields{ promptService: mockPromptService, @@ -139,6 +140,34 @@ func TestPromptExecuteApplicationImpl_ExecuteInternal(t *testing.T) { }, wantErr: nil, }, + { + name: "base64 convert error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockManageRepo := repomocks.NewMockIManageRepo(ctrl) + mockManageRepo.EXPECT().GetPrompt(gomock.Any(), gomock.Any()).Return(createMockPrompt(), nil) + + mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(mockReply, nil) + mockPromptService.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("convert error")) + + return fields{ + promptService: mockPromptService, + manageRepo: mockManageRepo, + } + }, + args: args{ + ctx: context.Background(), + req: &execute.ExecuteInternalRequest{ + PromptID: ptr.Of(int64(123)), + WorkspaceID: ptr.Of(int64(123456)), + Version: ptr.Of("1.0.0"), + Messages: []*prompt.Message{}, + VariableVals: []*prompt.VariableVal{}, + }, + }, + wantR: execute.NewExecuteInternalResponse(), + wantErr: errors.New("convert error"), + }, // 注释掉这个测试用例,因为getPromptByID方法在处理错误时会有空指针问题 // { // name: "get prompt error", @@ -196,6 +225,7 @@ func TestPromptExecuteApplicationImpl_ExecuteInternal(t *testing.T) { mockPromptService := servicemocks.NewMockIPromptService(ctrl) mockPromptService.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(mockReply, nil) + mockPromptService.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) return fields{ promptService: mockPromptService, @@ -239,6 +269,7 @@ func TestPromptExecuteApplicationImpl_ExecuteInternal(t *testing.T) { mockPromptService := servicemocks.NewMockIPromptService(ctrl) mockPromptService.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(mockReply, nil) + mockPromptService.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) return fields{ promptService: mockPromptService, diff --git a/backend/modules/prompt/application/openapi.go b/backend/modules/prompt/application/openapi.go index 774349b84..e6305bbeb 100644 --- a/backend/modules/prompt/application/openapi.go +++ b/backend/modules/prompt/application/openapi.go @@ -404,6 +404,14 @@ func (p *PromptOpenAPIApplicationImpl) doExecute(ctx context.Context, req *opena if err != nil { return promptDO, nil, err } + + // Convert base64 files to download URLs + if reply != nil && reply.Item != nil && reply.Item.Message != nil { + if err := p.promptService.MConvertBase64DataURLToFileURL(ctx, []*entity.Message{reply.Item.Message}, req.GetWorkspaceID()); err != nil { + return promptDO, nil, err + } + } + return promptDO, reply, nil } @@ -524,6 +532,13 @@ func (p *PromptOpenAPIApplicationImpl) doExecuteStreaming(ctx context.Context, r if reply == nil || reply.Item == nil { continue } + // Convert base64 files to download URLs + if reply.Item.Message != nil { + if err := p.promptService.MConvertBase64DataURLToFileURL(ctx, []*entity.Message{reply.Item.Message}, req.GetWorkspaceID()); err != nil { + logs.CtxError(ctx, "failed to convert base64 to file URLs: %v", err) + return promptDO, nil, err + } + } chunk := &openapi.ExecuteStreamingResponse{ Data: &openapi.ExecuteStreamingData{ Message: convertor.OpenAPIMessageDO2DTO(reply.Item.Message), diff --git a/backend/modules/prompt/application/openapi_test.go b/backend/modules/prompt/application/openapi_test.go index 1fa1e40ac..1fdedca53 100644 --- a/backend/modules/prompt/application/openapi_test.go +++ b/backend/modules/prompt/application/openapi_test.go @@ -2361,6 +2361,7 @@ func TestPromptOpenAPIApplicationImpl_doExecute(t *testing.T) { }, } mockPromptService.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(expectedReply, nil) + mockPromptService.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) return fields{ promptService: mockPromptService, @@ -2407,6 +2408,83 @@ func TestPromptOpenAPIApplicationImpl_doExecute(t *testing.T) { }, wantErr: nil, }, + { + name: "error: base64 convert failed", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockConfig := confmocks.NewMockIConfigProvider(ctrl) + mockConfig.EXPECT().GetPTaaSMaxQPSByPromptKey(gomock.Any(), int64(123456), "test_prompt").Return(100, nil) + + mockRateLimiter := limitermocks.NewMockIRateLimiter(ctrl) + mockRateLimiter.EXPECT().AllowN(gomock.Any(), "ptaas:qps:space_id:123456:prompt_key:test_prompt", 1, gomock.Any()).Return(&limiter.Result{ + Allowed: true, + }, nil) + + mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), int64(123456), []string{"test_prompt"}).Return(map[string]int64{ + "test_prompt": 123, + }, nil) + mockPromptService.EXPECT().MParseCommitVersion(gomock.Any(), int64(123456), gomock.Any()).Return(map[service.PromptQueryParam]string{ + {PromptID: 123, PromptKey: "test_prompt", Version: "1.0.0"}: "1.0.0", + }, nil) + + mockManageRepo := repomocks.NewMockIManageRepo(ctrl) + expectedPrompt := &entity.Prompt{ + ID: 123, + SpaceID: 123456, + PromptKey: "test_prompt", + } + mockManageRepo.EXPECT().MGetPrompt(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[repo.GetPromptParam]*entity.Prompt{ + {PromptID: 123, WithCommit: true, CommitVersion: "1.0.0"}: expectedPrompt, + }, nil) + + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockAuth.EXPECT().MCheckPromptPermissionForOpenAPI(gomock.Any(), int64(123456), []int64{123}, consts.ActionLoopPromptExecute).Return(nil) + + expectedReply := &entity.Reply{ + DebugID: 456, + Item: &entity.ReplyItem{ + Message: &entity.Message{ + Role: entity.RoleAssistant, + Parts: []*entity.ContentPart{ + { + Type: entity.ContentTypeImageURL, + ImageURL: &entity.ImageURL{ + URL: "", + }, + }, + }, + }, + }, + } + mockPromptService.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(expectedReply, nil) + mockPromptService.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("convert error")) + + return fields{ + promptService: mockPromptService, + promptManageRepo: mockManageRepo, + config: mockConfig, + auth: mockAuth, + rateLimiter: mockRateLimiter, + } + }, + args: args{ + ctx: context.Background(), + req: &openapi.ExecuteRequest{ + WorkspaceID: ptr.Of(int64(123456)), + PromptIdentifier: &openapi.PromptQuery{ + PromptKey: ptr.Of("test_prompt"), + Version: ptr.Of("1.0.0"), + }, + }, + }, + wantPromptDO: &entity.Prompt{ + ID: 123, + SpaceID: 123456, + PromptKey: "test_prompt", + }, + wantReply: nil, + wantErr: errors.New("convert error"), + }, { name: "error: rate limit exceeded", fieldsGetter: func(ctrl *gomock.Controller) fields { @@ -2732,6 +2810,7 @@ func TestPromptOpenAPIApplicationImpl_Execute(t *testing.T) { }, } mockPromptService.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(expectedReply, nil) + mockPromptService.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) mockCollector := collectormocks.NewMockICollectorProvider(ctrl) mockCollector.EXPECT().CollectPTaaSEvent(gomock.Any(), gomock.Any()).Return() @@ -2776,6 +2855,88 @@ func TestPromptOpenAPIApplicationImpl_Execute(t *testing.T) { }, wantErr: nil, }, + { + name: "error: base64 convert failed", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockConfig := confmocks.NewMockIConfigProvider(ctrl) + mockConfig.EXPECT().GetPTaaSMaxQPSByPromptKey(gomock.Any(), int64(123456), "test_prompt").Return(100, nil) + + mockRateLimiter := limitermocks.NewMockIRateLimiter(ctrl) + mockRateLimiter.EXPECT().AllowN(gomock.Any(), "ptaas:qps:space_id:123456:prompt_key:test_prompt", 1, gomock.Any()).Return(&limiter.Result{ + Allowed: true, + }, nil) + + mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), int64(123456), []string{"test_prompt"}).Return(map[string]int64{ + "test_prompt": 123, + }, nil) + mockPromptService.EXPECT().MParseCommitVersion(gomock.Any(), int64(123456), gomock.Any()).Return(map[service.PromptQueryParam]string{ + {PromptID: 123, PromptKey: "test_prompt", Version: "1.0.0"}: "1.0.0", + }, nil) + + mockManageRepo := repomocks.NewMockIManageRepo(ctrl) + expectedPrompt := &entity.Prompt{ + ID: 123, + SpaceID: 123456, + PromptKey: "test_prompt", + } + mockManageRepo.EXPECT().MGetPrompt(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[repo.GetPromptParam]*entity.Prompt{ + {PromptID: 123, WithCommit: true, CommitVersion: "1.0.0"}: expectedPrompt, + }, nil) + + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockAuth.EXPECT().MCheckPromptPermissionForOpenAPI(gomock.Any(), int64(123456), []int64{123}, consts.ActionLoopPromptExecute).Return(nil) + + expectedReply := &entity.Reply{ + DebugID: 456, + Item: &entity.ReplyItem{ + Message: &entity.Message{ + Role: entity.RoleAssistant, + Parts: []*entity.ContentPart{ + { + Type: entity.ContentTypeImageURL, + ImageURL: &entity.ImageURL{ + URL: "", + }, + }, + }, + }, + }, + } + mockPromptService.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(expectedReply, nil) + mockPromptService.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("convert error")) + + mockCollector := collectormocks.NewMockICollectorProvider(ctrl) + mockCollector.EXPECT().CollectPTaaSEvent(gomock.Any(), gomock.Any()).Return() + + return fields{ + promptService: mockPromptService, + promptManageRepo: mockManageRepo, + config: mockConfig, + auth: mockAuth, + rateLimiter: mockRateLimiter, + collector: mockCollector, + } + }, + args: args{ + ctx: context.Background(), + req: &openapi.ExecuteRequest{ + WorkspaceID: ptr.Of(int64(123456)), + PromptIdentifier: &openapi.PromptQuery{ + PromptKey: ptr.Of("test_prompt"), + Version: ptr.Of("1.0.0"), + }, + Messages: []*openapi.Message{ + { + Role: ptr.Of(prompt.RoleUser), + Content: ptr.Of("Hello"), + }, + }, + }, + }, + wantR: openapi.NewExecuteResponse(), + wantErr: errors.New("convert error"), + }, { name: "error: invalid request", fieldsGetter: func(ctrl *gomock.Controller) fields { @@ -3034,11 +3195,12 @@ func TestPromptOpenAPIApplicationImpl_ExecuteStreaming(t *testing.T) { } tests := []struct { - name string - fieldsGetter func(ctrl *gomock.Controller) fields - argsGetter func(ctrl *gomock.Controller) args - wantErr error - validateFunc func(t *testing.T, stream *mockExecuteStreamingServer) + name string + fieldsGetter func(ctrl *gomock.Controller) fields + argsGetter func(ctrl *gomock.Controller) args + wantErr error + validateFunc func(t *testing.T, stream *mockExecuteStreamingServer) + setupConvertMock func(mockSvc *servicemocks.MockIPromptService) }{ { name: "success: normal streaming execution", @@ -3197,6 +3359,100 @@ func TestPromptOpenAPIApplicationImpl_ExecuteStreaming(t *testing.T) { assert.Equal(t, "stop", calls[1].Data.GetFinishReason()) }, }, + { + name: "error: base64 convert failed", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockConfig := confmocks.NewMockIConfigProvider(ctrl) + mockConfig.EXPECT().GetPTaaSMaxQPSByPromptKey(gomock.Any(), int64(123456), "test_prompt").Return(100, nil) + + mockRateLimiter := limitermocks.NewMockIRateLimiter(ctrl) + mockRateLimiter.EXPECT().AllowN(gomock.Any(), "ptaas:qps:space_id:123456:prompt_key:test_prompt", 1, gomock.Any()).Return(&limiter.Result{ + Allowed: true, + }, nil) + + mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), int64(123456), []string{"test_prompt"}).Return(map[string]int64{ + "test_prompt": 123, + }, nil) + mockPromptService.EXPECT().MParseCommitVersion(gomock.Any(), int64(123456), gomock.Any()).Return(map[service.PromptQueryParam]string{ + {PromptID: 123, PromptKey: "test_prompt", Version: "1.0.0"}: "1.0.0", + }, nil) + + mockManageRepo := repomocks.NewMockIManageRepo(ctrl) + expectedPrompt := &entity.Prompt{ + ID: 123, + SpaceID: 123456, + PromptKey: "test_prompt", + } + mockManageRepo.EXPECT().MGetPrompt(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[repo.GetPromptParam]*entity.Prompt{ + {PromptID: 123, WithCommit: true, CommitVersion: "1.0.0"}: expectedPrompt, + }, nil) + + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockAuth.EXPECT().MCheckPromptPermissionForOpenAPI(gomock.Any(), int64(123456), []int64{123}, consts.ActionLoopPromptExecute).Return(nil) + + mockPromptService.EXPECT().ExecuteStreaming(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, param service.ExecuteStreamingParam) (*entity.Reply, error) { + param.ResultStream <- &entity.Reply{ + Item: &entity.ReplyItem{ + Message: &entity.Message{ + Role: entity.RoleAssistant, + Parts: []*entity.ContentPart{ + { + Type: entity.ContentTypeImageURL, + ImageURL: &entity.ImageURL{ + URL: "", + }, + }, + }, + }, + }, + } + return &entity.Reply{ + Item: &entity.ReplyItem{ + Message: &entity.Message{ + Role: entity.RoleAssistant, + }, + }, + }, nil + }) + + mockCollector := collectormocks.NewMockICollectorProvider(ctrl) + mockCollector.EXPECT().CollectPTaaSEvent(gomock.Any(), gomock.Any()).Return() + + return fields{ + promptService: mockPromptService, + promptManageRepo: mockManageRepo, + config: mockConfig, + auth: mockAuth, + rateLimiter: mockRateLimiter, + collector: mockCollector, + } + }, + argsGetter: func(ctrl *gomock.Controller) args { + ctx := context.Background() + stream := newMockExecuteStreamingServer(ctx) + return args{ + ctx: ctx, + req: &openapi.ExecuteRequest{ + WorkspaceID: ptr.Of(int64(123456)), + PromptIdentifier: &openapi.PromptQuery{ + PromptKey: ptr.Of("test_prompt"), + Version: ptr.Of("1.0.0"), + }, + }, + stream: stream, + } + }, + wantErr: errors.New("convert error"), + validateFunc: func(t *testing.T, stream *mockExecuteStreamingServer) { + calls := stream.GetSendCalls() + assert.Len(t, calls, 0) + }, + setupConvertMock: func(mockSvc *servicemocks.MockIPromptService) { + mockSvc.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("convert error")) + }, + }, { name: "error: workspace_id is empty", fieldsGetter: func(ctrl *gomock.Controller) fields { @@ -4057,6 +4313,13 @@ func TestPromptOpenAPIApplicationImpl_ExecuteStreaming(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() ttFields := tt.fieldsGetter(ctrl) + if mockSvc, ok := ttFields.promptService.(*servicemocks.MockIPromptService); ok { + if tt.setupConvertMock != nil { + tt.setupConvertMock(mockSvc) + } else { + mockSvc.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + } + } ttArgs := tt.argsGetter(ctrl) p := &PromptOpenAPIApplicationImpl{ promptService: ttFields.promptService, diff --git a/backend/modules/prompt/domain/component/rpc/file.go b/backend/modules/prompt/domain/component/rpc/file.go index 43d018bcd..f001956fe 100644 --- a/backend/modules/prompt/domain/component/rpc/file.go +++ b/backend/modules/prompt/domain/component/rpc/file.go @@ -8,4 +8,5 @@ import "context" //go:generate mockgen -destination=mocks/file_provider.go -package=mocks . IFileProvider type IFileProvider interface { MGetFileURL(ctx context.Context, keys []string) (urls map[string]string, err error) + UploadFileForServer(ctx context.Context, mimeType string, body []byte, workspaceID int64) (key string, err error) } diff --git a/backend/modules/prompt/domain/component/rpc/mocks/file_provider.go b/backend/modules/prompt/domain/component/rpc/mocks/file_provider.go index 33717aa60..6fa08ea13 100644 --- a/backend/modules/prompt/domain/component/rpc/mocks/file_provider.go +++ b/backend/modules/prompt/domain/component/rpc/mocks/file_provider.go @@ -54,3 +54,18 @@ func (mr *MockIFileProviderMockRecorder) MGetFileURL(ctx, keys any) *gomock.Call mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetFileURL", reflect.TypeOf((*MockIFileProvider)(nil).MGetFileURL), ctx, keys) } + +// UploadFileForServer mocks base method. +func (m *MockIFileProvider) UploadFileForServer(ctx context.Context, mimeType string, body []byte, workspaceID int64) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UploadFileForServer", ctx, mimeType, body, workspaceID) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UploadFileForServer indicates an expected call of UploadFileForServer. +func (mr *MockIFileProviderMockRecorder) UploadFileForServer(ctx, mimeType, body, workspaceID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UploadFileForServer", reflect.TypeOf((*MockIFileProvider)(nil).UploadFileForServer), ctx, mimeType, body, workspaceID) +} diff --git a/backend/modules/prompt/domain/entity/prompt_detail.go b/backend/modules/prompt/domain/entity/prompt_detail.go index e17ab4c5e..fb261f1a3 100644 --- a/backend/modules/prompt/domain/entity/prompt_detail.go +++ b/backend/modules/prompt/domain/entity/prompt_detail.go @@ -69,10 +69,12 @@ const ( ) type ContentPart struct { - Type ContentType `json:"type"` - Text *string `json:"text,omitempty"` - ImageURL *ImageURL `json:"image_url,omitempty"` - Base64Data *string `json:"base64_data,omitempty"` + Type ContentType `json:"type"` + Text *string `json:"text,omitempty"` + ImageURL *ImageURL `json:"image_url,omitempty"` + VideoURL *VideoURL `json:"video_url,omitempty"` + Base64Data *string `json:"base64_data,omitempty"` + MediaConfig *MediaConfig `json:"media_config,omitempty"` } type ContentType string @@ -80,6 +82,7 @@ type ContentType string const ( ContentTypeText ContentType = "text" ContentTypeImageURL ContentType = "image_url" + ContentTypeVideoURL ContentType = "video_url" ContentTypeBase64Data ContentType = "base64_data" ContentTypeMultiPartVariable ContentType = "multi_part_variable" ) @@ -89,6 +92,15 @@ type ImageURL struct { URL string `json:"url"` } +type VideoURL struct { + URI string `json:"uri"` + URL string `json:"url"` +} + +type MediaConfig struct { + Fps *float64 `json:"fps,omitempty"` +} + type VariableDef struct { Key string `json:"key"` Desc string `json:"desc"` @@ -263,7 +275,7 @@ func formatMultiPart(parts []*ContentPart, defMap map[string]*VariableDef, valMa if pt == nil { continue } - if ptr.From(pt.Text) != "" || pt.ImageURL != nil { + if ptr.From(pt.Text) != "" || pt.ImageURL != nil || pt.VideoURL != nil || ptr.From(pt.Base64Data) != "" { filtered = append(filtered, pt) } } diff --git a/backend/modules/prompt/domain/service/manage.go b/backend/modules/prompt/domain/service/manage.go index a7faa259c..b14f3db12 100644 --- a/backend/modules/prompt/domain/service/manage.go +++ b/backend/modules/prompt/domain/service/manage.go @@ -5,12 +5,15 @@ package service import ( "context" + "encoding/base64" "fmt" + "strings" "github.com/coze-dev/coze-loop/backend/modules/prompt/domain/entity" "github.com/coze-dev/coze-loop/backend/modules/prompt/domain/repo" prompterr "github.com/coze-dev/coze-loop/backend/modules/prompt/pkg/errno" "github.com/coze-dev/coze-loop/backend/pkg/errorx" + "github.com/coze-dev/coze-loop/backend/pkg/logs" ) func (p *PromptServiceImpl) MGetPromptIDs(ctx context.Context, spaceID int64, promptKeys []string) (PromptKeyIDMap map[string]int64, err error) { @@ -42,10 +45,15 @@ func (p *PromptServiceImpl) MCompleteMultiModalFileURL(ctx context.Context, mess continue } for _, part := range message.Parts { - if part == nil || part.ImageURL == nil { + if part == nil { continue } - fileKeys = append(fileKeys, part.ImageURL.URI) + if part.ImageURL != nil && part.ImageURL.URI != "" { + fileKeys = append(fileKeys, part.ImageURL.URI) + } + if part.VideoURL != nil && part.VideoURL.URI != "" { + fileKeys = append(fileKeys, part.VideoURL.URI) + } } } for _, val := range variableVals { @@ -53,10 +61,15 @@ func (p *PromptServiceImpl) MCompleteMultiModalFileURL(ctx context.Context, mess continue } for _, part := range val.MultiPartValues { - if part == nil || part.ImageURL == nil || part.ImageURL.URI == "" { + if part == nil { continue } - fileKeys = append(fileKeys, part.ImageURL.URI) + if part.ImageURL != nil && part.ImageURL.URI != "" { + fileKeys = append(fileKeys, part.ImageURL.URI) + } + if part.VideoURL != nil && part.VideoURL.URI != "" { + fileKeys = append(fileKeys, part.VideoURL.URI) + } } } if len(fileKeys) == 0 { @@ -72,10 +85,15 @@ func (p *PromptServiceImpl) MCompleteMultiModalFileURL(ctx context.Context, mess continue } for _, part := range message.Parts { - if part == nil || part.ImageURL == nil { + if part == nil { continue } - part.ImageURL.URL = urlMap[part.ImageURL.URI] + if part.ImageURL != nil { + part.ImageURL.URL = urlMap[part.ImageURL.URI] + } + if part.VideoURL != nil { + part.VideoURL.URL = urlMap[part.VideoURL.URI] + } } } for _, val := range variableVals { @@ -83,12 +101,117 @@ func (p *PromptServiceImpl) MCompleteMultiModalFileURL(ctx context.Context, mess continue } for _, part := range val.MultiPartValues { - if part == nil || part.ImageURL == nil || part.ImageURL.URI == "" { + if part == nil { + continue + } + if part.ImageURL != nil && part.ImageURL.URI != "" { + part.ImageURL.URL = urlMap[part.ImageURL.URI] + } + if part.VideoURL != nil && part.VideoURL.URI != "" { + part.VideoURL.URL = urlMap[part.VideoURL.URI] + } + } + } + return nil +} + +// MConvertBase64DataURLToFileURI converts base64 files to file URIs by uploading them +func (p *PromptServiceImpl) MConvertBase64DataURLToFileURI(ctx context.Context, messages []*entity.Message, workspaceID int64) error { + for _, message := range messages { + if message == nil || len(message.Parts) == 0 { + continue + } + + for _, part := range message.Parts { + if part == nil || part.ImageURL == nil { + continue + } + // Check if the URL is a base64 data URL + url := part.ImageURL.URL + if url == "" || !strings.HasPrefix(url, "data:") { + continue + } + + // Parse the data URL to extract mime type and base64 data + // Format: data:;base64, + parts := strings.SplitN(url, ",", 2) + if len(parts) != 2 { + logs.CtxWarn(ctx, "invalid data URL format: %s", url) + continue + } + + // Extract mime type from the first part + headerParts := strings.SplitN(parts[0], ";", 2) + if len(headerParts) != 2 { + logs.CtxWarn(ctx, "invalid data URL header: %s", parts[0]) + continue + } + mimeType := strings.TrimPrefix(headerParts[0], "data:") + if mimeType == "" { + logs.CtxWarn(ctx, "missing mime type in data URL") continue } - part.ImageURL.URL = urlMap[part.ImageURL.URI] + + // Decode base64 data + decodedData, err := base64.StdEncoding.DecodeString(parts[1]) + if err != nil { + logs.CtxError(ctx, "failed to decode base64 file: %v", err) + continue + } + + // Upload the file + fileKey, err := p.file.UploadFileForServer(ctx, mimeType, decodedData, workspaceID) + if err != nil { + logs.CtxError(ctx, "failed to upload file: %v", err) + return err + } + + // Replace the base64 URL with the file URI + part.ImageURL.URI = fileKey + part.ImageURL.URL = "" // Clear the URL, it will be filled later by MGetFileURL if needed + } + } + + return nil +} + +// messageContainsBase64File checks if messages contain base64 files +func (p *PromptServiceImpl) messageContainsBase64File(messages []*entity.Message) bool { + for _, message := range messages { + if message == nil || len(message.Parts) == 0 { + continue + } + for _, part := range message.Parts { + if part == nil || part.ImageURL == nil { + continue + } + // Check if the URL is a base64 data URL (format: data:;base64,) + url := part.ImageURL.URL + if url != "" && strings.HasPrefix(url, "data:") { + return true + } } } + return false +} + +// MConvertBase64DataURLToFileURL converts base64 files to download URLs +func (p *PromptServiceImpl) MConvertBase64DataURLToFileURL(ctx context.Context, messages []*entity.Message, workspaceID int64) error { + // Fast path: skip processing if no base64 files present + if !p.messageContainsBase64File(messages) { + return nil + } + + // Convert base64 files to file URIs + if err := p.MConvertBase64DataURLToFileURI(ctx, messages, workspaceID); err != nil { + return err + } + + // Convert file URIs to download URLs + if err := p.MCompleteMultiModalFileURL(ctx, messages, nil); err != nil { + return err + } + return nil } diff --git a/backend/modules/prompt/domain/service/manage_test.go b/backend/modules/prompt/domain/service/manage_test.go index 8debded80..8ae0145c1 100755 --- a/backend/modules/prompt/domain/service/manage_test.go +++ b/backend/modules/prompt/domain/service/manage_test.go @@ -5,6 +5,7 @@ package service import ( "context" + "encoding/base64" "testing" "github.com/stretchr/testify/assert" @@ -44,6 +45,7 @@ func TestPromptServiceImpl_MCompleteMultiModalFileURL(t *testing.T) { "test-image-1": "https://example.com/image1.jpg", "test-image-2": "https://example.com/image2.jpg", "test-image-3": "https://example.com/image3.jpg", + "test-video-1": "https://example.com/video1.mp4", } tests := []struct { name string @@ -158,6 +160,46 @@ func TestPromptServiceImpl_MCompleteMultiModalFileURL(t *testing.T) { }, wantErr: nil, }, + { + name: "video urls filled for messages and variable values", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockFile := mocks.NewMockIFileProvider(ctrl) + mockFile.EXPECT().MGetFileURL(gomock.Any(), gomock.Any()).Return(uri2URLMap, nil) + return fields{ + file: mockFile, + } + }, + args: args{ + ctx: context.Background(), + messages: []*entity.Message{ + { + Role: entity.RoleUser, + Parts: []*entity.ContentPart{ + { + Type: entity.ContentTypeVideoURL, + VideoURL: &entity.VideoURL{ + URI: "test-video-1", + }, + }, + }, + }, + }, + variableVals: []*entity.VariableVal{ + { + Key: "video-multi", + MultiPartValues: []*entity.ContentPart{ + { + Type: entity.ContentTypeVideoURL, + VideoURL: &entity.VideoURL{ + URI: "test-video-1", + }, + }, + }, + }, + }, + }, + wantErr: nil, + }, { name: "variableVals with nil MultiPartValues", fieldsGetter: func(ctrl *gomock.Controller) fields { @@ -486,11 +528,17 @@ func TestPromptServiceImpl_MCompleteMultiModalFileURL(t *testing.T) { continue } for _, part := range message.Parts { - if part == nil || part.ImageURL == nil { + if part == nil { continue } - assert.Equal(t, uri2URLMap[part.ImageURL.URI], part.ImageURL.URL) - part.ImageURL.URL = "" + if part.ImageURL != nil && part.ImageURL.URI != "" { + assert.Equal(t, uri2URLMap[part.ImageURL.URI], part.ImageURL.URL) + part.ImageURL.URL = "" + } + if part.VideoURL != nil && part.VideoURL.URI != "" { + assert.Equal(t, uri2URLMap[part.VideoURL.URI], part.VideoURL.URL) + part.VideoURL.URL = "" + } } } // 验证variableVals中的URL是否正确填充 @@ -499,11 +547,17 @@ func TestPromptServiceImpl_MCompleteMultiModalFileURL(t *testing.T) { continue } for _, part := range val.MultiPartValues { - if part == nil || part.ImageURL == nil || part.ImageURL.URI == "" { + if part == nil { continue } - assert.Equal(t, uri2URLMap[part.ImageURL.URI], part.ImageURL.URL) - part.ImageURL.URL = "" + if part.ImageURL != nil && part.ImageURL.URI != "" { + assert.Equal(t, uri2URLMap[part.ImageURL.URI], part.ImageURL.URL) + part.ImageURL.URL = "" + } + if part.VideoURL != nil && part.VideoURL.URI != "" { + assert.Equal(t, uri2URLMap[part.VideoURL.URI], part.VideoURL.URL) + part.VideoURL.URL = "" + } } } assert.Equal(t, originMessages, tt.args.messages) @@ -1211,3 +1265,375 @@ func TestPromptServiceImpl_MParseCommitVersion(t *testing.T) { }) } } + +func TestPromptServiceImpl_messageContainsBase64File(t *testing.T) { + encoded := base64.StdEncoding.EncodeToString([]byte("hello")) + dataURL := "data:image/png;base64," + encoded + + tests := []struct { + name string + messages []*entity.Message + want bool + }{ + { + name: "nil messages returns false", + messages: nil, + want: false, + }, + { + name: "message without parts returns false", + messages: []*entity.Message{ + {Role: entity.RoleUser}, + }, + want: false, + }, + { + name: "message without data url returns false", + messages: []*entity.Message{ + { + Role: entity.RoleUser, + Parts: []*entity.ContentPart{ + { + Type: entity.ContentTypeImageURL, + ImageURL: &entity.ImageURL{ + URL: "https://example.com/image.png", + }, + }, + }, + }, + }, + want: false, + }, + { + name: "contains base64 returns true", + messages: []*entity.Message{ + { + Role: entity.RoleUser, + Parts: []*entity.ContentPart{ + { + Type: entity.ContentTypeImageURL, + ImageURL: &entity.ImageURL{ + URL: dataURL, + }, + }, + }, + }, + }, + want: true, + }, + } + + p := &PromptServiceImpl{} + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, p.messageContainsBase64File(tt.messages)) + }) + } +} + +func TestPromptServiceImpl_MConvertBase64ToFileURI(t *testing.T) { + type args struct { + ctx context.Context + messages []*entity.Message + workspaceID int64 + } + + decoded := []byte("hello world") + dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(decoded) + + tests := []struct { + name string + args args + setupMock func(mock *mocks.MockIFileProvider) + wantErr error + validateFunc func(t *testing.T, messages []*entity.Message) + }{ + { + name: "successfully converts base64 image", + args: args{ + ctx: context.Background(), + messages: []*entity.Message{ + { + Role: entity.RoleUser, + Parts: []*entity.ContentPart{ + { + Type: entity.ContentTypeImageURL, + ImageURL: &entity.ImageURL{ + URL: dataURL, + }, + }, + }, + }, + }, + workspaceID: 101, + }, + setupMock: func(mock *mocks.MockIFileProvider) { + mock.EXPECT(). + UploadFileForServer(gomock.Any(), "image/png", gomock.Eq(decoded), int64(101)). + Return("workspace/101/file.png", nil) + }, + wantErr: nil, + validateFunc: func(t *testing.T, messages []*entity.Message) { + part := messages[0].Parts[0] + assert.Equal(t, "workspace/101/file.png", part.ImageURL.URI) + assert.Equal(t, "", part.ImageURL.URL) + }, + }, + { + name: "invalid data url skipped without error", + args: args{ + ctx: context.Background(), + messages: []*entity.Message{ + { + Role: entity.RoleUser, + Parts: []*entity.ContentPart{ + { + Type: entity.ContentTypeImageURL, + ImageURL: &entity.ImageURL{ + URL: "data:image/png;base64", + }, + }, + }, + }, + }, + workspaceID: 1, + }, + setupMock: func(mock *mocks.MockIFileProvider) { + mock.EXPECT().UploadFileForServer(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + }, + wantErr: nil, + validateFunc: func(t *testing.T, messages []*entity.Message) { + part := messages[0].Parts[0] + assert.Equal(t, "", part.ImageURL.URI) + assert.Equal(t, "data:image/png;base64", part.ImageURL.URL) + }, + }, + { + name: "upload error returns error", + args: args{ + ctx: context.Background(), + messages: []*entity.Message{ + { + Role: entity.RoleUser, + Parts: []*entity.ContentPart{ + { + Type: entity.ContentTypeImageURL, + ImageURL: &entity.ImageURL{ + URL: dataURL, + }, + }, + }, + }, + }, + workspaceID: 7, + }, + setupMock: func(mock *mocks.MockIFileProvider) { + mock.EXPECT(). + UploadFileForServer(gomock.Any(), "image/png", gomock.Eq(decoded), int64(7)). + Return("", assert.AnError) + }, + wantErr: assert.AnError, + }, + { + name: "message without parts returns nil", + args: args{ + ctx: context.Background(), + messages: []*entity.Message{ + {Role: entity.RoleUser}, + }, + workspaceID: 5, + }, + setupMock: func(mock *mocks.MockIFileProvider) { + mock.EXPECT().UploadFileForServer(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + }, + wantErr: nil, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockFile := mocks.NewMockIFileProvider(ctrl) + if tt.setupMock != nil { + tt.setupMock(mockFile) + } + p := &PromptServiceImpl{ + file: mockFile, + } + + err := p.MConvertBase64DataURLToFileURI(tt.args.ctx, tt.args.messages, tt.args.workspaceID) + unittest.AssertErrorEqual(t, tt.wantErr, err) + + if tt.validateFunc != nil { + tt.validateFunc(t, tt.args.messages) + } + }) + } +} + +func TestPromptServiceImpl_MConvertBase64ToFileURL(t *testing.T) { + type args struct { + ctx context.Context + messages []*entity.Message + workspaceID int64 + } + + decoded := []byte("image-bytes") + dataURL := "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(decoded) + + tests := []struct { + name string + args args + setupMock func(mock *mocks.MockIFileProvider) + wantErr error + validate func(t *testing.T, messages []*entity.Message) + }{ + { + name: "returns quickly when no base64 data", + args: args{ + ctx: context.Background(), + messages: []*entity.Message{ + { + Role: entity.RoleUser, + Parts: []*entity.ContentPart{ + { + Type: entity.ContentTypeImageURL, + ImageURL: &entity.ImageURL{URL: "https://example.com/image.png"}, + }, + }, + }, + }, + workspaceID: 1, + }, + setupMock: func(mock *mocks.MockIFileProvider) { + mock.EXPECT().UploadFileForServer(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + mock.EXPECT().MGetFileURL(gomock.Any(), gomock.Any()).Times(0) + }, + wantErr: nil, + }, + { + name: "successfully converts base64 to downloadable url", + args: args{ + ctx: context.Background(), + messages: []*entity.Message{ + { + Role: entity.RoleUser, + Parts: []*entity.ContentPart{ + { + Type: entity.ContentTypeImageURL, + ImageURL: &entity.ImageURL{ + URL: dataURL, + }, + }, + }, + }, + }, + workspaceID: 200, + }, + setupMock: func(mock *mocks.MockIFileProvider) { + gomock.InOrder( + mock.EXPECT(). + UploadFileForServer(gomock.Any(), "image/jpeg", gomock.Eq(decoded), int64(200)). + Return("workspace/200/file.jpg", nil), + mock.EXPECT(). + MGetFileURL(gomock.Any(), gomock.Eq([]string{"workspace/200/file.jpg"})). + Return(map[string]string{"workspace/200/file.jpg": "https://example.com/file.jpg"}, nil), + ) + }, + wantErr: nil, + validate: func(t *testing.T, messages []*entity.Message) { + part := messages[0].Parts[0] + assert.Equal(t, "workspace/200/file.jpg", part.ImageURL.URI) + assert.Equal(t, "https://example.com/file.jpg", part.ImageURL.URL) + }, + }, + { + name: "upload error bubbles up", + args: args{ + ctx: context.Background(), + messages: []*entity.Message{ + { + Role: entity.RoleUser, + Parts: []*entity.ContentPart{ + { + Type: entity.ContentTypeImageURL, + ImageURL: &entity.ImageURL{ + URL: dataURL, + }, + }, + }, + }, + }, + workspaceID: 300, + }, + setupMock: func(mock *mocks.MockIFileProvider) { + mock.EXPECT(). + UploadFileForServer(gomock.Any(), "image/jpeg", gomock.Eq(decoded), int64(300)). + Return("", assert.AnError) + }, + wantErr: assert.AnError, + }, + { + name: "fetching url error bubbles up", + args: args{ + ctx: context.Background(), + messages: []*entity.Message{ + { + Role: entity.RoleUser, + Parts: []*entity.ContentPart{ + { + Type: entity.ContentTypeImageURL, + ImageURL: &entity.ImageURL{ + URL: dataURL, + }, + }, + }, + }, + }, + workspaceID: 400, + }, + setupMock: func(mock *mocks.MockIFileProvider) { + gomock.InOrder( + mock.EXPECT(). + UploadFileForServer(gomock.Any(), "image/jpeg", gomock.Eq(decoded), int64(400)). + Return("workspace/400/file.jpg", nil), + mock.EXPECT(). + MGetFileURL(gomock.Any(), gomock.Eq([]string{"workspace/400/file.jpg"})). + Return(nil, assert.AnError), + ) + }, + wantErr: assert.AnError, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockFile := mocks.NewMockIFileProvider(ctrl) + if tt.setupMock != nil { + tt.setupMock(mockFile) + } + + p := &PromptServiceImpl{ + file: mockFile, + } + + err := p.MConvertBase64DataURLToFileURL(tt.args.ctx, tt.args.messages, tt.args.workspaceID) + unittest.AssertErrorEqual(t, tt.wantErr, err) + + if tt.validate != nil { + tt.validate(t, tt.args.messages) + } + }) + } +} diff --git a/backend/modules/prompt/domain/service/mocks/prompt_service.go b/backend/modules/prompt/domain/service/mocks/prompt_service.go index 952f8617f..ee51ab8c4 100644 --- a/backend/modules/prompt/domain/service/mocks/prompt_service.go +++ b/backend/modules/prompt/domain/service/mocks/prompt_service.go @@ -161,6 +161,34 @@ func (mr *MockIPromptServiceMockRecorder) MCompleteMultiModalFileURL(ctx, messag return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MCompleteMultiModalFileURL", reflect.TypeOf((*MockIPromptService)(nil).MCompleteMultiModalFileURL), ctx, messages, variableVals) } +// MConvertBase64ToFileURI mocks base method. +func (m *MockIPromptService) MConvertBase64DataURLToFileURI(ctx context.Context, messages []*entity.Message, workspaceID int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MConvertBase64DataURLToFileURI", ctx, messages, workspaceID) + ret0, _ := ret[0].(error) + return ret0 +} + +// MConvertBase64ToFileURI indicates an expected call of MConvertBase64ToFileURI. +func (mr *MockIPromptServiceMockRecorder) MConvertBase64ToFileURI(ctx, messages, workspaceID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MConvertBase64DataURLToFileURI", reflect.TypeOf((*MockIPromptService)(nil).MConvertBase64DataURLToFileURI), ctx, messages, workspaceID) +} + +// MConvertBase64ToFileURL mocks base method. +func (m *MockIPromptService) MConvertBase64DataURLToFileURL(ctx context.Context, messages []*entity.Message, workspaceID int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MConvertBase64DataURLToFileURL", ctx, messages, workspaceID) + ret0, _ := ret[0].(error) + return ret0 +} + +// MConvertBase64ToFileURL indicates an expected call of MConvertBase64ToFileURL. +func (mr *MockIPromptServiceMockRecorder) MConvertBase64ToFileURL(ctx, messages, workspaceID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MConvertBase64DataURLToFileURL", reflect.TypeOf((*MockIPromptService)(nil).MConvertBase64DataURLToFileURL), ctx, messages, workspaceID) +} + // MGetPromptIDs mocks base method. func (m *MockIPromptService) MGetPromptIDs(ctx context.Context, spaceID int64, promptKeys []string) (map[string]int64, error) { m.ctrl.T.Helper() diff --git a/backend/modules/prompt/domain/service/service.go b/backend/modules/prompt/domain/service/service.go index 7708b0e12..9e2040df5 100644 --- a/backend/modules/prompt/domain/service/service.go +++ b/backend/modules/prompt/domain/service/service.go @@ -19,6 +19,8 @@ type IPromptService interface { ExecuteStreaming(ctx context.Context, param ExecuteStreamingParam) (*entity.Reply, error) Execute(ctx context.Context, param ExecuteParam) (*entity.Reply, error) MCompleteMultiModalFileURL(ctx context.Context, messages []*entity.Message, variableVals []*entity.VariableVal) error + MConvertBase64DataURLToFileURI(ctx context.Context, messages []*entity.Message, workspaceID int64) error + MConvertBase64DataURLToFileURL(ctx context.Context, messages []*entity.Message, workspaceID int64) error // MGetPromptIDs 根据prompt key获取prompt id MGetPromptIDs(ctx context.Context, spaceID int64, promptKeys []string) (PromptKeyIDMap map[string]int64, err error) // MParseCommitVersion 统一解析提交版本,支持version和label两种方式 diff --git a/backend/modules/prompt/infra/rpc/convertor/chat.go b/backend/modules/prompt/infra/rpc/convertor/chat.go index b3aa3b069..93f9ffdd0 100644 --- a/backend/modules/prompt/infra/rpc/convertor/chat.go +++ b/backend/modules/prompt/infra/rpc/convertor/chat.go @@ -4,7 +4,9 @@ package convertor import ( + "fmt" "strconv" + "strings" "github.com/bytedance/gg/gptr" "github.com/vincent-petithory/dataurl" @@ -130,44 +132,104 @@ func ContentPartDO2DTO(do *entity.ContentPart) *runtimedto.ChatMessagePart { if do == nil { return nil } - return &runtimedto.ChatMessagePart{ - Type: ptr.Of(ContentTypeDO2DTO(do.Type)), - Text: do.Text, - ImageURL: ImageURLDO2DTO(do.Type, do.ImageURL, do.Base64Data), + part := &runtimedto.ChatMessagePart{ + Type: ptr.Of(ContentTypeDO2DTO(do.Type, do.Base64Data)), + Text: do.Text, } + switch do.Type { + case entity.ContentTypeImageURL: + part.ImageURL = ImageURLDO2DTO(do.ImageURL) + case entity.ContentTypeVideoURL: + part.VideoURL = VideoURLDO2DTO(do.VideoURL, do.MediaConfig) + case entity.ContentTypeBase64Data: + imageURL, videoURL := base64DataToMedia(do) + if videoURL != nil { + part.Type = ptr.Of(runtimedto.ChatMessagePartTypeVideoURL) + part.VideoURL = videoURL + } else if imageURL != nil { + part.Type = ptr.Of(runtimedto.ChatMessagePartTypeImageURL) + part.ImageURL = imageURL + } + } + return part } -func ContentTypeDO2DTO(do entity.ContentType) runtimedto.ChatMessagePartType { - switch do { +func ContentTypeDO2DTO(contentType entity.ContentType, base64Data *string) runtimedto.ChatMessagePartType { + switch contentType { case entity.ContentTypeText: return runtimedto.ChatMessagePartTypeText case entity.ContentTypeImageURL: return runtimedto.ChatMessagePartTypeImageURL + case entity.ContentTypeVideoURL: + return runtimedto.ChatMessagePartTypeVideoURL case entity.ContentTypeBase64Data: - return runtimedto.ChatMessagePartTypeImageURL // 目前base64都通过image_url传递 + imageURL, videoURL := base64DataToMedia(&entity.ContentPart{Base64Data: base64Data}) + if videoURL != nil { + return runtimedto.ChatMessagePartTypeVideoURL + } + if imageURL != nil { + return runtimedto.ChatMessagePartTypeImageURL + } + return runtimedto.ChatMessagePartTypeImageURL default: return runtimedto.ChatMessagePartTypeText } } -func ImageURLDO2DTO(contentType entity.ContentType, url *entity.ImageURL, base64Data *string) *runtimedto.ChatMessageImageURL { - switch contentType { - case entity.ContentTypeImageURL: - return &runtimedto.ChatMessageImageURL{ - URL: ptr.Of(url.URL), - } - case entity.ContentTypeBase64Data: - dataURL, _ := dataurl.DecodeString(ptr.From(base64Data)) - if dataURL == nil { - return nil +func ImageURLDO2DTO(url *entity.ImageURL) *runtimedto.ChatMessageImageURL { + if url == nil { + return nil + } + return &runtimedto.ChatMessageImageURL{ + URL: ptr.Of(url.URL), + } +} + +func VideoURLDO2DTO(url *entity.VideoURL, mediaConfig *entity.MediaConfig) *runtimedto.ChatMessageVideoURL { + if url == nil { + return nil + } + var detail *runtimedto.VideoURLDetail + if mediaConfig != nil && mediaConfig.Fps != nil { + detail = &runtimedto.VideoURLDetail{ + Fps: mediaConfig.Fps, } + } + return &runtimedto.ChatMessageVideoURL{ + URL: ptr.Of(url.URL), + Detail: detail, + } +} + +func base64DataToMedia(part *entity.ContentPart) (*runtimedto.ChatMessageImageURL, *runtimedto.ChatMessageVideoURL) { + if part == nil || part.Base64Data == nil || ptr.From(part.Base64Data) == "" { + return nil, nil + } + dataURL, _ := dataurl.DecodeString(ptr.From(part.Base64Data)) + if dataURL == nil { + return nil, nil + } + mimeType := dataURL.ContentType() + if strings.HasPrefix(mimeType, runtimedto.MimePrefixImage) { return &runtimedto.ChatMessageImageURL{ - URL: base64Data, - MimeType: ptr.Of(dataURL.Type), + URL: part.Base64Data, + MimeType: ptr.Of(mimeType), + }, nil + } + if strings.HasPrefix(mimeType, runtimedto.MimePrefixVideo) { + videoURL := &runtimedto.ChatMessageVideoURL{ + URL: part.Base64Data, + MimeType: ptr.Of(mimeType), } - default: - return nil + // Preserve fps from MediaConfig if available + if part.MediaConfig != nil && part.MediaConfig.Fps != nil { + videoURL.Detail = &runtimedto.VideoURLDetail{ + Fps: part.MediaConfig.Fps, + } + } + return nil, videoURL } + return nil, nil } func BatchToolCallDO2DTO(dos []*entity.ToolCall) []*runtimedto.ToolCall { @@ -309,10 +371,13 @@ func MultimodalContentDTO2DO(dto *runtimedto.ChatMessagePart) *entity.ContentPar if dto == nil { return nil } + videoURL, mediaConfig := VideoURLDTO2DO(dto.VideoURL) return &entity.ContentPart{ - Type: ContentTypeDTO2DO(dto.GetType()), - Text: dto.Text, - ImageURL: ImageURLDTO2DO(dto.ImageURL), + Type: ContentTypeDTO2DO(dto.GetType()), + Text: dto.Text, + ImageURL: ImageURLDTO2DO(dto.ImageURL), + VideoURL: videoURL, + MediaConfig: mediaConfig, } } @@ -322,6 +387,8 @@ func ContentTypeDTO2DO(dto runtimedto.ChatMessagePartType) entity.ContentType { return entity.ContentTypeText case runtimedto.ChatMessagePartTypeImageURL: return entity.ContentTypeImageURL + case runtimedto.ChatMessagePartTypeVideoURL: + return entity.ContentTypeVideoURL default: return entity.ContentTypeText } @@ -331,9 +398,34 @@ func ImageURLDTO2DO(dto *runtimedto.ChatMessageImageURL) *entity.ImageURL { if dto == nil { return nil } + url := ptr.From(dto.URL) + // If mimetype is provided and URL is base64 string, convert to dataurl format + if dto.MimeType != nil && ptr.From(dto.MimeType) != "" && !strings.HasPrefix(url, "data:") { + url = fmt.Sprintf("data:%s;base64,%s", ptr.From(dto.MimeType), url) + } return &entity.ImageURL{ - URL: ptr.From(dto.URL), + URL: url, + } +} + +func VideoURLDTO2DO(dto *runtimedto.ChatMessageVideoURL) (*entity.VideoURL, *entity.MediaConfig) { + if dto == nil { + return nil, nil + } + var mediaConfig *entity.MediaConfig + if dto.Detail != nil && dto.Detail.Fps != nil { + mediaConfig = &entity.MediaConfig{ + Fps: dto.Detail.Fps, + } + } + url := ptr.From(dto.URL) + // If mimetype is provided and URL is base64 string, convert to dataurl format + if dto.MimeType != nil && ptr.From(dto.MimeType) != "" && !strings.HasPrefix(url, "data:") { + url = fmt.Sprintf("data:%s;base64,%s", ptr.From(dto.MimeType), url) } + return &entity.VideoURL{ + URL: url, + }, mediaConfig } func BatchToolCallDTO2DO(dtos []*runtimedto.ToolCall) []*entity.ToolCall { diff --git a/backend/modules/prompt/infra/rpc/convertor/chat_test.go b/backend/modules/prompt/infra/rpc/convertor/chat_test.go index 02a5ad9ee..3835e848a 100644 --- a/backend/modules/prompt/infra/rpc/convertor/chat_test.go +++ b/backend/modules/prompt/infra/rpc/convertor/chat_test.go @@ -212,6 +212,67 @@ func TestMessageDO2DTO(t *testing.T) { }, }, }, + { + name: "user video message with detail", + do: &entity.Message{ + Role: "user", + Parts: []*entity.ContentPart{ + { + Type: entity.ContentTypeVideoURL, + VideoURL: &entity.VideoURL{ + URL: "https://example.com/video.mp4", + }, + MediaConfig: &entity.MediaConfig{ + Fps: ptr.Of(1.25), + }, + }, + }, + }, + want: &runtimedto.Message{ + Role: runtimedto.RoleUser, + MultimodalContents: []*runtimedto.ChatMessagePart{ + { + Type: ptr.Of(runtimedto.ChatMessagePartTypeVideoURL), + VideoURL: &runtimedto.ChatMessageVideoURL{ + URL: ptr.Of("https://example.com/video.mp4"), + Detail: &runtimedto.VideoURLDetail{ + Fps: ptr.Of(1.25), + }, + }, + }, + }, + }, + }, + { + name: "user base64 video message", + do: &entity.Message{ + Role: "user", + Parts: []*entity.ContentPart{ + { + Type: entity.ContentTypeBase64Data, + Base64Data: ptr.Of("data:video/mp4;base64,QUJDRA=="), + MediaConfig: &entity.MediaConfig{ + Fps: ptr.Of(3.5), + }, + }, + }, + }, + want: &runtimedto.Message{ + Role: runtimedto.RoleUser, + MultimodalContents: []*runtimedto.ChatMessagePart{ + { + Type: ptr.Of(runtimedto.ChatMessagePartTypeVideoURL), + VideoURL: &runtimedto.ChatMessageVideoURL{ + URL: ptr.Of("data:video/mp4;base64,QUJDRA=="), + MimeType: ptr.Of("video/mp4"), + Detail: &runtimedto.VideoURLDetail{ + Fps: ptr.Of(3.5), + }, + }, + }, + }, + }, + }, { name: "ai tool call message", do: &entity.Message{ @@ -441,6 +502,37 @@ func TestMessageDTO2DO(t *testing.T) { }, }, }, + { + name: "video content part with detail", + dto: &runtimedto.Message{ + Role: runtimedto.RoleAssistant, + MultimodalContents: []*runtimedto.ChatMessagePart{ + { + Type: ptr.Of(runtimedto.ChatMessagePartTypeVideoURL), + VideoURL: &runtimedto.ChatMessageVideoURL{ + URL: ptr.Of("https://example.com/video.mp4"), + Detail: &runtimedto.VideoURLDetail{ + Fps: ptr.Of(2.5), + }, + }, + }, + }, + }, + want: &entity.Message{ + Role: entity.RoleAssistant, + Parts: []*entity.ContentPart{ + { + Type: entity.ContentTypeVideoURL, + VideoURL: &entity.VideoURL{ + URL: "https://example.com/video.mp4", + }, + MediaConfig: &entity.MediaConfig{ + Fps: ptr.Of(2.5), + }, + }, + }, + }, + }, { name: "message with tool call", dto: &runtimedto.Message{ diff --git a/backend/modules/prompt/infra/rpc/file.go b/backend/modules/prompt/infra/rpc/file.go index 65be2909a..fdffb2abd 100644 --- a/backend/modules/prompt/infra/rpc/file.go +++ b/backend/modules/prompt/infra/rpc/file.go @@ -13,6 +13,7 @@ import ( "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" ) +//go:generate mockgen -source=../../../../kitex_gen/coze/loop/foundation/file/fileservice/client.go -destination=mocks/fileservice_mock.go -package=mocks -mock_names Client=FileServiceClient type FileRPCAdapter struct { client fileservice.Client } @@ -45,3 +46,19 @@ func (f *FileRPCAdapter) MGetFileURL(ctx context.Context, keys []string) (urls m } return urls, nil } + +func (f *FileRPCAdapter) UploadFileForServer(ctx context.Context, mimeType string, body []byte, workspaceID int64) (key string, err error) { + req := &file.UploadFileForServerRequest{ + MimeType: mimeType, + Body: body, + WorkspaceID: workspaceID, + } + resp, err := f.client.UploadFileForServer(ctx, req) + if err != nil { + return "", err + } + if resp.Data == nil || resp.Data.FileName == nil { + return "", errorx.New("upload file response invalid: missing file name") + } + return *resp.Data.FileName, nil +} diff --git a/backend/modules/prompt/infra/rpc/file_test.go b/backend/modules/prompt/infra/rpc/file_test.go new file mode 100644 index 000000000..7a7204977 --- /dev/null +++ b/backend/modules/prompt/infra/rpc/file_test.go @@ -0,0 +1,201 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package rpc + +import ( + "context" + "errors" + "testing" + + "github.com/cloudwego/kitex/client/callopt" + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/foundation/file" + "github.com/coze-dev/coze-loop/backend/modules/prompt/infra/rpc/mocks" + "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +func TestNewFileRPCProvider(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewFileServiceClient(ctrl) + provider := NewFileRPCProvider(mockClient) + + assert.NotNil(t, provider) + adapter, ok := provider.(*FileRPCAdapter) + assert.True(t, ok) + assert.Equal(t, mockClient, adapter.client) +} + +func TestFileRPCAdapter_MGetFileURL(t *testing.T) { + tests := []struct { + name string + keys []string + setupMock func(*mocks.FileServiceClient) + expectErr string + expectURLs map[string]string + }{ + { + name: "success - returns url map", + keys: []string{"file-1", "file-2"}, + setupMock: func(mc *mocks.FileServiceClient) { + mc.EXPECT().SignDownloadFile(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, req *file.SignDownloadFileRequest, _ ...callopt.Option) (*file.SignDownloadFileResponse, error) { + assert.Equal(t, []string{"file-1", "file-2"}, req.Keys) + if assert.NotNil(t, req.Option) { + assert.Equal(t, int64(24*60*60), req.Option.GetTTL()) + } + if assert.NotNil(t, req.BusinessType) { + assert.Equal(t, file.BusinessTypePrompt, *req.BusinessType) + } + return &file.SignDownloadFileResponse{ + Uris: []string{"https://file-1", "https://file-2"}, + }, nil + }, + ) + }, + expectURLs: map[string]string{ + "file-1": "https://file-1", + "file-2": "https://file-2", + }, + }, + { + name: "failure - mismatched uri count", + keys: []string{"file-1", "file-2"}, + setupMock: func(mc *mocks.FileServiceClient) { + mc.EXPECT().SignDownloadFile(gomock.Any(), gomock.Any()).Return( + &file.SignDownloadFileResponse{ + Uris: []string{"https://file-1"}, + }, + nil, + ) + }, + expectErr: "url length mismatch with keys", + }, + { + name: "failure - rpc error", + keys: []string{"file-1"}, + setupMock: func(mc *mocks.FileServiceClient) { + mc.EXPECT().SignDownloadFile(gomock.Any(), gomock.Any()).Return(nil, errors.New("sign download failed")) + }, + expectErr: "sign download failed", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewFileServiceClient(ctrl) + if tt.setupMock != nil { + tt.setupMock(mockClient) + } + + adapter := &FileRPCAdapter{ + client: mockClient, + } + + urls, err := adapter.MGetFileURL(context.Background(), tt.keys) + + if tt.expectErr != "" { + assert.ErrorContains(t, err, tt.expectErr) + assert.Nil(t, urls) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.expectURLs, urls) + }) + } +} + +func TestFileRPCAdapter_UploadFileForServer(t *testing.T) { + tests := []struct { + name string + mimeType string + body []byte + workspaceID int64 + setupMock func(*mocks.FileServiceClient) + expectErr string + expectResult string + }{ + { + name: "success - returns uploaded key", + mimeType: "image/png", + body: []byte("content"), + workspaceID: 101, + setupMock: func(mc *mocks.FileServiceClient) { + mc.EXPECT().UploadFileForServer(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, req *file.UploadFileForServerRequest, _ ...callopt.Option) (*file.UploadFileForServerResponse, error) { + assert.Equal(t, "image/png", req.MimeType) + assert.Equal(t, []byte("content"), req.Body) + assert.Equal(t, int64(101), req.WorkspaceID) + return &file.UploadFileForServerResponse{ + Data: &file.FileData{ + FileName: ptr.Of("uploaded-key"), + }, + }, nil + }, + ) + }, + expectResult: "uploaded-key", + }, + { + name: "failure - missing file name in response", + mimeType: "image/png", + body: []byte("content"), + workspaceID: 101, + setupMock: func(mc *mocks.FileServiceClient) { + mc.EXPECT().UploadFileForServer(gomock.Any(), gomock.Any()).Return( + &file.UploadFileForServerResponse{ + Data: nil, + }, + nil, + ) + }, + expectErr: "upload file response invalid: missing file name", + }, + { + name: "failure - rpc error", + mimeType: "image/png", + body: []byte("content"), + workspaceID: 101, + setupMock: func(mc *mocks.FileServiceClient) { + mc.EXPECT().UploadFileForServer(gomock.Any(), gomock.Any()).Return(nil, errors.New("upload failed")) + }, + expectErr: "upload failed", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewFileServiceClient(ctrl) + if tt.setupMock != nil { + tt.setupMock(mockClient) + } + + adapter := &FileRPCAdapter{ + client: mockClient, + } + + result, err := adapter.UploadFileForServer(context.Background(), tt.mimeType, tt.body, tt.workspaceID) + + if tt.expectErr != "" { + assert.ErrorContains(t, err, tt.expectErr) + assert.Empty(t, result) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.expectResult, result) + }) + } +} diff --git a/backend/modules/prompt/infra/rpc/mocks/fileservice_mock.go b/backend/modules/prompt/infra/rpc/mocks/fileservice_mock.go new file mode 100644 index 000000000..cc506f08c --- /dev/null +++ b/backend/modules/prompt/infra/rpc/mocks/fileservice_mock.go @@ -0,0 +1,123 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: backend/kitex_gen/coze/loop/foundation/file/fileservice/client.go +// +// Generated by this command: +// +// mockgen -source=backend/kitex_gen/coze/loop/foundation/file/fileservice/client.go -destination=backend/modules/prompt/infra/rpc/mocks/fileservice_mock.go -package=mocks -mock_names Client=FileServiceClient +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + callopt "github.com/cloudwego/kitex/client/callopt" + file "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/foundation/file" + gomock "go.uber.org/mock/gomock" +) + +// FileServiceClient is a mock of Client interface. +type FileServiceClient struct { + ctrl *gomock.Controller + recorder *FileServiceClientMockRecorder + isgomock struct{} +} + +// FileServiceClientMockRecorder is the mock recorder for FileServiceClient. +type FileServiceClientMockRecorder struct { + mock *FileServiceClient +} + +// NewFileServiceClient creates a new mock instance. +func NewFileServiceClient(ctrl *gomock.Controller) *FileServiceClient { + mock := &FileServiceClient{ctrl: ctrl} + mock.recorder = &FileServiceClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *FileServiceClient) EXPECT() *FileServiceClientMockRecorder { + return m.recorder +} + +// SignDownloadFile mocks base method. +func (m *FileServiceClient) SignDownloadFile(ctx context.Context, req *file.SignDownloadFileRequest, callOptions ...callopt.Option) (*file.SignDownloadFileResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, req} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "SignDownloadFile", varargs...) + ret0, _ := ret[0].(*file.SignDownloadFileResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SignDownloadFile indicates an expected call of SignDownloadFile. +func (mr *FileServiceClientMockRecorder) SignDownloadFile(ctx, req any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, req}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignDownloadFile", reflect.TypeOf((*FileServiceClient)(nil).SignDownloadFile), varargs...) +} + +// SignUploadFile mocks base method. +func (m *FileServiceClient) SignUploadFile(ctx context.Context, req *file.SignUploadFileRequest, callOptions ...callopt.Option) (*file.SignUploadFileResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, req} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "SignUploadFile", varargs...) + ret0, _ := ret[0].(*file.SignUploadFileResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SignUploadFile indicates an expected call of SignUploadFile. +func (mr *FileServiceClientMockRecorder) SignUploadFile(ctx, req any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, req}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignUploadFile", reflect.TypeOf((*FileServiceClient)(nil).SignUploadFile), varargs...) +} + +// UploadFileForServer mocks base method. +func (m *FileServiceClient) UploadFileForServer(ctx context.Context, req *file.UploadFileForServerRequest, callOptions ...callopt.Option) (*file.UploadFileForServerResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, req} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "UploadFileForServer", varargs...) + ret0, _ := ret[0].(*file.UploadFileForServerResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UploadFileForServer indicates an expected call of UploadFileForServer. +func (mr *FileServiceClientMockRecorder) UploadFileForServer(ctx, req any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, req}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UploadFileForServer", reflect.TypeOf((*FileServiceClient)(nil).UploadFileForServer), varargs...) +} + +// UploadLoopFileInner mocks base method. +func (m *FileServiceClient) UploadLoopFileInner(ctx context.Context, req *file.UploadLoopFileInnerRequest, callOptions ...callopt.Option) (*file.UploadLoopFileInnerResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, req} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "UploadLoopFileInner", varargs...) + ret0, _ := ret[0].(*file.UploadLoopFileInnerResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UploadLoopFileInner indicates an expected call of UploadLoopFileInner. +func (mr *FileServiceClientMockRecorder) UploadLoopFileInner(ctx, req any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, req}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UploadLoopFileInner", reflect.TypeOf((*FileServiceClient)(nil).UploadLoopFileInner), varargs...) +} diff --git a/idl/thrift/coze/loop/foundation/coze.loop.foundation.file.thrift b/idl/thrift/coze/loop/foundation/coze.loop.foundation.file.thrift index e20d548a7..c8057020d 100644 --- a/idl/thrift/coze/loop/foundation/coze.loop.foundation.file.thrift +++ b/idl/thrift/coze/loop/foundation/coze.loop.foundation.file.thrift @@ -87,8 +87,31 @@ struct SignDownloadFileResponse { 255: base.BaseResp BaseResp } +struct UploadFileOption { + 1: optional string file_name // file name + 2: optional map mime_type_ext_mapping // custom mimetype -> ext mapping +} + +struct UploadFileForServerRequest { + 1: required string mime_type // file mime type + 2: required binary body // file binary data + 3: required i64 workspace_id (api.js_conv='true', go.tag='json:"workspace_id"') // workspace id + 4: optional UploadFileOption option // upload options + + 255: optional base.Base Base +} + +struct UploadFileForServerResponse { + 1: optional i32 code + 2: optional string msg + 3: optional FileData data + + 255: base.BaseResp BaseResp +} + service FileService { UploadLoopFileInnerResponse UploadLoopFileInner(1: UploadLoopFileInnerRequest req) // for inner service, etc prompt or eval + UploadFileForServerResponse UploadFileForServer(1: UploadFileForServerRequest req) // for internal server upload SignUploadFileResponse SignUploadFile(1: SignUploadFileRequest req) (api.post='/api/foundation/v1/sign_upload_files') SignDownloadFileResponse SignDownloadFile(1: SignDownloadFileRequest req) // for inner service, etc prompt or eval } \ No newline at end of file diff --git a/idl/thrift/coze/loop/llm/domain/runtime.thrift b/idl/thrift/coze/loop/llm/domain/runtime.thrift index 0ae685dc8..f88820b4f 100644 --- a/idl/thrift/coze/loop/llm/domain/runtime.thrift +++ b/idl/thrift/coze/loop/llm/domain/runtime.thrift @@ -31,10 +31,19 @@ struct ChatMessagePart { 2: optional string text 3: optional ChatMessageImageURL image_url // 4: optional ChatMessageAudioURL audio_url 占位,暂不支持 -// 5: optional ChatMessageVideoURL video_url 占位,暂不支持 + 5: optional ChatMessageVideoURL video_url // 6: optional ChatMessageFileURL file_url 占位,暂不支持 } +struct ChatMessageVideoURL { + 1: optional string url + 2: optional VideoURLDetail detail + 3: optional string mime_type +} +struct VideoURLDetail { + 1: optional double fps (vt.ge="0.2", vt.le="5") +} + struct ChatMessageImageURL { 1: optional string url 2: optional ImageURLDetail detail @@ -111,10 +120,14 @@ typedef string ChatMessagePartType (ts.enum="true") const ChatMessagePartType chat_message_part_type_text = "text" const ChatMessagePartType chat_message_part_type_image_url = "image_url" // const ChatMessagePartType chat_message_part_type_audio_url = "audio_url" -// const ChatMessagePartType chat_message_part_type_video_url = "video_url" + const ChatMessagePartType chat_message_part_type_video_url = "video_url" // const ChatMessagePartType chat_message_part_type_file_url = "file_url" typedef string ImageURLDetail (ts.enum="true") const ImageURLDetail image_url_detail_auto = "auto" const ImageURLDetail image_url_detail_low = "low" const ImageURLDetail image_url_detail_high = "high" + +typedef string MimeTypePrefix (ts.enum="true") +const MimeTypePrefix mime_prefix_image = "image/" +const MimeTypePrefix mime_prefix_video = "video/" diff --git a/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift b/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift index 0917d6018..038b4e83d 100644 --- a/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift +++ b/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift @@ -128,11 +128,18 @@ struct ContentPart { 2: optional string text 3: optional string image_url 4: optional string base64_data + 5: optional string video_url + 6: optional MediaConfig config +} + +struct MediaConfig { + 1: optional double fps (vt.ge="0.2", vt.le="5") } typedef string ContentType (ts.enum="true") const ContentType ContentType_Text = "text" const ContentType ContentType_ImageURL = "image_url" +const ContentType ContentType_VideoURL = "video_url" const ContentType ContentType_Base64Data = "base64_data" const ContentType ContentType_MultiPartVariable = "multi_part_variable" diff --git a/idl/thrift/coze/loop/prompt/domain/prompt.thrift b/idl/thrift/coze/loop/prompt/domain/prompt.thrift index c78f0a95c..6c24c61a6 100644 --- a/idl/thrift/coze/loop/prompt/domain/prompt.thrift +++ b/idl/thrift/coze/loop/prompt/domain/prompt.thrift @@ -124,11 +124,14 @@ struct ContentPart { 1: optional ContentType type 2: optional string text 3: optional ImageURL image_url + 4: optional VideoURL video_url + 5: optional MediaConfig media_config } typedef string ContentType (ts.enum="true") const ContentType ContentType_Text = "text" const ContentType ContentType_ImageURL = "image_url" +const ContentType ContentType_VideoURL = "video_url" const ContentType ContentType_MultiPartVariable = "multi_part_variable" struct ImageURL { @@ -136,6 +139,15 @@ struct ImageURL { 2: optional string url } +struct VideoURL { + 1: optional string url + 2: optional string uri +} + +struct MediaConfig { + 1: optional double fps (vt.ge="0.2", vt.le="5") +} + struct ToolCall { 1: optional i64 index (api.js_conv="true", go.tag='json:"index"') 2: optional string id From 17eb70c9cfa4b20ee14db79464aabb6bfdd77f4a Mon Sep 17 00:00:00 2001 From: caijialin0626 <61818131+caijialin0626@users.noreply.github.com> Date: Mon, 27 Oct 2025 18:44:48 +0800 Subject: [PATCH 06/12] [feat][prompt] prompt support model config extra field (#260) --- .../loop/prompt/domain/prompt/k-prompt.go | 56 ++++++++++++++ .../coze/loop/prompt/domain/prompt/prompt.go | 77 +++++++++++++++++++ .../prompt/application/convertor/prompt.go | 2 + .../application/convertor/prompt_test.go | 15 ++++ .../prompt/application/execute_test.go | 16 +++- .../prompt/domain/entity/prompt_detail.go | 1 + .../prompt/domain/service/execute_test.go | 49 ++++++++++++ .../coze/loop/prompt/domain/prompt.thrift | 1 + 8 files changed, 214 insertions(+), 3 deletions(-) diff --git a/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go b/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go index 874e621b7..ae4053b2f 100644 --- a/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go +++ b/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go @@ -3384,6 +3384,20 @@ func (p *ModelConfig) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 9: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField9(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } default: l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l @@ -3514,6 +3528,20 @@ func (p *ModelConfig) FastReadField8(buf []byte) (int, error) { return offset, nil } +func (p *ModelConfig) FastReadField9(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Extra = _field + return offset, nil +} + func (p *ModelConfig) FastWrite(buf []byte) int { return p.FastWriteNocopy(buf, nil) } @@ -3529,6 +3557,7 @@ func (p *ModelConfig) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset += p.fastWriteField6(buf[offset:], w) offset += p.fastWriteField7(buf[offset:], w) offset += p.fastWriteField8(buf[offset:], w) + offset += p.fastWriteField9(buf[offset:], w) } offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset @@ -3545,6 +3574,7 @@ func (p *ModelConfig) BLength() int { l += p.field6Length() l += p.field7Length() l += p.field8Length() + l += p.field9Length() } l += thrift.Binary.FieldStopLength() return l @@ -3622,6 +3652,15 @@ func (p *ModelConfig) fastWriteField8(buf []byte, w thrift.NocopyWriter) int { return offset } +func (p *ModelConfig) fastWriteField9(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetExtra() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 9) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Extra) + } + return offset +} + func (p *ModelConfig) field1Length() int { l := 0 if p.IsSetModelID() { @@ -3694,6 +3733,15 @@ func (p *ModelConfig) field8Length() int { return l } +func (p *ModelConfig) field9Length() int { + l := 0 + if p.IsSetExtra() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Extra) + } + return l +} + func (p *ModelConfig) DeepCopy(s interface{}) error { src, ok := s.(*ModelConfig) if !ok { @@ -3740,6 +3788,14 @@ func (p *ModelConfig) DeepCopy(s interface{}) error { p.JSONMode = &tmp } + if src.Extra != nil { + var tmp string + if *src.Extra != "" { + tmp = kutils.StringDeepCopy(*src.Extra) + } + p.Extra = &tmp + } + return nil } diff --git a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go index cb2c99b73..2cb24c558 100644 --- a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go +++ b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go @@ -4579,6 +4579,7 @@ type ModelConfig struct { PresencePenalty *float64 `thrift:"presence_penalty,6,optional" frugal:"6,optional,double" form:"presence_penalty" json:"presence_penalty,omitempty" query:"presence_penalty"` FrequencyPenalty *float64 `thrift:"frequency_penalty,7,optional" frugal:"7,optional,double" form:"frequency_penalty" json:"frequency_penalty,omitempty" query:"frequency_penalty"` JSONMode *bool `thrift:"json_mode,8,optional" frugal:"8,optional,bool" form:"json_mode" json:"json_mode,omitempty" query:"json_mode"` + Extra *string `thrift:"extra,9,optional" frugal:"9,optional,string" form:"extra" json:"extra,omitempty" query:"extra"` } func NewModelConfig() *ModelConfig { @@ -4683,6 +4684,18 @@ func (p *ModelConfig) GetJSONMode() (v bool) { } return *p.JSONMode } + +var ModelConfig_Extra_DEFAULT string + +func (p *ModelConfig) GetExtra() (v string) { + if p == nil { + return + } + if !p.IsSetExtra() { + return ModelConfig_Extra_DEFAULT + } + return *p.Extra +} func (p *ModelConfig) SetModelID(val *int64) { p.ModelID = val } @@ -4707,6 +4720,9 @@ func (p *ModelConfig) SetFrequencyPenalty(val *float64) { func (p *ModelConfig) SetJSONMode(val *bool) { p.JSONMode = val } +func (p *ModelConfig) SetExtra(val *string) { + p.Extra = val +} var fieldIDToName_ModelConfig = map[int16]string{ 1: "model_id", @@ -4717,6 +4733,7 @@ var fieldIDToName_ModelConfig = map[int16]string{ 6: "presence_penalty", 7: "frequency_penalty", 8: "json_mode", + 9: "extra", } func (p *ModelConfig) IsSetModelID() bool { @@ -4751,6 +4768,10 @@ func (p *ModelConfig) IsSetJSONMode() bool { return p.JSONMode != nil } +func (p *ModelConfig) IsSetExtra() bool { + return p.Extra != nil +} + func (p *ModelConfig) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -4833,6 +4854,14 @@ func (p *ModelConfig) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 9: + if fieldTypeId == thrift.STRING { + if err = p.ReadField9(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError @@ -4950,6 +4979,17 @@ func (p *ModelConfig) ReadField8(iprot thrift.TProtocol) error { p.JSONMode = _field return nil } +func (p *ModelConfig) ReadField9(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Extra = _field + return nil +} func (p *ModelConfig) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 @@ -4989,6 +5029,10 @@ func (p *ModelConfig) Write(oprot thrift.TProtocol) (err error) { fieldId = 8 goto WriteFieldError } + if err = p.writeField9(oprot); err != nil { + fieldId = 9 + goto WriteFieldError + } } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -5151,6 +5195,24 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 8 end error: ", p), err) } +func (p *ModelConfig) writeField9(oprot thrift.TProtocol) (err error) { + if p.IsSetExtra() { + if err = oprot.WriteFieldBegin("extra", thrift.STRING, 9); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Extra); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 9 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 9 end error: ", p), err) +} func (p *ModelConfig) String() string { if p == nil { @@ -5190,6 +5252,9 @@ func (p *ModelConfig) DeepEqual(ano *ModelConfig) bool { if !p.Field8DeepEqual(ano.JSONMode) { return false } + if !p.Field9DeepEqual(ano.Extra) { + return false + } return true } @@ -5289,6 +5354,18 @@ func (p *ModelConfig) Field8DeepEqual(src *bool) bool { } return true } +func (p *ModelConfig) Field9DeepEqual(src *string) bool { + + if p.Extra == src { + return true + } else if p.Extra == nil || src == nil { + return false + } + if strings.Compare(*p.Extra, *src) != 0 { + return false + } + return true +} type Message struct { Role *Role `thrift:"role,1,optional" frugal:"1,optional,string" form:"role" json:"role,omitempty" query:"role"` diff --git a/backend/modules/prompt/application/convertor/prompt.go b/backend/modules/prompt/application/convertor/prompt.go index 1c30c2413..febbafc6a 100644 --- a/backend/modules/prompt/application/convertor/prompt.go +++ b/backend/modules/prompt/application/convertor/prompt.go @@ -421,6 +421,7 @@ func ModelConfigDTO2DO(dto *prompt.ModelConfig) *entity.ModelConfig { PresencePenalty: dto.PresencePenalty, FrequencyPenalty: dto.FrequencyPenalty, JSONMode: dto.JSONMode, + Extra: dto.Extra, } } @@ -824,6 +825,7 @@ func ModelConfigDO2DTO(do *entity.ModelConfig) *prompt.ModelConfig { PresencePenalty: do.PresencePenalty, FrequencyPenalty: do.FrequencyPenalty, JSONMode: do.JSONMode, + Extra: do.Extra, } } diff --git a/backend/modules/prompt/application/convertor/prompt_test.go b/backend/modules/prompt/application/convertor/prompt_test.go index bd3ac6d15..816ae1ba1 100644 --- a/backend/modules/prompt/application/convertor/prompt_test.go +++ b/backend/modules/prompt/application/convertor/prompt_test.go @@ -650,3 +650,18 @@ func TestMessageDO2DTO(t *testing.T) { }) } } + +func TestModelConfigExtraConversion(t *testing.T) { + extra := ptr.Of(`{"foo":"bar"}`) + dto := &prompt.ModelConfig{ + Extra: extra, + } + + do := ModelConfigDTO2DO(dto) + assert.NotNil(t, do) + assert.Equal(t, extra, do.Extra) + + dtoBack := ModelConfigDO2DTO(do) + assert.NotNil(t, dtoBack) + assert.Equal(t, extra, dtoBack.Extra) +} diff --git a/backend/modules/prompt/application/execute_test.go b/backend/modules/prompt/application/execute_test.go index db9a57628..0400c9ffe 100755 --- a/backend/modules/prompt/application/execute_test.go +++ b/backend/modules/prompt/application/execute_test.go @@ -519,6 +519,7 @@ func TestOverridePromptParams(t *testing.T) { ModelConfig: &entity.ModelConfig{ ModelID: 456, Temperature: ptr.Of(0.7), + Extra: ptr.Of(`{"source":"base"}`), }, }, }, @@ -586,6 +587,7 @@ func TestOverridePromptParams(t *testing.T) { ModelID: ptr.Of(int64(789)), Temperature: ptr.Of(0.9), MaxTokens: ptr.Of(int32(2000)), + Extra: ptr.Of(`{"source":"override"}`), }, }, }, @@ -598,6 +600,7 @@ func TestOverridePromptParams(t *testing.T) { ModelID: 789, Temperature: ptr.Of(0.9), MaxTokens: ptr.Of(int32(2000)), + Extra: ptr.Of(`{"source":"override"}`), }, }, } @@ -651,10 +654,17 @@ func TestOverridePromptParams(t *testing.T) { if tt.args.promptDO.PromptCommit.PromptDetail != nil { promptCopy.PromptCommit.PromptDetail = &entity.PromptDetail{} if tt.args.promptDO.PromptCommit.PromptDetail.ModelConfig != nil { + orig := tt.args.promptDO.PromptCommit.PromptDetail.ModelConfig promptCopy.PromptCommit.PromptDetail.ModelConfig = &entity.ModelConfig{ - ModelID: tt.args.promptDO.PromptCommit.PromptDetail.ModelConfig.ModelID, - Temperature: tt.args.promptDO.PromptCommit.PromptDetail.ModelConfig.Temperature, - MaxTokens: tt.args.promptDO.PromptCommit.PromptDetail.ModelConfig.MaxTokens, + ModelID: orig.ModelID, + MaxTokens: orig.MaxTokens, + Temperature: orig.Temperature, + TopK: orig.TopK, + TopP: orig.TopP, + PresencePenalty: orig.PresencePenalty, + FrequencyPenalty: orig.FrequencyPenalty, + JSONMode: orig.JSONMode, + Extra: orig.Extra, } } } diff --git a/backend/modules/prompt/domain/entity/prompt_detail.go b/backend/modules/prompt/domain/entity/prompt_detail.go index fb261f1a3..0c9c23157 100644 --- a/backend/modules/prompt/domain/entity/prompt_detail.go +++ b/backend/modules/prompt/domain/entity/prompt_detail.go @@ -181,6 +181,7 @@ type ModelConfig struct { PresencePenalty *float64 `json:"presence_penalty,omitempty"` FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` JSONMode *bool `json:"json_mode,omitempty"` + Extra *string `json:"extra,omitempty"` } func (pt *PromptTemplate) formatMessages(messages []*Message, variableVals []*VariableVal) ([]*Message, error) { diff --git a/backend/modules/prompt/domain/service/execute_test.go b/backend/modules/prompt/domain/service/execute_test.go index 4eb316ca0..1f97f1315 100644 --- a/backend/modules/prompt/domain/service/execute_test.go +++ b/backend/modules/prompt/domain/service/execute_test.go @@ -854,3 +854,52 @@ func TestPromptServiceImpl_Execute(t *testing.T) { }) } } + +func TestPromptServiceImpl_prepareLLMCallParam_PreservesExtra(t *testing.T) { + t.Parallel() + extra := ptr.Of(`{"foo":"bar"}`) + prompt := &entity.Prompt{ + ID: 1, + SpaceID: 42, + PromptKey: "test_prompt", + PromptCommit: &entity.PromptCommit{ + CommitInfo: &entity.CommitInfo{ + Version: "v1", + }, + PromptDetail: &entity.PromptDetail{ + ModelConfig: &entity.ModelConfig{ + ModelID: 99, + Extra: extra, + JSONMode: ptr.Of(true), + }, + PromptTemplate: &entity.PromptTemplate{ + TemplateType: entity.TemplateTypeNormal, + Messages: []*entity.Message{ + { + Role: entity.RoleSystem, + Content: ptr.Of("System prompt"), + }, + }, + }, + }, + }, + } + svc := &PromptServiceImpl{} + param := ExecuteParam{ + Prompt: prompt, + Messages: []*entity.Message{ + { + Role: entity.RoleUser, + Content: ptr.Of("Hi"), + }, + }, + VariableVals: nil, + Scenario: entity.ScenarioPromptDebug, + } + got, err := svc.prepareLLMCallParam(context.Background(), param) + assert.NoError(t, err) + if assert.NotNil(t, got.ModelConfig) { + assert.Equal(t, extra, got.ModelConfig.Extra) + assert.Equal(t, prompt.PromptCommit.PromptDetail.ModelConfig.Extra, got.ModelConfig.Extra) + } +} diff --git a/idl/thrift/coze/loop/prompt/domain/prompt.thrift b/idl/thrift/coze/loop/prompt/domain/prompt.thrift index 6c24c61a6..76b874e14 100644 --- a/idl/thrift/coze/loop/prompt/domain/prompt.thrift +++ b/idl/thrift/coze/loop/prompt/domain/prompt.thrift @@ -100,6 +100,7 @@ struct ModelConfig { 6: optional double presence_penalty 7: optional double frequency_penalty 8: optional bool json_mode + 9: optional string extra } struct Message { From 1e4f43f484ac7dd259dcc9e0181eb3c8c71dbcb5 Mon Sep 17 00:00:00 2001 From: caijialin0626 <61818131+caijialin0626@users.noreply.github.com> Date: Tue, 28 Oct 2025 16:45:23 +0800 Subject: [PATCH 07/12] [feat][prompt] prompt support go template (#269) * [feat][prompt] prompt support go template * [feat][prompt] prompt support go template * [feat][prompt] prompt support go template ut * [feat][prompt] prompt support go template ut --- .../coze/loop/prompt/domain/prompt/prompt.go | 2 + .../openapi/coze.loop.prompt.openapi.go | 2 + .../prompt/application/convertor/prompt.go | 2 + .../application/convertor/prompt_test.go | 119 +++++++ .../prompt/domain/entity/prompt_detail.go | 18 +- .../domain/entity/prompt_detail_test.go | 208 ++++++++++++ .../prompt/pkg/template/go_template.go | 29 ++ .../prompt/pkg/template/go_template_test.go | 318 ++++++++++++++++++ .../prompt/coze.loop.prompt.openapi.thrift | 1 + .../coze/loop/prompt/domain/prompt.thrift | 1 + 10 files changed, 698 insertions(+), 2 deletions(-) create mode 100644 backend/modules/prompt/pkg/template/go_template.go create mode 100644 backend/modules/prompt/pkg/template/go_template_test.go diff --git a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go index 2cb24c558..0b5b8575e 100644 --- a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go +++ b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go @@ -13,6 +13,8 @@ const ( TemplateTypeJinja2 = "jinja2" + TemplateTypeGoTemplate = "go_template" + ToolTypeFunction = "function" ToolChoiceTypeNone = "none" diff --git a/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go b/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go index 18c4cf28d..d509564b1 100644 --- a/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go +++ b/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go @@ -16,6 +16,8 @@ const ( TemplateTypeJinja2 = "jinja2" + TemplateTypeGoTemplate = "go_template" + ToolChoiceTypeAuto = "auto" ToolChoiceTypeNone = "none" diff --git a/backend/modules/prompt/application/convertor/prompt.go b/backend/modules/prompt/application/convertor/prompt.go index febbafc6a..c1448fafb 100644 --- a/backend/modules/prompt/application/convertor/prompt.go +++ b/backend/modules/prompt/application/convertor/prompt.go @@ -119,6 +119,8 @@ func TemplateTypeDTO2DO(dto prompt.TemplateType) entity.TemplateType { return entity.TemplateTypeNormal case prompt.TemplateTypeJinja2: return entity.TemplateTypeJinja2 + case prompt.TemplateTypeGoTemplate: + return entity.TemplateTypeGoTemplate default: return entity.TemplateTypeNormal } diff --git a/backend/modules/prompt/application/convertor/prompt_test.go b/backend/modules/prompt/application/convertor/prompt_test.go index 816ae1ba1..18d1b9aef 100644 --- a/backend/modules/prompt/application/convertor/prompt_test.go +++ b/backend/modules/prompt/application/convertor/prompt_test.go @@ -665,3 +665,122 @@ func TestModelConfigExtraConversion(t *testing.T) { assert.NotNil(t, dtoBack) assert.Equal(t, extra, dtoBack.Extra) } + +func TestTemplateTypeDTO2DO(t *testing.T) { + tests := []struct { + name string + dto prompt.TemplateType + want entity.TemplateType + }{ + { + name: "normal template type", + dto: prompt.TemplateTypeNormal, + want: entity.TemplateTypeNormal, + }, + { + name: "jinja2 template type", + dto: prompt.TemplateTypeJinja2, + want: entity.TemplateTypeJinja2, + }, + { + name: "go template type", + dto: prompt.TemplateTypeGoTemplate, + want: entity.TemplateTypeGoTemplate, + }, + { + name: "unknown template type defaults to normal", + dto: prompt.TemplateType("unknown"), + want: entity.TemplateTypeNormal, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := TemplateTypeDTO2DO(tt.dto) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestPromptTemplateWithDifferentTypes(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + dto *prompt.PromptTemplate + want *entity.PromptTemplate + }{ + { + name: "normal template", + dto: &prompt.PromptTemplate{ + TemplateType: ptr.Of(prompt.TemplateTypeNormal), + Messages: []*prompt.Message{ + { + Role: ptr.Of(prompt.RoleUser), + Content: ptr.Of("Hello {{name}}"), + }, + }, + }, + want: &entity.PromptTemplate{ + TemplateType: entity.TemplateTypeNormal, + Messages: []*entity.Message{ + { + Role: entity.RoleUser, + Content: ptr.Of("Hello {{name}}"), + }, + }, + }, + }, + { + name: "jinja2 template", + dto: &prompt.PromptTemplate{ + TemplateType: ptr.Of(prompt.TemplateTypeJinja2), + Messages: []*prompt.Message{ + { + Role: ptr.Of(prompt.RoleUser), + Content: ptr.Of("Hello {{ name }}"), + }, + }, + }, + want: &entity.PromptTemplate{ + TemplateType: entity.TemplateTypeJinja2, + Messages: []*entity.Message{ + { + Role: entity.RoleUser, + Content: ptr.Of("Hello {{ name }}"), + }, + }, + }, + }, + { + name: "go template", + dto: &prompt.PromptTemplate{ + TemplateType: ptr.Of(prompt.TemplateTypeGoTemplate), + Messages: []*prompt.Message{ + { + Role: ptr.Of(prompt.RoleUser), + Content: ptr.Of("Hello {{.name}}"), + }, + }, + }, + want: &entity.PromptTemplate{ + TemplateType: entity.TemplateTypeGoTemplate, + Messages: []*entity.Message{ + { + Role: entity.RoleUser, + Content: ptr.Of("Hello {{.name}}"), + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := PromptTemplateDTO2DO(tt.dto) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/backend/modules/prompt/domain/entity/prompt_detail.go b/backend/modules/prompt/domain/entity/prompt_detail.go index 0c9c23157..1bcc8434b 100644 --- a/backend/modules/prompt/domain/entity/prompt_detail.go +++ b/backend/modules/prompt/domain/entity/prompt_detail.go @@ -43,8 +43,9 @@ type PromptTemplate struct { type TemplateType string const ( - TemplateTypeNormal TemplateType = "normal" - TemplateTypeJinja2 TemplateType = "jinja2" + TemplateTypeNormal TemplateType = "normal" + TemplateTypeJinja2 TemplateType = "jinja2" + TemplateTypeGoTemplate TemplateType = "go_template" ) type Message struct { @@ -300,6 +301,8 @@ func formatText(templateType TemplateType, templateStr string, defMap map[string }), nil case TemplateTypeJinja2: return renderJinja2Template(templateStr, defMap, valMap) + case TemplateTypeGoTemplate: + return renderGoTemplate(templateStr, defMap, valMap) default: return "", errorx.NewByCode(prompterr.UnsupportedTemplateTypeCode, errorx.WithExtraMsg("unknown template type: "+string(templateType))) } @@ -316,6 +319,17 @@ func renderJinja2Template(templateStr string, defMap map[string]*VariableDef, va return template.InterpolateJinja2(templateStr, variables) } +// renderGoTemplate 渲染 Go Template 模板 +func renderGoTemplate(templateStr string, defMap map[string]*VariableDef, valMap map[string]*VariableVal) (string, error) { + // 转换变量为 map[string]any 格式 + variables, err := convertVariablesToMap(defMap, valMap) + if err != nil { + return "", err + } + + return template.InterpolateGoTemplate(templateStr, variables) +} + // convertVariablesToMap 将变量定义和变量值转换为模板引擎可用的 map func convertVariablesToMap(defMap map[string]*VariableDef, valMap map[string]*VariableVal) (map[string]any, error) { if len(defMap) == 0 || len(valMap) == 0 { diff --git a/backend/modules/prompt/domain/entity/prompt_detail_test.go b/backend/modules/prompt/domain/entity/prompt_detail_test.go index a1435ae9a..47bd5c6e4 100644 --- a/backend/modules/prompt/domain/entity/prompt_detail_test.go +++ b/backend/modules/prompt/domain/entity/prompt_detail_test.go @@ -1289,3 +1289,211 @@ func TestPromptTemplate_formatMessages_Jinja2(t *testing.T) { }) } } + +func TestRenderGoTemplate(t *testing.T) { + tests := []struct { + name string + templateStr string + defMap map[string]*VariableDef + valMap map[string]*VariableVal + expected string + expectedError error + }{ + { + name: "simple string variable", + templateStr: "Hello {{.name}}!", + defMap: map[string]*VariableDef{ + "name": {Key: "name", Type: VariableTypeString}, + }, + valMap: map[string]*VariableVal{ + "name": {Key: "name", Value: ptr.Of("John")}, + }, + expected: "Hello John!", + }, + { + name: "multiple variables", + templateStr: "Hello {{.name}}, you are {{.age}} years old.", + defMap: map[string]*VariableDef{ + "name": {Key: "name", Type: VariableTypeString}, + "age": {Key: "age", Type: VariableTypeInteger}, + }, + valMap: map[string]*VariableVal{ + "name": {Key: "name", Value: ptr.Of("John")}, + "age": {Key: "age", Value: ptr.Of("30")}, + }, + expected: "Hello John, you are 30 years old.", + }, + { + name: "boolean variable in condition", + templateStr: "{{if .enabled}}Feature is enabled{{else}}Feature is disabled{{end}}", + defMap: map[string]*VariableDef{ + "enabled": {Key: "enabled", Type: VariableTypeBoolean}, + }, + valMap: map[string]*VariableVal{ + "enabled": {Key: "enabled", Value: ptr.Of("true")}, + }, + expected: "Feature is enabled", + }, + { + name: "boolean variable false in condition", + templateStr: "{{if .enabled}}Feature is enabled{{else}}Feature is disabled{{end}}", + defMap: map[string]*VariableDef{ + "enabled": {Key: "enabled", Type: VariableTypeBoolean}, + }, + valMap: map[string]*VariableVal{ + "enabled": {Key: "enabled", Value: ptr.Of("false")}, + }, + expected: "Feature is disabled", + }, + { + name: "array iteration", + templateStr: "Items: {{range $i, $item := .items}}{{if $i}}, {{end}}{{$item}}{{end}}", + defMap: map[string]*VariableDef{ + "items": {Key: "items", Type: VariableTypeArrayString}, + }, + valMap: map[string]*VariableVal{ + "items": {Key: "items", Value: ptr.Of(`["apple", "banana", "cherry"]`)}, + }, + expected: "Items: apple, banana, cherry", + }, + { + name: "integer variable", + templateStr: "Count: {{.count}}", + defMap: map[string]*VariableDef{ + "count": {Key: "count", Type: VariableTypeInteger}, + }, + valMap: map[string]*VariableVal{ + "count": {Key: "count", Value: ptr.Of("42")}, + }, + expected: "Count: 42", + }, + { + name: "float variable", + templateStr: "Price: ${{.price}}", + defMap: map[string]*VariableDef{ + "price": {Key: "price", Type: VariableTypeFloat}, + }, + valMap: map[string]*VariableVal{ + "price": {Key: "price", Value: ptr.Of("3.14")}, + }, + expected: "Price: $3.14", + }, + { + name: "invalid template syntax", + templateStr: "Hello {{.name", + defMap: map[string]*VariableDef{ + "name": {Key: "name", Type: VariableTypeString}, + }, + valMap: map[string]*VariableVal{ + "name": {Key: "name", Value: ptr.Of("John")}, + }, + expectedError: errorx.NewByCode(prompterr.TemplateParseErrorCode), + }, + { + name: "variable conversion error", + templateStr: "Count: {{.count}}", + defMap: map[string]*VariableDef{ + "count": {Key: "count", Type: VariableTypeInteger}, + }, + valMap: map[string]*VariableVal{ + "count": {Key: "count", Value: ptr.Of("not_a_number")}, + }, + expectedError: errorx.NewByCode(prompterr.CommonInvalidParamCode), + }, + { + name: "template with undefined variable", + templateStr: "Hello {{.undefined_var}}", + defMap: map[string]*VariableDef{}, + valMap: map[string]*VariableVal{}, + expected: "Hello ", + }, + { + name: "empty variable maps", + templateStr: "Hello World", + defMap: map[string]*VariableDef{}, + valMap: map[string]*VariableVal{}, + expected: "Hello World", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := renderGoTemplate(tt.templateStr, tt.defMap, tt.valMap) + unittest.AssertErrorEqual(t, tt.expectedError, err) + if tt.expectedError == nil { + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestFormatText_GoTemplate(t *testing.T) { + tests := []struct { + name string + templateType TemplateType + templateStr string + defMap map[string]*VariableDef + valMap map[string]*VariableVal + expected string + expectedError error + }{ + { + name: "goTemplate template type", + templateType: TemplateTypeGoTemplate, + templateStr: "Hello {{.name}}!", + defMap: map[string]*VariableDef{ + "name": {Key: "name", Type: VariableTypeString}, + }, + valMap: map[string]*VariableVal{ + "name": {Key: "name", Value: ptr.Of("John")}, + }, + expected: "Hello John!", + }, + { + name: "goTemplate with condition", + templateType: TemplateTypeGoTemplate, + templateStr: "{{if .enabled}}Active{{else}}Inactive{{end}}", + defMap: map[string]*VariableDef{ + "enabled": {Key: "enabled", Type: VariableTypeBoolean}, + }, + valMap: map[string]*VariableVal{ + "enabled": {Key: "enabled", Value: ptr.Of("true")}, + }, + expected: "Active", + }, + { + name: "goTemplate parse error", + templateType: TemplateTypeGoTemplate, + templateStr: "Hello {{.name", + defMap: map[string]*VariableDef{ + "name": {Key: "name", Type: VariableTypeString}, + }, + valMap: map[string]*VariableVal{ + "name": {Key: "name", Value: ptr.Of("John")}, + }, + expectedError: errorx.NewByCode(prompterr.TemplateParseErrorCode), + }, + { + name: "goTemplate with integer", + templateType: TemplateTypeGoTemplate, + templateStr: "Count: {{.count}}", + defMap: map[string]*VariableDef{ + "count": {Key: "count", Type: VariableTypeInteger}, + }, + valMap: map[string]*VariableVal{ + "count": {Key: "count", Value: ptr.Of("100")}, + }, + expected: "Count: 100", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := formatText(tt.templateType, tt.templateStr, tt.defMap, tt.valMap) + unittest.AssertErrorEqual(t, tt.expectedError, err) + if tt.expectedError == nil { + assert.Equal(t, tt.expected, result) + } + }) + } +} diff --git a/backend/modules/prompt/pkg/template/go_template.go b/backend/modules/prompt/pkg/template/go_template.go new file mode 100644 index 000000000..0ed67eefe --- /dev/null +++ b/backend/modules/prompt/pkg/template/go_template.go @@ -0,0 +1,29 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package template + +import ( + "bytes" + "text/template" + + prompterr "github.com/coze-dev/coze-loop/backend/modules/prompt/pkg/errno" + "github.com/coze-dev/coze-loop/backend/pkg/errorx" +) + +func InterpolateGoTemplate(templateStr string, variables map[string]any) (string, error) { + // 解析模板 + tpl, err := template.New("prompt").Parse(templateStr) + if err != nil { + return "", errorx.NewByCode(prompterr.TemplateParseErrorCode, errorx.WithExtraMsg(err.Error())) + } + + // 执行模板渲染 + var out bytes.Buffer + err = tpl.Execute(&out, variables) + if err != nil { + return "", errorx.NewByCode(prompterr.TemplateRenderErrorCode, errorx.WithExtraMsg(err.Error())) + } + + return out.String(), nil +} diff --git a/backend/modules/prompt/pkg/template/go_template_test.go b/backend/modules/prompt/pkg/template/go_template_test.go new file mode 100644 index 000000000..83e89258a --- /dev/null +++ b/backend/modules/prompt/pkg/template/go_template_test.go @@ -0,0 +1,318 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package template + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + prompterr "github.com/coze-dev/coze-loop/backend/modules/prompt/pkg/errno" + "github.com/coze-dev/coze-loop/backend/pkg/errorx" +) + +func TestInterpolateGoTemplate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + templateStr string + variables map[string]any + want string + wantErr bool + errCode int32 + }{ + { + name: "simple string interpolation", + templateStr: "Hello, {{.name}}!", + variables: map[string]any{"name": "World"}, + want: "Hello, World!", + wantErr: false, + }, + { + name: "multiple variables", + templateStr: "{{.greeting}}, {{.name}}! You are {{.age}} years old.", + variables: map[string]any{ + "greeting": "Hi", + "name": "Alice", + "age": 30, + }, + want: "Hi, Alice! You are 30 years old.", + wantErr: false, + }, + { + name: "integer variable", + templateStr: "The answer is {{.number}}", + variables: map[string]any{"number": 42}, + want: "The answer is 42", + wantErr: false, + }, + { + name: "float variable", + templateStr: "Pi is approximately {{.pi}}", + variables: map[string]any{"pi": 3.14159}, + want: "Pi is approximately 3.14159", + wantErr: false, + }, + { + name: "boolean variable", + templateStr: "Is active: {{.active}}", + variables: map[string]any{"active": true}, + want: "Is active: true", + wantErr: false, + }, + { + name: "empty template", + templateStr: "", + variables: map[string]any{}, + want: "", + wantErr: false, + }, + { + name: "template with no variables", + templateStr: "Static text without variables", + variables: map[string]any{}, + want: "Static text without variables", + wantErr: false, + }, + { + name: "empty variables map", + templateStr: "Hello, World!", + variables: map[string]any{}, + want: "Hello, World!", + wantErr: false, + }, + { + name: "nil variables map", + templateStr: "Hello, World!", + variables: nil, + want: "Hello, World!", + wantErr: false, + }, + { + name: "conditional in template", + templateStr: "{{if .show}}Visible{{else}}Hidden{{end}}", + variables: map[string]any{"show": true}, + want: "Visible", + wantErr: false, + }, + { + name: "range over slice", + templateStr: "{{range .items}}{{.}},{{end}}", + variables: map[string]any{"items": []string{"a", "b", "c"}}, + want: "a,b,c,", + wantErr: false, + }, + { + name: "nested object access", + templateStr: "{{.user.name}} is {{.user.age}} years old", + variables: map[string]any{ + "user": map[string]any{ + "name": "Bob", + "age": 25, + }, + }, + want: "Bob is 25 years old", + wantErr: false, + }, + { + name: "template with newlines", + templateStr: "Line 1: {{.line1}}\nLine 2: {{.line2}}", + variables: map[string]any{ + "line1": "First", + "line2": "Second", + }, + want: "Line 1: First\nLine 2: Second", + wantErr: false, + }, + { + name: "template parse error - unclosed action", + templateStr: "Hello, {{.name!", + variables: map[string]any{"name": "World"}, + want: "", + wantErr: true, + errCode: prompterr.TemplateParseErrorCode, + }, + { + name: "template parse error - invalid syntax", + templateStr: "Hello, {{..name}}", + variables: map[string]any{"name": "World"}, + want: "", + wantErr: true, + errCode: prompterr.TemplateParseErrorCode, + }, + { + name: "missing variable returns no value", + templateStr: "Hello, {{.name}}!", + variables: map[string]any{}, + want: "Hello, !", + wantErr: false, + }, + { + name: "accessing field on non-existent map key returns no value", + templateStr: "{{.user.name}}", + variables: map[string]any{}, + want: "", + wantErr: false, + }, + { + name: "special characters in string", + templateStr: "Special: {{.text}}", + variables: map[string]any{"text": "<>&\"'"}, + want: "Special: <>&\"'", + wantErr: false, + }, + { + name: "unicode characters", + templateStr: "你好,{{.name}}!", + variables: map[string]any{"name": "世界"}, + want: "你好,世界!", + wantErr: false, + }, + { + name: "with function - printf", + templateStr: "{{printf \"Hello, %s!\" .name}}", + variables: map[string]any{"name": "World"}, + want: "Hello, World!", + wantErr: false, + }, + { + name: "template execution error - index out of range", + templateStr: "{{index .items 10}}", + variables: map[string]any{"items": []int{1, 2, 3}}, + want: "", + wantErr: true, + errCode: prompterr.TemplateRenderErrorCode, + }, + { + name: "template execution error - invalid index operation on non-indexable type", + templateStr: "{{index .value 0}}", + variables: map[string]any{"value": 42}, + want: "", + wantErr: true, + errCode: prompterr.TemplateRenderErrorCode, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := InterpolateGoTemplate(tt.templateStr, tt.variables) + + if tt.wantErr { + assert.Error(t, err) + if tt.errCode != 0 { + statusErr, ok := errorx.FromStatusError(err) + if assert.True(t, ok, "Error should be a StatusError") { + assert.Equal(t, tt.errCode, statusErr.Code()) + } + } + assert.Empty(t, got) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +// TestInterpolateGoTemplate_ComplexScenarios tests more complex template scenarios +func TestInterpolateGoTemplate_ComplexScenarios(t *testing.T) { + t.Parallel() + + t.Run("template with multiple operations", func(t *testing.T) { + t.Parallel() + templateStr := ` +{{- if .showGreeting -}} +Hello, {{.user.name}}! +{{- end -}} +{{- if .showDetails -}} +Age: {{.user.age}} +City: {{.user.city}} +{{- end -}} +` + variables := map[string]any{ + "showGreeting": true, + "showDetails": true, + "user": map[string]any{ + "name": "Alice", + "age": 30, + "city": "NYC", + }, + } + + got, err := InterpolateGoTemplate(templateStr, variables) + assert.NoError(t, err) + assert.Contains(t, got, "Hello, Alice!") + assert.Contains(t, got, "Age: 30") + assert.Contains(t, got, "City: NYC") + }) + + t.Run("template with array of structs", func(t *testing.T) { + t.Parallel() + templateStr := `Users: +{{range .users}} +- {{.name}} ({{.role}}) +{{end}}` + variables := map[string]any{ + "users": []map[string]any{ + {"name": "Alice", "role": "Admin"}, + {"name": "Bob", "role": "User"}, + }, + } + + got, err := InterpolateGoTemplate(templateStr, variables) + assert.NoError(t, err) + assert.Contains(t, got, "Alice (Admin)") + assert.Contains(t, got, "Bob (User)") + }) +} + +// TestInterpolateGoTemplate_EdgeCases tests edge cases +func TestInterpolateGoTemplate_EdgeCases(t *testing.T) { + t.Parallel() + + t.Run("very long template", func(t *testing.T) { + t.Parallel() + var templateStr string + for i := 0; i < 100; i++ { + templateStr += "{{.value}}" + } + variables := map[string]any{"value": "x"} + + got, err := InterpolateGoTemplate(templateStr, variables) + assert.NoError(t, err) + assert.Len(t, got, 100) + }) + + t.Run("deeply nested structure", func(t *testing.T) { + t.Parallel() + templateStr := "{{.a.b.c.d.e}}" + variables := map[string]any{ + "a": map[string]any{ + "b": map[string]any{ + "c": map[string]any{ + "d": map[string]any{ + "e": "deep", + }, + }, + }, + }, + } + + got, err := InterpolateGoTemplate(templateStr, variables) + assert.NoError(t, err) + assert.Equal(t, "deep", got) + }) + + t.Run("nil value in map", func(t *testing.T) { + t.Parallel() + templateStr := "Value: {{.value}}" + variables := map[string]any{"value": nil} + + got, err := InterpolateGoTemplate(templateStr, variables) + assert.NoError(t, err) + assert.Equal(t, "Value: ", got) + }) +} diff --git a/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift b/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift index 038b4e83d..0ebc2eebb 100644 --- a/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift +++ b/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift @@ -102,6 +102,7 @@ struct PromptTemplate { typedef string TemplateType const TemplateType TemplateType_Normal = "normal" const TemplateType TemplateType_Jinja2 = "jinja2" +const TemplateType TemplateType_GoTemplate = "go_template" typedef string ToolChoiceType diff --git a/idl/thrift/coze/loop/prompt/domain/prompt.thrift b/idl/thrift/coze/loop/prompt/domain/prompt.thrift index 76b874e14..e715ef07d 100644 --- a/idl/thrift/coze/loop/prompt/domain/prompt.thrift +++ b/idl/thrift/coze/loop/prompt/domain/prompt.thrift @@ -68,6 +68,7 @@ struct PromptTemplate { typedef string TemplateType (ts.enum="true") const TemplateType TemplateType_Normal = "normal" const TemplateType TemplateType_Jinja2 = "jinja2" +const TemplateType TemplateType_GoTemplate = "go_template" struct Tool { 1: optional ToolType type From a23764dc1e8d162b37ed65f640e59be5d8944d48 Mon Sep 17 00:00:00 2001 From: caijialin0626 <61818131+caijialin0626@users.noreply.github.com> Date: Thu, 30 Oct 2025 17:10:25 +0800 Subject: [PATCH 08/12] [refactor][prompt] prompt support custom format (#271) * [feat][prompt] prompt support custom format * [feat][prompt] prompt support custom format * [feat][prompt] prompt support custom format * [feat][prompt] prompt support custom format * [feat][prompt] prompt support custom format * [feat][prompt] prompt support custom format * [feat][prompt] prompt support custom format * [feat][prompt] prompt support custom format * [feat][prompt] prompt support custom format * [feat][prompt] prompt support custom format * [feat][prompt] prompt support custom format --- .../coze/loop/prompt/domain/prompt/prompt.go | 2 + .../openapi/coze.loop.prompt.openapi.go | 2 + .../prompt/application/convertor/prompt.go | 2 + .../application/convertor/prompt_test.go | 26 ++ .../modules/prompt/application/debug_test.go | 8 +- .../prompt/application/execute_test.go | 8 +- .../prompt/application/openapi_test.go | 12 +- backend/modules/prompt/application/wire.go | 1 + .../modules/prompt/application/wire_gen.go | 14 +- .../prompt/domain/entity/prompt_detail.go | 7 +- .../modules/prompt/domain/service/execute.go | 28 +- .../prompt/domain/service/execute_test.go | 105 ++++++- .../prompt/domain/service/formatter.go | 60 ++++ .../prompt/domain/service/formatter_test.go | 270 ++++++++++++++++++ .../prompt/domain/service/mocks/formatter.go | 57 ++++ .../domain/service/mocks/prompt_service.go | 12 +- .../modules/prompt/domain/service/service.go | 3 + .../prompt/domain/service/service_test.go | 96 +++++++ .../prompt/coze.loop.prompt.openapi.thrift | 1 + .../coze/loop/prompt/domain/prompt.thrift | 1 + 20 files changed, 656 insertions(+), 59 deletions(-) create mode 100644 backend/modules/prompt/domain/service/formatter.go create mode 100644 backend/modules/prompt/domain/service/formatter_test.go create mode 100644 backend/modules/prompt/domain/service/mocks/formatter.go create mode 100644 backend/modules/prompt/domain/service/service_test.go diff --git a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go index 0b5b8575e..b0666a814 100644 --- a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go +++ b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go @@ -15,6 +15,8 @@ const ( TemplateTypeGoTemplate = "go_template" + TemplateTypeCustomTemplateM = "custom_template_m" + ToolTypeFunction = "function" ToolChoiceTypeNone = "none" diff --git a/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go b/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go index d509564b1..9dc3751ec 100644 --- a/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go +++ b/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go @@ -18,6 +18,8 @@ const ( TemplateTypeGoTemplate = "go_template" + TemplateTypeCustomTemplateM = "custom_template_m" + ToolChoiceTypeAuto = "auto" ToolChoiceTypeNone = "none" diff --git a/backend/modules/prompt/application/convertor/prompt.go b/backend/modules/prompt/application/convertor/prompt.go index c1448fafb..154f0875c 100644 --- a/backend/modules/prompt/application/convertor/prompt.go +++ b/backend/modules/prompt/application/convertor/prompt.go @@ -121,6 +121,8 @@ func TemplateTypeDTO2DO(dto prompt.TemplateType) entity.TemplateType { return entity.TemplateTypeJinja2 case prompt.TemplateTypeGoTemplate: return entity.TemplateTypeGoTemplate + case prompt.TemplateTypeCustomTemplateM: + return entity.TemplateTYpeCustomTemplateM default: return entity.TemplateTypeNormal } diff --git a/backend/modules/prompt/application/convertor/prompt_test.go b/backend/modules/prompt/application/convertor/prompt_test.go index 18d1b9aef..4e9b15311 100644 --- a/backend/modules/prompt/application/convertor/prompt_test.go +++ b/backend/modules/prompt/application/convertor/prompt_test.go @@ -687,6 +687,11 @@ func TestTemplateTypeDTO2DO(t *testing.T) { dto: prompt.TemplateTypeGoTemplate, want: entity.TemplateTypeGoTemplate, }, + { + name: "custom template m type", + dto: prompt.TemplateTypeCustomTemplateM, + want: entity.TemplateTYpeCustomTemplateM, + }, { name: "unknown template type defaults to normal", dto: prompt.TemplateType("unknown"), @@ -774,6 +779,27 @@ func TestPromptTemplateWithDifferentTypes(t *testing.T) { }, }, }, + { + name: "custom template m", + dto: &prompt.PromptTemplate{ + TemplateType: ptr.Of(prompt.TemplateTypeCustomTemplateM), + Messages: []*prompt.Message{ + { + Role: ptr.Of(prompt.RoleUser), + Content: ptr.Of("Hello world"), + }, + }, + }, + want: &entity.PromptTemplate{ + TemplateType: entity.TemplateTYpeCustomTemplateM, + Messages: []*entity.Message{ + { + Role: entity.RoleUser, + Content: ptr.Of("Hello world"), + }, + }, + }, + }, } for _, tt := range tests { diff --git a/backend/modules/prompt/application/debug_test.go b/backend/modules/prompt/application/debug_test.go index 2c0783e40..7b79c147d 100644 --- a/backend/modules/prompt/application/debug_test.go +++ b/backend/modules/prompt/application/debug_test.go @@ -65,7 +65,7 @@ func TestPromptDebugApplicationImpl_DebugStreaming(t *testing.T) { mockDebugLogRepo.EXPECT().SaveDebugLog(gomock.Any(), gomock.Any()).Return(nil) mockPromptSvc := servicemocks.NewMockIPromptService(ctrl) mockPromptSvc.EXPECT().MCompleteMultiModalFileURL(gomock.Any(), gomock.Any(), nil).Return(nil) - mockPromptSvc.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockPromptSvc.EXPECT().MConvertBase64DataURLToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptSvc.EXPECT().ExecuteStreaming(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, param service.ExecuteStreamingParam) (*entity.Reply, error) { for _, v := range mockContent { param.ResultStream <- &entity.Reply{ @@ -127,7 +127,7 @@ func TestPromptDebugApplicationImpl_DebugStreaming(t *testing.T) { mockDebugLogRepo.EXPECT().SaveDebugLog(gomock.Any(), gomock.Any()).Return(nil) mockPromptSvc := servicemocks.NewMockIPromptService(ctrl) mockPromptSvc.EXPECT().MCompleteMultiModalFileURL(gomock.Any(), gomock.Any(), nil).Return(nil) - mockPromptSvc.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockPromptSvc.EXPECT().MConvertBase64DataURLToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptSvc.EXPECT().ExecuteStreaming(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, param service.ExecuteStreamingParam) (*entity.Reply, error) { for _, v := range mockContent { param.ResultStream <- &entity.Reply{ @@ -250,7 +250,7 @@ func TestPromptDebugApplicationImpl_DebugStreaming(t *testing.T) { }, nil }) convertErr := errors.New("convert error") - mockPromptSvc.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(convertErr) + mockPromptSvc.EXPECT().MConvertBase64DataURLToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(convertErr) mockBenefitSvc := benefitmocks.NewMockIBenefitService(ctrl) mockBenefitSvc.EXPECT().CheckPromptBenefit(gomock.Any(), gomock.Any()).Return(&benefit.CheckPromptBenefitResult{}, nil) mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) @@ -292,7 +292,7 @@ func TestPromptDebugApplicationImpl_DebugStreaming(t *testing.T) { mockDebugLogRepo.EXPECT().SaveDebugLog(gomock.Any(), gomock.Any()).Return(nil) mockPromptSvc := servicemocks.NewMockIPromptService(ctrl) mockPromptSvc.EXPECT().MCompleteMultiModalFileURL(gomock.Any(), gomock.Any(), nil).Return(nil) - mockPromptSvc.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockPromptSvc.EXPECT().MConvertBase64DataURLToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptSvc.EXPECT().ExecuteStreaming(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, param service.ExecuteStreamingParam) (*entity.Reply, error) { panic("mock panic") }) diff --git a/backend/modules/prompt/application/execute_test.go b/backend/modules/prompt/application/execute_test.go index 0400c9ffe..0b64848c0 100755 --- a/backend/modules/prompt/application/execute_test.go +++ b/backend/modules/prompt/application/execute_test.go @@ -110,7 +110,7 @@ func TestPromptExecuteApplicationImpl_ExecuteInternal(t *testing.T) { mockPromptService := servicemocks.NewMockIPromptService(ctrl) mockPromptService.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(mockReply, nil) - mockPromptService.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockPromptService.EXPECT().MConvertBase64DataURLToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) return fields{ promptService: mockPromptService, @@ -148,7 +148,7 @@ func TestPromptExecuteApplicationImpl_ExecuteInternal(t *testing.T) { mockPromptService := servicemocks.NewMockIPromptService(ctrl) mockPromptService.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(mockReply, nil) - mockPromptService.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("convert error")) + mockPromptService.EXPECT().MConvertBase64DataURLToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("convert error")) return fields{ promptService: mockPromptService, @@ -225,7 +225,7 @@ func TestPromptExecuteApplicationImpl_ExecuteInternal(t *testing.T) { mockPromptService := servicemocks.NewMockIPromptService(ctrl) mockPromptService.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(mockReply, nil) - mockPromptService.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockPromptService.EXPECT().MConvertBase64DataURLToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) return fields{ promptService: mockPromptService, @@ -269,7 +269,7 @@ func TestPromptExecuteApplicationImpl_ExecuteInternal(t *testing.T) { mockPromptService := servicemocks.NewMockIPromptService(ctrl) mockPromptService.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(mockReply, nil) - mockPromptService.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockPromptService.EXPECT().MConvertBase64DataURLToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) return fields{ promptService: mockPromptService, diff --git a/backend/modules/prompt/application/openapi_test.go b/backend/modules/prompt/application/openapi_test.go index 1fdedca53..3fd599ec0 100644 --- a/backend/modules/prompt/application/openapi_test.go +++ b/backend/modules/prompt/application/openapi_test.go @@ -2361,7 +2361,7 @@ func TestPromptOpenAPIApplicationImpl_doExecute(t *testing.T) { }, } mockPromptService.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(expectedReply, nil) - mockPromptService.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockPromptService.EXPECT().MConvertBase64DataURLToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) return fields{ promptService: mockPromptService, @@ -2457,7 +2457,7 @@ func TestPromptOpenAPIApplicationImpl_doExecute(t *testing.T) { }, } mockPromptService.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(expectedReply, nil) - mockPromptService.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("convert error")) + mockPromptService.EXPECT().MConvertBase64DataURLToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("convert error")) return fields{ promptService: mockPromptService, @@ -2810,7 +2810,7 @@ func TestPromptOpenAPIApplicationImpl_Execute(t *testing.T) { }, } mockPromptService.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(expectedReply, nil) - mockPromptService.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockPromptService.EXPECT().MConvertBase64DataURLToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) mockCollector := collectormocks.NewMockICollectorProvider(ctrl) mockCollector.EXPECT().CollectPTaaSEvent(gomock.Any(), gomock.Any()).Return() @@ -2904,7 +2904,7 @@ func TestPromptOpenAPIApplicationImpl_Execute(t *testing.T) { }, } mockPromptService.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(expectedReply, nil) - mockPromptService.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("convert error")) + mockPromptService.EXPECT().MConvertBase64DataURLToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("convert error")) mockCollector := collectormocks.NewMockICollectorProvider(ctrl) mockCollector.EXPECT().CollectPTaaSEvent(gomock.Any(), gomock.Any()).Return() @@ -3450,7 +3450,7 @@ func TestPromptOpenAPIApplicationImpl_ExecuteStreaming(t *testing.T) { assert.Len(t, calls, 0) }, setupConvertMock: func(mockSvc *servicemocks.MockIPromptService) { - mockSvc.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("convert error")) + mockSvc.EXPECT().MConvertBase64DataURLToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("convert error")) }, }, { @@ -4317,7 +4317,7 @@ func TestPromptOpenAPIApplicationImpl_ExecuteStreaming(t *testing.T) { if tt.setupConvertMock != nil { tt.setupConvertMock(mockSvc) } else { - mockSvc.EXPECT().MConvertBase64ToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockSvc.EXPECT().MConvertBase64DataURLToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() } } ttArgs := tt.argsGetter(ctrl) diff --git a/backend/modules/prompt/application/wire.go b/backend/modules/prompt/application/wire.go index 6a18eff86..37359831b 100644 --- a/backend/modules/prompt/application/wire.go +++ b/backend/modules/prompt/application/wire.go @@ -36,6 +36,7 @@ import ( var ( promptDomainSet = wire.NewSet( + service.NewPromptFormatter, service.NewPromptService, repo.NewManageRepo, repo.NewLabelRepo, diff --git a/backend/modules/prompt/application/wire_gen.go b/backend/modules/prompt/application/wire_gen.go index bcbbfb37b..b506425eb 100644 --- a/backend/modules/prompt/application/wire_gen.go +++ b/backend/modules/prompt/application/wire_gen.go @@ -50,13 +50,14 @@ func InitPromptManageApplication(idgen2 idgen.IIDGenerator, db2 db.Provider, red } iPromptLabelVersionDAO := redis2.NewPromptLabelVersionDAO(redisCli, iConfigProvider) iLabelRepo := repo.NewLabelRepo(db2, idgen2, meter, iLabelDAO, iCommitLabelMappingDAO, iPromptBasicDAO, iPromptLabelVersionDAO) + iPromptFormatter := service.NewPromptFormatter() iDebugLogDAO := mysql.NewDebugLogDAO(db2) iDebugLogRepo := repo.NewDebugLogRepo(idgen2, iDebugLogDAO) iDebugContextDAO := mysql.NewDebugContextDAO(db2) iDebugContextRepo := repo.NewDebugContextRepo(idgen2, iDebugContextDAO) illmProvider := rpc.NewLLMRPCProvider(llmClient) iFileProvider := rpc.NewFileRPCProvider(fileClient) - iPromptService := service.NewPromptService(idgen2, iDebugLogRepo, iDebugContextRepo, iManageRepo, iLabelRepo, iConfigProvider, illmProvider, iFileProvider) + iPromptService := service.NewPromptService(iPromptFormatter, idgen2, iDebugLogRepo, iDebugContextRepo, iManageRepo, iLabelRepo, iConfigProvider, illmProvider, iFileProvider) iAuthProvider := rpc.NewAuthRPCProvider(authClient) iUserProvider := rpc.NewUserRPCProvider(userClient) iAuditProvider := rpc.NewAuditRPCProvider(auditClient) @@ -69,6 +70,7 @@ func InitPromptDebugApplication(idgen2 idgen.IIDGenerator, db2 db.Provider, redi iDebugLogRepo := repo.NewDebugLogRepo(idgen2, iDebugLogDAO) iDebugContextDAO := mysql.NewDebugContextDAO(db2) iDebugContextRepo := repo.NewDebugContextRepo(idgen2, iDebugContextDAO) + iPromptFormatter := service.NewPromptFormatter() iPromptBasicDAO := mysql.NewPromptBasicDAO(db2, redisCli) iPromptCommitDAO := mysql.NewPromptCommitDAO(db2, redisCli) iPromptUserDraftDAO := mysql.NewPromptUserDraftDAO(db2, redisCli) @@ -85,13 +87,14 @@ func InitPromptDebugApplication(idgen2 idgen.IIDGenerator, db2 db.Provider, redi iLabelRepo := repo.NewLabelRepo(db2, idgen2, meter, iLabelDAO, iCommitLabelMappingDAO, iPromptBasicDAO, iPromptLabelVersionDAO) illmProvider := rpc.NewLLMRPCProvider(llmClient) iFileProvider := rpc.NewFileRPCProvider(fileClient) - iPromptService := service.NewPromptService(idgen2, iDebugLogRepo, iDebugContextRepo, iManageRepo, iLabelRepo, iConfigProvider, illmProvider, iFileProvider) + iPromptService := service.NewPromptService(iPromptFormatter, idgen2, iDebugLogRepo, iDebugContextRepo, iManageRepo, iLabelRepo, iConfigProvider, illmProvider, iFileProvider) iAuthProvider := rpc.NewAuthRPCProvider(authClient) promptDebugService := NewPromptDebugApplication(iDebugLogRepo, iDebugContextRepo, iPromptService, benefitSvc, iAuthProvider, iFileProvider) return promptDebugService, nil } func InitPromptExecuteApplication(idgen2 idgen.IIDGenerator, db2 db.Provider, redisCli redis.Cmdable, meter metrics.Meter, configFactory conf.IConfigLoaderFactory, llmClient llmruntimeservice.Client, fileClient fileservice.Client) (execute.PromptExecuteService, error) { + iPromptFormatter := service.NewPromptFormatter() iDebugLogDAO := mysql.NewDebugLogDAO(db2) iDebugLogRepo := repo.NewDebugLogRepo(idgen2, iDebugLogDAO) iDebugContextDAO := mysql.NewDebugContextDAO(db2) @@ -112,12 +115,13 @@ func InitPromptExecuteApplication(idgen2 idgen.IIDGenerator, db2 db.Provider, re iLabelRepo := repo.NewLabelRepo(db2, idgen2, meter, iLabelDAO, iCommitLabelMappingDAO, iPromptBasicDAO, iPromptLabelVersionDAO) illmProvider := rpc.NewLLMRPCProvider(llmClient) iFileProvider := rpc.NewFileRPCProvider(fileClient) - iPromptService := service.NewPromptService(idgen2, iDebugLogRepo, iDebugContextRepo, iManageRepo, iLabelRepo, iConfigProvider, illmProvider, iFileProvider) + iPromptService := service.NewPromptService(iPromptFormatter, idgen2, iDebugLogRepo, iDebugContextRepo, iManageRepo, iLabelRepo, iConfigProvider, illmProvider, iFileProvider) promptExecuteService := NewPromptExecuteApplication(iPromptService, iManageRepo) return promptExecuteService, nil } func InitPromptOpenAPIApplication(idgen2 idgen.IIDGenerator, db2 db.Provider, redisCli redis.Cmdable, meter metrics.Meter, configFactory conf.IConfigLoaderFactory, limiterFactory limiter.IRateLimiterFactory, llmClient llmruntimeservice.Client, authClient authservice.Client, fileClient fileservice.Client) (openapi.PromptOpenAPIService, error) { + iPromptFormatter := service.NewPromptFormatter() iDebugLogDAO := mysql.NewDebugLogDAO(db2) iDebugLogRepo := repo.NewDebugLogRepo(idgen2, iDebugLogDAO) iDebugContextDAO := mysql.NewDebugContextDAO(db2) @@ -138,7 +142,7 @@ func InitPromptOpenAPIApplication(idgen2 idgen.IIDGenerator, db2 db.Provider, re iLabelRepo := repo.NewLabelRepo(db2, idgen2, meter, iLabelDAO, iCommitLabelMappingDAO, iPromptBasicDAO, iPromptLabelVersionDAO) illmProvider := rpc.NewLLMRPCProvider(llmClient) iFileProvider := rpc.NewFileRPCProvider(fileClient) - iPromptService := service.NewPromptService(idgen2, iDebugLogRepo, iDebugContextRepo, iManageRepo, iLabelRepo, iConfigProvider, illmProvider, iFileProvider) + iPromptService := service.NewPromptService(iPromptFormatter, idgen2, iDebugLogRepo, iDebugContextRepo, iManageRepo, iLabelRepo, iConfigProvider, illmProvider, iFileProvider) iAuthProvider := rpc.NewAuthRPCProvider(authClient) iCollectorProvider := collector.NewEventCollectorProvider() promptOpenAPIService, err := NewPromptOpenAPIApplication(iPromptService, iManageRepo, iConfigProvider, iAuthProvider, limiterFactory, iCollectorProvider) @@ -151,7 +155,7 @@ func InitPromptOpenAPIApplication(idgen2 idgen.IIDGenerator, db2 db.Provider, re // wire.go: var ( - promptDomainSet = wire.NewSet(service.NewPromptService, repo.NewManageRepo, repo.NewLabelRepo, repo.NewDebugLogRepo, repo.NewDebugContextRepo, mysql.NewPromptBasicDAO, mysql.NewPromptCommitDAO, mysql.NewPromptUserDraftDAO, mysql.NewLabelDAO, mysql.NewCommitLabelMappingDAO, mysql.NewDebugLogDAO, mysql.NewDebugContextDAO, redis2.NewPromptBasicDAO, redis2.NewPromptDAO, redis2.NewPromptLabelVersionDAO, conf2.NewPromptConfigProvider, rpc.NewLLMRPCProvider, rpc.NewAuthRPCProvider, rpc.NewFileRPCProvider, rpc.NewUserRPCProvider, rpc.NewAuditRPCProvider, collector.NewEventCollectorProvider) + promptDomainSet = wire.NewSet(service.NewPromptFormatter, service.NewPromptService, repo.NewManageRepo, repo.NewLabelRepo, repo.NewDebugLogRepo, repo.NewDebugContextRepo, mysql.NewPromptBasicDAO, mysql.NewPromptCommitDAO, mysql.NewPromptUserDraftDAO, mysql.NewLabelDAO, mysql.NewCommitLabelMappingDAO, mysql.NewDebugLogDAO, mysql.NewDebugContextDAO, redis2.NewPromptBasicDAO, redis2.NewPromptDAO, redis2.NewPromptLabelVersionDAO, conf2.NewPromptConfigProvider, rpc.NewLLMRPCProvider, rpc.NewAuthRPCProvider, rpc.NewFileRPCProvider, rpc.NewUserRPCProvider, rpc.NewAuditRPCProvider, collector.NewEventCollectorProvider) manageSet = wire.NewSet( NewPromptManageApplication, promptDomainSet, diff --git a/backend/modules/prompt/domain/entity/prompt_detail.go b/backend/modules/prompt/domain/entity/prompt_detail.go index 1bcc8434b..7b608298c 100644 --- a/backend/modules/prompt/domain/entity/prompt_detail.go +++ b/backend/modules/prompt/domain/entity/prompt_detail.go @@ -43,9 +43,10 @@ type PromptTemplate struct { type TemplateType string const ( - TemplateTypeNormal TemplateType = "normal" - TemplateTypeJinja2 TemplateType = "jinja2" - TemplateTypeGoTemplate TemplateType = "go_template" + TemplateTypeNormal TemplateType = "normal" + TemplateTypeJinja2 TemplateType = "jinja2" + TemplateTypeGoTemplate TemplateType = "go_template" + TemplateTYpeCustomTemplateM TemplateType = "custom_template_m" ) type Message struct { diff --git a/backend/modules/prompt/domain/service/execute.go b/backend/modules/prompt/domain/service/execute.go index e04c4afbf..eb29da0d8 100644 --- a/backend/modules/prompt/domain/service/execute.go +++ b/backend/modules/prompt/domain/service/execute.go @@ -54,28 +54,8 @@ type ExecuteStreamingParam struct { } func (p *PromptServiceImpl) FormatPrompt(ctx context.Context, prompt *entity.Prompt, messages []*entity.Message, variableVals []*entity.VariableVal) (formattedMessages []*entity.Message, err error) { - if parentSpan := looptracer.GetTracer().GetSpanFromContext(ctx); parentSpan != nil { - var span looptracer.Span - ctx, span = looptracer.GetTracer().StartSpan(ctx, consts.SpanNamePromptTemplate, tracespec.VPromptTemplateSpanType, looptracer.WithSpanWorkspaceID(strconv.FormatInt(prompt.SpaceID, 10))) - if span != nil { - span.SetPrompt(ctx, loopentity.Prompt{PromptKey: prompt.PromptKey, Version: prompt.GetVersion()}) - span.SetInput(ctx, json.Jsonify(tracespec.PromptInput{ - Templates: trace.MessagesToSpanMessages(prompt.GetTemplateMessages(messages)), - Arguments: trace.VariableValsToSpanPromptVariables(variableVals), - })) - defer func() { - span.SetOutput(ctx, json.Jsonify(tracespec.PromptOutput{ - Prompts: trace.MessagesToSpanMessages(formattedMessages), - })) - if err != nil { - span.SetStatusCode(ctx, int(traceutil.GetTraceStatusCode(err))) - span.SetError(ctx, errors.New(errorx.ErrorWithoutStack(err))) - } - span.Finish(ctx) - }() - } - } - return prompt.FormatMessages(messages, variableVals) + // Delegate to the formatter interface + return p.formatter.FormatPrompt(ctx, prompt, messages, variableVals) } func (p *PromptServiceImpl) ExecuteStreaming(ctx context.Context, param ExecuteStreamingParam) (aggregatedReply *entity.Reply, err error) { @@ -395,8 +375,8 @@ func (p *PromptServiceImpl) finishSequenceSpan(ctx context.Context, span cozeloo } func (p *PromptServiceImpl) prepareLLMCallParam(ctx context.Context, param ExecuteParam) (rpc.LLMCallParam, error) { - // format messages - messages, err := p.FormatPrompt(ctx, param.Prompt, param.Messages, param.VariableVals) + // format messages using the formatter interface + messages, err := p.formatter.FormatPrompt(ctx, param.Prompt, param.Messages, param.VariableVals) if err != nil { return rpc.LLMCallParam{}, err } diff --git a/backend/modules/prompt/domain/service/execute_test.go b/backend/modules/prompt/domain/service/execute_test.go index 1f97f1315..e22fbcfe6 100644 --- a/backend/modules/prompt/domain/service/execute_test.go +++ b/backend/modules/prompt/domain/service/execute_test.go @@ -17,6 +17,7 @@ import ( rpcmocks "github.com/coze-dev/coze-loop/backend/modules/prompt/domain/component/rpc/mocks" "github.com/coze-dev/coze-loop/backend/modules/prompt/domain/entity" "github.com/coze-dev/coze-loop/backend/modules/prompt/domain/repo" + prompterr "github.com/coze-dev/coze-loop/backend/modules/prompt/pkg/errno" "github.com/coze-dev/coze-loop/backend/pkg/errorx" "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" "github.com/coze-dev/coze-loop/backend/pkg/unittest" @@ -226,6 +227,7 @@ func TestPromptServiceImpl_FormatPrompt(t *testing.T) { ttFields := tt.fieldsGetter(ctrl) p := &PromptServiceImpl{ + formatter: NewPromptFormatter(), idgen: ttFields.idgen, debugLogRepo: ttFields.debugLogRepo, debugContextRepo: ttFields.debugContextRepo, @@ -248,7 +250,9 @@ func TestPromptServiceImpl_ExecuteStreaming(t *testing.T) { t.Run("nil prompt", func(t *testing.T) { t.Parallel() - p := &PromptServiceImpl{} + p := &PromptServiceImpl{ + formatter: NewPromptFormatter(), + } param := ExecuteStreamingParam{ ExecuteParam: ExecuteParam{ Prompt: nil, @@ -262,7 +266,9 @@ func TestPromptServiceImpl_ExecuteStreaming(t *testing.T) { t.Run("nil result stream", func(t *testing.T) { t.Parallel() - p := &PromptServiceImpl{} + p := &PromptServiceImpl{ + formatter: NewPromptFormatter(), + } param := ExecuteStreamingParam{ ExecuteParam: ExecuteParam{ Prompt: &entity.Prompt{}, @@ -327,8 +333,9 @@ func TestPromptServiceImpl_ExecuteStreaming(t *testing.T) { DebugStep: 1, } p := &PromptServiceImpl{ - idgen: mockIDGen, - llm: mockLLM, + formatter: NewPromptFormatter(), + idgen: mockIDGen, + llm: mockLLM, } stream := make(chan *entity.Reply) @@ -527,8 +534,9 @@ func TestPromptServiceImpl_ExecuteStreaming(t *testing.T) { DebugStep: 2, } p := &PromptServiceImpl{ - idgen: mockIDGen, - llm: mockLLM, + formatter: NewPromptFormatter(), + idgen: mockIDGen, + llm: mockLLM, } stream := make(chan *entity.Reply) @@ -824,6 +832,86 @@ func TestPromptServiceImpl_Execute(t *testing.T) { DebugStep: 2, }, }, + { + name: "error_llm_call_failed", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockIDGen := idgenmocks.NewMockIIDGenerator(ctrl) + mockIDGen.EXPECT().GenID(gomock.Any()).Return(int64(123456789), nil) + mockLLM := rpcmocks.NewMockILLMProvider(ctrl) + mockLLM.EXPECT().Call(gomock.Any(), gomock.Any()).Return(nil, errorx.New("llm call failed")) + return fields{ + llm: mockLLM, + idgen: mockIDGen, + } + }, + args: args{ + ctx: context.Background(), + param: ExecuteParam{ + Prompt: &entity.Prompt{ + ID: 1, + SpaceID: 123, + PromptKey: "test_prompt", + PromptDraft: &entity.PromptDraft{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + TemplateType: entity.TemplateTypeNormal, + Messages: []*entity.Message{ + { + Role: entity.RoleSystem, + Content: ptr.Of("You are a helpful assistant."), + }, + }, + }, + }, + }, + }, + Messages: []*entity.Message{ + { + Role: entity.RoleUser, + Content: ptr.Of("Hello"), + }, + }, + SingleStep: true, + }, + }, + wantErr: errorx.New("llm call failed"), + }, + { + name: "error_format_prompt_failed", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockIDGen := idgenmocks.NewMockIIDGenerator(ctrl) + mockIDGen.EXPECT().GenID(gomock.Any()).Return(int64(123456789), nil) + return fields{ + idgen: mockIDGen, + } + }, + args: args{ + ctx: context.Background(), + param: ExecuteParam{ + Prompt: &entity.Prompt{ + ID: 1, + SpaceID: 123, + PromptKey: "test_prompt", + PromptDraft: &entity.PromptDraft{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + TemplateType: entity.TemplateTypeGoTemplate, + Messages: []*entity.Message{ + { + Role: entity.RoleSystem, + Content: ptr.Of("You are a {{.InvalidSyntax"), // Invalid template + }, + }, + }, + }, + }, + }, + SingleStep: true, + }, + }, + wantReply: nil, + wantErr: errorx.NewByCode(prompterr.TemplateParseErrorCode), + }, } for _, tt := range tests { @@ -834,6 +922,7 @@ func TestPromptServiceImpl_Execute(t *testing.T) { ttFields := tt.fieldsGetter(ctrl) p := &PromptServiceImpl{ + formatter: NewPromptFormatter(), idgen: ttFields.idgen, debugLogRepo: ttFields.debugLogRepo, debugContextRepo: ttFields.debugContextRepo, @@ -884,7 +973,9 @@ func TestPromptServiceImpl_prepareLLMCallParam_PreservesExtra(t *testing.T) { }, }, } - svc := &PromptServiceImpl{} + svc := &PromptServiceImpl{ + formatter: NewPromptFormatter(), + } param := ExecuteParam{ Prompt: prompt, Messages: []*entity.Message{ diff --git a/backend/modules/prompt/domain/service/formatter.go b/backend/modules/prompt/domain/service/formatter.go new file mode 100644 index 000000000..485715090 --- /dev/null +++ b/backend/modules/prompt/domain/service/formatter.go @@ -0,0 +1,60 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package service + +import ( + "context" + "errors" + "strconv" + + loopentity "github.com/coze-dev/cozeloop-go/entity" + "github.com/coze-dev/cozeloop-go/spec/tracespec" + + "github.com/coze-dev/coze-loop/backend/infra/looptracer" + "github.com/coze-dev/coze-loop/backend/modules/prompt/domain/component/trace" + "github.com/coze-dev/coze-loop/backend/modules/prompt/domain/entity" + "github.com/coze-dev/coze-loop/backend/modules/prompt/pkg/consts" + "github.com/coze-dev/coze-loop/backend/pkg/errorx" + "github.com/coze-dev/coze-loop/backend/pkg/json" + "github.com/coze-dev/coze-loop/backend/pkg/traceutil" +) + +// IPromptFormatter defines the interface for formatting prompts +type IPromptFormatter interface { + FormatPrompt(ctx context.Context, prompt *entity.Prompt, messages []*entity.Message, variableVals []*entity.VariableVal) (formattedMessages []*entity.Message, err error) +} + +// PromptFormatter provides the default implementation of IPromptFormatter +type PromptFormatter struct{} + +// NewPromptFormatter creates a new instance of PromptFormatter +func NewPromptFormatter() IPromptFormatter { + return &PromptFormatter{} +} + +// FormatPrompt implements the IPromptFormatter interface +func (f *PromptFormatter) FormatPrompt(ctx context.Context, prompt *entity.Prompt, messages []*entity.Message, variableVals []*entity.VariableVal) (formattedMessages []*entity.Message, err error) { + if parentSpan := looptracer.GetTracer().GetSpanFromContext(ctx); parentSpan != nil { + var span looptracer.Span + ctx, span = looptracer.GetTracer().StartSpan(ctx, consts.SpanNamePromptTemplate, tracespec.VPromptTemplateSpanType, looptracer.WithSpanWorkspaceID(strconv.FormatInt(prompt.SpaceID, 10))) + if span != nil { + span.SetPrompt(ctx, loopentity.Prompt{PromptKey: prompt.PromptKey, Version: prompt.GetVersion()}) + span.SetInput(ctx, json.Jsonify(tracespec.PromptInput{ + Templates: trace.MessagesToSpanMessages(prompt.GetTemplateMessages(messages)), + Arguments: trace.VariableValsToSpanPromptVariables(variableVals), + })) + defer func() { + span.SetOutput(ctx, json.Jsonify(tracespec.PromptOutput{ + Prompts: trace.MessagesToSpanMessages(formattedMessages), + })) + if err != nil { + span.SetStatusCode(ctx, int(traceutil.GetTraceStatusCode(err))) + span.SetError(ctx, errors.New(errorx.ErrorWithoutStack(err))) + } + span.Finish(ctx) + }() + } + } + return prompt.FormatMessages(messages, variableVals) +} diff --git a/backend/modules/prompt/domain/service/formatter_test.go b/backend/modules/prompt/domain/service/formatter_test.go new file mode 100644 index 000000000..b977fedaf --- /dev/null +++ b/backend/modules/prompt/domain/service/formatter_test.go @@ -0,0 +1,270 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/coze-dev/coze-loop/backend/modules/prompt/domain/entity" + "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" +) + +func TestPromptFormatter_FormatPrompt(t *testing.T) { + type args struct { + ctx context.Context + prompt *entity.Prompt + messages []*entity.Message + variableVals []*entity.VariableVal + } + tests := []struct { + name string + args args + wantFormattedMessages []*entity.Message + wantErr bool + }{ + { + name: "success_simple_template", + args: args{ + ctx: context.Background(), + prompt: &entity.Prompt{ + ID: 123, + SpaceID: 456, + PromptKey: "test_key", + PromptDraft: &entity.PromptDraft{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + TemplateType: entity.TemplateTypeNormal, + Messages: []*entity.Message{ + { + Role: entity.RoleSystem, + Content: ptr.Of("You are a helpful assistant."), + }, + { + Role: entity.RoleUser, + Content: ptr.Of("Hello {{name}}"), + }, + }, + VariableDefs: []*entity.VariableDef{ + { + Key: "name", + Desc: "User name", + Type: entity.VariableTypeString, + }, + }, + }, + }, + }, + }, + variableVals: []*entity.VariableVal{ + { + Key: "name", + Value: ptr.Of("World"), + }, + }, + }, + wantFormattedMessages: []*entity.Message{ + { + Role: entity.RoleSystem, + Content: ptr.Of("You are a helpful assistant."), + }, + { + Role: entity.RoleUser, + Content: ptr.Of("Hello World"), + }, + }, + wantErr: false, + }, + { + name: "success_with_additional_messages", + args: args{ + ctx: context.Background(), + prompt: &entity.Prompt{ + ID: 123, + SpaceID: 456, + PromptKey: "test_key", + PromptDraft: &entity.PromptDraft{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + TemplateType: entity.TemplateTypeNormal, + Messages: []*entity.Message{ + { + Role: entity.RoleSystem, + Content: ptr.Of("You are a helpful assistant."), + }, + }, + }, + }, + }, + }, + messages: []*entity.Message{ + { + Role: entity.RoleUser, + Content: ptr.Of("What is AI?"), + }, + }, + }, + wantFormattedMessages: []*entity.Message{ + { + Role: entity.RoleSystem, + Content: ptr.Of("You are a helpful assistant."), + }, + { + Role: entity.RoleUser, + Content: ptr.Of("What is AI?"), + }, + }, + wantErr: false, + }, + { + name: "success_multimodal_content", + args: args{ + ctx: context.Background(), + prompt: &entity.Prompt{ + ID: 123, + SpaceID: 456, + PromptKey: "test_key", + PromptDraft: &entity.PromptDraft{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + TemplateType: entity.TemplateTypeNormal, + Messages: []*entity.Message{ + { + Role: entity.RoleUser, + Parts: []*entity.ContentPart{ + { + Type: entity.ContentTypeText, + Text: ptr.Of("Describe this image:"), + }, + { + Type: entity.ContentTypeImageURL, + ImageURL: &entity.ImageURL{ + URI: "test-image-uri", + URL: "https://example.com/test.jpg", + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + wantFormattedMessages: []*entity.Message{ + { + Role: entity.RoleUser, + Parts: []*entity.ContentPart{ + { + Type: entity.ContentTypeText, + Text: ptr.Of("Describe this image:"), + }, + { + Type: entity.ContentTypeImageURL, + ImageURL: &entity.ImageURL{ + URI: "test-image-uri", + URL: "https://example.com/test.jpg", + }, + }, + }, + }, + }, + wantErr: false, + }, + { + name: "error_invalid_placeholder_role", + args: args{ + ctx: context.Background(), + prompt: &entity.Prompt{ + ID: 123, + SpaceID: 456, + PromptKey: "test_key", + PromptDraft: &entity.PromptDraft{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + TemplateType: entity.TemplateTypeNormal, + Messages: []*entity.Message{ + { + Role: entity.RolePlaceholder, + Content: ptr.Of("history"), + }, + }, + VariableDefs: []*entity.VariableDef{ + { + Key: "history", + Desc: "Chat history", + Type: entity.VariableTypePlaceholder, + }, + }, + }, + }, + }, + }, + variableVals: []*entity.VariableVal{ + { + Key: "history", + PlaceholderMessages: []*entity.Message{ + { + Role: entity.RolePlaceholder, // Invalid role for placeholder message + Content: ptr.Of("Invalid"), + }, + }, + }, + }, + }, + wantFormattedMessages: nil, + wantErr: true, + }, + { + name: "error_go_template_syntax_error", + args: args{ + ctx: context.Background(), + prompt: &entity.Prompt{ + ID: 123, + SpaceID: 456, + PromptKey: "test_key", + PromptDraft: &entity.PromptDraft{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + TemplateType: entity.TemplateTypeGoTemplate, + Messages: []*entity.Message{ + { + Role: entity.RoleUser, + Content: ptr.Of("Hello {{.InvalidSyntax"), // Invalid Go template syntax + }, + }, + }, + }, + }, + }, + }, + wantFormattedMessages: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + formatter := NewPromptFormatter() + gotFormattedMessages, err := formatter.FormatPrompt(tt.args.ctx, tt.args.prompt, tt.args.messages, tt.args.variableVals) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.wantFormattedMessages, gotFormattedMessages) + } + }) + } +} + +func TestNewPromptFormatter(t *testing.T) { + formatter := NewPromptFormatter() + assert.NotNil(t, formatter) + // Verify it implements the interface + _ = formatter +} diff --git a/backend/modules/prompt/domain/service/mocks/formatter.go b/backend/modules/prompt/domain/service/mocks/formatter.go new file mode 100644 index 000000000..4fa36c54f --- /dev/null +++ b/backend/modules/prompt/domain/service/mocks/formatter.go @@ -0,0 +1,57 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/coze-dev/coze-loop/backend/modules/prompt/domain/service (interfaces: IPromptFormatter) +// +// Generated by this command: +// +// mockgen -destination=mocks/formatter.go -package=mocks . IPromptFormatter +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + entity "github.com/coze-dev/coze-loop/backend/modules/prompt/domain/entity" + gomock "go.uber.org/mock/gomock" +) + +// MockIPromptFormatter is a mock of IPromptFormatter interface. +type MockIPromptFormatter struct { + ctrl *gomock.Controller + recorder *MockIPromptFormatterMockRecorder + isgomock struct{} +} + +// MockIPromptFormatterMockRecorder is the mock recorder for MockIPromptFormatter. +type MockIPromptFormatterMockRecorder struct { + mock *MockIPromptFormatter +} + +// NewMockIPromptFormatter creates a new mock instance. +func NewMockIPromptFormatter(ctrl *gomock.Controller) *MockIPromptFormatter { + mock := &MockIPromptFormatter{ctrl: ctrl} + mock.recorder = &MockIPromptFormatterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockIPromptFormatter) EXPECT() *MockIPromptFormatterMockRecorder { + return m.recorder +} + +// FormatPrompt mocks base method. +func (m *MockIPromptFormatter) FormatPrompt(ctx context.Context, prompt *entity.Prompt, messages []*entity.Message, variableVals []*entity.VariableVal) ([]*entity.Message, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FormatPrompt", ctx, prompt, messages, variableVals) + ret0, _ := ret[0].([]*entity.Message) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FormatPrompt indicates an expected call of FormatPrompt. +func (mr *MockIPromptFormatterMockRecorder) FormatPrompt(ctx, prompt, messages, variableVals any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FormatPrompt", reflect.TypeOf((*MockIPromptFormatter)(nil).FormatPrompt), ctx, prompt, messages, variableVals) +} diff --git a/backend/modules/prompt/domain/service/mocks/prompt_service.go b/backend/modules/prompt/domain/service/mocks/prompt_service.go index ee51ab8c4..d1582e4be 100644 --- a/backend/modules/prompt/domain/service/mocks/prompt_service.go +++ b/backend/modules/prompt/domain/service/mocks/prompt_service.go @@ -161,7 +161,7 @@ func (mr *MockIPromptServiceMockRecorder) MCompleteMultiModalFileURL(ctx, messag return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MCompleteMultiModalFileURL", reflect.TypeOf((*MockIPromptService)(nil).MCompleteMultiModalFileURL), ctx, messages, variableVals) } -// MConvertBase64ToFileURI mocks base method. +// MConvertBase64DataURLToFileURI mocks base method. func (m *MockIPromptService) MConvertBase64DataURLToFileURI(ctx context.Context, messages []*entity.Message, workspaceID int64) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "MConvertBase64DataURLToFileURI", ctx, messages, workspaceID) @@ -169,13 +169,13 @@ func (m *MockIPromptService) MConvertBase64DataURLToFileURI(ctx context.Context, return ret0 } -// MConvertBase64ToFileURI indicates an expected call of MConvertBase64ToFileURI. -func (mr *MockIPromptServiceMockRecorder) MConvertBase64ToFileURI(ctx, messages, workspaceID any) *gomock.Call { +// MConvertBase64DataURLToFileURI indicates an expected call of MConvertBase64DataURLToFileURI. +func (mr *MockIPromptServiceMockRecorder) MConvertBase64DataURLToFileURI(ctx, messages, workspaceID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MConvertBase64DataURLToFileURI", reflect.TypeOf((*MockIPromptService)(nil).MConvertBase64DataURLToFileURI), ctx, messages, workspaceID) } -// MConvertBase64ToFileURL mocks base method. +// MConvertBase64DataURLToFileURL mocks base method. func (m *MockIPromptService) MConvertBase64DataURLToFileURL(ctx context.Context, messages []*entity.Message, workspaceID int64) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "MConvertBase64DataURLToFileURL", ctx, messages, workspaceID) @@ -183,8 +183,8 @@ func (m *MockIPromptService) MConvertBase64DataURLToFileURL(ctx context.Context, return ret0 } -// MConvertBase64ToFileURL indicates an expected call of MConvertBase64ToFileURL. -func (mr *MockIPromptServiceMockRecorder) MConvertBase64ToFileURL(ctx, messages, workspaceID any) *gomock.Call { +// MConvertBase64DataURLToFileURL indicates an expected call of MConvertBase64DataURLToFileURL. +func (mr *MockIPromptServiceMockRecorder) MConvertBase64DataURLToFileURL(ctx, messages, workspaceID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MConvertBase64DataURLToFileURL", reflect.TypeOf((*MockIPromptService)(nil).MConvertBase64DataURLToFileURL), ctx, messages, workspaceID) } diff --git a/backend/modules/prompt/domain/service/service.go b/backend/modules/prompt/domain/service/service.go index 9e2040df5..f234bf005 100644 --- a/backend/modules/prompt/domain/service/service.go +++ b/backend/modules/prompt/domain/service/service.go @@ -67,6 +67,7 @@ type PromptLabelQuery struct { } type PromptServiceImpl struct { + formatter IPromptFormatter idgen idgen.IIDGenerator debugLogRepo repo.IDebugLogRepo debugContextRepo repo.IDebugContextRepo @@ -78,6 +79,7 @@ type PromptServiceImpl struct { } func NewPromptService( + formatter IPromptFormatter, idgen idgen.IIDGenerator, debugLogRepo repo.IDebugLogRepo, debugContextRepo repo.IDebugContextRepo, @@ -88,6 +90,7 @@ func NewPromptService( file rpc.IFileProvider, ) IPromptService { return &PromptServiceImpl{ + formatter: formatter, idgen: idgen, debugLogRepo: debugLogRepo, debugContextRepo: debugContextRepo, diff --git a/backend/modules/prompt/domain/service/service_test.go b/backend/modules/prompt/domain/service/service_test.go new file mode 100644 index 000000000..586a7f268 --- /dev/null +++ b/backend/modules/prompt/domain/service/service_test.go @@ -0,0 +1,96 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package service + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "github.com/coze-dev/coze-loop/backend/infra/idgen/mocks" + confmocks "github.com/coze-dev/coze-loop/backend/modules/prompt/domain/component/conf/mocks" + rpcmocks "github.com/coze-dev/coze-loop/backend/modules/prompt/domain/component/rpc/mocks" + repomocks "github.com/coze-dev/coze-loop/backend/modules/prompt/domain/repo/mocks" +) + +func TestNewPromptService(t *testing.T) { + t.Run("creates service with all dependencies", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create mock dependencies + mockFormatter := NewPromptFormatter() + mockIDGen := mocks.NewMockIIDGenerator(ctrl) + mockDebugLogRepo := repomocks.NewMockIDebugLogRepo(ctrl) + mockDebugContextRepo := repomocks.NewMockIDebugContextRepo(ctrl) + mockManageRepo := repomocks.NewMockIManageRepo(ctrl) + mockLabelRepo := repomocks.NewMockILabelRepo(ctrl) + mockConfigProvider := confmocks.NewMockIConfigProvider(ctrl) + mockLLM := rpcmocks.NewMockILLMProvider(ctrl) + mockFile := rpcmocks.NewMockIFileProvider(ctrl) + + // Call constructor + service := NewPromptService( + mockFormatter, + mockIDGen, + mockDebugLogRepo, + mockDebugContextRepo, + mockManageRepo, + mockLabelRepo, + mockConfigProvider, + mockLLM, + mockFile, + ) + + // Verify + assert.NotNil(t, service) + + // Verify it returns the interface type + _ = service + + // Verify implementation has all fields set (by converting to concrete type for inspection) + impl, ok := service.(*PromptServiceImpl) + assert.True(t, ok, "should return *PromptServiceImpl") + assert.NotNil(t, impl.formatter) + assert.NotNil(t, impl.idgen) + assert.NotNil(t, impl.debugLogRepo) + assert.NotNil(t, impl.debugContextRepo) + assert.NotNil(t, impl.manageRepo) + assert.NotNil(t, impl.labelRepo) + assert.NotNil(t, impl.configProvider) + assert.NotNil(t, impl.llm) + assert.NotNil(t, impl.file) + }) + + t.Run("sets formatter correctly", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockFormatter := NewPromptFormatter() + mockIDGen := mocks.NewMockIIDGenerator(ctrl) + mockDebugLogRepo := repomocks.NewMockIDebugLogRepo(ctrl) + mockDebugContextRepo := repomocks.NewMockIDebugContextRepo(ctrl) + mockManageRepo := repomocks.NewMockIManageRepo(ctrl) + mockLabelRepo := repomocks.NewMockILabelRepo(ctrl) + mockConfigProvider := confmocks.NewMockIConfigProvider(ctrl) + mockLLM := rpcmocks.NewMockILLMProvider(ctrl) + mockFile := rpcmocks.NewMockIFileProvider(ctrl) + + service := NewPromptService( + mockFormatter, + mockIDGen, + mockDebugLogRepo, + mockDebugContextRepo, + mockManageRepo, + mockLabelRepo, + mockConfigProvider, + mockLLM, + mockFile, + ) + + impl := service.(*PromptServiceImpl) + assert.Equal(t, mockFormatter, impl.formatter, "formatter should be set correctly") + }) +} diff --git a/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift b/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift index 0ebc2eebb..f00e44235 100644 --- a/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift +++ b/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift @@ -103,6 +103,7 @@ typedef string TemplateType const TemplateType TemplateType_Normal = "normal" const TemplateType TemplateType_Jinja2 = "jinja2" const TemplateType TemplateType_GoTemplate = "go_template" +const TemplateType TemplateType_CustomTemplate_M = "custom_template_m" typedef string ToolChoiceType diff --git a/idl/thrift/coze/loop/prompt/domain/prompt.thrift b/idl/thrift/coze/loop/prompt/domain/prompt.thrift index e715ef07d..3c27d37cd 100644 --- a/idl/thrift/coze/loop/prompt/domain/prompt.thrift +++ b/idl/thrift/coze/loop/prompt/domain/prompt.thrift @@ -69,6 +69,7 @@ typedef string TemplateType (ts.enum="true") const TemplateType TemplateType_Normal = "normal" const TemplateType TemplateType_Jinja2 = "jinja2" const TemplateType TemplateType_GoTemplate = "go_template" +const TemplateType TemplateType_CustomTemplate_M = "custom_template_m" struct Tool { 1: optional ToolType type From 84faae61b98466bb9d1a4fd3a0f391aab933a27b Mon Sep 17 00:00:00 2001 From: yanghoule Date: Wed, 5 Nov 2025 20:32:51 +0800 Subject: [PATCH 09/12] update idl:llm add image gen and video --- .../coze/loop/llm/domain/manage/k-manage.go | 356 ++++++++++++ .../coze/loop/llm/domain/manage/manage.go | 543 +++++++++++++++++- .../llm/domain/manage/manage_validator.go | 8 + idl/thrift/coze/loop/llm/domain/manage.thrift | 27 +- 4 files changed, 929 insertions(+), 5 deletions(-) diff --git a/backend/kitex_gen/coze/loop/llm/domain/manage/k-manage.go b/backend/kitex_gen/coze/loop/llm/domain/manage/k-manage.go index d195556bb..d501fc201 100644 --- a/backend/kitex_gen/coze/loop/llm/domain/manage/k-manage.go +++ b/backend/kitex_gen/coze/loop/llm/domain/manage/k-manage.go @@ -1108,6 +1108,34 @@ func (p *AbilityMultiModal) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 3: + if fieldTypeId == thrift.BOOL { + l, err = p.FastReadField3(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 4: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField4(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } default: l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l @@ -1152,6 +1180,32 @@ func (p *AbilityMultiModal) FastReadField2(buf []byte) (int, error) { return offset, nil } +func (p *AbilityMultiModal) FastReadField3(buf []byte) (int, error) { + offset := 0 + + var _field *bool + if v, l, err := thrift.Binary.ReadBool(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Video = _field + return offset, nil +} + +func (p *AbilityMultiModal) FastReadField4(buf []byte) (int, error) { + offset := 0 + _field := NewAbilityVideo() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.AbilityVideo = _field + return offset, nil +} + func (p *AbilityMultiModal) FastWrite(buf []byte) int { return p.FastWriteNocopy(buf, nil) } @@ -1160,7 +1214,9 @@ func (p *AbilityMultiModal) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) i offset := 0 if p != nil { offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField3(buf[offset:], w) offset += p.fastWriteField2(buf[offset:], w) + offset += p.fastWriteField4(buf[offset:], w) } offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset @@ -1171,6 +1227,8 @@ func (p *AbilityMultiModal) BLength() int { if p != nil { l += p.field1Length() l += p.field2Length() + l += p.field3Length() + l += p.field4Length() } l += thrift.Binary.FieldStopLength() return l @@ -1194,6 +1252,24 @@ func (p *AbilityMultiModal) fastWriteField2(buf []byte, w thrift.NocopyWriter) i return offset } +func (p *AbilityMultiModal) fastWriteField3(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetVideo() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.BOOL, 3) + offset += thrift.Binary.WriteBool(buf[offset:], *p.Video) + } + return offset +} + +func (p *AbilityMultiModal) fastWriteField4(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetAbilityVideo() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 4) + offset += p.AbilityVideo.FastWriteNocopy(buf[offset:], w) + } + return offset +} + func (p *AbilityMultiModal) field1Length() int { l := 0 if p.IsSetImage() { @@ -1212,6 +1288,24 @@ func (p *AbilityMultiModal) field2Length() int { return l } +func (p *AbilityMultiModal) field3Length() int { + l := 0 + if p.IsSetVideo() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.BoolLength() + } + return l +} + +func (p *AbilityMultiModal) field4Length() int { + l := 0 + if p.IsSetAbilityVideo() { + l += thrift.Binary.FieldBeginLength() + l += p.AbilityVideo.BLength() + } + return l +} + func (p *AbilityMultiModal) DeepCopy(s interface{}) error { src, ok := s.(*AbilityMultiModal) if !ok { @@ -1232,6 +1326,20 @@ func (p *AbilityMultiModal) DeepCopy(s interface{}) error { } p.AbilityImage = _abilityImage + if src.Video != nil { + tmp := *src.Video + p.Video = &tmp + } + + var _abilityVideo *AbilityVideo + if src.AbilityVideo != nil { + _abilityVideo = &AbilityVideo{} + if err := _abilityVideo.DeepCopy(src.AbilityVideo); err != nil { + return err + } + } + p.AbilityVideo = _abilityVideo + return nil } @@ -1308,6 +1416,20 @@ func (p *AbilityImage) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 5: + if fieldTypeId == thrift.BOOL { + l, err = p.FastReadField5(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } default: l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l @@ -1382,6 +1504,20 @@ func (p *AbilityImage) FastReadField4(buf []byte) (int, error) { return offset, nil } +func (p *AbilityImage) FastReadField5(buf []byte) (int, error) { + offset := 0 + + var _field *bool + if v, l, err := thrift.Binary.ReadBool(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.ImageGenEnabled = _field + return offset, nil +} + func (p *AbilityImage) FastWrite(buf []byte) int { return p.FastWriteNocopy(buf, nil) } @@ -1393,6 +1529,7 @@ func (p *AbilityImage) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset += p.fastWriteField2(buf[offset:], w) offset += p.fastWriteField3(buf[offset:], w) offset += p.fastWriteField4(buf[offset:], w) + offset += p.fastWriteField5(buf[offset:], w) } offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset @@ -1405,6 +1542,7 @@ func (p *AbilityImage) BLength() int { l += p.field2Length() l += p.field3Length() l += p.field4Length() + l += p.field5Length() } l += thrift.Binary.FieldStopLength() return l @@ -1446,6 +1584,15 @@ func (p *AbilityImage) fastWriteField4(buf []byte, w thrift.NocopyWriter) int { return offset } +func (p *AbilityImage) fastWriteField5(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetImageGenEnabled() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.BOOL, 5) + offset += thrift.Binary.WriteBool(buf[offset:], *p.ImageGenEnabled) + } + return offset +} + func (p *AbilityImage) field1Length() int { l := 0 if p.IsSetURLEnabled() { @@ -1482,6 +1629,15 @@ func (p *AbilityImage) field4Length() int { return l } +func (p *AbilityImage) field5Length() int { + l := 0 + if p.IsSetImageGenEnabled() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.BoolLength() + } + return l +} + func (p *AbilityImage) DeepCopy(s interface{}) error { src, ok := s.(*AbilityImage) if !ok { @@ -1508,6 +1664,206 @@ func (p *AbilityImage) DeepCopy(s interface{}) error { p.MaxImageCount = &tmp } + if src.ImageGenEnabled != nil { + tmp := *src.ImageGenEnabled + p.ImageGenEnabled = &tmp + } + + return nil +} + +func (p *AbilityVideo) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.I32 { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.LIST { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_AbilityVideo[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *AbilityVideo) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field *int32 + if v, l, err := thrift.Binary.ReadI32(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.MaxVideoSizeInMb = _field + return offset, nil +} + +func (p *AbilityVideo) FastReadField2(buf []byte) (int, error) { + offset := 0 + + _, size, l, err := thrift.Binary.ReadListBegin(buf[offset:]) + offset += l + if err != nil { + return offset, err + } + _field := make([]VideoFormat, 0, size) + for i := 0; i < size; i++ { + var _elem VideoFormat + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _elem = v + } + + _field = append(_field, _elem) + } + p.SupportedVideoFormats = _field + return offset, nil +} + +func (p *AbilityVideo) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *AbilityVideo) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *AbilityVideo) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + l += p.field2Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *AbilityVideo) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetMaxVideoSizeInMb() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I32, 1) + offset += thrift.Binary.WriteI32(buf[offset:], *p.MaxVideoSizeInMb) + } + return offset +} + +func (p *AbilityVideo) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetSupportedVideoFormats() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.LIST, 2) + listBeginOffset := offset + offset += thrift.Binary.ListBeginLength() + var length int + for _, v := range p.SupportedVideoFormats { + length++ + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, v) + } + thrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRING, length) + } + return offset +} + +func (p *AbilityVideo) field1Length() int { + l := 0 + if p.IsSetMaxVideoSizeInMb() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I32Length() + } + return l +} + +func (p *AbilityVideo) field2Length() int { + l := 0 + if p.IsSetSupportedVideoFormats() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.ListBeginLength() + for _, v := range p.SupportedVideoFormats { + _ = v + l += thrift.Binary.StringLengthNocopy(v) + } + } + return l +} + +func (p *AbilityVideo) DeepCopy(s interface{}) error { + src, ok := s.(*AbilityVideo) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.MaxVideoSizeInMb != nil { + tmp := *src.MaxVideoSizeInMb + p.MaxVideoSizeInMb = &tmp + } + + if src.SupportedVideoFormats != nil { + p.SupportedVideoFormats = make([]VideoFormat, 0, len(src.SupportedVideoFormats)) + for _, elem := range src.SupportedVideoFormats { + var _elem VideoFormat + _elem = elem + p.SupportedVideoFormats = append(p.SupportedVideoFormats, _elem) + } + } + return nil } diff --git a/backend/kitex_gen/coze/loop/llm/domain/manage/manage.go b/backend/kitex_gen/coze/loop/llm/domain/manage/manage.go index 523d3bbf0..6f0e4dbaf 100644 --- a/backend/kitex_gen/coze/loop/llm/domain/manage/manage.go +++ b/backend/kitex_gen/coze/loop/llm/domain/manage/manage.go @@ -35,12 +35,44 @@ const ( ParamTypeBoolean = "boolean" ParamTypeString = "string" + + VideoFormatUndefined = "undefined" + + VideoFormatMp4 = "mp4" + + VideoFormatAvi = "avi" + + VideoFormatMov = "mov" + + VideoFormatMpg = "mpg" + + VideoFormatWebm = "webm" + + VideoFormatRvmb = "rvmb" + + VideoFormatWmv = "wmv" + + VideoFormatMkv = "mkv" + + VideoFormatT3gp = "t3gp" + + VideoFormatFlv = "flv" + + VideoFormatMpeg = "mpeg" + + VideoFormatTs = "ts" + + VideoFormatRm = "rm" + + VideoFormatM4v = "m4v" ) type Protocol = string type ParamType = string +type VideoFormat = string + type Model struct { ModelID *int64 `thrift:"model_id,1,optional" frugal:"1,optional,i64" json:"model_id" form:"model_id" query:"model_id"` WorkspaceID *int64 `thrift:"workspace_id,2,optional" frugal:"2,optional,i64" json:"workspace_id" form:"workspace_id" query:"workspace_id"` @@ -1482,6 +1514,8 @@ func (p *Ability) Field7DeepEqual(src *AbilityMultiModal) bool { type AbilityMultiModal struct { Image *bool `thrift:"image,1,optional" frugal:"1,optional,bool" form:"image" json:"image,omitempty" query:"image"` AbilityImage *AbilityImage `thrift:"ability_image,2,optional" frugal:"2,optional,AbilityImage" form:"ability_image" json:"ability_image,omitempty" query:"ability_image"` + Video *bool `thrift:"video,3,optional" frugal:"3,optional,bool" form:"video" json:"video,omitempty" query:"video"` + AbilityVideo *AbilityVideo `thrift:"ability_video,4,optional" frugal:"4,optional,AbilityVideo" form:"ability_video" json:"ability_video,omitempty" query:"ability_video"` } func NewAbilityMultiModal() *AbilityMultiModal { @@ -1514,16 +1548,48 @@ func (p *AbilityMultiModal) GetAbilityImage() (v *AbilityImage) { } return p.AbilityImage } + +var AbilityMultiModal_Video_DEFAULT bool + +func (p *AbilityMultiModal) GetVideo() (v bool) { + if p == nil { + return + } + if !p.IsSetVideo() { + return AbilityMultiModal_Video_DEFAULT + } + return *p.Video +} + +var AbilityMultiModal_AbilityVideo_DEFAULT *AbilityVideo + +func (p *AbilityMultiModal) GetAbilityVideo() (v *AbilityVideo) { + if p == nil { + return + } + if !p.IsSetAbilityVideo() { + return AbilityMultiModal_AbilityVideo_DEFAULT + } + return p.AbilityVideo +} func (p *AbilityMultiModal) SetImage(val *bool) { p.Image = val } func (p *AbilityMultiModal) SetAbilityImage(val *AbilityImage) { p.AbilityImage = val } +func (p *AbilityMultiModal) SetVideo(val *bool) { + p.Video = val +} +func (p *AbilityMultiModal) SetAbilityVideo(val *AbilityVideo) { + p.AbilityVideo = val +} var fieldIDToName_AbilityMultiModal = map[int16]string{ 1: "image", 2: "ability_image", + 3: "video", + 4: "ability_video", } func (p *AbilityMultiModal) IsSetImage() bool { @@ -1534,6 +1600,14 @@ func (p *AbilityMultiModal) IsSetAbilityImage() bool { return p.AbilityImage != nil } +func (p *AbilityMultiModal) IsSetVideo() bool { + return p.Video != nil +} + +func (p *AbilityMultiModal) IsSetAbilityVideo() bool { + return p.AbilityVideo != nil +} + func (p *AbilityMultiModal) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -1568,6 +1642,22 @@ func (p *AbilityMultiModal) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 3: + if fieldTypeId == thrift.BOOL { + if err = p.ReadField3(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 4: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField4(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError @@ -1616,6 +1706,25 @@ func (p *AbilityMultiModal) ReadField2(iprot thrift.TProtocol) error { p.AbilityImage = _field return nil } +func (p *AbilityMultiModal) ReadField3(iprot thrift.TProtocol) error { + + var _field *bool + if v, err := iprot.ReadBool(); err != nil { + return err + } else { + _field = &v + } + p.Video = _field + return nil +} +func (p *AbilityMultiModal) ReadField4(iprot thrift.TProtocol) error { + _field := NewAbilityVideo() + if err := _field.Read(iprot); err != nil { + return err + } + p.AbilityVideo = _field + return nil +} func (p *AbilityMultiModal) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 @@ -1631,6 +1740,14 @@ func (p *AbilityMultiModal) Write(oprot thrift.TProtocol) (err error) { fieldId = 2 goto WriteFieldError } + if err = p.writeField3(oprot); err != nil { + fieldId = 3 + goto WriteFieldError + } + if err = p.writeField4(oprot); err != nil { + fieldId = 4 + goto WriteFieldError + } } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -1685,6 +1802,42 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) } +func (p *AbilityMultiModal) writeField3(oprot thrift.TProtocol) (err error) { + if p.IsSetVideo() { + if err = oprot.WriteFieldBegin("video", thrift.BOOL, 3); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteBool(*p.Video); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) +} +func (p *AbilityMultiModal) writeField4(oprot thrift.TProtocol) (err error) { + if p.IsSetAbilityVideo() { + if err = oprot.WriteFieldBegin("ability_video", thrift.STRUCT, 4); err != nil { + goto WriteFieldBeginError + } + if err := p.AbilityVideo.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 end error: ", p), err) +} func (p *AbilityMultiModal) String() string { if p == nil { @@ -1706,6 +1859,12 @@ func (p *AbilityMultiModal) DeepEqual(ano *AbilityMultiModal) bool { if !p.Field2DeepEqual(ano.AbilityImage) { return false } + if !p.Field3DeepEqual(ano.Video) { + return false + } + if !p.Field4DeepEqual(ano.AbilityVideo) { + return false + } return true } @@ -1728,12 +1887,32 @@ func (p *AbilityMultiModal) Field2DeepEqual(src *AbilityImage) bool { } return true } +func (p *AbilityMultiModal) Field3DeepEqual(src *bool) bool { + + if p.Video == src { + return true + } else if p.Video == nil || src == nil { + return false + } + if *p.Video != *src { + return false + } + return true +} +func (p *AbilityMultiModal) Field4DeepEqual(src *AbilityVideo) bool { + + if !p.AbilityVideo.DeepEqual(src) { + return false + } + return true +} type AbilityImage struct { - URLEnabled *bool `thrift:"url_enabled,1,optional" frugal:"1,optional,bool" form:"url_enabled" json:"url_enabled,omitempty" query:"url_enabled"` - BinaryEnabled *bool `thrift:"binary_enabled,2,optional" frugal:"2,optional,bool" form:"binary_enabled" json:"binary_enabled,omitempty" query:"binary_enabled"` - MaxImageSize *int64 `thrift:"max_image_size,3,optional" frugal:"3,optional,i64" json:"max_image_size" form:"max_image_size" query:"max_image_size"` - MaxImageCount *int64 `thrift:"max_image_count,4,optional" frugal:"4,optional,i64" json:"max_image_count" form:"max_image_count" query:"max_image_count"` + URLEnabled *bool `thrift:"url_enabled,1,optional" frugal:"1,optional,bool" form:"url_enabled" json:"url_enabled,omitempty" query:"url_enabled"` + BinaryEnabled *bool `thrift:"binary_enabled,2,optional" frugal:"2,optional,bool" form:"binary_enabled" json:"binary_enabled,omitempty" query:"binary_enabled"` + MaxImageSize *int64 `thrift:"max_image_size,3,optional" frugal:"3,optional,i64" json:"max_image_size" form:"max_image_size" query:"max_image_size"` + MaxImageCount *int64 `thrift:"max_image_count,4,optional" frugal:"4,optional,i64" json:"max_image_count" form:"max_image_count" query:"max_image_count"` + ImageGenEnabled *bool `thrift:"image_gen_enabled,5,optional" frugal:"5,optional,bool" form:"image_gen_enabled" json:"image_gen_enabled,omitempty" query:"image_gen_enabled"` } func NewAbilityImage() *AbilityImage { @@ -1790,6 +1969,18 @@ func (p *AbilityImage) GetMaxImageCount() (v int64) { } return *p.MaxImageCount } + +var AbilityImage_ImageGenEnabled_DEFAULT bool + +func (p *AbilityImage) GetImageGenEnabled() (v bool) { + if p == nil { + return + } + if !p.IsSetImageGenEnabled() { + return AbilityImage_ImageGenEnabled_DEFAULT + } + return *p.ImageGenEnabled +} func (p *AbilityImage) SetURLEnabled(val *bool) { p.URLEnabled = val } @@ -1802,12 +1993,16 @@ func (p *AbilityImage) SetMaxImageSize(val *int64) { func (p *AbilityImage) SetMaxImageCount(val *int64) { p.MaxImageCount = val } +func (p *AbilityImage) SetImageGenEnabled(val *bool) { + p.ImageGenEnabled = val +} var fieldIDToName_AbilityImage = map[int16]string{ 1: "url_enabled", 2: "binary_enabled", 3: "max_image_size", 4: "max_image_count", + 5: "image_gen_enabled", } func (p *AbilityImage) IsSetURLEnabled() bool { @@ -1826,6 +2021,10 @@ func (p *AbilityImage) IsSetMaxImageCount() bool { return p.MaxImageCount != nil } +func (p *AbilityImage) IsSetImageGenEnabled() bool { + return p.ImageGenEnabled != nil +} + func (p *AbilityImage) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -1876,6 +2075,14 @@ func (p *AbilityImage) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 5: + if fieldTypeId == thrift.BOOL { + if err = p.ReadField5(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError @@ -1949,6 +2156,17 @@ func (p *AbilityImage) ReadField4(iprot thrift.TProtocol) error { p.MaxImageCount = _field return nil } +func (p *AbilityImage) ReadField5(iprot thrift.TProtocol) error { + + var _field *bool + if v, err := iprot.ReadBool(); err != nil { + return err + } else { + _field = &v + } + p.ImageGenEnabled = _field + return nil +} func (p *AbilityImage) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 @@ -1972,6 +2190,10 @@ func (p *AbilityImage) Write(oprot thrift.TProtocol) (err error) { fieldId = 4 goto WriteFieldError } + if err = p.writeField5(oprot); err != nil { + fieldId = 5 + goto WriteFieldError + } } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -2062,6 +2284,24 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 4 end error: ", p), err) } +func (p *AbilityImage) writeField5(oprot thrift.TProtocol) (err error) { + if p.IsSetImageGenEnabled() { + if err = oprot.WriteFieldBegin("image_gen_enabled", thrift.BOOL, 5); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteBool(*p.ImageGenEnabled); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 5 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 5 end error: ", p), err) +} func (p *AbilityImage) String() string { if p == nil { @@ -2089,6 +2329,9 @@ func (p *AbilityImage) DeepEqual(ano *AbilityImage) bool { if !p.Field4DeepEqual(ano.MaxImageCount) { return false } + if !p.Field5DeepEqual(ano.ImageGenEnabled) { + return false + } return true } @@ -2140,6 +2383,298 @@ func (p *AbilityImage) Field4DeepEqual(src *int64) bool { } return true } +func (p *AbilityImage) Field5DeepEqual(src *bool) bool { + + if p.ImageGenEnabled == src { + return true + } else if p.ImageGenEnabled == nil || src == nil { + return false + } + if *p.ImageGenEnabled != *src { + return false + } + return true +} + +type AbilityVideo struct { + // the size limit of single video + MaxVideoSizeInMb *int32 `thrift:"max_video_size_in_mb,1,optional" frugal:"1,optional,i32" form:"max_video_size_in_mb" json:"max_video_size_in_mb,omitempty" query:"max_video_size_in_mb"` + SupportedVideoFormats []VideoFormat `thrift:"supported_video_formats,2,optional" frugal:"2,optional,list" form:"supported_video_formats" json:"supported_video_formats,omitempty" query:"supported_video_formats"` +} + +func NewAbilityVideo() *AbilityVideo { + return &AbilityVideo{} +} + +func (p *AbilityVideo) InitDefault() { +} + +var AbilityVideo_MaxVideoSizeInMb_DEFAULT int32 + +func (p *AbilityVideo) GetMaxVideoSizeInMb() (v int32) { + if p == nil { + return + } + if !p.IsSetMaxVideoSizeInMb() { + return AbilityVideo_MaxVideoSizeInMb_DEFAULT + } + return *p.MaxVideoSizeInMb +} + +var AbilityVideo_SupportedVideoFormats_DEFAULT []VideoFormat + +func (p *AbilityVideo) GetSupportedVideoFormats() (v []VideoFormat) { + if p == nil { + return + } + if !p.IsSetSupportedVideoFormats() { + return AbilityVideo_SupportedVideoFormats_DEFAULT + } + return p.SupportedVideoFormats +} +func (p *AbilityVideo) SetMaxVideoSizeInMb(val *int32) { + p.MaxVideoSizeInMb = val +} +func (p *AbilityVideo) SetSupportedVideoFormats(val []VideoFormat) { + p.SupportedVideoFormats = val +} + +var fieldIDToName_AbilityVideo = map[int16]string{ + 1: "max_video_size_in_mb", + 2: "supported_video_formats", +} + +func (p *AbilityVideo) IsSetMaxVideoSizeInMb() bool { + return p.MaxVideoSizeInMb != nil +} + +func (p *AbilityVideo) IsSetSupportedVideoFormats() bool { + return p.SupportedVideoFormats != nil +} + +func (p *AbilityVideo) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.I32 { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 2: + if fieldTypeId == thrift.LIST { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_AbilityVideo[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *AbilityVideo) ReadField1(iprot thrift.TProtocol) error { + + var _field *int32 + if v, err := iprot.ReadI32(); err != nil { + return err + } else { + _field = &v + } + p.MaxVideoSizeInMb = _field + return nil +} +func (p *AbilityVideo) ReadField2(iprot thrift.TProtocol) error { + _, size, err := iprot.ReadListBegin() + if err != nil { + return err + } + _field := make([]VideoFormat, 0, size) + for i := 0; i < size; i++ { + + var _elem VideoFormat + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _elem = v + } + + _field = append(_field, _elem) + } + if err := iprot.ReadListEnd(); err != nil { + return err + } + p.SupportedVideoFormats = _field + return nil +} + +func (p *AbilityVideo) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("AbilityVideo"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *AbilityVideo) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetMaxVideoSizeInMb() { + if err = oprot.WriteFieldBegin("max_video_size_in_mb", thrift.I32, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI32(*p.MaxVideoSizeInMb); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} +func (p *AbilityVideo) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetSupportedVideoFormats() { + if err = oprot.WriteFieldBegin("supported_video_formats", thrift.LIST, 2); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteListBegin(thrift.STRING, len(p.SupportedVideoFormats)); err != nil { + return err + } + for _, v := range p.SupportedVideoFormats { + if err := oprot.WriteString(v); err != nil { + return err + } + } + if err := oprot.WriteListEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} + +func (p *AbilityVideo) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("AbilityVideo(%+v)", *p) + +} + +func (p *AbilityVideo) DeepEqual(ano *AbilityVideo) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.MaxVideoSizeInMb) { + return false + } + if !p.Field2DeepEqual(ano.SupportedVideoFormats) { + return false + } + return true +} + +func (p *AbilityVideo) Field1DeepEqual(src *int32) bool { + + if p.MaxVideoSizeInMb == src { + return true + } else if p.MaxVideoSizeInMb == nil || src == nil { + return false + } + if *p.MaxVideoSizeInMb != *src { + return false + } + return true +} +func (p *AbilityVideo) Field2DeepEqual(src []VideoFormat) bool { + + if len(p.SupportedVideoFormats) != len(src) { + return false + } + for i, v := range p.SupportedVideoFormats { + _src := src[i] + if strings.Compare(v, _src) != 0 { + return false + } + } + return true +} type ProtocolConfig struct { BaseURL *string `thrift:"base_url,1,optional" frugal:"1,optional,string" form:"base_url" json:"base_url,omitempty" query:"base_url"` diff --git a/backend/kitex_gen/coze/loop/llm/domain/manage/manage_validator.go b/backend/kitex_gen/coze/loop/llm/domain/manage/manage_validator.go index 321eb92a6..b10464682 100644 --- a/backend/kitex_gen/coze/loop/llm/domain/manage/manage_validator.go +++ b/backend/kitex_gen/coze/loop/llm/domain/manage/manage_validator.go @@ -53,11 +53,19 @@ func (p *AbilityMultiModal) IsValid() error { return fmt.Errorf("field AbilityImage not valid, %w", err) } } + if p.AbilityVideo != nil { + if err := p.AbilityVideo.IsValid(); err != nil { + return fmt.Errorf("field AbilityVideo not valid, %w", err) + } + } return nil } func (p *AbilityImage) IsValid() error { return nil } +func (p *AbilityVideo) IsValid() error { + return nil +} func (p *ProtocolConfig) IsValid() error { if p.ProtocolConfigArk != nil { if err := p.ProtocolConfigArk.IsValid(); err != nil { diff --git a/idl/thrift/coze/loop/llm/domain/manage.thrift b/idl/thrift/coze/loop/llm/domain/manage.thrift index 0e3b45731..a3462c26e 100644 --- a/idl/thrift/coze/loop/llm/domain/manage.thrift +++ b/idl/thrift/coze/loop/llm/domain/manage.thrift @@ -27,6 +27,8 @@ struct Ability { struct AbilityMultiModal { 1: optional bool image 2: optional AbilityImage ability_image + 3: optional bool video + 4: optional AbilityVideo ability_video } struct AbilityImage { @@ -34,6 +36,12 @@ struct AbilityImage { 2: optional bool binary_enabled 3: optional i64 max_image_size (api.js_conv='true', go.tag='json:"max_image_size"') 4: optional i64 max_image_count (api.js_conv='true', go.tag='json:"max_image_count"') + 5: optional bool image_gen_enabled +} + +struct AbilityVideo { + 1: optional i32 max_video_size_in_mb // the size limit of single video + 2: optional list supported_video_formats } struct ProtocolConfig { @@ -162,4 +170,21 @@ typedef string ParamType (ts.enum="true") const ParamType param_type_float = "float" const ParamType param_type_int = "int" const ParamType param_type_boolean = "boolean" -const ParamType param_type_string = "string" \ No newline at end of file +const ParamType param_type_string = "string" + +typedef string VideoFormat (ts.enum="true") +const VideoFormat video_format_undefined = "undefined" +const VideoFormat video_format_mp4 = "mp4" +const VideoFormat video_format_avi = "avi" +const VideoFormat video_format_mov = "mov" +const VideoFormat video_format_mpg = "mpg" +const VideoFormat video_format_webm = "webm" +const VideoFormat video_format_rvmb = "rvmb" +const VideoFormat video_format_wmv = "wmv" +const VideoFormat video_format_mkv = "mkv" +const VideoFormat video_format_t3gp = "t3gp" +const VideoFormat video_format_flv = "flv" +const VideoFormat video_format_mpeg = "mpeg" +const VideoFormat video_format_ts = "ts" +const VideoFormat video_format_rm = "rm" +const VideoFormat video_format_m4v = "m4v" \ No newline at end of file From 74f5ee485e7cc6bccc69337119a491481a41d63a Mon Sep 17 00:00:00 2001 From: caijialin0626 <61818131+caijialin0626@users.noreply.github.com> Date: Tue, 11 Nov 2025 19:40:40 +0800 Subject: [PATCH 10/12] [feat][prompt] add google_search and tool call specification (#304) --- .../loop/prompt/domain/prompt/k-prompt.go | 228 ++++++++++++ .../coze/loop/prompt/domain/prompt/prompt.go | 333 +++++++++++++++++- .../prompt/domain/prompt/prompt_validator.go | 8 + .../openapi/coze.loop.prompt.openapi.go | 333 +++++++++++++++++- .../coze.loop.prompt.openapi_validator.go | 8 + .../openapi/k-coze.loop.prompt.openapi.go | 228 ++++++++++++ .../prompt/application/convertor/openapi.go | 17 +- .../application/convertor/openapi_test.go | 166 +++++++++ .../prompt/application/convertor/prompt.go | 37 +- .../application/convertor/prompt_test.go | 307 ++++++++++++++++ .../prompt/domain/entity/prompt_detail.go | 16 +- .../modules/prompt/domain/service/execute.go | 15 +- .../prompt/domain/service/execute_test.go | 274 ++++++++++++++ .../prompt/coze.loop.prompt.openapi.thrift | 8 + .../coze/loop/prompt/domain/prompt.thrift | 8 + 15 files changed, 1976 insertions(+), 10 deletions(-) diff --git a/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go b/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go index ae4053b2f..2a1409c02 100644 --- a/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go +++ b/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go @@ -3169,6 +3169,20 @@ func (p *ToolCallConfig) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 2: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } default: l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l @@ -3201,6 +3215,18 @@ func (p *ToolCallConfig) FastReadField1(buf []byte) (int, error) { return offset, nil } +func (p *ToolCallConfig) FastReadField2(buf []byte) (int, error) { + offset := 0 + _field := NewToolChoiceSpecification() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.ToolChoiceSpecification = _field + return offset, nil +} + func (p *ToolCallConfig) FastWrite(buf []byte) int { return p.FastWriteNocopy(buf, nil) } @@ -3209,6 +3235,7 @@ func (p *ToolCallConfig) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int offset := 0 if p != nil { offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) } offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset @@ -3218,6 +3245,7 @@ func (p *ToolCallConfig) BLength() int { l := 0 if p != nil { l += p.field1Length() + l += p.field2Length() } l += thrift.Binary.FieldStopLength() return l @@ -3232,6 +3260,15 @@ func (p *ToolCallConfig) fastWriteField1(buf []byte, w thrift.NocopyWriter) int return offset } +func (p *ToolCallConfig) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetToolChoiceSpecification() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 2) + offset += p.ToolChoiceSpecification.FastWriteNocopy(buf[offset:], w) + } + return offset +} + func (p *ToolCallConfig) field1Length() int { l := 0 if p.IsSetToolChoice() { @@ -3241,6 +3278,15 @@ func (p *ToolCallConfig) field1Length() int { return l } +func (p *ToolCallConfig) field2Length() int { + l := 0 + if p.IsSetToolChoiceSpecification() { + l += thrift.Binary.FieldBeginLength() + l += p.ToolChoiceSpecification.BLength() + } + return l +} + func (p *ToolCallConfig) DeepCopy(s interface{}) error { src, ok := s.(*ToolCallConfig) if !ok { @@ -3252,6 +3298,188 @@ func (p *ToolCallConfig) DeepCopy(s interface{}) error { p.ToolChoice = &tmp } + var _toolChoiceSpecification *ToolChoiceSpecification + if src.ToolChoiceSpecification != nil { + _toolChoiceSpecification = &ToolChoiceSpecification{} + if err := _toolChoiceSpecification.DeepCopy(src.ToolChoiceSpecification); err != nil { + return err + } + } + p.ToolChoiceSpecification = _toolChoiceSpecification + + return nil +} + +func (p *ToolChoiceSpecification) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ToolChoiceSpecification[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *ToolChoiceSpecification) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field *ToolType + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Type = _field + return offset, nil +} + +func (p *ToolChoiceSpecification) FastReadField2(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Name = _field + return offset, nil +} + +func (p *ToolChoiceSpecification) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *ToolChoiceSpecification) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *ToolChoiceSpecification) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + l += p.field2Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *ToolChoiceSpecification) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetType() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 1) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Type) + } + return offset +} + +func (p *ToolChoiceSpecification) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetName() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 2) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Name) + } + return offset +} + +func (p *ToolChoiceSpecification) field1Length() int { + l := 0 + if p.IsSetType() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Type) + } + return l +} + +func (p *ToolChoiceSpecification) field2Length() int { + l := 0 + if p.IsSetName() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Name) + } + return l +} + +func (p *ToolChoiceSpecification) DeepCopy(s interface{}) error { + src, ok := s.(*ToolChoiceSpecification) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.Type != nil { + tmp := *src.Type + p.Type = &tmp + } + + if src.Name != nil { + var tmp string + if *src.Name != "" { + tmp = kutils.StringDeepCopy(*src.Name) + } + p.Name = &tmp + } + return nil } diff --git a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go index b0666a814..6d7c795fd 100644 --- a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go +++ b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go @@ -19,10 +19,14 @@ const ( ToolTypeFunction = "function" + ToolTypeGoogleSearch = "google_search" + ToolChoiceTypeNone = "none" ToolChoiceTypeAuto = "auto" + ToolChoiceTypeSpecific = "specific" + RoleSystem = "system" RoleUser = "user" @@ -4394,7 +4398,8 @@ func (p *Function) Field3DeepEqual(src *string) bool { } type ToolCallConfig struct { - ToolChoice *ToolChoiceType `thrift:"tool_choice,1,optional" frugal:"1,optional,string" form:"tool_choice" json:"tool_choice,omitempty" query:"tool_choice"` + ToolChoice *ToolChoiceType `thrift:"tool_choice,1,optional" frugal:"1,optional,string" form:"tool_choice" json:"tool_choice,omitempty" query:"tool_choice"` + ToolChoiceSpecification *ToolChoiceSpecification `thrift:"tool_choice_specification,2,optional" frugal:"2,optional,ToolChoiceSpecification" form:"tool_choice_specification" json:"tool_choice_specification,omitempty" query:"tool_choice_specification"` } func NewToolCallConfig() *ToolCallConfig { @@ -4415,18 +4420,38 @@ func (p *ToolCallConfig) GetToolChoice() (v ToolChoiceType) { } return *p.ToolChoice } + +var ToolCallConfig_ToolChoiceSpecification_DEFAULT *ToolChoiceSpecification + +func (p *ToolCallConfig) GetToolChoiceSpecification() (v *ToolChoiceSpecification) { + if p == nil { + return + } + if !p.IsSetToolChoiceSpecification() { + return ToolCallConfig_ToolChoiceSpecification_DEFAULT + } + return p.ToolChoiceSpecification +} func (p *ToolCallConfig) SetToolChoice(val *ToolChoiceType) { p.ToolChoice = val } +func (p *ToolCallConfig) SetToolChoiceSpecification(val *ToolChoiceSpecification) { + p.ToolChoiceSpecification = val +} var fieldIDToName_ToolCallConfig = map[int16]string{ 1: "tool_choice", + 2: "tool_choice_specification", } func (p *ToolCallConfig) IsSetToolChoice() bool { return p.ToolChoice != nil } +func (p *ToolCallConfig) IsSetToolChoiceSpecification() bool { + return p.ToolChoiceSpecification != nil +} + func (p *ToolCallConfig) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -4453,6 +4478,14 @@ func (p *ToolCallConfig) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 2: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError @@ -4493,6 +4526,14 @@ func (p *ToolCallConfig) ReadField1(iprot thrift.TProtocol) error { p.ToolChoice = _field return nil } +func (p *ToolCallConfig) ReadField2(iprot thrift.TProtocol) error { + _field := NewToolChoiceSpecification() + if err := _field.Read(iprot); err != nil { + return err + } + p.ToolChoiceSpecification = _field + return nil +} func (p *ToolCallConfig) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 @@ -4504,6 +4545,10 @@ func (p *ToolCallConfig) Write(oprot thrift.TProtocol) (err error) { fieldId = 1 goto WriteFieldError } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -4540,6 +4585,24 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) } +func (p *ToolCallConfig) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetToolChoiceSpecification() { + if err = oprot.WriteFieldBegin("tool_choice_specification", thrift.STRUCT, 2); err != nil { + goto WriteFieldBeginError + } + if err := p.ToolChoiceSpecification.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} func (p *ToolCallConfig) String() string { if p == nil { @@ -4558,6 +4621,9 @@ func (p *ToolCallConfig) DeepEqual(ano *ToolCallConfig) bool { if !p.Field1DeepEqual(ano.ToolChoice) { return false } + if !p.Field2DeepEqual(ano.ToolChoiceSpecification) { + return false + } return true } @@ -4573,6 +4639,271 @@ func (p *ToolCallConfig) Field1DeepEqual(src *ToolChoiceType) bool { } return true } +func (p *ToolCallConfig) Field2DeepEqual(src *ToolChoiceSpecification) bool { + + if !p.ToolChoiceSpecification.DeepEqual(src) { + return false + } + return true +} + +type ToolChoiceSpecification struct { + Type *ToolType `thrift:"type,1,optional" frugal:"1,optional,string" form:"type" json:"type,omitempty" query:"type"` + Name *string `thrift:"name,2,optional" frugal:"2,optional,string" form:"name" json:"name,omitempty" query:"name"` +} + +func NewToolChoiceSpecification() *ToolChoiceSpecification { + return &ToolChoiceSpecification{} +} + +func (p *ToolChoiceSpecification) InitDefault() { +} + +var ToolChoiceSpecification_Type_DEFAULT ToolType + +func (p *ToolChoiceSpecification) GetType() (v ToolType) { + if p == nil { + return + } + if !p.IsSetType() { + return ToolChoiceSpecification_Type_DEFAULT + } + return *p.Type +} + +var ToolChoiceSpecification_Name_DEFAULT string + +func (p *ToolChoiceSpecification) GetName() (v string) { + if p == nil { + return + } + if !p.IsSetName() { + return ToolChoiceSpecification_Name_DEFAULT + } + return *p.Name +} +func (p *ToolChoiceSpecification) SetType(val *ToolType) { + p.Type = val +} +func (p *ToolChoiceSpecification) SetName(val *string) { + p.Name = val +} + +var fieldIDToName_ToolChoiceSpecification = map[int16]string{ + 1: "type", + 2: "name", +} + +func (p *ToolChoiceSpecification) IsSetType() bool { + return p.Type != nil +} + +func (p *ToolChoiceSpecification) IsSetName() bool { + return p.Name != nil +} + +func (p *ToolChoiceSpecification) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 2: + if fieldTypeId == thrift.STRING { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ToolChoiceSpecification[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *ToolChoiceSpecification) ReadField1(iprot thrift.TProtocol) error { + + var _field *ToolType + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Type = _field + return nil +} +func (p *ToolChoiceSpecification) ReadField2(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Name = _field + return nil +} + +func (p *ToolChoiceSpecification) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("ToolChoiceSpecification"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *ToolChoiceSpecification) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetType() { + if err = oprot.WriteFieldBegin("type", thrift.STRING, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Type); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} +func (p *ToolChoiceSpecification) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetName() { + if err = oprot.WriteFieldBegin("name", thrift.STRING, 2); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Name); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} + +func (p *ToolChoiceSpecification) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("ToolChoiceSpecification(%+v)", *p) + +} + +func (p *ToolChoiceSpecification) DeepEqual(ano *ToolChoiceSpecification) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.Type) { + return false + } + if !p.Field2DeepEqual(ano.Name) { + return false + } + return true +} + +func (p *ToolChoiceSpecification) Field1DeepEqual(src *ToolType) bool { + + if p.Type == src { + return true + } else if p.Type == nil || src == nil { + return false + } + if strings.Compare(*p.Type, *src) != 0 { + return false + } + return true +} +func (p *ToolChoiceSpecification) Field2DeepEqual(src *string) bool { + + if p.Name == src { + return true + } else if p.Name == nil || src == nil { + return false + } + if strings.Compare(*p.Name, *src) != 0 { + return false + } + return true +} type ModelConfig struct { ModelID *int64 `thrift:"model_id,1,optional" frugal:"1,optional,i64" json:"model_id" form:"model_id" query:"model_id"` diff --git a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt_validator.go b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt_validator.go index 658007ba5..2cc88b636 100644 --- a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt_validator.go +++ b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt_validator.go @@ -107,6 +107,14 @@ func (p *Function) IsValid() error { return nil } func (p *ToolCallConfig) IsValid() error { + if p.ToolChoiceSpecification != nil { + if err := p.ToolChoiceSpecification.IsValid(); err != nil { + return fmt.Errorf("field ToolChoiceSpecification not valid, %w", err) + } + } + return nil +} +func (p *ToolChoiceSpecification) IsValid() error { return nil } func (p *ModelConfig) IsValid() error { diff --git a/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go b/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go index 9dc3751ec..f6888a966 100644 --- a/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go +++ b/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go @@ -24,6 +24,8 @@ const ( ToolChoiceTypeNone = "none" + ToolChoiceTypeSpecific = "specific" + ContentTypeText = "text" ContentTypeImageURL = "image_url" @@ -69,6 +71,8 @@ const ( RolePlaceholder = "placeholder" ToolTypeFunction = "function" + + ToolTypeGoogleSearch = "google_search" ) type TemplateType = string @@ -4931,7 +4935,8 @@ func (p *PromptTemplate) Field100DeepEqual(src map[string]string) bool { } type ToolCallConfig struct { - ToolChoice *ToolChoiceType `thrift:"tool_choice,1,optional" frugal:"1,optional,string" form:"tool_choice" json:"tool_choice,omitempty" query:"tool_choice"` + ToolChoice *ToolChoiceType `thrift:"tool_choice,1,optional" frugal:"1,optional,string" form:"tool_choice" json:"tool_choice,omitempty" query:"tool_choice"` + ToolChoiceSpecification *ToolChoiceSpecification `thrift:"tool_choice_specification,2,optional" frugal:"2,optional,ToolChoiceSpecification" form:"tool_choice_specification" json:"tool_choice_specification,omitempty" query:"tool_choice_specification"` } func NewToolCallConfig() *ToolCallConfig { @@ -4952,18 +4957,38 @@ func (p *ToolCallConfig) GetToolChoice() (v ToolChoiceType) { } return *p.ToolChoice } + +var ToolCallConfig_ToolChoiceSpecification_DEFAULT *ToolChoiceSpecification + +func (p *ToolCallConfig) GetToolChoiceSpecification() (v *ToolChoiceSpecification) { + if p == nil { + return + } + if !p.IsSetToolChoiceSpecification() { + return ToolCallConfig_ToolChoiceSpecification_DEFAULT + } + return p.ToolChoiceSpecification +} func (p *ToolCallConfig) SetToolChoice(val *ToolChoiceType) { p.ToolChoice = val } +func (p *ToolCallConfig) SetToolChoiceSpecification(val *ToolChoiceSpecification) { + p.ToolChoiceSpecification = val +} var fieldIDToName_ToolCallConfig = map[int16]string{ 1: "tool_choice", + 2: "tool_choice_specification", } func (p *ToolCallConfig) IsSetToolChoice() bool { return p.ToolChoice != nil } +func (p *ToolCallConfig) IsSetToolChoiceSpecification() bool { + return p.ToolChoiceSpecification != nil +} + func (p *ToolCallConfig) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -4990,6 +5015,14 @@ func (p *ToolCallConfig) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 2: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError @@ -5030,6 +5063,14 @@ func (p *ToolCallConfig) ReadField1(iprot thrift.TProtocol) error { p.ToolChoice = _field return nil } +func (p *ToolCallConfig) ReadField2(iprot thrift.TProtocol) error { + _field := NewToolChoiceSpecification() + if err := _field.Read(iprot); err != nil { + return err + } + p.ToolChoiceSpecification = _field + return nil +} func (p *ToolCallConfig) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 @@ -5041,6 +5082,10 @@ func (p *ToolCallConfig) Write(oprot thrift.TProtocol) (err error) { fieldId = 1 goto WriteFieldError } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -5077,6 +5122,24 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) } +func (p *ToolCallConfig) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetToolChoiceSpecification() { + if err = oprot.WriteFieldBegin("tool_choice_specification", thrift.STRUCT, 2); err != nil { + goto WriteFieldBeginError + } + if err := p.ToolChoiceSpecification.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} func (p *ToolCallConfig) String() string { if p == nil { @@ -5095,6 +5158,9 @@ func (p *ToolCallConfig) DeepEqual(ano *ToolCallConfig) bool { if !p.Field1DeepEqual(ano.ToolChoice) { return false } + if !p.Field2DeepEqual(ano.ToolChoiceSpecification) { + return false + } return true } @@ -5110,6 +5176,271 @@ func (p *ToolCallConfig) Field1DeepEqual(src *ToolChoiceType) bool { } return true } +func (p *ToolCallConfig) Field2DeepEqual(src *ToolChoiceSpecification) bool { + + if !p.ToolChoiceSpecification.DeepEqual(src) { + return false + } + return true +} + +type ToolChoiceSpecification struct { + Type *ToolType `thrift:"type,1,optional" frugal:"1,optional,string" form:"type" json:"type,omitempty" query:"type"` + Name *string `thrift:"name,2,optional" frugal:"2,optional,string" form:"name" json:"name,omitempty" query:"name"` +} + +func NewToolChoiceSpecification() *ToolChoiceSpecification { + return &ToolChoiceSpecification{} +} + +func (p *ToolChoiceSpecification) InitDefault() { +} + +var ToolChoiceSpecification_Type_DEFAULT ToolType + +func (p *ToolChoiceSpecification) GetType() (v ToolType) { + if p == nil { + return + } + if !p.IsSetType() { + return ToolChoiceSpecification_Type_DEFAULT + } + return *p.Type +} + +var ToolChoiceSpecification_Name_DEFAULT string + +func (p *ToolChoiceSpecification) GetName() (v string) { + if p == nil { + return + } + if !p.IsSetName() { + return ToolChoiceSpecification_Name_DEFAULT + } + return *p.Name +} +func (p *ToolChoiceSpecification) SetType(val *ToolType) { + p.Type = val +} +func (p *ToolChoiceSpecification) SetName(val *string) { + p.Name = val +} + +var fieldIDToName_ToolChoiceSpecification = map[int16]string{ + 1: "type", + 2: "name", +} + +func (p *ToolChoiceSpecification) IsSetType() bool { + return p.Type != nil +} + +func (p *ToolChoiceSpecification) IsSetName() bool { + return p.Name != nil +} + +func (p *ToolChoiceSpecification) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 2: + if fieldTypeId == thrift.STRING { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ToolChoiceSpecification[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *ToolChoiceSpecification) ReadField1(iprot thrift.TProtocol) error { + + var _field *ToolType + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Type = _field + return nil +} +func (p *ToolChoiceSpecification) ReadField2(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Name = _field + return nil +} + +func (p *ToolChoiceSpecification) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("ToolChoiceSpecification"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *ToolChoiceSpecification) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetType() { + if err = oprot.WriteFieldBegin("type", thrift.STRING, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Type); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} +func (p *ToolChoiceSpecification) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetName() { + if err = oprot.WriteFieldBegin("name", thrift.STRING, 2); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Name); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} + +func (p *ToolChoiceSpecification) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("ToolChoiceSpecification(%+v)", *p) + +} + +func (p *ToolChoiceSpecification) DeepEqual(ano *ToolChoiceSpecification) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.Type) { + return false + } + if !p.Field2DeepEqual(ano.Name) { + return false + } + return true +} + +func (p *ToolChoiceSpecification) Field1DeepEqual(src *ToolType) bool { + + if p.Type == src { + return true + } else if p.Type == nil || src == nil { + return false + } + if strings.Compare(*p.Type, *src) != 0 { + return false + } + return true +} +func (p *ToolChoiceSpecification) Field2DeepEqual(src *string) bool { + + if p.Name == src { + return true + } else if p.Name == nil || src == nil { + return false + } + if strings.Compare(*p.Name, *src) != 0 { + return false + } + return true +} type Message struct { // 角色 diff --git a/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi_validator.go b/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi_validator.go index 86bbe7350..a2fb6fb98 100644 --- a/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi_validator.go +++ b/backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi_validator.go @@ -148,6 +148,14 @@ func (p *PromptTemplate) IsValid() error { return nil } func (p *ToolCallConfig) IsValid() error { + if p.ToolChoiceSpecification != nil { + if err := p.ToolChoiceSpecification.IsValid(); err != nil { + return fmt.Errorf("field ToolChoiceSpecification not valid, %w", err) + } + } + return nil +} +func (p *ToolChoiceSpecification) IsValid() error { return nil } func (p *Message) IsValid() error { diff --git a/backend/kitex_gen/coze/loop/prompt/openapi/k-coze.loop.prompt.openapi.go b/backend/kitex_gen/coze/loop/prompt/openapi/k-coze.loop.prompt.openapi.go index 8d7ce402c..485c86299 100644 --- a/backend/kitex_gen/coze/loop/prompt/openapi/k-coze.loop.prompt.openapi.go +++ b/backend/kitex_gen/coze/loop/prompt/openapi/k-coze.loop.prompt.openapi.go @@ -3606,6 +3606,20 @@ func (p *ToolCallConfig) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 2: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } default: l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l @@ -3638,6 +3652,18 @@ func (p *ToolCallConfig) FastReadField1(buf []byte) (int, error) { return offset, nil } +func (p *ToolCallConfig) FastReadField2(buf []byte) (int, error) { + offset := 0 + _field := NewToolChoiceSpecification() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.ToolChoiceSpecification = _field + return offset, nil +} + func (p *ToolCallConfig) FastWrite(buf []byte) int { return p.FastWriteNocopy(buf, nil) } @@ -3646,6 +3672,7 @@ func (p *ToolCallConfig) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int offset := 0 if p != nil { offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) } offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset @@ -3655,6 +3682,7 @@ func (p *ToolCallConfig) BLength() int { l := 0 if p != nil { l += p.field1Length() + l += p.field2Length() } l += thrift.Binary.FieldStopLength() return l @@ -3669,6 +3697,15 @@ func (p *ToolCallConfig) fastWriteField1(buf []byte, w thrift.NocopyWriter) int return offset } +func (p *ToolCallConfig) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetToolChoiceSpecification() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 2) + offset += p.ToolChoiceSpecification.FastWriteNocopy(buf[offset:], w) + } + return offset +} + func (p *ToolCallConfig) field1Length() int { l := 0 if p.IsSetToolChoice() { @@ -3678,6 +3715,15 @@ func (p *ToolCallConfig) field1Length() int { return l } +func (p *ToolCallConfig) field2Length() int { + l := 0 + if p.IsSetToolChoiceSpecification() { + l += thrift.Binary.FieldBeginLength() + l += p.ToolChoiceSpecification.BLength() + } + return l +} + func (p *ToolCallConfig) DeepCopy(s interface{}) error { src, ok := s.(*ToolCallConfig) if !ok { @@ -3689,6 +3735,188 @@ func (p *ToolCallConfig) DeepCopy(s interface{}) error { p.ToolChoice = &tmp } + var _toolChoiceSpecification *ToolChoiceSpecification + if src.ToolChoiceSpecification != nil { + _toolChoiceSpecification = &ToolChoiceSpecification{} + if err := _toolChoiceSpecification.DeepCopy(src.ToolChoiceSpecification); err != nil { + return err + } + } + p.ToolChoiceSpecification = _toolChoiceSpecification + + return nil +} + +func (p *ToolChoiceSpecification) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ToolChoiceSpecification[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *ToolChoiceSpecification) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field *ToolType + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Type = _field + return offset, nil +} + +func (p *ToolChoiceSpecification) FastReadField2(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Name = _field + return offset, nil +} + +func (p *ToolChoiceSpecification) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *ToolChoiceSpecification) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *ToolChoiceSpecification) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + l += p.field2Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *ToolChoiceSpecification) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetType() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 1) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Type) + } + return offset +} + +func (p *ToolChoiceSpecification) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetName() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 2) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Name) + } + return offset +} + +func (p *ToolChoiceSpecification) field1Length() int { + l := 0 + if p.IsSetType() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Type) + } + return l +} + +func (p *ToolChoiceSpecification) field2Length() int { + l := 0 + if p.IsSetName() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Name) + } + return l +} + +func (p *ToolChoiceSpecification) DeepCopy(s interface{}) error { + src, ok := s.(*ToolChoiceSpecification) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.Type != nil { + tmp := *src.Type + p.Type = &tmp + } + + if src.Name != nil { + var tmp string + if *src.Name != "" { + tmp = kutils.StringDeepCopy(*src.Name) + } + p.Name = &tmp + } + return nil } diff --git a/backend/modules/prompt/application/convertor/openapi.go b/backend/modules/prompt/application/convertor/openapi.go index d29f264e4..c9acf5416 100644 --- a/backend/modules/prompt/application/convertor/openapi.go +++ b/backend/modules/prompt/application/convertor/openapi.go @@ -138,7 +138,18 @@ func OpenAPIToolCallConfigDO2DTO(do *entity.ToolCallConfig) *openapi.ToolCallCon return nil } return &openapi.ToolCallConfig{ - ToolChoice: ptr.Of(prompt.ToolChoiceType(do.ToolChoice)), + ToolChoice: ptr.Of(prompt.ToolChoiceType(do.ToolChoice)), + ToolChoiceSpecification: OpenAPIToolChoiceSpecificationDO2DTO(do.ToolChoiceSpecification), + } +} + +func OpenAPIToolChoiceSpecificationDO2DTO(do *entity.ToolChoiceSpecification) *openapi.ToolChoiceSpecification { + if do == nil { + return nil + } + return &openapi.ToolChoiceSpecification{ + Type: ptr.Of(prompt.ToolType(do.Type)), + Name: ptr.Of(do.Name), } } @@ -389,6 +400,8 @@ func OpenAPIToolTypeDO2DTO(do entity.ToolType) openapi.ToolType { switch do { case entity.ToolTypeFunction: return openapi.ToolTypeFunction + case entity.ToolTypeGoogleSearch: + return openapi.ToolTypeGoogleSearch default: return openapi.ToolTypeFunction } @@ -438,6 +451,8 @@ func OpenAPIToolTypeDTO2DO(dto openapi.ToolType) entity.ToolType { switch dto { case openapi.ToolTypeFunction: return entity.ToolTypeFunction + case openapi.ToolTypeGoogleSearch: + return entity.ToolTypeGoogleSearch default: return entity.ToolTypeFunction } diff --git a/backend/modules/prompt/application/convertor/openapi_test.go b/backend/modules/prompt/application/convertor/openapi_test.go index 8a2210f06..2b3234ea1 100755 --- a/backend/modules/prompt/application/convertor/openapi_test.go +++ b/backend/modules/prompt/application/convertor/openapi_test.go @@ -2084,3 +2084,169 @@ func TestOpenAPIPromptBasicDO2DTO(t *testing.T) { }) } } + +func TestOpenAPIToolTypeDO2DTO(t *testing.T) { + tests := []struct { + name string + do entity.ToolType + want openapi.ToolType + }{ + { + name: "function type", + do: entity.ToolTypeFunction, + want: openapi.ToolTypeFunction, + }, + { + name: "google_search type", + do: entity.ToolTypeGoogleSearch, + want: openapi.ToolTypeGoogleSearch, + }, + { + name: "unknown type defaults to function", + do: entity.ToolType("unknown"), + want: openapi.ToolTypeFunction, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, OpenAPIToolTypeDO2DTO(tt.do)) + }) + } +} + +func TestOpenAPIToolTypeDTO2DO(t *testing.T) { + tests := []struct { + name string + dto openapi.ToolType + want entity.ToolType + }{ + { + name: "function type", + dto: openapi.ToolTypeFunction, + want: entity.ToolTypeFunction, + }, + { + name: "google_search type", + dto: openapi.ToolTypeGoogleSearch, + want: entity.ToolTypeGoogleSearch, + }, + { + name: "unknown type defaults to function", + dto: openapi.ToolType("unknown"), + want: entity.ToolTypeFunction, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, OpenAPIToolTypeDTO2DO(tt.dto)) + }) + } +} + +func TestOpenAPIToolChoiceSpecificationDO2DTO(t *testing.T) { + tests := []struct { + name string + do *entity.ToolChoiceSpecification + want *openapi.ToolChoiceSpecification + }{ + { + name: "nil input", + do: nil, + want: nil, + }, + { + name: "specification with function type", + do: &entity.ToolChoiceSpecification{ + Type: entity.ToolTypeFunction, + Name: "get_weather", + }, + want: &openapi.ToolChoiceSpecification{ + Type: ptr.Of(prompt.ToolTypeFunction), + Name: ptr.Of("get_weather"), + }, + }, + { + name: "specification with google_search type", + do: &entity.ToolChoiceSpecification{ + Type: entity.ToolTypeGoogleSearch, + Name: "search", + }, + want: &openapi.ToolChoiceSpecification{ + Type: ptr.Of(prompt.ToolTypeGoogleSearch), + Name: ptr.Of("search"), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, OpenAPIToolChoiceSpecificationDO2DTO(tt.do)) + }) + } +} + +func TestOpenAPIToolCallConfigDO2DTO_WithSpecification(t *testing.T) { + tests := []struct { + name string + do *entity.ToolCallConfig + want *openapi.ToolCallConfig + }{ + { + name: "nil input", + do: nil, + want: nil, + }, + { + name: "auto without specification", + do: &entity.ToolCallConfig{ + ToolChoice: entity.ToolChoiceTypeAuto, + }, + want: &openapi.ToolCallConfig{ + ToolChoice: ptr.Of(prompt.ToolChoiceTypeAuto), + ToolChoiceSpecification: nil, + }, + }, + { + name: "specific with specification", + do: &entity.ToolCallConfig{ + ToolChoice: entity.ToolChoiceTypeSpecific, + ToolChoiceSpecification: &entity.ToolChoiceSpecification{ + Type: entity.ToolTypeFunction, + Name: "get_weather", + }, + }, + want: &openapi.ToolCallConfig{ + ToolChoice: ptr.Of(prompt.ToolChoiceTypeSpecific), + ToolChoiceSpecification: &openapi.ToolChoiceSpecification{ + Type: ptr.Of(prompt.ToolTypeFunction), + Name: ptr.Of("get_weather"), + }, + }, + }, + { + name: "specific with google_search specification", + do: &entity.ToolCallConfig{ + ToolChoice: entity.ToolChoiceTypeSpecific, + ToolChoiceSpecification: &entity.ToolChoiceSpecification{ + Type: entity.ToolTypeGoogleSearch, + Name: "search", + }, + }, + want: &openapi.ToolCallConfig{ + ToolChoice: ptr.Of(prompt.ToolChoiceTypeSpecific), + ToolChoiceSpecification: &openapi.ToolChoiceSpecification{ + Type: ptr.Of(prompt.ToolTypeGoogleSearch), + Name: ptr.Of("search"), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, OpenAPIToolCallConfigDO2DTO(tt.do)) + }) + } +} diff --git a/backend/modules/prompt/application/convertor/prompt.go b/backend/modules/prompt/application/convertor/prompt.go index 154f0875c..12e49b60b 100644 --- a/backend/modules/prompt/application/convertor/prompt.go +++ b/backend/modules/prompt/application/convertor/prompt.go @@ -374,6 +374,10 @@ func ToolCallDTO2DO(dto *prompt.ToolCall) *entity.ToolCall { func ToolTypeDTO2DO(dto prompt.ToolType) entity.ToolType { switch dto { + case prompt.ToolTypeFunction: + return entity.ToolTypeFunction + case prompt.ToolTypeGoogleSearch: + return entity.ToolTypeGoogleSearch default: return entity.ToolTypeFunction } @@ -396,7 +400,19 @@ func ToolCallConfigDTO2DO(dto *prompt.ToolCallConfig) *entity.ToolCallConfig { } return &entity.ToolCallConfig{ - ToolChoice: ToolChoiceTypeDTO2DO(dto.GetToolChoice()), + ToolChoice: ToolChoiceTypeDTO2DO(dto.GetToolChoice()), + ToolChoiceSpecification: ToolChoiceSpecificationDTO2DO(dto.ToolChoiceSpecification), + } +} + +func ToolChoiceSpecificationDTO2DO(dto *prompt.ToolChoiceSpecification) *entity.ToolChoiceSpecification { + if dto == nil { + return nil + } + + return &entity.ToolChoiceSpecification{ + Type: ToolTypeDTO2DO(dto.GetType()), + Name: dto.GetName(), } } @@ -406,6 +422,8 @@ func ToolChoiceTypeDTO2DO(dto prompt.ToolChoiceType) entity.ToolChoiceType { return entity.ToolChoiceTypeNone case prompt.ToolChoiceTypeAuto: return entity.ToolChoiceTypeAuto + case prompt.ToolChoiceTypeSpecific: + return entity.ToolChoiceTypeSpecific default: return entity.ToolChoiceTypeAuto } @@ -511,6 +529,10 @@ func ToolCallDO2DTO(do *entity.ToolCall) *prompt.ToolCall { func ToolTypeDO2DTO(do entity.ToolType) prompt.ToolType { switch do { + case entity.ToolTypeFunction: + return prompt.ToolTypeFunction + case entity.ToolTypeGoogleSearch: + return prompt.ToolTypeGoogleSearch default: return prompt.ToolTypeFunction } @@ -838,7 +860,18 @@ func ToolCallConfigDO2DTO(do *entity.ToolCallConfig) *prompt.ToolCallConfig { return nil } return &prompt.ToolCallConfig{ - ToolChoice: ptr.Of(prompt.ToolChoiceType(do.ToolChoice)), + ToolChoice: ptr.Of(prompt.ToolChoiceType(do.ToolChoice)), + ToolChoiceSpecification: ToolChoiceSpecificationDO2DTO(do.ToolChoiceSpecification), + } +} + +func ToolChoiceSpecificationDO2DTO(do *entity.ToolChoiceSpecification) *prompt.ToolChoiceSpecification { + if do == nil { + return nil + } + return &prompt.ToolChoiceSpecification{ + Type: ptr.Of(prompt.ToolType(do.Type)), + Name: ptr.Of(do.Name), } } diff --git a/backend/modules/prompt/application/convertor/prompt_test.go b/backend/modules/prompt/application/convertor/prompt_test.go index 4e9b15311..81917072d 100644 --- a/backend/modules/prompt/application/convertor/prompt_test.go +++ b/backend/modules/prompt/application/convertor/prompt_test.go @@ -810,3 +810,310 @@ func TestPromptTemplateWithDifferentTypes(t *testing.T) { }) } } + +func TestToolTypeDO2DTO(t *testing.T) { + tests := []struct { + name string + do entity.ToolType + want prompt.ToolType + }{ + { + name: "function type", + do: entity.ToolTypeFunction, + want: prompt.ToolTypeFunction, + }, + { + name: "google_search type", + do: entity.ToolTypeGoogleSearch, + want: prompt.ToolTypeGoogleSearch, + }, + { + name: "unknown type defaults to function", + do: entity.ToolType("unknown"), + want: prompt.ToolTypeFunction, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, ToolTypeDO2DTO(tt.do)) + }) + } +} + +func TestToolTypeDTO2DO(t *testing.T) { + tests := []struct { + name string + dto prompt.ToolType + want entity.ToolType + }{ + { + name: "function type", + dto: prompt.ToolTypeFunction, + want: entity.ToolTypeFunction, + }, + { + name: "google_search type", + dto: prompt.ToolTypeGoogleSearch, + want: entity.ToolTypeGoogleSearch, + }, + { + name: "unknown type defaults to function", + dto: prompt.ToolType("unknown"), + want: entity.ToolTypeFunction, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, ToolTypeDTO2DO(tt.dto)) + }) + } +} + +func TestToolChoiceSpecificationDO2DTO(t *testing.T) { + tests := []struct { + name string + do *entity.ToolChoiceSpecification + want *prompt.ToolChoiceSpecification + }{ + { + name: "nil input", + do: nil, + want: nil, + }, + { + name: "specification with function type", + do: &entity.ToolChoiceSpecification{ + Type: entity.ToolTypeFunction, + Name: "get_weather", + }, + want: &prompt.ToolChoiceSpecification{ + Type: ptr.Of(prompt.ToolTypeFunction), + Name: ptr.Of("get_weather"), + }, + }, + { + name: "specification with google_search type", + do: &entity.ToolChoiceSpecification{ + Type: entity.ToolTypeGoogleSearch, + Name: "search", + }, + want: &prompt.ToolChoiceSpecification{ + Type: ptr.Of(prompt.ToolTypeGoogleSearch), + Name: ptr.Of("search"), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, ToolChoiceSpecificationDO2DTO(tt.do)) + }) + } +} + +func TestToolChoiceSpecificationDTO2DO(t *testing.T) { + tests := []struct { + name string + dto *prompt.ToolChoiceSpecification + want *entity.ToolChoiceSpecification + }{ + { + name: "nil input", + dto: nil, + want: nil, + }, + { + name: "specification with function type", + dto: &prompt.ToolChoiceSpecification{ + Type: ptr.Of(prompt.ToolTypeFunction), + Name: ptr.Of("get_weather"), + }, + want: &entity.ToolChoiceSpecification{ + Type: entity.ToolTypeFunction, + Name: "get_weather", + }, + }, + { + name: "specification with google_search type", + dto: &prompt.ToolChoiceSpecification{ + Type: ptr.Of(prompt.ToolTypeGoogleSearch), + Name: ptr.Of("search"), + }, + want: &entity.ToolChoiceSpecification{ + Type: entity.ToolTypeGoogleSearch, + Name: "search", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, ToolChoiceSpecificationDTO2DO(tt.dto)) + }) + } +} + +func TestToolCallConfigDO2DTO_WithSpecification(t *testing.T) { + tests := []struct { + name string + do *entity.ToolCallConfig + want *prompt.ToolCallConfig + }{ + { + name: "nil input", + do: nil, + want: nil, + }, + { + name: "auto without specification", + do: &entity.ToolCallConfig{ + ToolChoice: entity.ToolChoiceTypeAuto, + }, + want: &prompt.ToolCallConfig{ + ToolChoice: ptr.Of(prompt.ToolChoiceTypeAuto), + ToolChoiceSpecification: nil, + }, + }, + { + name: "specific with specification", + do: &entity.ToolCallConfig{ + ToolChoice: entity.ToolChoiceTypeSpecific, + ToolChoiceSpecification: &entity.ToolChoiceSpecification{ + Type: entity.ToolTypeFunction, + Name: "get_weather", + }, + }, + want: &prompt.ToolCallConfig{ + ToolChoice: ptr.Of(prompt.ToolChoiceTypeSpecific), + ToolChoiceSpecification: &prompt.ToolChoiceSpecification{ + Type: ptr.Of(prompt.ToolTypeFunction), + Name: ptr.Of("get_weather"), + }, + }, + }, + { + name: "specific with google_search specification", + do: &entity.ToolCallConfig{ + ToolChoice: entity.ToolChoiceTypeSpecific, + ToolChoiceSpecification: &entity.ToolChoiceSpecification{ + Type: entity.ToolTypeGoogleSearch, + Name: "search", + }, + }, + want: &prompt.ToolCallConfig{ + ToolChoice: ptr.Of(prompt.ToolChoiceTypeSpecific), + ToolChoiceSpecification: &prompt.ToolChoiceSpecification{ + Type: ptr.Of(prompt.ToolTypeGoogleSearch), + Name: ptr.Of("search"), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, ToolCallConfigDO2DTO(tt.do)) + }) + } +} + +func TestToolCallConfigDTO2DO_WithSpecification(t *testing.T) { + tests := []struct { + name string + dto *prompt.ToolCallConfig + want *entity.ToolCallConfig + }{ + { + name: "nil input", + dto: nil, + want: nil, + }, + { + name: "auto without specification", + dto: &prompt.ToolCallConfig{ + ToolChoice: ptr.Of(prompt.ToolChoiceTypeAuto), + }, + want: &entity.ToolCallConfig{ + ToolChoice: entity.ToolChoiceTypeAuto, + ToolChoiceSpecification: nil, + }, + }, + { + name: "specific with specification", + dto: &prompt.ToolCallConfig{ + ToolChoice: ptr.Of(prompt.ToolChoiceTypeSpecific), + ToolChoiceSpecification: &prompt.ToolChoiceSpecification{ + Type: ptr.Of(prompt.ToolTypeFunction), + Name: ptr.Of("get_weather"), + }, + }, + want: &entity.ToolCallConfig{ + ToolChoice: entity.ToolChoiceTypeSpecific, + ToolChoiceSpecification: &entity.ToolChoiceSpecification{ + Type: entity.ToolTypeFunction, + Name: "get_weather", + }, + }, + }, + { + name: "specific with google_search specification", + dto: &prompt.ToolCallConfig{ + ToolChoice: ptr.Of(prompt.ToolChoiceTypeSpecific), + ToolChoiceSpecification: &prompt.ToolChoiceSpecification{ + Type: ptr.Of(prompt.ToolTypeGoogleSearch), + Name: ptr.Of("search"), + }, + }, + want: &entity.ToolCallConfig{ + ToolChoice: entity.ToolChoiceTypeSpecific, + ToolChoiceSpecification: &entity.ToolChoiceSpecification{ + Type: entity.ToolTypeGoogleSearch, + Name: "search", + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, ToolCallConfigDTO2DO(tt.dto)) + }) + } +} + +func TestToolChoiceTypeDTO2DO(t *testing.T) { + tests := []struct { + name string + dto prompt.ToolChoiceType + want entity.ToolChoiceType + }{ + { + name: "none type", + dto: prompt.ToolChoiceTypeNone, + want: entity.ToolChoiceTypeNone, + }, + { + name: "auto type", + dto: prompt.ToolChoiceTypeAuto, + want: entity.ToolChoiceTypeAuto, + }, + { + name: "specific type", + dto: prompt.ToolChoiceTypeSpecific, + want: entity.ToolChoiceTypeSpecific, + }, + { + name: "unknown type defaults to auto", + dto: prompt.ToolChoiceType("unknown"), + want: entity.ToolChoiceTypeAuto, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, ToolChoiceTypeDTO2DO(tt.dto)) + }) + } +} diff --git a/backend/modules/prompt/domain/entity/prompt_detail.go b/backend/modules/prompt/domain/entity/prompt_detail.go index 7b608298c..c5558dc19 100644 --- a/backend/modules/prompt/domain/entity/prompt_detail.go +++ b/backend/modules/prompt/domain/entity/prompt_detail.go @@ -142,7 +142,8 @@ type Tool struct { type ToolType string const ( - ToolTypeFunction ToolType = "function" + ToolTypeFunction ToolType = "function" + ToolTypeGoogleSearch ToolType = "google_search" ) type Function struct { @@ -152,16 +153,23 @@ type Function struct { } type ToolCallConfig struct { - ToolChoice ToolChoiceType `json:"tool_choice"` + ToolChoice ToolChoiceType `json:"tool_choice"` + ToolChoiceSpecification *ToolChoiceSpecification `json:"tool_choice_specification,omitempty"` } type ToolChoiceType string const ( - ToolChoiceTypeNone ToolChoiceType = "none" - ToolChoiceTypeAuto ToolChoiceType = "auto" + ToolChoiceTypeNone ToolChoiceType = "none" + ToolChoiceTypeAuto ToolChoiceType = "auto" + ToolChoiceTypeSpecific ToolChoiceType = "specific" ) +type ToolChoiceSpecification struct { + Type ToolType `json:"type"` + Name string `json:"name"` +} + type ToolCall struct { Index int64 `json:"index"` ID string `json:"id"` diff --git a/backend/modules/prompt/domain/service/execute.go b/backend/modules/prompt/domain/service/execute.go index eb29da0d8..3f857de34 100644 --- a/backend/modules/prompt/domain/service/execute.go +++ b/backend/modules/prompt/domain/service/execute.go @@ -383,9 +383,22 @@ func (p *PromptServiceImpl) prepareLLMCallParam(ctx context.Context, param Execu // call llm promptDetail := param.Prompt.GetPromptDetail() var tools []*entity.Tool + var toolCallConfig *entity.ToolCallConfig if promptDetail != nil { if promptDetail.ToolCallConfig != nil && promptDetail.ToolCallConfig.ToolChoice != entity.ToolChoiceTypeNone { tools = promptDetail.Tools + toolCallConfig = promptDetail.ToolCallConfig + } + } + // Validate tool choice specification + if toolCallConfig != nil && toolCallConfig.ToolChoice == entity.ToolChoiceTypeSpecific { + // When tool choice is specific, must be in single step mode + if !param.SingleStep { + return rpc.LLMCallParam{}, errorx.New("tool choice specific must be used with single step mode to avoid infinite loops") + } + // ToolChoiceSpecification must not be empty + if toolCallConfig.ToolChoiceSpecification == nil { + return rpc.LLMCallParam{}, errorx.New("tool_choice_specification must not be empty when tool choice is specific") } } var modelConfig *entity.ModelConfig @@ -405,7 +418,7 @@ func (p *PromptServiceImpl) prepareLLMCallParam(ctx context.Context, param Execu UserID: userID, Messages: messages, Tools: tools, - ToolCallConfig: nil, + ToolCallConfig: toolCallConfig, ModelConfig: modelConfig, }, nil } diff --git a/backend/modules/prompt/domain/service/execute_test.go b/backend/modules/prompt/domain/service/execute_test.go index e22fbcfe6..7d5aef92b 100644 --- a/backend/modules/prompt/domain/service/execute_test.go +++ b/backend/modules/prompt/domain/service/execute_test.go @@ -994,3 +994,277 @@ func TestPromptServiceImpl_prepareLLMCallParam_PreservesExtra(t *testing.T) { assert.Equal(t, prompt.PromptCommit.PromptDetail.ModelConfig.Extra, got.ModelConfig.Extra) } } + +func TestPromptServiceImpl_prepareLLMCallParam_ValidationErrors(t *testing.T) { + t.Parallel() + svc := &PromptServiceImpl{ + formatter: NewPromptFormatter(), + } + + tests := []struct { + name string + param ExecuteParam + wantErr bool + errContains string + }{ + { + name: "specific tool choice without single step mode - should error", + param: ExecuteParam{ + Prompt: &entity.Prompt{ + ID: 1, + SpaceID: 42, + PromptKey: "test_prompt", + PromptCommit: &entity.PromptCommit{ + CommitInfo: &entity.CommitInfo{ + Version: "v1", + }, + PromptDetail: &entity.PromptDetail{ + ToolCallConfig: &entity.ToolCallConfig{ + ToolChoice: entity.ToolChoiceTypeSpecific, + ToolChoiceSpecification: &entity.ToolChoiceSpecification{ + Type: entity.ToolTypeFunction, + Name: "get_weather", + }, + }, + Tools: []*entity.Tool{ + { + Type: entity.ToolTypeFunction, + Function: &entity.Function{ + Name: "get_weather", + Description: "Get weather", + Parameters: "{}", + }, + }, + }, + ModelConfig: &entity.ModelConfig{ + ModelID: 1, + }, + PromptTemplate: &entity.PromptTemplate{ + TemplateType: entity.TemplateTypeNormal, + Messages: []*entity.Message{ + { + Role: entity.RoleSystem, + Content: ptr.Of("Test"), + }, + }, + }, + }, + }, + }, + Messages: []*entity.Message{}, + SingleStep: false, // Should be true for specific tool choice + Scenario: entity.ScenarioPromptDebug, + }, + wantErr: true, + errContains: "single step mode", + }, + { + name: "specific tool choice without specification - should error", + param: ExecuteParam{ + Prompt: &entity.Prompt{ + ID: 1, + SpaceID: 42, + PromptKey: "test_prompt", + PromptCommit: &entity.PromptCommit{ + CommitInfo: &entity.CommitInfo{ + Version: "v1", + }, + PromptDetail: &entity.PromptDetail{ + ToolCallConfig: &entity.ToolCallConfig{ + ToolChoice: entity.ToolChoiceTypeSpecific, + ToolChoiceSpecification: nil, // Should not be nil + }, + Tools: []*entity.Tool{ + { + Type: entity.ToolTypeFunction, + Function: &entity.Function{ + Name: "get_weather", + Description: "Get weather", + Parameters: "{}", + }, + }, + }, + ModelConfig: &entity.ModelConfig{ + ModelID: 1, + }, + PromptTemplate: &entity.PromptTemplate{ + TemplateType: entity.TemplateTypeNormal, + Messages: []*entity.Message{ + { + Role: entity.RoleSystem, + Content: ptr.Of("Test"), + }, + }, + }, + }, + }, + }, + Messages: []*entity.Message{}, + SingleStep: true, + Scenario: entity.ScenarioPromptDebug, + }, + wantErr: true, + errContains: "must not be empty", + }, + { + name: "specific tool choice with single step and specification - should succeed", + param: ExecuteParam{ + Prompt: &entity.Prompt{ + ID: 1, + SpaceID: 42, + PromptKey: "test_prompt", + PromptCommit: &entity.PromptCommit{ + CommitInfo: &entity.CommitInfo{ + Version: "v1", + }, + PromptDetail: &entity.PromptDetail{ + ToolCallConfig: &entity.ToolCallConfig{ + ToolChoice: entity.ToolChoiceTypeSpecific, + ToolChoiceSpecification: &entity.ToolChoiceSpecification{ + Type: entity.ToolTypeFunction, + Name: "get_weather", + }, + }, + Tools: []*entity.Tool{ + { + Type: entity.ToolTypeFunction, + Function: &entity.Function{ + Name: "get_weather", + Description: "Get weather", + Parameters: "{}", + }, + }, + }, + ModelConfig: &entity.ModelConfig{ + ModelID: 1, + }, + PromptTemplate: &entity.PromptTemplate{ + TemplateType: entity.TemplateTypeNormal, + Messages: []*entity.Message{ + { + Role: entity.RoleSystem, + Content: ptr.Of("Test"), + }, + }, + }, + }, + }, + }, + Messages: []*entity.Message{}, + SingleStep: true, + Scenario: entity.ScenarioPromptDebug, + }, + wantErr: false, + }, + { + name: "specific tool choice with google_search - should succeed", + param: ExecuteParam{ + Prompt: &entity.Prompt{ + ID: 1, + SpaceID: 42, + PromptKey: "test_prompt", + PromptCommit: &entity.PromptCommit{ + CommitInfo: &entity.CommitInfo{ + Version: "v1", + }, + PromptDetail: &entity.PromptDetail{ + ToolCallConfig: &entity.ToolCallConfig{ + ToolChoice: entity.ToolChoiceTypeSpecific, + ToolChoiceSpecification: &entity.ToolChoiceSpecification{ + Type: entity.ToolTypeGoogleSearch, + Name: "search", + }, + }, + Tools: []*entity.Tool{ + { + Type: entity.ToolTypeGoogleSearch, + }, + }, + ModelConfig: &entity.ModelConfig{ + ModelID: 1, + }, + PromptTemplate: &entity.PromptTemplate{ + TemplateType: entity.TemplateTypeNormal, + Messages: []*entity.Message{ + { + Role: entity.RoleSystem, + Content: ptr.Of("Test"), + }, + }, + }, + }, + }, + }, + Messages: []*entity.Message{}, + SingleStep: true, + Scenario: entity.ScenarioPromptDebug, + }, + wantErr: false, + }, + { + name: "auto tool choice - should succeed without validation", + param: ExecuteParam{ + Prompt: &entity.Prompt{ + ID: 1, + SpaceID: 42, + PromptKey: "test_prompt", + PromptCommit: &entity.PromptCommit{ + CommitInfo: &entity.CommitInfo{ + Version: "v1", + }, + PromptDetail: &entity.PromptDetail{ + ToolCallConfig: &entity.ToolCallConfig{ + ToolChoice: entity.ToolChoiceTypeAuto, + }, + Tools: []*entity.Tool{ + { + Type: entity.ToolTypeFunction, + Function: &entity.Function{ + Name: "get_weather", + Description: "Get weather", + Parameters: "{}", + }, + }, + }, + ModelConfig: &entity.ModelConfig{ + ModelID: 1, + }, + PromptTemplate: &entity.PromptTemplate{ + TemplateType: entity.TemplateTypeNormal, + Messages: []*entity.Message{ + { + Role: entity.RoleSystem, + Content: ptr.Of("Test"), + }, + }, + }, + }, + }, + }, + Messages: []*entity.Message{}, + SingleStep: false, + Scenario: entity.ScenarioPromptDebug, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := svc.prepareLLMCallParam(context.Background(), tt.param) + if tt.wantErr { + assert.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + assert.NoError(t, err) + assert.NotNil(t, got) + if tt.param.Prompt.PromptCommit.PromptDetail.ToolCallConfig != nil { + assert.Equal(t, tt.param.Prompt.PromptCommit.PromptDetail.ToolCallConfig, got.ToolCallConfig) + } + } + }) + } +} diff --git a/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift b/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift index f00e44235..fc6164aad 100644 --- a/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift +++ b/idl/thrift/coze/loop/prompt/coze.loop.prompt.openapi.thrift @@ -109,9 +109,16 @@ const TemplateType TemplateType_CustomTemplate_M = "custom_template_m" typedef string ToolChoiceType const ToolChoiceType ToolChoiceType_Auto = "auto" const ToolChoiceType ToolChoiceType_None = "none" +const ToolChoiceType ToolChoiceType_Specific = "specific" struct ToolCallConfig { 1: optional ToolChoiceType tool_choice + 2: optional ToolChoiceSpecification tool_choice_specification +} + +struct ToolChoiceSpecification { + 1: optional ToolType type + 2: optional string name } struct Message { @@ -179,6 +186,7 @@ struct Tool { typedef string ToolType (ts.enum="true") const ToolType ToolType_Function = "function" +const ToolType ToolType_GoogleSearch = "google_search" struct Function { 1: optional string name diff --git a/idl/thrift/coze/loop/prompt/domain/prompt.thrift b/idl/thrift/coze/loop/prompt/domain/prompt.thrift index 3c27d37cd..34a6bc443 100644 --- a/idl/thrift/coze/loop/prompt/domain/prompt.thrift +++ b/idl/thrift/coze/loop/prompt/domain/prompt.thrift @@ -78,6 +78,7 @@ struct Tool { typedef string ToolType (ts.enum="true") const ToolType ToolType_Function = "function" +const ToolType ToolType_GoogleSearch = "google_search" struct Function { 1: optional string name @@ -87,11 +88,18 @@ struct Function { struct ToolCallConfig { 1: optional ToolChoiceType tool_choice + 2: optional ToolChoiceSpecification tool_choice_specification +} + +struct ToolChoiceSpecification { + 1: optional ToolType type + 2: optional string name } typedef string ToolChoiceType (ts.enum="true") const ToolChoiceType ToolChoiceType_None = "none" const ToolChoiceType ToolChoiceType_Auto = "auto" +const ToolChoiceType ToolChoiceType_Specific = "specific" struct ModelConfig { 1: optional i64 model_id (api.js_conv="true", go.tag='json:"model_id"') From 3174aff0b61e3bf5757a8e91f888d073f8791808 Mon Sep 17 00:00:00 2001 From: "caijialin.626" Date: Tue, 11 Nov 2025 19:42:39 +0800 Subject: [PATCH 11/12] [feat][prompt] fix conflict --- backend/api/router/coze/loop/apis/middleware.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/backend/api/router/coze/loop/apis/middleware.go b/backend/api/router/coze/loop/apis/middleware.go index 527d0b72f..e9d9a1605 100644 --- a/backend/api/router/coze/loop/apis/middleware.go +++ b/backend/api/router/coze/loop/apis/middleware.go @@ -1645,3 +1645,8 @@ func _updateevaluationsetoapiMw(handler *apis.APIHandler) []app.HandlerFunc { // your code... return nil } + +func _listpromptbasicMw(handler *apis.APIHandler) []app.HandlerFunc { + // your code... + return nil +} From d0e59251dd60ef9d6eca312472e521f50e826ae8 Mon Sep 17 00:00:00 2001 From: kasarolzzw <39260341+kasarolzzw@users.noreply.github.com> Date: Tue, 18 Nov 2025 19:50:27 +0800 Subject: [PATCH 12/12] [feat][prompt] prompt snippet (#240) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [feat][prompt] add prompt list openapi * [feat][prompt] metadata (#224) * feat: [Coda] add prompt metadata pipelines Change-Id: I142040fd89ef89429df790dc6e4dfc5d04e53609 * sql Change-Id: I44b779e1c2a749217cec05dfbc609ae74474a217 * sql Change-Id: Ie46a446e3eca59cee9c3170d85b751c7c2aa630b * git ignore Change-Id: I59dd9b1517ded4a40e3fb3efcae608164adf9678 * ci Change-Id: I75a355cb6ec1700a54108f7e16a4717563cb7046 * [feat][prompt] prompt img video file upload (#244) * [feat][prompt] prompt support model config extra field (#260) * [feat][prompt] prompt support go template (#269) * [feat][prompt] prompt support go template * [feat][prompt] prompt support go template * [feat][prompt] prompt support go template ut * [feat][prompt] prompt support go template ut * [refactor][prompt] prompt support custom format (#271) * [feat][prompt] prompt support custom format * [feat][prompt] prompt support custom format * [feat][prompt] prompt support custom format * [feat][prompt] prompt support custom format * [feat][prompt] prompt support custom format * [feat][prompt] prompt support custom format * [feat][prompt] prompt support custom format * [feat][prompt] prompt support custom format * [feat][prompt] prompt support custom format * [feat][prompt] prompt support custom format * [feat][prompt] prompt support custom format * Snippet Co-Authored-By: Coda * refactor: [Coda] stop merging snippet variable definitions (LogID: 202510241537520100911040168388BC8) Co-Authored-By: Coda * list commit info Change-Id: I8194f0db96952a81c93a032791aed3732d25ba36 * feat:snippet bugfix * feat:UpdatedBy * feat:UpdatedBy * feat:snippet idl update * feat:snippet update * feat:snippet update * feat: [Coda] repo commit versions and sub prompt retrieval (LogID: 202511112106110100910922425618E3A) Co-Authored-By: Coda * feat:snippet update * feat:snippet update * Revert "feat:snippet update" This reverts commit 5bbd45757ba0d7d2e71bd9822164c103ea197c1c. * feat: [Coda] prompt commit versions query refinement (LogID: 20251113132425010091092242951D7F2) Co-Authored-By: Coda * feat:snippet update * feat:snippet update * feat:snippet update * feat:list repo * feat:list * Revert "feat:list" This reverts commit 2e99b10bef1325181b6b294c130d860fd3707bf0. * feat:list * feat:parentPromptCommitVersions * feat:parentPromptCommitVersions * feat:parentPromptCommitVersions * feat:idl * feat:解冲突 * feat:解冲突 * feat:snippet * feat:snippet * feat:snippet * feat:unit * feat:unit * feat:unit * feat:unit * feat:unit * feat:unit --------- Co-authored-by: caijialin.626 Co-authored-by: caijialin0626 <61818131+caijialin0626@users.noreply.github.com> Co-authored-by: Coda Co-authored-by: wangluming.wlm --- .../coze/loop/apis/prompt_manage_service.go | 6 + .../router/coze/loop/apis/coze.loop.apis.go | 1 + .../api/router/coze/loop/apis/middleware.go | 5 + backend/go.mod | 2 +- backend/go.sum | 2 + backend/infra/limiter/mocks/rate_limiter.go | 11 +- .../limiter/mocks/rate_limiter_factory.go | 9 +- .../infra/looptracer/mocks/cozeloop_client.go | 40 + .../session/mocks/session_service.go | 72 + .../platestwrite/latest_write_tracker.go | 1 + .../mocks/latest_write_tracker.go | 35 +- .../loop/apis/promptmanageservice/client.go | 6 + .../promptmanageservice.go | 36 + .../loop/prompt/domain/prompt/k-prompt.go | 552 +++ .../coze/loop/prompt/domain/prompt/prompt.go | 776 +++- .../prompt/domain/prompt/prompt_validator.go | 8 + .../prompt/manage/coze.loop.prompt.manage.go | 2781 +++++++++--- .../coze.loop.prompt.manage_validator.go | 28 + .../manage/k-coze.loop.prompt.manage.go | 1308 +++++- .../manage/promptmanageservice/client.go | 6 + .../promptmanageservice.go | 36 + .../loop/prompt/promptmanageservice/client.go | 6 + .../promptmanageservice.go | 36 + .../lomanage/local_promptmanageservice.go | 23 + .../data/domain/component/conf/mocks/conf.go | 1 + .../component/rpc/mocks/user_provider.go | 9 +- .../component/userinfo/mocks/userinfo.go | 9 +- .../data/domain/tag/repo/mocks/tag_mock.go | 131 +- .../tag/service/mocks/tag_service_mock.go | 129 +- .../entity/mocks/expt_scheduler_mock.go | 62 +- .../mocks/evaluator_event_publisher_mock.go | 14 +- .../events/mocks/expt_event_publisher_mock.go | 65 +- .../domain/repo/mocks/evaluator_mock.go | 110 +- .../repo/mocks/evaluator_record_mock.go | 38 +- .../domain/repo/mocks/ratelimiter_mock.go | 14 +- .../evaluator/mysql/mocks/evaluator_mock.go | 96 +- .../mysql/mocks/evaluator_record_mock.go | 46 +- .../mysql/mocks/evaluator_version_mock.go | 96 +- .../infra/repo/experiment/ck/mocks/expt.go | 30 +- .../experiment/mysql/mocks/annotate_record.go | 47 +- .../mysql/mocks/expt_aggr_result.go | 89 +- .../mysql/mocks/expt_evaluator_ref.go | 22 +- .../expt_insight_analysis_feedback_comment.go | 52 +- .../expt_insight_analysis_feedback_vote.go | 44 +- .../mocks/expt_insight_analysis_record.go | 52 +- .../mysql/mocks/expt_item_result.go | 142 +- .../mysql/mocks/expt_result_export_record.go | 44 +- .../experiment/mysql/mocks/expt_run_log.go | 46 +- .../repo/experiment/mysql/mocks/expt_stats.go | 54 +- .../mocks/expt_turn_annotate_record_ref.go | 68 +- .../mocks/expt_turn_evaluator_result_ref.go | 36 +- .../mysql/mocks/expt_turn_result.go | 179 +- .../expt_turn_result_filter_key_mapping.go | 36 +- .../repo/experiment/redis/dao/mocks/quota.go | 33 +- .../domain/authn/repo/mocks/authn_repo.go | 62 +- .../domain/user/service/mocks/user_service.go | 80 +- .../domain/component/metrics/mocks/metrics.go | 4 +- .../rpc/mocks/dataset_provider_mock.go | 49 +- .../collector/confmap/mocks/provider.go | 85 + .../entity/collector/mocks/conf_provider.go | 57 + .../mocks/trace_export_service_mock.go | 17 +- .../infra/rpc/dataset/dataset_test.go | 3 +- .../mocks/mock_datasetservice_client.go | 263 +- .../rpc/evaluationset/evaluation_set_test.go | 5 +- .../mocks/mock_evaluationsetservice_client.go | 363 ++ .../convertor/debug_context_test.go | 4 + .../application/convertor/openapi_test.go | 5 +- .../prompt/application/convertor/prompt.go | 59 + .../application/convertor/prompt_test.go | 127 + backend/modules/prompt/application/debug.go | 5 + .../modules/prompt/application/debug_test.go | 43 + backend/modules/prompt/application/execute.go | 6 + .../prompt/application/execute_test.go | 33 +- backend/modules/prompt/application/manage.go | 169 +- .../modules/prompt/application/manage_test.go | 3740 +++++++++-------- backend/modules/prompt/application/openapi.go | 16 +- .../prompt/application/openapi_test.go | 124 + backend/modules/prompt/application/wire.go | 2 + .../modules/prompt/application/wire_gen.go | 26 +- .../prompt/domain/entity/prompt_basic.go | 8 + .../prompt/domain/entity/prompt_detail.go | 11 +- backend/modules/prompt/domain/repo/manage.go | 32 +- .../prompt/domain/repo/mocks/manage_repo.go | 30 + .../modules/prompt/domain/service/manage.go | 293 ++ .../prompt/domain/service/manage_test.go | 1221 ++++++ .../domain/service/mocks/prompt_service.go | 59 + .../modules/prompt/domain/service/service.go | 22 + .../prompt/domain/service/service_test.go | 2 + .../prompt/domain/service/snippet_parser.go | 78 + .../domain/service/snippet_parser_test.go | 108 + backend/modules/prompt/infra/repo/manage.go | 336 ++ .../modules/prompt/infra/repo/manage_test.go | 1256 +++++- .../mysql/convertor/debug_context_test.go | 6 +- .../infra/repo/mysql/convertor/manage.go | 29 + .../infra/repo/mysql/convertor/manage_test.go | 5 + .../mysql/gorm_gen/model/prompt_basic.gen.go | 25 +- .../mysql/gorm_gen/model/prompt_commit.gen.go | 1 + .../gorm_gen/model/prompt_relation.gen.go | 29 + .../gorm_gen/model/prompt_user_draft.gen.go | 1 + .../infra/repo/mysql/gorm_gen/query/gen.go | 6 + .../mysql/gorm_gen/query/prompt_basic.gen.go | 6 +- .../mysql/gorm_gen/query/prompt_commit.gen.go | 6 +- .../gorm_gen/query/prompt_relation.gen.go | 364 ++ .../gorm_gen/query/prompt_user_draft.gen.go | 6 +- .../repo/mysql/mocks/prompt_commit_dao.go | 20 + .../repo/mysql/mocks/prompt_relation_dao.go | 140 + .../prompt/infra/repo/mysql/prompt_basic.go | 8 + .../prompt/infra/repo/mysql/prompt_commit.go | 30 + .../infra/repo/mysql/prompt_relation.go | 174 + .../infra/repo/mysql/prompt_user_draft.go | 1 + backend/script/gorm_gen/generate.go | 2 +- .../prompt/coze.loop.prompt.manage.thrift | 23 + .../coze/loop/prompt/domain/prompt.thrift | 15 + .../mysql-init/init-sql/prompt_basic.sql | 4 +- .../mysql-init/init-sql/prompt_commit.sql | 1 + .../mysql-init/init-sql/prompt_relation.sql | 15 + .../mysql-init/init-sql/prompt_user_draft.sql | 1 + .../patch-sql/prompt_basic_alter.sql | 2 + .../patch-sql/prompt_commit_alter.sql | 1 + .../patch-sql/prompt_user_draft_alter.sql | 1 + .../init/mysql/init-sql/prompt_basic.sql | 10 +- .../mysql/init-sql/prompt_basic_alter.sql | 2 + .../init/mysql/init-sql/prompt_commit.sql | 7 +- .../mysql/init-sql/prompt_commit_alter.sql | 1 + .../init/mysql/init-sql/prompt_relation.sql | 15 + .../init/mysql/init-sql/prompt_user_draft.sql | 7 +- .../init-sql/prompt_user_draft_alter.sql | 1 + 127 files changed, 13778 insertions(+), 3570 deletions(-) create mode 100644 backend/infra/middleware/session/mocks/session_service.go create mode 100644 backend/modules/observability/domain/trace/entity/collector/confmap/mocks/provider.go create mode 100644 backend/modules/observability/domain/trace/entity/collector/mocks/conf_provider.go create mode 100644 backend/modules/observability/infra/rpc/evaluationset/mocks/mock_evaluationsetservice_client.go create mode 100644 backend/modules/prompt/domain/service/snippet_parser.go create mode 100755 backend/modules/prompt/domain/service/snippet_parser_test.go create mode 100644 backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_relation.gen.go create mode 100644 backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_relation.gen.go create mode 100644 backend/modules/prompt/infra/repo/mysql/mocks/prompt_relation_dao.go create mode 100644 backend/modules/prompt/infra/repo/mysql/prompt_relation.go create mode 100644 release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_relation.sql create mode 100644 release/deployment/docker-compose/bootstrap/mysql-init/patch-sql/prompt_basic_alter.sql create mode 100644 release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_basic_alter.sql create mode 100644 release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_relation.sql diff --git a/backend/api/handler/coze/loop/apis/prompt_manage_service.go b/backend/api/handler/coze/loop/apis/prompt_manage_service.go index 50fe9fb89..388b0d4a8 100644 --- a/backend/api/handler/coze/loop/apis/prompt_manage_service.go +++ b/backend/api/handler/coze/loop/apis/prompt_manage_service.go @@ -98,3 +98,9 @@ func UpdateCommitLabels(ctx context.Context, c *app.RequestContext) { func BatchGetLabel(ctx context.Context, c *app.RequestContext) { invokeAndRender(ctx, c, promptManageSvc.BatchGetLabel) } + +// ListParentPrompt . +// @router /api/prompt/v1/prompts/list_parent [POST] +func ListParentPrompt(ctx context.Context, c *app.RequestContext) { + invokeAndRender(ctx, c, promptManageSvc.ListParentPrompt) +} diff --git a/backend/api/router/coze/loop/apis/coze.loop.apis.go b/backend/api/router/coze/loop/apis/coze.loop.apis.go index baf132698..6c81cfa7d 100644 --- a/backend/api/router/coze/loop/apis/coze.loop.apis.go +++ b/backend/api/router/coze/loop/apis/coze.loop.apis.go @@ -341,6 +341,7 @@ func Register(r *server.Hertz, handler *apis.APIHandler) { _v15.POST("/prompts", append(_promptsMw(handler), apis.CreatePrompt)...) _prompts := _v15.Group("/prompts", _promptsMw(handler)...) _prompts.POST("/list", append(_listpromptMw(handler), apis.ListPrompt)...) + _prompts.POST("/list_parent", append(_listparentpromptMw(handler), apis.ListParentPrompt)...) _prompts.DELETE("/:prompt_id", append(_prompt_idMw(handler), apis.DeletePrompt)...) _prompt_id := _prompts.Group("/:prompt_id", _prompt_idMw(handler)...) _prompt_id.POST("/debug_streaming", append(_debugstreamingMw(handler), apis.DebugStreaming)...) diff --git a/backend/api/router/coze/loop/apis/middleware.go b/backend/api/router/coze/loop/apis/middleware.go index e9d9a1605..978688b63 100644 --- a/backend/api/router/coze/loop/apis/middleware.go +++ b/backend/api/router/coze/loop/apis/middleware.go @@ -1650,3 +1650,8 @@ func _listpromptbasicMw(handler *apis.APIHandler) []app.HandlerFunc { // your code... return nil } + +func _listparentpromptMw(handler *apis.APIHandler) []app.HandlerFunc { + // your code... + return nil +} diff --git a/backend/go.mod b/backend/go.mod index a8ccc2e13..b335f5970 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -37,7 +37,7 @@ require ( github.com/cloudwego/kitex v0.13.1 github.com/coocood/freecache v1.2.4 github.com/coreos/go-semver v0.3.0 - github.com/coze-dev/cozeloop-go v0.1.10-0.20250901062520-61d3699b1e83 + github.com/coze-dev/cozeloop-go v0.1.14 github.com/coze-dev/cozeloop-go/spec v0.1.4-0.20250829072213-3812ddbfb735 github.com/deatil/go-encoding v1.0.3003 github.com/dimchansky/utfbom v1.1.1 diff --git a/backend/go.sum b/backend/go.sum index 53bf8e523..755515be3 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -254,6 +254,8 @@ github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3Ee github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coze-dev/cozeloop-go v0.1.10-0.20250901062520-61d3699b1e83 h1:7Jh4flr9XqvissJtafWhTcs1vcErUcsjNkkniH/szxY= github.com/coze-dev/cozeloop-go v0.1.10-0.20250901062520-61d3699b1e83/go.mod h1:RMH0F6ZMwZm4ZL92IHLjTf4lmr8QHxYJVPCdz60ZbbI= +github.com/coze-dev/cozeloop-go v0.1.14 h1:Pu6P+G72czlGn9e86aSXpXuRqJvu388fWXN8J/heVOc= +github.com/coze-dev/cozeloop-go v0.1.14/go.mod h1:lM7cmUEZlnAlQYdwfk4Li0SC3RdZ++QMHX75nvKceSc= github.com/coze-dev/cozeloop-go/spec v0.1.4-0.20250829072213-3812ddbfb735 h1:qxAwjHy0SLQazDO3oGJ8D24vOeM2Oz2+n27bNPegBls= github.com/coze-dev/cozeloop-go/spec v0.1.4-0.20250829072213-3812ddbfb735/go.mod h1:/f3BrWehffwXIpd4b5rYIqktLd/v5dlLBw0h9F/LQIU= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= diff --git a/backend/infra/limiter/mocks/rate_limiter.go b/backend/infra/limiter/mocks/rate_limiter.go index a36427311..81e2fb50c 100644 --- a/backend/infra/limiter/mocks/rate_limiter.go +++ b/backend/infra/limiter/mocks/rate_limiter.go @@ -21,6 +21,7 @@ import ( type MockIRateLimiter struct { ctrl *gomock.Controller recorder *MockIRateLimiterMockRecorder + isgomock struct{} } // MockIRateLimiterMockRecorder is the mock recorder for MockIRateLimiter. @@ -41,10 +42,10 @@ func (m *MockIRateLimiter) EXPECT() *MockIRateLimiterMockRecorder { } // AllowN mocks base method. -func (m *MockIRateLimiter) AllowN(arg0 context.Context, arg1 string, arg2 int, arg3 ...limiter.LimitOptionFn) (*limiter.Result, error) { +func (m *MockIRateLimiter) AllowN(ctx context.Context, key string, n int, opts ...limiter.LimitOptionFn) (*limiter.Result, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, key, n} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "AllowN", varargs...) @@ -54,8 +55,8 @@ func (m *MockIRateLimiter) AllowN(arg0 context.Context, arg1 string, arg2 int, a } // AllowN indicates an expected call of AllowN. -func (mr *MockIRateLimiterMockRecorder) AllowN(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { +func (mr *MockIRateLimiterMockRecorder) AllowN(ctx, key, n any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, key, n}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowN", reflect.TypeOf((*MockIRateLimiter)(nil).AllowN), varargs...) } diff --git a/backend/infra/limiter/mocks/rate_limiter_factory.go b/backend/infra/limiter/mocks/rate_limiter_factory.go index 3a6c0154a..57612b48f 100644 --- a/backend/infra/limiter/mocks/rate_limiter_factory.go +++ b/backend/infra/limiter/mocks/rate_limiter_factory.go @@ -20,6 +20,7 @@ import ( type MockIRateLimiterFactory struct { ctrl *gomock.Controller recorder *MockIRateLimiterFactoryMockRecorder + isgomock struct{} } // MockIRateLimiterFactoryMockRecorder is the mock recorder for MockIRateLimiterFactory. @@ -40,10 +41,10 @@ func (m *MockIRateLimiterFactory) EXPECT() *MockIRateLimiterFactoryMockRecorder } // NewRateLimiter mocks base method. -func (m *MockIRateLimiterFactory) NewRateLimiter(arg0 ...limiter.FactoryOptionFn) limiter.IRateLimiter { +func (m *MockIRateLimiterFactory) NewRateLimiter(opts ...limiter.FactoryOptionFn) limiter.IRateLimiter { m.ctrl.T.Helper() varargs := []any{} - for _, a := range arg0 { + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "NewRateLimiter", varargs...) @@ -52,7 +53,7 @@ func (m *MockIRateLimiterFactory) NewRateLimiter(arg0 ...limiter.FactoryOptionFn } // NewRateLimiter indicates an expected call of NewRateLimiter. -func (mr *MockIRateLimiterFactoryMockRecorder) NewRateLimiter(arg0 ...any) *gomock.Call { +func (mr *MockIRateLimiterFactoryMockRecorder) NewRateLimiter(opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewRateLimiter", reflect.TypeOf((*MockIRateLimiterFactory)(nil).NewRateLimiter), arg0...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewRateLimiter", reflect.TypeOf((*MockIRateLimiterFactory)(nil).NewRateLimiter), opts...) } diff --git a/backend/infra/looptracer/mocks/cozeloop_client.go b/backend/infra/looptracer/mocks/cozeloop_client.go index d90001aa9..54c1437d9 100644 --- a/backend/infra/looptracer/mocks/cozeloop_client.go +++ b/backend/infra/looptracer/mocks/cozeloop_client.go @@ -54,6 +54,46 @@ func (mr *MockClientMockRecorder) Close(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockClient)(nil).Close), ctx) } +// Execute mocks base method. +func (m *MockClient) Execute(ctx context.Context, param *entity.ExecuteParam, options ...cozeloop.ExecuteOption) (entity.ExecuteResult, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, param} + for _, a := range options { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Execute", varargs...) + ret0, _ := ret[0].(entity.ExecuteResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Execute indicates an expected call of Execute. +func (mr *MockClientMockRecorder) Execute(ctx, param any, options ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, param}, options...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Execute", reflect.TypeOf((*MockClient)(nil).Execute), varargs...) +} + +// ExecuteStreaming mocks base method. +func (m *MockClient) ExecuteStreaming(ctx context.Context, param *entity.ExecuteParam, options ...cozeloop.ExecuteStreamingOption) (entity.StreamReader[entity.ExecuteResult], error) { + m.ctrl.T.Helper() + varargs := []any{ctx, param} + for _, a := range options { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ExecuteStreaming", varargs...) + ret0, _ := ret[0].(entity.StreamReader[entity.ExecuteResult]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ExecuteStreaming indicates an expected call of ExecuteStreaming. +func (mr *MockClientMockRecorder) ExecuteStreaming(ctx, param any, options ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, param}, options...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecuteStreaming", reflect.TypeOf((*MockClient)(nil).ExecuteStreaming), varargs...) +} + // Flush mocks base method. func (m *MockClient) Flush(ctx context.Context) { m.ctrl.T.Helper() diff --git a/backend/infra/middleware/session/mocks/session_service.go b/backend/infra/middleware/session/mocks/session_service.go new file mode 100644 index 000000000..beede4f94 --- /dev/null +++ b/backend/infra/middleware/session/mocks/session_service.go @@ -0,0 +1,72 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/coze-dev/coze-loop/backend/infra/middleware/session (interfaces: ISessionService) +// +// Generated by this command: +// +// mockgen -destination=mocks/session_service.go -package=mock_session . ISessionService +// + +// Package mock_session is a generated GoMock package. +package mock_session + +import ( + context "context" + reflect "reflect" + + session "github.com/coze-dev/coze-loop/backend/infra/middleware/session" + gomock "go.uber.org/mock/gomock" +) + +// MockISessionService is a mock of ISessionService interface. +type MockISessionService struct { + ctrl *gomock.Controller + recorder *MockISessionServiceMockRecorder + isgomock struct{} +} + +// MockISessionServiceMockRecorder is the mock recorder for MockISessionService. +type MockISessionServiceMockRecorder struct { + mock *MockISessionService +} + +// NewMockISessionService creates a new mock instance. +func NewMockISessionService(ctrl *gomock.Controller) *MockISessionService { + mock := &MockISessionService{ctrl: ctrl} + mock.recorder = &MockISessionServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockISessionService) EXPECT() *MockISessionServiceMockRecorder { + return m.recorder +} + +// GenerateSessionKey mocks base method. +func (m *MockISessionService) GenerateSessionKey(ctx context.Context, arg1 *session.Session) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GenerateSessionKey", ctx, arg1) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GenerateSessionKey indicates an expected call of GenerateSessionKey. +func (mr *MockISessionServiceMockRecorder) GenerateSessionKey(ctx, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateSessionKey", reflect.TypeOf((*MockISessionService)(nil).GenerateSessionKey), ctx, arg1) +} + +// ValidateSession mocks base method. +func (m *MockISessionService) ValidateSession(ctx context.Context, sessionID string) (*session.Session, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateSession", ctx, sessionID) + ret0, _ := ret[0].(*session.Session) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ValidateSession indicates an expected call of ValidateSession. +func (mr *MockISessionServiceMockRecorder) ValidateSession(ctx, sessionID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateSession", reflect.TypeOf((*MockISessionService)(nil).ValidateSession), ctx, sessionID) +} diff --git a/backend/infra/platestwrite/latest_write_tracker.go b/backend/infra/platestwrite/latest_write_tracker.go index dd4707b02..814848156 100644 --- a/backend/infra/platestwrite/latest_write_tracker.go +++ b/backend/infra/platestwrite/latest_write_tracker.go @@ -130,6 +130,7 @@ const ( ResourceTypePromptLabel ResourceType = "prompt_label" ResourceTypePromptCommitLabelMapping ResourceType = "prompt_commit_label_mapping" ResourceTypeCozeloopOptimizeTask ResourceType = "cozeloop_optimize_task" // 外场智能优化 + ResourceTypePromptRelation ResourceType = "prompt_relation" ResourceTypeExperiment ResourceType = "experiment" ResourceTypeEvalSet ResourceType = "eval_set" diff --git a/backend/infra/platestwrite/mocks/latest_write_tracker.go b/backend/infra/platestwrite/mocks/latest_write_tracker.go index 382a0f98d..adbdaf1b5 100644 --- a/backend/infra/platestwrite/mocks/latest_write_tracker.go +++ b/backend/infra/platestwrite/mocks/latest_write_tracker.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/infra/platestwrite (interfaces: ILatestWriteTracker) +// +// Generated by this command: +// +// mockgen -destination ./mocks/latest_write_tracker.go --package mocks . ILatestWriteTracker +// // Package mocks is a generated GoMock package. package mocks @@ -8,15 +13,15 @@ import ( context "context" reflect "reflect" - gomock "go.uber.org/mock/gomock" - platestwrite "github.com/coze-dev/coze-loop/backend/infra/platestwrite" + gomock "go.uber.org/mock/gomock" ) // MockILatestWriteTracker is a mock of ILatestWriteTracker interface. type MockILatestWriteTracker struct { ctrl *gomock.Controller recorder *MockILatestWriteTrackerMockRecorder + isgomock struct{} } // MockILatestWriteTrackerMockRecorder is the mock recorder for MockILatestWriteTracker. @@ -37,46 +42,46 @@ func (m *MockILatestWriteTracker) EXPECT() *MockILatestWriteTrackerMockRecorder } // CheckWriteFlagByID mocks base method. -func (m *MockILatestWriteTracker) CheckWriteFlagByID(arg0 context.Context, arg1 platestwrite.ResourceType, arg2 int64) bool { +func (m *MockILatestWriteTracker) CheckWriteFlagByID(ctx context.Context, resourceType platestwrite.ResourceType, id int64) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CheckWriteFlagByID", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "CheckWriteFlagByID", ctx, resourceType, id) ret0, _ := ret[0].(bool) return ret0 } // CheckWriteFlagByID indicates an expected call of CheckWriteFlagByID. -func (mr *MockILatestWriteTrackerMockRecorder) CheckWriteFlagByID(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockILatestWriteTrackerMockRecorder) CheckWriteFlagByID(ctx, resourceType, id any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckWriteFlagByID", reflect.TypeOf((*MockILatestWriteTracker)(nil).CheckWriteFlagByID), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckWriteFlagByID", reflect.TypeOf((*MockILatestWriteTracker)(nil).CheckWriteFlagByID), ctx, resourceType, id) } // CheckWriteFlagBySearchParam mocks base method. -func (m *MockILatestWriteTracker) CheckWriteFlagBySearchParam(arg0 context.Context, arg1 platestwrite.ResourceType, arg2 string) bool { +func (m *MockILatestWriteTracker) CheckWriteFlagBySearchParam(ctx context.Context, resourceType platestwrite.ResourceType, searchParam string) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CheckWriteFlagBySearchParam", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "CheckWriteFlagBySearchParam", ctx, resourceType, searchParam) ret0, _ := ret[0].(bool) return ret0 } // CheckWriteFlagBySearchParam indicates an expected call of CheckWriteFlagBySearchParam. -func (mr *MockILatestWriteTrackerMockRecorder) CheckWriteFlagBySearchParam(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockILatestWriteTrackerMockRecorder) CheckWriteFlagBySearchParam(ctx, resourceType, searchParam any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckWriteFlagBySearchParam", reflect.TypeOf((*MockILatestWriteTracker)(nil).CheckWriteFlagBySearchParam), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckWriteFlagBySearchParam", reflect.TypeOf((*MockILatestWriteTracker)(nil).CheckWriteFlagBySearchParam), ctx, resourceType, searchParam) } // SetWriteFlag mocks base method. -func (m *MockILatestWriteTracker) SetWriteFlag(arg0 context.Context, arg1 platestwrite.ResourceType, arg2 int64, arg3 ...platestwrite.SetWriteFlagOpt) { +func (m *MockILatestWriteTracker) SetWriteFlag(ctx context.Context, resourceType platestwrite.ResourceType, resourceID int64, opts ...platestwrite.SetWriteFlagOpt) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, resourceType, resourceID} + for _, a := range opts { varargs = append(varargs, a) } m.ctrl.Call(m, "SetWriteFlag", varargs...) } // SetWriteFlag indicates an expected call of SetWriteFlag. -func (mr *MockILatestWriteTrackerMockRecorder) SetWriteFlag(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockILatestWriteTrackerMockRecorder) SetWriteFlag(ctx, resourceType, resourceID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, resourceType, resourceID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteFlag", reflect.TypeOf((*MockILatestWriteTracker)(nil).SetWriteFlag), varargs...) } diff --git a/backend/kitex_gen/coze/loop/apis/promptmanageservice/client.go b/backend/kitex_gen/coze/loop/apis/promptmanageservice/client.go index bc2ea53f6..3e7b1c158 100644 --- a/backend/kitex_gen/coze/loop/apis/promptmanageservice/client.go +++ b/backend/kitex_gen/coze/loop/apis/promptmanageservice/client.go @@ -17,6 +17,7 @@ type Client interface { GetPrompt(ctx context.Context, request *manage.GetPromptRequest, callOptions ...callopt.Option) (r *manage.GetPromptResponse, err error) BatchGetPrompt(ctx context.Context, request *manage.BatchGetPromptRequest, callOptions ...callopt.Option) (r *manage.BatchGetPromptResponse, err error) ListPrompt(ctx context.Context, request *manage.ListPromptRequest, callOptions ...callopt.Option) (r *manage.ListPromptResponse, err error) + ListParentPrompt(ctx context.Context, request *manage.ListParentPromptRequest, callOptions ...callopt.Option) (r *manage.ListParentPromptResponse, err error) UpdatePrompt(ctx context.Context, request *manage.UpdatePromptRequest, callOptions ...callopt.Option) (r *manage.UpdatePromptResponse, err error) SaveDraft(ctx context.Context, request *manage.SaveDraftRequest, callOptions ...callopt.Option) (r *manage.SaveDraftResponse, err error) CreateLabel(ctx context.Context, request *manage.CreateLabelRequest, callOptions ...callopt.Option) (r *manage.CreateLabelResponse, err error) @@ -87,6 +88,11 @@ func (p *kPromptManageServiceClient) ListPrompt(ctx context.Context, request *ma return p.kClient.ListPrompt(ctx, request) } +func (p *kPromptManageServiceClient) ListParentPrompt(ctx context.Context, request *manage.ListParentPromptRequest, callOptions ...callopt.Option) (r *manage.ListParentPromptResponse, err error) { + ctx = client.NewCtxWithCallOptions(ctx, callOptions) + return p.kClient.ListParentPrompt(ctx, request) +} + func (p *kPromptManageServiceClient) UpdatePrompt(ctx context.Context, request *manage.UpdatePromptRequest, callOptions ...callopt.Option) (r *manage.UpdatePromptResponse, err error) { ctx = client.NewCtxWithCallOptions(ctx, callOptions) return p.kClient.UpdatePrompt(ctx, request) diff --git a/backend/kitex_gen/coze/loop/apis/promptmanageservice/promptmanageservice.go b/backend/kitex_gen/coze/loop/apis/promptmanageservice/promptmanageservice.go index 45e15a627..331bb3c02 100644 --- a/backend/kitex_gen/coze/loop/apis/promptmanageservice/promptmanageservice.go +++ b/backend/kitex_gen/coze/loop/apis/promptmanageservice/promptmanageservice.go @@ -56,6 +56,13 @@ var serviceMethods = map[string]kitex.MethodInfo{ false, kitex.WithStreamingMode(kitex.StreamingNone), ), + "ListParentPrompt": kitex.NewMethodInfo( + listParentPromptHandler, + newPromptManageServiceListParentPromptArgs, + newPromptManageServiceListParentPromptResult, + false, + kitex.WithStreamingMode(kitex.StreamingNone), + ), "UpdatePrompt": kitex.NewMethodInfo( updatePromptHandler, newPromptManageServiceUpdatePromptArgs, @@ -266,6 +273,25 @@ func newPromptManageServiceListPromptResult() interface{} { return manage.NewPromptManageServiceListPromptResult() } +func listParentPromptHandler(ctx context.Context, handler interface{}, arg, result interface{}) error { + realArg := arg.(*manage.PromptManageServiceListParentPromptArgs) + realResult := result.(*manage.PromptManageServiceListParentPromptResult) + success, err := handler.(manage.PromptManageService).ListParentPrompt(ctx, realArg.Request) + if err != nil { + return err + } + realResult.Success = success + return nil +} + +func newPromptManageServiceListParentPromptArgs() interface{} { + return manage.NewPromptManageServiceListParentPromptArgs() +} + +func newPromptManageServiceListParentPromptResult() interface{} { + return manage.NewPromptManageServiceListParentPromptResult() +} + func updatePromptHandler(ctx context.Context, handler interface{}, arg, result interface{}) error { realArg := arg.(*manage.PromptManageServiceUpdatePromptArgs) realResult := result.(*manage.PromptManageServiceUpdatePromptResult) @@ -509,6 +535,16 @@ func (p *kClient) ListPrompt(ctx context.Context, request *manage.ListPromptRequ return _result.GetSuccess(), nil } +func (p *kClient) ListParentPrompt(ctx context.Context, request *manage.ListParentPromptRequest) (r *manage.ListParentPromptResponse, err error) { + var _args manage.PromptManageServiceListParentPromptArgs + _args.Request = request + var _result manage.PromptManageServiceListParentPromptResult + if err = p.c.Call(ctx, "ListParentPrompt", &_args, &_result); err != nil { + return + } + return _result.GetSuccess(), nil +} + func (p *kClient) UpdatePrompt(ctx context.Context, request *manage.UpdatePromptRequest) (r *manage.UpdatePromptResponse, err error) { var _args manage.PromptManageServiceUpdatePromptArgs _args.Request = request diff --git a/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go b/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go index 2a1409c02..4a0cb3dba 100644 --- a/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go +++ b/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go @@ -541,6 +541,20 @@ func (p *PromptBasic) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 9: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField9(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } default: l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l @@ -671,6 +685,20 @@ func (p *PromptBasic) FastReadField8(buf []byte) (int, error) { return offset, nil } +func (p *PromptBasic) FastReadField9(buf []byte) (int, error) { + offset := 0 + + var _field *PromptType + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.PromptType = _field + return offset, nil +} + func (p *PromptBasic) FastWrite(buf []byte) int { return p.FastWriteNocopy(buf, nil) } @@ -686,6 +714,7 @@ func (p *PromptBasic) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset += p.fastWriteField3(buf[offset:], w) offset += p.fastWriteField4(buf[offset:], w) offset += p.fastWriteField5(buf[offset:], w) + offset += p.fastWriteField9(buf[offset:], w) } offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset @@ -702,6 +731,7 @@ func (p *PromptBasic) BLength() int { l += p.field6Length() l += p.field7Length() l += p.field8Length() + l += p.field9Length() } l += thrift.Binary.FieldStopLength() return l @@ -779,6 +809,15 @@ func (p *PromptBasic) fastWriteField8(buf []byte, w thrift.NocopyWriter) int { return offset } +func (p *PromptBasic) fastWriteField9(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetPromptType() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 9) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.PromptType) + } + return offset +} + func (p *PromptBasic) field1Length() int { l := 0 if p.IsSetDisplayName() { @@ -851,6 +890,15 @@ func (p *PromptBasic) field8Length() int { return l } +func (p *PromptBasic) field9Length() int { + l := 0 + if p.IsSetPromptType() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.PromptType) + } + return l +} + func (p *PromptBasic) DeepCopy(s interface{}) error { src, ok := s.(*PromptBasic) if !ok { @@ -912,6 +960,11 @@ func (p *PromptBasic) DeepCopy(s interface{}) error { p.LatestCommittedAt = &tmp } + if src.PromptType != nil { + tmp := *src.PromptType + p.PromptType = &tmp + } + return nil } @@ -2409,6 +2462,34 @@ func (p *PromptTemplate) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 4: + if fieldTypeId == thrift.BOOL { + l, err = p.FastReadField4(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 5: + if fieldTypeId == thrift.LIST { + l, err = p.FastReadField5(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } case 100: if fieldTypeId == thrift.MAP { l, err = p.FastReadField100(buf[offset:]) @@ -2505,6 +2586,45 @@ func (p *PromptTemplate) FastReadField3(buf []byte) (int, error) { return offset, nil } +func (p *PromptTemplate) FastReadField4(buf []byte) (int, error) { + offset := 0 + + var _field *bool + if v, l, err := thrift.Binary.ReadBool(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.HasSnippet = _field + return offset, nil +} + +func (p *PromptTemplate) FastReadField5(buf []byte) (int, error) { + offset := 0 + + _, size, l, err := thrift.Binary.ReadListBegin(buf[offset:]) + offset += l + if err != nil { + return offset, err + } + _field := make([]*Prompt, 0, size) + values := make([]Prompt, size) + for i := 0; i < size; i++ { + _elem := &values[i] + _elem.InitDefault() + if l, err := _elem.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + + _field = append(_field, _elem) + } + p.Snippets = _field + return offset, nil +} + func (p *PromptTemplate) FastReadField100(buf []byte) (int, error) { offset := 0 @@ -2544,9 +2664,11 @@ func (p *PromptTemplate) FastWrite(buf []byte) int { func (p *PromptTemplate) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset := 0 if p != nil { + offset += p.fastWriteField4(buf[offset:], w) offset += p.fastWriteField1(buf[offset:], w) offset += p.fastWriteField2(buf[offset:], w) offset += p.fastWriteField3(buf[offset:], w) + offset += p.fastWriteField5(buf[offset:], w) offset += p.fastWriteField100(buf[offset:], w) } offset += thrift.Binary.WriteFieldStop(buf[offset:]) @@ -2559,6 +2681,8 @@ func (p *PromptTemplate) BLength() int { l += p.field1Length() l += p.field2Length() l += p.field3Length() + l += p.field4Length() + l += p.field5Length() l += p.field100Length() } l += thrift.Binary.FieldStopLength() @@ -2606,6 +2730,31 @@ func (p *PromptTemplate) fastWriteField3(buf []byte, w thrift.NocopyWriter) int return offset } +func (p *PromptTemplate) fastWriteField4(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetHasSnippet() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.BOOL, 4) + offset += thrift.Binary.WriteBool(buf[offset:], *p.HasSnippet) + } + return offset +} + +func (p *PromptTemplate) fastWriteField5(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetSnippets() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.LIST, 5) + listBeginOffset := offset + offset += thrift.Binary.ListBeginLength() + var length int + for _, v := range p.Snippets { + length++ + offset += v.FastWriteNocopy(buf[offset:], w) + } + thrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRUCT, length) + } + return offset +} + func (p *PromptTemplate) fastWriteField100(buf []byte, w thrift.NocopyWriter) int { offset := 0 if p.IsSetMetadata() { @@ -2658,6 +2807,28 @@ func (p *PromptTemplate) field3Length() int { return l } +func (p *PromptTemplate) field4Length() int { + l := 0 + if p.IsSetHasSnippet() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.BoolLength() + } + return l +} + +func (p *PromptTemplate) field5Length() int { + l := 0 + if p.IsSetSnippets() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.ListBeginLength() + for _, v := range p.Snippets { + _ = v + l += v.BLength() + } + } + return l +} + func (p *PromptTemplate) field100Length() int { l := 0 if p.IsSetMetadata() { @@ -2714,6 +2885,26 @@ func (p *PromptTemplate) DeepCopy(s interface{}) error { } } + if src.HasSnippet != nil { + tmp := *src.HasSnippet + p.HasSnippet = &tmp + } + + if src.Snippets != nil { + p.Snippets = make([]*Prompt, 0, len(src.Snippets)) + for _, elem := range src.Snippets { + var _elem *Prompt + if elem != nil { + _elem = &Prompt{} + if err := _elem.DeepCopy(elem); err != nil { + return err + } + } + + p.Snippets = append(p.Snippets, _elem) + } + } + if src.Metadata != nil { p.Metadata = make(map[string]string, len(src.Metadata)) for key, val := range src.Metadata { @@ -9786,3 +9977,364 @@ func (p *OverridePromptParams) DeepCopy(s interface{}) error { return nil } + +func (p *PromptCommitVersions) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.I64 { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.I64 { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 3: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField3(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 4: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField4(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 5: + if fieldTypeId == thrift.LIST { + l, err = p.FastReadField5(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptCommitVersions[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *PromptCommitVersions) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field *int64 + if v, l, err := thrift.Binary.ReadI64(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.ID = _field + return offset, nil +} + +func (p *PromptCommitVersions) FastReadField2(buf []byte) (int, error) { + offset := 0 + + var _field *int64 + if v, l, err := thrift.Binary.ReadI64(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.WorkspaceID = _field + return offset, nil +} + +func (p *PromptCommitVersions) FastReadField3(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.PromptKey = _field + return offset, nil +} + +func (p *PromptCommitVersions) FastReadField4(buf []byte) (int, error) { + offset := 0 + _field := NewPromptBasic() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.PromptBasic = _field + return offset, nil +} + +func (p *PromptCommitVersions) FastReadField5(buf []byte) (int, error) { + offset := 0 + + _, size, l, err := thrift.Binary.ReadListBegin(buf[offset:]) + offset += l + if err != nil { + return offset, err + } + _field := make([]string, 0, size) + for i := 0; i < size; i++ { + var _elem string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _elem = v + } + + _field = append(_field, _elem) + } + p.CommitVersions = _field + return offset, nil +} + +func (p *PromptCommitVersions) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *PromptCommitVersions) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) + offset += p.fastWriteField3(buf[offset:], w) + offset += p.fastWriteField4(buf[offset:], w) + offset += p.fastWriteField5(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *PromptCommitVersions) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + l += p.field2Length() + l += p.field3Length() + l += p.field4Length() + l += p.field5Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *PromptCommitVersions) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetID() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I64, 1) + offset += thrift.Binary.WriteI64(buf[offset:], *p.ID) + } + return offset +} + +func (p *PromptCommitVersions) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetWorkspaceID() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I64, 2) + offset += thrift.Binary.WriteI64(buf[offset:], *p.WorkspaceID) + } + return offset +} + +func (p *PromptCommitVersions) fastWriteField3(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetPromptKey() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 3) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.PromptKey) + } + return offset +} + +func (p *PromptCommitVersions) fastWriteField4(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetPromptBasic() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 4) + offset += p.PromptBasic.FastWriteNocopy(buf[offset:], w) + } + return offset +} + +func (p *PromptCommitVersions) fastWriteField5(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetCommitVersions() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.LIST, 5) + listBeginOffset := offset + offset += thrift.Binary.ListBeginLength() + var length int + for _, v := range p.CommitVersions { + length++ + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, v) + } + thrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRING, length) + } + return offset +} + +func (p *PromptCommitVersions) field1Length() int { + l := 0 + if p.IsSetID() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I64Length() + } + return l +} + +func (p *PromptCommitVersions) field2Length() int { + l := 0 + if p.IsSetWorkspaceID() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I64Length() + } + return l +} + +func (p *PromptCommitVersions) field3Length() int { + l := 0 + if p.IsSetPromptKey() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.PromptKey) + } + return l +} + +func (p *PromptCommitVersions) field4Length() int { + l := 0 + if p.IsSetPromptBasic() { + l += thrift.Binary.FieldBeginLength() + l += p.PromptBasic.BLength() + } + return l +} + +func (p *PromptCommitVersions) field5Length() int { + l := 0 + if p.IsSetCommitVersions() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.ListBeginLength() + for _, v := range p.CommitVersions { + _ = v + l += thrift.Binary.StringLengthNocopy(v) + } + } + return l +} + +func (p *PromptCommitVersions) DeepCopy(s interface{}) error { + src, ok := s.(*PromptCommitVersions) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.ID != nil { + tmp := *src.ID + p.ID = &tmp + } + + if src.WorkspaceID != nil { + tmp := *src.WorkspaceID + p.WorkspaceID = &tmp + } + + if src.PromptKey != nil { + var tmp string + if *src.PromptKey != "" { + tmp = kutils.StringDeepCopy(*src.PromptKey) + } + p.PromptKey = &tmp + } + + var _promptBasic *PromptBasic + if src.PromptBasic != nil { + _promptBasic = &PromptBasic{} + if err := _promptBasic.DeepCopy(src.PromptBasic); err != nil { + return err + } + } + p.PromptBasic = _promptBasic + + if src.CommitVersions != nil { + p.CommitVersions = make([]string, 0, len(src.CommitVersions)) + for _, elem := range src.CommitVersions { + var _elem string + if elem != "" { + _elem = kutils.StringDeepCopy(elem) + } + p.CommitVersions = append(p.CommitVersions, _elem) + } + } + + return nil +} diff --git a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go index 6d7c795fd..7f700a440 100644 --- a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go +++ b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go @@ -9,6 +9,10 @@ import ( ) const ( + PromptTypeNormal = "normal" + + PromptTypeSnippet = "snippet" + TemplateTypeNormal = "normal" TemplateTypeJinja2 = "jinja2" @@ -74,6 +78,8 @@ const ( ScenarioEvalTarget = "eval_target" ) +type PromptType = string + type TemplateType = string type ToolType = string @@ -631,14 +637,15 @@ func (p *Prompt) Field6DeepEqual(src *PromptCommit) bool { } type PromptBasic struct { - DisplayName *string `thrift:"display_name,1,optional" frugal:"1,optional,string" form:"display_name" json:"display_name,omitempty" query:"display_name"` - Description *string `thrift:"description,2,optional" frugal:"2,optional,string" form:"description" json:"description,omitempty" query:"description"` - LatestVersion *string `thrift:"latest_version,3,optional" frugal:"3,optional,string" form:"latest_version" json:"latest_version,omitempty" query:"latest_version"` - CreatedBy *string `thrift:"created_by,4,optional" frugal:"4,optional,string" form:"created_by" json:"created_by,omitempty" query:"created_by"` - UpdatedBy *string `thrift:"updated_by,5,optional" frugal:"5,optional,string" form:"updated_by" json:"updated_by,omitempty" query:"updated_by"` - CreatedAt *int64 `thrift:"created_at,6,optional" frugal:"6,optional,i64" json:"created_at" form:"created_at" query:"created_at"` - UpdatedAt *int64 `thrift:"updated_at,7,optional" frugal:"7,optional,i64" json:"updated_at" form:"updated_at" query:"updated_at"` - LatestCommittedAt *int64 `thrift:"latest_committed_at,8,optional" frugal:"8,optional,i64" json:"latest_committed_at" form:"latest_committed_at" query:"latest_committed_at"` + DisplayName *string `thrift:"display_name,1,optional" frugal:"1,optional,string" form:"display_name" json:"display_name,omitempty" query:"display_name"` + Description *string `thrift:"description,2,optional" frugal:"2,optional,string" form:"description" json:"description,omitempty" query:"description"` + LatestVersion *string `thrift:"latest_version,3,optional" frugal:"3,optional,string" form:"latest_version" json:"latest_version,omitempty" query:"latest_version"` + CreatedBy *string `thrift:"created_by,4,optional" frugal:"4,optional,string" form:"created_by" json:"created_by,omitempty" query:"created_by"` + UpdatedBy *string `thrift:"updated_by,5,optional" frugal:"5,optional,string" form:"updated_by" json:"updated_by,omitempty" query:"updated_by"` + CreatedAt *int64 `thrift:"created_at,6,optional" frugal:"6,optional,i64" json:"created_at" form:"created_at" query:"created_at"` + UpdatedAt *int64 `thrift:"updated_at,7,optional" frugal:"7,optional,i64" json:"updated_at" form:"updated_at" query:"updated_at"` + LatestCommittedAt *int64 `thrift:"latest_committed_at,8,optional" frugal:"8,optional,i64" json:"latest_committed_at" form:"latest_committed_at" query:"latest_committed_at"` + PromptType *PromptType `thrift:"prompt_type,9,optional" frugal:"9,optional,string" form:"prompt_type" json:"prompt_type,omitempty" query:"prompt_type"` } func NewPromptBasic() *PromptBasic { @@ -743,6 +750,18 @@ func (p *PromptBasic) GetLatestCommittedAt() (v int64) { } return *p.LatestCommittedAt } + +var PromptBasic_PromptType_DEFAULT PromptType + +func (p *PromptBasic) GetPromptType() (v PromptType) { + if p == nil { + return + } + if !p.IsSetPromptType() { + return PromptBasic_PromptType_DEFAULT + } + return *p.PromptType +} func (p *PromptBasic) SetDisplayName(val *string) { p.DisplayName = val } @@ -767,6 +786,9 @@ func (p *PromptBasic) SetUpdatedAt(val *int64) { func (p *PromptBasic) SetLatestCommittedAt(val *int64) { p.LatestCommittedAt = val } +func (p *PromptBasic) SetPromptType(val *PromptType) { + p.PromptType = val +} var fieldIDToName_PromptBasic = map[int16]string{ 1: "display_name", @@ -777,6 +799,7 @@ var fieldIDToName_PromptBasic = map[int16]string{ 6: "created_at", 7: "updated_at", 8: "latest_committed_at", + 9: "prompt_type", } func (p *PromptBasic) IsSetDisplayName() bool { @@ -811,6 +834,10 @@ func (p *PromptBasic) IsSetLatestCommittedAt() bool { return p.LatestCommittedAt != nil } +func (p *PromptBasic) IsSetPromptType() bool { + return p.PromptType != nil +} + func (p *PromptBasic) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -893,6 +920,14 @@ func (p *PromptBasic) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 9: + if fieldTypeId == thrift.STRING { + if err = p.ReadField9(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError @@ -1010,6 +1045,17 @@ func (p *PromptBasic) ReadField8(iprot thrift.TProtocol) error { p.LatestCommittedAt = _field return nil } +func (p *PromptBasic) ReadField9(iprot thrift.TProtocol) error { + + var _field *PromptType + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.PromptType = _field + return nil +} func (p *PromptBasic) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 @@ -1049,6 +1095,10 @@ func (p *PromptBasic) Write(oprot thrift.TProtocol) (err error) { fieldId = 8 goto WriteFieldError } + if err = p.writeField9(oprot); err != nil { + fieldId = 9 + goto WriteFieldError + } } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -1211,6 +1261,24 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 8 end error: ", p), err) } +func (p *PromptBasic) writeField9(oprot thrift.TProtocol) (err error) { + if p.IsSetPromptType() { + if err = oprot.WriteFieldBegin("prompt_type", thrift.STRING, 9); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.PromptType); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 9 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 9 end error: ", p), err) +} func (p *PromptBasic) String() string { if p == nil { @@ -1250,6 +1318,9 @@ func (p *PromptBasic) DeepEqual(ano *PromptBasic) bool { if !p.Field8DeepEqual(ano.LatestCommittedAt) { return false } + if !p.Field9DeepEqual(ano.PromptType) { + return false + } return true } @@ -1349,6 +1420,18 @@ func (p *PromptBasic) Field8DeepEqual(src *int64) bool { } return true } +func (p *PromptBasic) Field9DeepEqual(src *PromptType) bool { + + if p.PromptType == src { + return true + } else if p.PromptType == nil || src == nil { + return false + } + if strings.Compare(*p.PromptType, *src) != 0 { + return false + } + return true +} type PromptCommit struct { Detail *PromptDetail `thrift:"detail,1,optional" frugal:"1,optional,PromptDetail" form:"detail" json:"detail,omitempty" query:"detail"` @@ -3332,6 +3415,8 @@ type PromptTemplate struct { TemplateType *TemplateType `thrift:"template_type,1,optional" frugal:"1,optional,string" form:"template_type" json:"template_type,omitempty" query:"template_type"` Messages []*Message `thrift:"messages,2,optional" frugal:"2,optional,list" form:"messages" json:"messages,omitempty" query:"messages"` VariableDefs []*VariableDef `thrift:"variable_defs,3,optional" frugal:"3,optional,list" form:"variable_defs" json:"variable_defs,omitempty" query:"variable_defs"` + HasSnippet *bool `thrift:"has_snippet,4,optional" frugal:"4,optional,bool" form:"has_snippet" json:"has_snippet,omitempty" query:"has_snippet"` + Snippets []*Prompt `thrift:"snippets,5,optional" frugal:"5,optional,list" form:"snippets" json:"snippets,omitempty" query:"snippets"` Metadata map[string]string `thrift:"metadata,100,optional" frugal:"100,optional,map" form:"metadata" json:"metadata,omitempty" query:"metadata"` } @@ -3378,6 +3463,30 @@ func (p *PromptTemplate) GetVariableDefs() (v []*VariableDef) { return p.VariableDefs } +var PromptTemplate_HasSnippet_DEFAULT bool + +func (p *PromptTemplate) GetHasSnippet() (v bool) { + if p == nil { + return + } + if !p.IsSetHasSnippet() { + return PromptTemplate_HasSnippet_DEFAULT + } + return *p.HasSnippet +} + +var PromptTemplate_Snippets_DEFAULT []*Prompt + +func (p *PromptTemplate) GetSnippets() (v []*Prompt) { + if p == nil { + return + } + if !p.IsSetSnippets() { + return PromptTemplate_Snippets_DEFAULT + } + return p.Snippets +} + var PromptTemplate_Metadata_DEFAULT map[string]string func (p *PromptTemplate) GetMetadata() (v map[string]string) { @@ -3398,6 +3507,12 @@ func (p *PromptTemplate) SetMessages(val []*Message) { func (p *PromptTemplate) SetVariableDefs(val []*VariableDef) { p.VariableDefs = val } +func (p *PromptTemplate) SetHasSnippet(val *bool) { + p.HasSnippet = val +} +func (p *PromptTemplate) SetSnippets(val []*Prompt) { + p.Snippets = val +} func (p *PromptTemplate) SetMetadata(val map[string]string) { p.Metadata = val } @@ -3406,6 +3521,8 @@ var fieldIDToName_PromptTemplate = map[int16]string{ 1: "template_type", 2: "messages", 3: "variable_defs", + 4: "has_snippet", + 5: "snippets", 100: "metadata", } @@ -3421,6 +3538,14 @@ func (p *PromptTemplate) IsSetVariableDefs() bool { return p.VariableDefs != nil } +func (p *PromptTemplate) IsSetHasSnippet() bool { + return p.HasSnippet != nil +} + +func (p *PromptTemplate) IsSetSnippets() bool { + return p.Snippets != nil +} + func (p *PromptTemplate) IsSetMetadata() bool { return p.Metadata != nil } @@ -3467,6 +3592,22 @@ func (p *PromptTemplate) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 4: + if fieldTypeId == thrift.BOOL { + if err = p.ReadField4(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 5: + if fieldTypeId == thrift.LIST { + if err = p.ReadField5(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } case 100: if fieldTypeId == thrift.MAP { if err = p.ReadField100(iprot); err != nil { @@ -3561,6 +3702,40 @@ func (p *PromptTemplate) ReadField3(iprot thrift.TProtocol) error { p.VariableDefs = _field return nil } +func (p *PromptTemplate) ReadField4(iprot thrift.TProtocol) error { + + var _field *bool + if v, err := iprot.ReadBool(); err != nil { + return err + } else { + _field = &v + } + p.HasSnippet = _field + return nil +} +func (p *PromptTemplate) ReadField5(iprot thrift.TProtocol) error { + _, size, err := iprot.ReadListBegin() + if err != nil { + return err + } + _field := make([]*Prompt, 0, size) + values := make([]Prompt, size) + for i := 0; i < size; i++ { + _elem := &values[i] + _elem.InitDefault() + + if err := _elem.Read(iprot); err != nil { + return err + } + + _field = append(_field, _elem) + } + if err := iprot.ReadListEnd(); err != nil { + return err + } + p.Snippets = _field + return nil +} func (p *PromptTemplate) ReadField100(iprot thrift.TProtocol) error { _, _, size, err := iprot.ReadMapBegin() if err != nil { @@ -3609,6 +3784,14 @@ func (p *PromptTemplate) Write(oprot thrift.TProtocol) (err error) { fieldId = 3 goto WriteFieldError } + if err = p.writeField4(oprot); err != nil { + fieldId = 4 + goto WriteFieldError + } + if err = p.writeField5(oprot); err != nil { + fieldId = 5 + goto WriteFieldError + } if err = p.writeField100(oprot); err != nil { fieldId = 100 goto WriteFieldError @@ -3701,6 +3884,50 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) } +func (p *PromptTemplate) writeField4(oprot thrift.TProtocol) (err error) { + if p.IsSetHasSnippet() { + if err = oprot.WriteFieldBegin("has_snippet", thrift.BOOL, 4); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteBool(*p.HasSnippet); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 end error: ", p), err) +} +func (p *PromptTemplate) writeField5(oprot thrift.TProtocol) (err error) { + if p.IsSetSnippets() { + if err = oprot.WriteFieldBegin("snippets", thrift.LIST, 5); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteListBegin(thrift.STRUCT, len(p.Snippets)); err != nil { + return err + } + for _, v := range p.Snippets { + if err := v.Write(oprot); err != nil { + return err + } + } + if err := oprot.WriteListEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 5 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 5 end error: ", p), err) +} func (p *PromptTemplate) writeField100(oprot thrift.TProtocol) (err error) { if p.IsSetMetadata() { if err = oprot.WriteFieldBegin("metadata", thrift.MAP, 100); err != nil { @@ -3754,6 +3981,12 @@ func (p *PromptTemplate) DeepEqual(ano *PromptTemplate) bool { if !p.Field3DeepEqual(ano.VariableDefs) { return false } + if !p.Field4DeepEqual(ano.HasSnippet) { + return false + } + if !p.Field5DeepEqual(ano.Snippets) { + return false + } if !p.Field100DeepEqual(ano.Metadata) { return false } @@ -3798,6 +4031,31 @@ func (p *PromptTemplate) Field3DeepEqual(src []*VariableDef) bool { } return true } +func (p *PromptTemplate) Field4DeepEqual(src *bool) bool { + + if p.HasSnippet == src { + return true + } else if p.HasSnippet == nil || src == nil { + return false + } + if *p.HasSnippet != *src { + return false + } + return true +} +func (p *PromptTemplate) Field5DeepEqual(src []*Prompt) bool { + + if len(p.Snippets) != len(src) { + return false + } + for i, v := range p.Snippets { + _src := src[i] + if !v.DeepEqual(_src) { + return false + } + } + return true +} func (p *PromptTemplate) Field100DeepEqual(src map[string]string) bool { if len(p.Metadata) != len(src) { @@ -13756,3 +14014,505 @@ func (p *OverridePromptParams) Field1DeepEqual(src *ModelConfig) bool { } return true } + +type PromptCommitVersions struct { + ID *int64 `thrift:"id,1,optional" frugal:"1,optional,i64" json:"id" form:"id" query:"id"` + WorkspaceID *int64 `thrift:"workspace_id,2,optional" frugal:"2,optional,i64" json:"workspace_id" form:"workspace_id" query:"workspace_id"` + PromptKey *string `thrift:"prompt_key,3,optional" frugal:"3,optional,string" form:"prompt_key" json:"prompt_key,omitempty" query:"prompt_key"` + PromptBasic *PromptBasic `thrift:"prompt_basic,4,optional" frugal:"4,optional,PromptBasic" form:"prompt_basic" json:"prompt_basic,omitempty" query:"prompt_basic"` + CommitVersions []string `thrift:"commit_versions,5,optional" frugal:"5,optional,list" form:"commit_versions" json:"commit_versions,omitempty" query:"commit_versions"` +} + +func NewPromptCommitVersions() *PromptCommitVersions { + return &PromptCommitVersions{} +} + +func (p *PromptCommitVersions) InitDefault() { +} + +var PromptCommitVersions_ID_DEFAULT int64 + +func (p *PromptCommitVersions) GetID() (v int64) { + if p == nil { + return + } + if !p.IsSetID() { + return PromptCommitVersions_ID_DEFAULT + } + return *p.ID +} + +var PromptCommitVersions_WorkspaceID_DEFAULT int64 + +func (p *PromptCommitVersions) GetWorkspaceID() (v int64) { + if p == nil { + return + } + if !p.IsSetWorkspaceID() { + return PromptCommitVersions_WorkspaceID_DEFAULT + } + return *p.WorkspaceID +} + +var PromptCommitVersions_PromptKey_DEFAULT string + +func (p *PromptCommitVersions) GetPromptKey() (v string) { + if p == nil { + return + } + if !p.IsSetPromptKey() { + return PromptCommitVersions_PromptKey_DEFAULT + } + return *p.PromptKey +} + +var PromptCommitVersions_PromptBasic_DEFAULT *PromptBasic + +func (p *PromptCommitVersions) GetPromptBasic() (v *PromptBasic) { + if p == nil { + return + } + if !p.IsSetPromptBasic() { + return PromptCommitVersions_PromptBasic_DEFAULT + } + return p.PromptBasic +} + +var PromptCommitVersions_CommitVersions_DEFAULT []string + +func (p *PromptCommitVersions) GetCommitVersions() (v []string) { + if p == nil { + return + } + if !p.IsSetCommitVersions() { + return PromptCommitVersions_CommitVersions_DEFAULT + } + return p.CommitVersions +} +func (p *PromptCommitVersions) SetID(val *int64) { + p.ID = val +} +func (p *PromptCommitVersions) SetWorkspaceID(val *int64) { + p.WorkspaceID = val +} +func (p *PromptCommitVersions) SetPromptKey(val *string) { + p.PromptKey = val +} +func (p *PromptCommitVersions) SetPromptBasic(val *PromptBasic) { + p.PromptBasic = val +} +func (p *PromptCommitVersions) SetCommitVersions(val []string) { + p.CommitVersions = val +} + +var fieldIDToName_PromptCommitVersions = map[int16]string{ + 1: "id", + 2: "workspace_id", + 3: "prompt_key", + 4: "prompt_basic", + 5: "commit_versions", +} + +func (p *PromptCommitVersions) IsSetID() bool { + return p.ID != nil +} + +func (p *PromptCommitVersions) IsSetWorkspaceID() bool { + return p.WorkspaceID != nil +} + +func (p *PromptCommitVersions) IsSetPromptKey() bool { + return p.PromptKey != nil +} + +func (p *PromptCommitVersions) IsSetPromptBasic() bool { + return p.PromptBasic != nil +} + +func (p *PromptCommitVersions) IsSetCommitVersions() bool { + return p.CommitVersions != nil +} + +func (p *PromptCommitVersions) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.I64 { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 2: + if fieldTypeId == thrift.I64 { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 3: + if fieldTypeId == thrift.STRING { + if err = p.ReadField3(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 4: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField4(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 5: + if fieldTypeId == thrift.LIST { + if err = p.ReadField5(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptCommitVersions[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *PromptCommitVersions) ReadField1(iprot thrift.TProtocol) error { + + var _field *int64 + if v, err := iprot.ReadI64(); err != nil { + return err + } else { + _field = &v + } + p.ID = _field + return nil +} +func (p *PromptCommitVersions) ReadField2(iprot thrift.TProtocol) error { + + var _field *int64 + if v, err := iprot.ReadI64(); err != nil { + return err + } else { + _field = &v + } + p.WorkspaceID = _field + return nil +} +func (p *PromptCommitVersions) ReadField3(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.PromptKey = _field + return nil +} +func (p *PromptCommitVersions) ReadField4(iprot thrift.TProtocol) error { + _field := NewPromptBasic() + if err := _field.Read(iprot); err != nil { + return err + } + p.PromptBasic = _field + return nil +} +func (p *PromptCommitVersions) ReadField5(iprot thrift.TProtocol) error { + _, size, err := iprot.ReadListBegin() + if err != nil { + return err + } + _field := make([]string, 0, size) + for i := 0; i < size; i++ { + + var _elem string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _elem = v + } + + _field = append(_field, _elem) + } + if err := iprot.ReadListEnd(); err != nil { + return err + } + p.CommitVersions = _field + return nil +} + +func (p *PromptCommitVersions) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("PromptCommitVersions"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } + if err = p.writeField3(oprot); err != nil { + fieldId = 3 + goto WriteFieldError + } + if err = p.writeField4(oprot); err != nil { + fieldId = 4 + goto WriteFieldError + } + if err = p.writeField5(oprot); err != nil { + fieldId = 5 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *PromptCommitVersions) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetID() { + if err = oprot.WriteFieldBegin("id", thrift.I64, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI64(*p.ID); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} +func (p *PromptCommitVersions) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetWorkspaceID() { + if err = oprot.WriteFieldBegin("workspace_id", thrift.I64, 2); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI64(*p.WorkspaceID); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} +func (p *PromptCommitVersions) writeField3(oprot thrift.TProtocol) (err error) { + if p.IsSetPromptKey() { + if err = oprot.WriteFieldBegin("prompt_key", thrift.STRING, 3); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.PromptKey); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) +} +func (p *PromptCommitVersions) writeField4(oprot thrift.TProtocol) (err error) { + if p.IsSetPromptBasic() { + if err = oprot.WriteFieldBegin("prompt_basic", thrift.STRUCT, 4); err != nil { + goto WriteFieldBeginError + } + if err := p.PromptBasic.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 end error: ", p), err) +} +func (p *PromptCommitVersions) writeField5(oprot thrift.TProtocol) (err error) { + if p.IsSetCommitVersions() { + if err = oprot.WriteFieldBegin("commit_versions", thrift.LIST, 5); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteListBegin(thrift.STRING, len(p.CommitVersions)); err != nil { + return err + } + for _, v := range p.CommitVersions { + if err := oprot.WriteString(v); err != nil { + return err + } + } + if err := oprot.WriteListEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 5 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 5 end error: ", p), err) +} + +func (p *PromptCommitVersions) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("PromptCommitVersions(%+v)", *p) + +} + +func (p *PromptCommitVersions) DeepEqual(ano *PromptCommitVersions) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.ID) { + return false + } + if !p.Field2DeepEqual(ano.WorkspaceID) { + return false + } + if !p.Field3DeepEqual(ano.PromptKey) { + return false + } + if !p.Field4DeepEqual(ano.PromptBasic) { + return false + } + if !p.Field5DeepEqual(ano.CommitVersions) { + return false + } + return true +} + +func (p *PromptCommitVersions) Field1DeepEqual(src *int64) bool { + + if p.ID == src { + return true + } else if p.ID == nil || src == nil { + return false + } + if *p.ID != *src { + return false + } + return true +} +func (p *PromptCommitVersions) Field2DeepEqual(src *int64) bool { + + if p.WorkspaceID == src { + return true + } else if p.WorkspaceID == nil || src == nil { + return false + } + if *p.WorkspaceID != *src { + return false + } + return true +} +func (p *PromptCommitVersions) Field3DeepEqual(src *string) bool { + + if p.PromptKey == src { + return true + } else if p.PromptKey == nil || src == nil { + return false + } + if strings.Compare(*p.PromptKey, *src) != 0 { + return false + } + return true +} +func (p *PromptCommitVersions) Field4DeepEqual(src *PromptBasic) bool { + + if !p.PromptBasic.DeepEqual(src) { + return false + } + return true +} +func (p *PromptCommitVersions) Field5DeepEqual(src []string) bool { + + if len(p.CommitVersions) != len(src) { + return false + } + for i, v := range p.CommitVersions { + _src := src[i] + if strings.Compare(v, _src) != 0 { + return false + } + } + return true +} diff --git a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt_validator.go b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt_validator.go index 2cc88b636..855872776 100644 --- a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt_validator.go +++ b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt_validator.go @@ -246,3 +246,11 @@ func (p *OverridePromptParams) IsValid() error { } return nil } +func (p *PromptCommitVersions) IsValid() error { + if p.PromptBasic != nil { + if err := p.PromptBasic.IsValid(); err != nil { + return fmt.Errorf("field PromptBasic not valid, %w", err) + } + } + return nil +} diff --git a/backend/kitex_gen/coze/loop/prompt/manage/coze.loop.prompt.manage.go b/backend/kitex_gen/coze/loop/prompt/manage/coze.loop.prompt.manage.go index 7faa7d53a..7aedcaeb4 100644 --- a/backend/kitex_gen/coze/loop/prompt/manage/coze.loop.prompt.manage.go +++ b/backend/kitex_gen/coze/loop/prompt/manage/coze.loop.prompt.manage.go @@ -26,6 +26,7 @@ type CreatePromptRequest struct { PromptName *string `thrift:"prompt_name,11,optional" frugal:"11,optional,string" form:"prompt_name" json:"prompt_name,omitempty" query:"prompt_name"` PromptKey *string `thrift:"prompt_key,12,optional" frugal:"12,optional,string" form:"prompt_key" json:"prompt_key,omitempty" query:"prompt_key"` PromptDescription *string `thrift:"prompt_description,13,optional" frugal:"13,optional,string" form:"prompt_description" json:"prompt_description,omitempty" query:"prompt_description"` + PromptType *prompt.PromptType `thrift:"prompt_type,14,optional" frugal:"14,optional,string" form:"prompt_type" json:"prompt_type,omitempty" query:"prompt_type"` DraftDetail *prompt.PromptDetail `thrift:"draft_detail,21,optional" frugal:"21,optional,prompt.PromptDetail" form:"draft_detail" json:"draft_detail,omitempty" query:"draft_detail"` Base *base.Base `thrift:"Base,255,optional" frugal:"255,optional,base.Base" form:"Base" json:"Base,omitempty" query:"Base"` } @@ -85,6 +86,18 @@ func (p *CreatePromptRequest) GetPromptDescription() (v string) { return *p.PromptDescription } +var CreatePromptRequest_PromptType_DEFAULT prompt.PromptType + +func (p *CreatePromptRequest) GetPromptType() (v prompt.PromptType) { + if p == nil { + return + } + if !p.IsSetPromptType() { + return CreatePromptRequest_PromptType_DEFAULT + } + return *p.PromptType +} + var CreatePromptRequest_DraftDetail_DEFAULT *prompt.PromptDetail func (p *CreatePromptRequest) GetDraftDetail() (v *prompt.PromptDetail) { @@ -120,6 +133,9 @@ func (p *CreatePromptRequest) SetPromptKey(val *string) { func (p *CreatePromptRequest) SetPromptDescription(val *string) { p.PromptDescription = val } +func (p *CreatePromptRequest) SetPromptType(val *prompt.PromptType) { + p.PromptType = val +} func (p *CreatePromptRequest) SetDraftDetail(val *prompt.PromptDetail) { p.DraftDetail = val } @@ -132,6 +148,7 @@ var fieldIDToName_CreatePromptRequest = map[int16]string{ 11: "prompt_name", 12: "prompt_key", 13: "prompt_description", + 14: "prompt_type", 21: "draft_detail", 255: "Base", } @@ -152,6 +169,10 @@ func (p *CreatePromptRequest) IsSetPromptDescription() bool { return p.PromptDescription != nil } +func (p *CreatePromptRequest) IsSetPromptType() bool { + return p.PromptType != nil +} + func (p *CreatePromptRequest) IsSetDraftDetail() bool { return p.DraftDetail != nil } @@ -210,6 +231,14 @@ func (p *CreatePromptRequest) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 14: + if fieldTypeId == thrift.STRING { + if err = p.ReadField14(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } case 21: if fieldTypeId == thrift.STRUCT { if err = p.ReadField21(iprot); err != nil { @@ -299,6 +328,17 @@ func (p *CreatePromptRequest) ReadField13(iprot thrift.TProtocol) error { p.PromptDescription = _field return nil } +func (p *CreatePromptRequest) ReadField14(iprot thrift.TProtocol) error { + + var _field *prompt.PromptType + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.PromptType = _field + return nil +} func (p *CreatePromptRequest) ReadField21(iprot thrift.TProtocol) error { _field := prompt.NewPromptDetail() if err := _field.Read(iprot); err != nil { @@ -338,6 +378,10 @@ func (p *CreatePromptRequest) Write(oprot thrift.TProtocol) (err error) { fieldId = 13 goto WriteFieldError } + if err = p.writeField14(oprot); err != nil { + fieldId = 14 + goto WriteFieldError + } if err = p.writeField21(oprot); err != nil { fieldId = 21 goto WriteFieldError @@ -436,6 +480,24 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 13 end error: ", p), err) } +func (p *CreatePromptRequest) writeField14(oprot thrift.TProtocol) (err error) { + if p.IsSetPromptType() { + if err = oprot.WriteFieldBegin("prompt_type", thrift.STRING, 14); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.PromptType); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 14 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 14 end error: ", p), err) +} func (p *CreatePromptRequest) writeField21(oprot thrift.TProtocol) (err error) { if p.IsSetDraftDetail() { if err = oprot.WriteFieldBegin("draft_detail", thrift.STRUCT, 21); err != nil { @@ -499,6 +561,9 @@ func (p *CreatePromptRequest) DeepEqual(ano *CreatePromptRequest) bool { if !p.Field13DeepEqual(ano.PromptDescription) { return false } + if !p.Field14DeepEqual(ano.PromptType) { + return false + } if !p.Field21DeepEqual(ano.DraftDetail) { return false } @@ -556,6 +621,18 @@ func (p *CreatePromptRequest) Field13DeepEqual(src *string) bool { } return true } +func (p *CreatePromptRequest) Field14DeepEqual(src *prompt.PromptType) bool { + + if p.PromptType == src { + return true + } else if p.PromptType == nil || src == nil { + return false + } + if strings.Compare(*p.PromptType, *src) != 0 { + return false + } + return true +} func (p *CreatePromptRequest) Field21DeepEqual(src *prompt.PromptDetail) bool { if !p.DraftDetail.DeepEqual(src) { @@ -2053,13 +2130,15 @@ func (p *DeletePromptResponse) Field255DeepEqual(src *base.BaseResp) bool { } type GetPromptRequest struct { - PromptID *int64 `thrift:"prompt_id,1,optional" frugal:"1,optional,i64" json:"prompt_id" path:"prompt_id" ` - WorkspaceID *int64 `thrift:"workspace_id,2,optional" frugal:"2,optional,i64" json:"workspace_id" query:"workspace_id" ` - WithCommit *bool `thrift:"with_commit,11,optional" frugal:"11,optional,bool" json:"with_commit,omitempty" query:"with_commit"` - CommitVersion *string `thrift:"commit_version,12,optional" frugal:"12,optional,string" json:"commit_version,omitempty" query:"commit_version"` - WithDraft *bool `thrift:"with_draft,21,optional" frugal:"21,optional,bool" json:"with_draft,omitempty" query:"with_draft"` - WithDefaultConfig *bool `thrift:"with_default_config,31,optional" frugal:"31,optional,bool" json:"with_default_config,omitempty" query:"with_default_config"` - Base *base.Base `thrift:"Base,255,optional" frugal:"255,optional,base.Base" form:"Base" json:"Base,omitempty" query:"Base"` + PromptID *int64 `thrift:"prompt_id,1,optional" frugal:"1,optional,i64" json:"prompt_id" path:"prompt_id" ` + WorkspaceID *int64 `thrift:"workspace_id,2,optional" frugal:"2,optional,i64" json:"workspace_id" query:"workspace_id" ` + WithCommit *bool `thrift:"with_commit,11,optional" frugal:"11,optional,bool" json:"with_commit,omitempty" query:"with_commit"` + CommitVersion *string `thrift:"commit_version,12,optional" frugal:"12,optional,string" json:"commit_version,omitempty" query:"commit_version"` + WithDraft *bool `thrift:"with_draft,21,optional" frugal:"21,optional,bool" json:"with_draft,omitempty" query:"with_draft"` + WithDefaultConfig *bool `thrift:"with_default_config,31,optional" frugal:"31,optional,bool" json:"with_default_config,omitempty" query:"with_default_config"` + // 是否展开子片段 + ExpandSnippet *bool `thrift:"expand_snippet,32,optional" frugal:"32,optional,bool" json:"expand_snippet,omitempty" query:"expand_snippet"` + Base *base.Base `thrift:"Base,255,optional" frugal:"255,optional,base.Base" form:"Base" json:"Base,omitempty" query:"Base"` } func NewGetPromptRequest() *GetPromptRequest { @@ -2141,6 +2220,18 @@ func (p *GetPromptRequest) GetWithDefaultConfig() (v bool) { return *p.WithDefaultConfig } +var GetPromptRequest_ExpandSnippet_DEFAULT bool + +func (p *GetPromptRequest) GetExpandSnippet() (v bool) { + if p == nil { + return + } + if !p.IsSetExpandSnippet() { + return GetPromptRequest_ExpandSnippet_DEFAULT + } + return *p.ExpandSnippet +} + var GetPromptRequest_Base_DEFAULT *base.Base func (p *GetPromptRequest) GetBase() (v *base.Base) { @@ -2170,6 +2261,9 @@ func (p *GetPromptRequest) SetWithDraft(val *bool) { func (p *GetPromptRequest) SetWithDefaultConfig(val *bool) { p.WithDefaultConfig = val } +func (p *GetPromptRequest) SetExpandSnippet(val *bool) { + p.ExpandSnippet = val +} func (p *GetPromptRequest) SetBase(val *base.Base) { p.Base = val } @@ -2181,6 +2275,7 @@ var fieldIDToName_GetPromptRequest = map[int16]string{ 12: "commit_version", 21: "with_draft", 31: "with_default_config", + 32: "expand_snippet", 255: "Base", } @@ -2208,6 +2303,10 @@ func (p *GetPromptRequest) IsSetWithDefaultConfig() bool { return p.WithDefaultConfig != nil } +func (p *GetPromptRequest) IsSetExpandSnippet() bool { + return p.ExpandSnippet != nil +} + func (p *GetPromptRequest) IsSetBase() bool { return p.Base != nil } @@ -2278,6 +2377,14 @@ func (p *GetPromptRequest) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 32: + if fieldTypeId == thrift.BOOL { + if err = p.ReadField32(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } case 255: if fieldTypeId == thrift.STRUCT { if err = p.ReadField255(iprot); err != nil { @@ -2381,6 +2488,17 @@ func (p *GetPromptRequest) ReadField31(iprot thrift.TProtocol) error { p.WithDefaultConfig = _field return nil } +func (p *GetPromptRequest) ReadField32(iprot thrift.TProtocol) error { + + var _field *bool + if v, err := iprot.ReadBool(); err != nil { + return err + } else { + _field = &v + } + p.ExpandSnippet = _field + return nil +} func (p *GetPromptRequest) ReadField255(iprot thrift.TProtocol) error { _field := base.NewBase() if err := _field.Read(iprot); err != nil { @@ -2420,6 +2538,10 @@ func (p *GetPromptRequest) Write(oprot thrift.TProtocol) (err error) { fieldId = 31 goto WriteFieldError } + if err = p.writeField32(oprot); err != nil { + fieldId = 32 + goto WriteFieldError + } if err = p.writeField255(oprot); err != nil { fieldId = 255 goto WriteFieldError @@ -2550,6 +2672,24 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 31 end error: ", p), err) } +func (p *GetPromptRequest) writeField32(oprot thrift.TProtocol) (err error) { + if p.IsSetExpandSnippet() { + if err = oprot.WriteFieldBegin("expand_snippet", thrift.BOOL, 32); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteBool(*p.ExpandSnippet); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 32 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 32 end error: ", p), err) +} func (p *GetPromptRequest) writeField255(oprot thrift.TProtocol) (err error) { if p.IsSetBase() { if err = oprot.WriteFieldBegin("Base", thrift.STRUCT, 255); err != nil { @@ -2601,6 +2741,9 @@ func (p *GetPromptRequest) DeepEqual(ano *GetPromptRequest) bool { if !p.Field31DeepEqual(ano.WithDefaultConfig) { return false } + if !p.Field32DeepEqual(ano.ExpandSnippet) { + return false + } if !p.Field255DeepEqual(ano.Base) { return false } @@ -2679,6 +2822,18 @@ func (p *GetPromptRequest) Field31DeepEqual(src *bool) bool { } return true } +func (p *GetPromptRequest) Field32DeepEqual(src *bool) bool { + + if p.ExpandSnippet == src { + return true + } else if p.ExpandSnippet == nil || src == nil { + return false + } + if *p.ExpandSnippet != *src { + return false + } + return true +} func (p *GetPromptRequest) Field255DeepEqual(src *base.Base) bool { if !p.Base.DeepEqual(src) { @@ -2690,7 +2845,9 @@ func (p *GetPromptRequest) Field255DeepEqual(src *base.Base) bool { type GetPromptResponse struct { Prompt *prompt.Prompt `thrift:"prompt,1,optional" frugal:"1,optional,prompt.Prompt" form:"prompt" json:"prompt,omitempty" query:"prompt"` DefaultConfig *prompt.PromptDetail `thrift:"default_config,11,optional" frugal:"11,optional,prompt.PromptDetail" form:"default_config" json:"default_config,omitempty" query:"default_config"` - BaseResp *base.BaseResp `thrift:"BaseResp,255,optional" frugal:"255,optional,base.BaseResp" form:"BaseResp" json:"BaseResp,omitempty" query:"BaseResp"` + // [片段]被引用的总数 + TotalParentReferences *int32 `thrift:"total_parent_references,12,optional" frugal:"12,optional,i32" form:"total_parent_references" json:"total_parent_references,omitempty" query:"total_parent_references"` + BaseResp *base.BaseResp `thrift:"BaseResp,255,optional" frugal:"255,optional,base.BaseResp" form:"BaseResp" json:"BaseResp,omitempty" query:"BaseResp"` } func NewGetPromptResponse() *GetPromptResponse { @@ -2724,6 +2881,18 @@ func (p *GetPromptResponse) GetDefaultConfig() (v *prompt.PromptDetail) { return p.DefaultConfig } +var GetPromptResponse_TotalParentReferences_DEFAULT int32 + +func (p *GetPromptResponse) GetTotalParentReferences() (v int32) { + if p == nil { + return + } + if !p.IsSetTotalParentReferences() { + return GetPromptResponse_TotalParentReferences_DEFAULT + } + return *p.TotalParentReferences +} + var GetPromptResponse_BaseResp_DEFAULT *base.BaseResp func (p *GetPromptResponse) GetBaseResp() (v *base.BaseResp) { @@ -2741,6 +2910,9 @@ func (p *GetPromptResponse) SetPrompt(val *prompt.Prompt) { func (p *GetPromptResponse) SetDefaultConfig(val *prompt.PromptDetail) { p.DefaultConfig = val } +func (p *GetPromptResponse) SetTotalParentReferences(val *int32) { + p.TotalParentReferences = val +} func (p *GetPromptResponse) SetBaseResp(val *base.BaseResp) { p.BaseResp = val } @@ -2748,6 +2920,7 @@ func (p *GetPromptResponse) SetBaseResp(val *base.BaseResp) { var fieldIDToName_GetPromptResponse = map[int16]string{ 1: "prompt", 11: "default_config", + 12: "total_parent_references", 255: "BaseResp", } @@ -2759,6 +2932,10 @@ func (p *GetPromptResponse) IsSetDefaultConfig() bool { return p.DefaultConfig != nil } +func (p *GetPromptResponse) IsSetTotalParentReferences() bool { + return p.TotalParentReferences != nil +} + func (p *GetPromptResponse) IsSetBaseResp() bool { return p.BaseResp != nil } @@ -2797,6 +2974,14 @@ func (p *GetPromptResponse) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 12: + if fieldTypeId == thrift.I32 { + if err = p.ReadField12(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } case 255: if fieldTypeId == thrift.STRUCT { if err = p.ReadField255(iprot); err != nil { @@ -2850,6 +3035,17 @@ func (p *GetPromptResponse) ReadField11(iprot thrift.TProtocol) error { p.DefaultConfig = _field return nil } +func (p *GetPromptResponse) ReadField12(iprot thrift.TProtocol) error { + + var _field *int32 + if v, err := iprot.ReadI32(); err != nil { + return err + } else { + _field = &v + } + p.TotalParentReferences = _field + return nil +} func (p *GetPromptResponse) ReadField255(iprot thrift.TProtocol) error { _field := base.NewBaseResp() if err := _field.Read(iprot); err != nil { @@ -2873,6 +3069,10 @@ func (p *GetPromptResponse) Write(oprot thrift.TProtocol) (err error) { fieldId = 11 goto WriteFieldError } + if err = p.writeField12(oprot); err != nil { + fieldId = 12 + goto WriteFieldError + } if err = p.writeField255(oprot); err != nil { fieldId = 255 goto WriteFieldError @@ -2931,6 +3131,24 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 11 end error: ", p), err) } +func (p *GetPromptResponse) writeField12(oprot thrift.TProtocol) (err error) { + if p.IsSetTotalParentReferences() { + if err = oprot.WriteFieldBegin("total_parent_references", thrift.I32, 12); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI32(*p.TotalParentReferences); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 12 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 12 end error: ", p), err) +} func (p *GetPromptResponse) writeField255(oprot thrift.TProtocol) (err error) { if p.IsSetBaseResp() { if err = oprot.WriteFieldBegin("BaseResp", thrift.STRUCT, 255); err != nil { @@ -2970,6 +3188,9 @@ func (p *GetPromptResponse) DeepEqual(ano *GetPromptResponse) bool { if !p.Field11DeepEqual(ano.DefaultConfig) { return false } + if !p.Field12DeepEqual(ano.TotalParentReferences) { + return false + } if !p.Field255DeepEqual(ano.BaseResp) { return false } @@ -2990,6 +3211,18 @@ func (p *GetPromptResponse) Field11DeepEqual(src *prompt.PromptDetail) bool { } return true } +func (p *GetPromptResponse) Field12DeepEqual(src *int32) bool { + + if p.TotalParentReferences == src { + return true + } else if p.TotalParentReferences == nil || src == nil { + return false + } + if *p.TotalParentReferences != *src { + return false + } + return true +} func (p *GetPromptResponse) Field255DeepEqual(src *base.BaseResp) bool { if !p.BaseResp.DeepEqual(src) { @@ -4118,15 +4351,17 @@ func (p *PromptResult_) Field2DeepEqual(src *prompt.Prompt) bool { } type ListPromptRequest struct { - WorkspaceID *int64 `thrift:"workspace_id,1,optional" frugal:"1,optional,i64" json:"workspace_id" form:"workspace_id" query:"workspace_id"` - KeyWord *string `thrift:"key_word,11,optional" frugal:"11,optional,string" form:"key_word" json:"key_word,omitempty" query:"key_word"` - CreatedBys []string `thrift:"created_bys,12,optional" frugal:"12,optional,list" form:"created_bys" json:"created_bys,omitempty" query:"created_bys"` - CommittedOnly *bool `thrift:"committed_only,13,optional" frugal:"13,optional,bool" form:"committed_only" json:"committed_only,omitempty" query:"committed_only"` - PageNum *int32 `thrift:"page_num,127,optional" frugal:"127,optional,i32" form:"page_num" json:"page_num,omitempty" query:"page_num"` - PageSize *int32 `thrift:"page_size,128,optional" frugal:"128,optional,i32" form:"page_size" json:"page_size,omitempty" query:"page_size"` - OrderBy *ListPromptOrderBy `thrift:"order_by,129,optional" frugal:"129,optional,string" form:"order_by" json:"order_by,omitempty" query:"order_by"` - Asc *bool `thrift:"asc,130,optional" frugal:"130,optional,bool" form:"asc" json:"asc,omitempty" query:"asc"` - Base *base.Base `thrift:"Base,255,optional" frugal:"255,optional,base.Base" form:"Base" json:"Base,omitempty" query:"Base"` + WorkspaceID *int64 `thrift:"workspace_id,1,optional" frugal:"1,optional,i64" json:"workspace_id" form:"workspace_id" query:"workspace_id"` + KeyWord *string `thrift:"key_word,11,optional" frugal:"11,optional,string" form:"key_word" json:"key_word,omitempty" query:"key_word"` + CreatedBys []string `thrift:"created_bys,12,optional" frugal:"12,optional,list" form:"created_bys" json:"created_bys,omitempty" query:"created_bys"` + CommittedOnly *bool `thrift:"committed_only,13,optional" frugal:"13,optional,bool" form:"committed_only" json:"committed_only,omitempty" query:"committed_only"` + // 向前兼容,如果不传,默认查询normal类型的Prompt + FilterPromptTypes []prompt.PromptType `thrift:"filter_prompt_types,14,optional" frugal:"14,optional,list" form:"filter_prompt_types" json:"filter_prompt_types,omitempty" query:"filter_prompt_types"` + PageNum *int32 `thrift:"page_num,127,optional" frugal:"127,optional,i32" form:"page_num" json:"page_num,omitempty" query:"page_num"` + PageSize *int32 `thrift:"page_size,128,optional" frugal:"128,optional,i32" form:"page_size" json:"page_size,omitempty" query:"page_size"` + OrderBy *ListPromptOrderBy `thrift:"order_by,129,optional" frugal:"129,optional,string" form:"order_by" json:"order_by,omitempty" query:"order_by"` + Asc *bool `thrift:"asc,130,optional" frugal:"130,optional,bool" form:"asc" json:"asc,omitempty" query:"asc"` + Base *base.Base `thrift:"Base,255,optional" frugal:"255,optional,base.Base" form:"Base" json:"Base,omitempty" query:"Base"` } func NewListPromptRequest() *ListPromptRequest { @@ -4184,6 +4419,18 @@ func (p *ListPromptRequest) GetCommittedOnly() (v bool) { return *p.CommittedOnly } +var ListPromptRequest_FilterPromptTypes_DEFAULT []prompt.PromptType + +func (p *ListPromptRequest) GetFilterPromptTypes() (v []prompt.PromptType) { + if p == nil { + return + } + if !p.IsSetFilterPromptTypes() { + return ListPromptRequest_FilterPromptTypes_DEFAULT + } + return p.FilterPromptTypes +} + var ListPromptRequest_PageNum_DEFAULT int32 func (p *ListPromptRequest) GetPageNum() (v int32) { @@ -4255,6 +4502,9 @@ func (p *ListPromptRequest) SetCreatedBys(val []string) { func (p *ListPromptRequest) SetCommittedOnly(val *bool) { p.CommittedOnly = val } +func (p *ListPromptRequest) SetFilterPromptTypes(val []prompt.PromptType) { + p.FilterPromptTypes = val +} func (p *ListPromptRequest) SetPageNum(val *int32) { p.PageNum = val } @@ -4276,6 +4526,7 @@ var fieldIDToName_ListPromptRequest = map[int16]string{ 11: "key_word", 12: "created_bys", 13: "committed_only", + 14: "filter_prompt_types", 127: "page_num", 128: "page_size", 129: "order_by", @@ -4299,6 +4550,10 @@ func (p *ListPromptRequest) IsSetCommittedOnly() bool { return p.CommittedOnly != nil } +func (p *ListPromptRequest) IsSetFilterPromptTypes() bool { + return p.FilterPromptTypes != nil +} + func (p *ListPromptRequest) IsSetPageNum() bool { return p.PageNum != nil } @@ -4369,6 +4624,14 @@ func (p *ListPromptRequest) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 14: + if fieldTypeId == thrift.LIST { + if err = p.ReadField14(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } case 127: if fieldTypeId == thrift.I32 { if err = p.ReadField127(iprot); err != nil { @@ -4494,6 +4757,29 @@ func (p *ListPromptRequest) ReadField13(iprot thrift.TProtocol) error { p.CommittedOnly = _field return nil } +func (p *ListPromptRequest) ReadField14(iprot thrift.TProtocol) error { + _, size, err := iprot.ReadListBegin() + if err != nil { + return err + } + _field := make([]prompt.PromptType, 0, size) + for i := 0; i < size; i++ { + + var _elem prompt.PromptType + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _elem = v + } + + _field = append(_field, _elem) + } + if err := iprot.ReadListEnd(); err != nil { + return err + } + p.FilterPromptTypes = _field + return nil +} func (p *ListPromptRequest) ReadField127(iprot thrift.TProtocol) error { var _field *int32 @@ -4569,6 +4855,10 @@ func (p *ListPromptRequest) Write(oprot thrift.TProtocol) (err error) { fieldId = 13 goto WriteFieldError } + if err = p.writeField14(oprot); err != nil { + fieldId = 14 + goto WriteFieldError + } if err = p.writeField127(oprot); err != nil { fieldId = 127 goto WriteFieldError @@ -4687,6 +4977,32 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 13 end error: ", p), err) } +func (p *ListPromptRequest) writeField14(oprot thrift.TProtocol) (err error) { + if p.IsSetFilterPromptTypes() { + if err = oprot.WriteFieldBegin("filter_prompt_types", thrift.LIST, 14); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteListBegin(thrift.STRING, len(p.FilterPromptTypes)); err != nil { + return err + } + for _, v := range p.FilterPromptTypes { + if err := oprot.WriteString(v); err != nil { + return err + } + } + if err := oprot.WriteListEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 14 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 14 end error: ", p), err) +} func (p *ListPromptRequest) writeField127(oprot thrift.TProtocol) (err error) { if p.IsSetPageNum() { if err = oprot.WriteFieldBegin("page_num", thrift.I32, 127); err != nil { @@ -4804,6 +5120,9 @@ func (p *ListPromptRequest) DeepEqual(ano *ListPromptRequest) bool { if !p.Field13DeepEqual(ano.CommittedOnly) { return false } + if !p.Field14DeepEqual(ano.FilterPromptTypes) { + return false + } if !p.Field127DeepEqual(ano.PageNum) { return false } @@ -4871,6 +5190,19 @@ func (p *ListPromptRequest) Field13DeepEqual(src *bool) bool { } return true } +func (p *ListPromptRequest) Field14DeepEqual(src []prompt.PromptType) bool { + + if len(p.FilterPromptTypes) != len(src) { + return false + } + for i, v := range p.FilterPromptTypes { + _src := src[i] + if strings.Compare(v, _src) != 0 { + return false + } + } + return true +} func (p *ListPromptRequest) Field127DeepEqual(src *int32) bool { if p.PageNum == src { @@ -7188,11 +7520,13 @@ func (p *CommitDraftResponse) Field255DeepEqual(src *base.BaseResp) bool { // 搜索Prompt提交版本 type ListCommitRequest struct { - PromptID *int64 `thrift:"prompt_id,1,optional" frugal:"1,optional,i64" json:"prompt_id" path:"prompt_id" ` - PageSize *int32 `thrift:"page_size,127,optional" frugal:"127,optional,i32" form:"page_size" json:"page_size,omitempty" query:"page_size"` - PageToken *string `thrift:"page_token,128,optional" frugal:"128,optional,string" form:"page_token" json:"page_token,omitempty" query:"page_token"` - Asc *bool `thrift:"asc,129,optional" frugal:"129,optional,bool" form:"asc" json:"asc,omitempty" query:"asc"` - Base *base.Base `thrift:"Base,255,optional" frugal:"255,optional,base.Base" form:"Base" json:"Base,omitempty" query:"Base"` + PromptID *int64 `thrift:"prompt_id,1,optional" frugal:"1,optional,i64" json:"prompt_id" path:"prompt_id" ` + // 是否查询详情 + WithCommitDetail *bool `thrift:"with_commit_detail,2,optional" frugal:"2,optional,bool" json:"with_commit_detail,omitempty" query:"with_commit_detail"` + PageSize *int32 `thrift:"page_size,127,optional" frugal:"127,optional,i32" form:"page_size" json:"page_size,omitempty" query:"page_size"` + PageToken *string `thrift:"page_token,128,optional" frugal:"128,optional,string" form:"page_token" json:"page_token,omitempty" query:"page_token"` + Asc *bool `thrift:"asc,129,optional" frugal:"129,optional,bool" form:"asc" json:"asc,omitempty" query:"asc"` + Base *base.Base `thrift:"Base,255,optional" frugal:"255,optional,base.Base" form:"Base" json:"Base,omitempty" query:"Base"` } func NewListCommitRequest() *ListCommitRequest { @@ -7214,6 +7548,18 @@ func (p *ListCommitRequest) GetPromptID() (v int64) { return *p.PromptID } +var ListCommitRequest_WithCommitDetail_DEFAULT bool + +func (p *ListCommitRequest) GetWithCommitDetail() (v bool) { + if p == nil { + return + } + if !p.IsSetWithCommitDetail() { + return ListCommitRequest_WithCommitDetail_DEFAULT + } + return *p.WithCommitDetail +} + var ListCommitRequest_PageSize_DEFAULT int32 func (p *ListCommitRequest) GetPageSize() (v int32) { @@ -7264,6 +7610,9 @@ func (p *ListCommitRequest) GetBase() (v *base.Base) { func (p *ListCommitRequest) SetPromptID(val *int64) { p.PromptID = val } +func (p *ListCommitRequest) SetWithCommitDetail(val *bool) { + p.WithCommitDetail = val +} func (p *ListCommitRequest) SetPageSize(val *int32) { p.PageSize = val } @@ -7279,6 +7628,7 @@ func (p *ListCommitRequest) SetBase(val *base.Base) { var fieldIDToName_ListCommitRequest = map[int16]string{ 1: "prompt_id", + 2: "with_commit_detail", 127: "page_size", 128: "page_token", 129: "asc", @@ -7289,6 +7639,10 @@ func (p *ListCommitRequest) IsSetPromptID() bool { return p.PromptID != nil } +func (p *ListCommitRequest) IsSetWithCommitDetail() bool { + return p.WithCommitDetail != nil +} + func (p *ListCommitRequest) IsSetPageSize() bool { return p.PageSize != nil } @@ -7331,6 +7685,14 @@ func (p *ListCommitRequest) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 2: + if fieldTypeId == thrift.BOOL { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } case 127: if fieldTypeId == thrift.I32 { if err = p.ReadField127(iprot); err != nil { @@ -7403,6 +7765,17 @@ func (p *ListCommitRequest) ReadField1(iprot thrift.TProtocol) error { p.PromptID = _field return nil } +func (p *ListCommitRequest) ReadField2(iprot thrift.TProtocol) error { + + var _field *bool + if v, err := iprot.ReadBool(); err != nil { + return err + } else { + _field = &v + } + p.WithCommitDetail = _field + return nil +} func (p *ListCommitRequest) ReadField127(iprot thrift.TProtocol) error { var _field *int32 @@ -7455,6 +7828,10 @@ func (p *ListCommitRequest) Write(oprot thrift.TProtocol) (err error) { fieldId = 1 goto WriteFieldError } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } if err = p.writeField127(oprot); err != nil { fieldId = 127 goto WriteFieldError @@ -7507,8 +7884,26 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) } -func (p *ListCommitRequest) writeField127(oprot thrift.TProtocol) (err error) { - if p.IsSetPageSize() { +func (p *ListCommitRequest) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetWithCommitDetail() { + if err = oprot.WriteFieldBegin("with_commit_detail", thrift.BOOL, 2); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteBool(*p.WithCommitDetail); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} +func (p *ListCommitRequest) writeField127(oprot thrift.TProtocol) (err error) { + if p.IsSetPageSize() { if err = oprot.WriteFieldBegin("page_size", thrift.I32, 127); err != nil { goto WriteFieldBeginError } @@ -7597,6 +7992,9 @@ func (p *ListCommitRequest) DeepEqual(ano *ListCommitRequest) bool { if !p.Field1DeepEqual(ano.PromptID) { return false } + if !p.Field2DeepEqual(ano.WithCommitDetail) { + return false + } if !p.Field127DeepEqual(ano.PageSize) { return false } @@ -7624,6 +8022,18 @@ func (p *ListCommitRequest) Field1DeepEqual(src *int64) bool { } return true } +func (p *ListCommitRequest) Field2DeepEqual(src *bool) bool { + + if p.WithCommitDetail == src { + return true + } else if p.WithCommitDetail == nil || src == nil { + return false + } + if *p.WithCommitDetail != *src { + return false + } + return true +} func (p *ListCommitRequest) Field127DeepEqual(src *int32) bool { if p.PageSize == src { @@ -7671,10 +8081,14 @@ func (p *ListCommitRequest) Field255DeepEqual(src *base.Base) bool { type ListCommitResponse struct { PromptCommitInfos []*prompt.CommitInfo `thrift:"prompt_commit_infos,1,optional" frugal:"1,optional,list" form:"prompt_commit_infos" json:"prompt_commit_infos,omitempty" query:"prompt_commit_infos"` CommitVersionLabelMapping map[string][]*prompt.Label `thrift:"commit_version_label_mapping,2,optional" frugal:"2,optional,map>" form:"commit_version_label_mapping" json:"commit_version_label_mapping,omitempty" query:"commit_version_label_mapping"` - Users []*user.UserInfoDetail `thrift:"users,11,optional" frugal:"11,optional,list" form:"users" json:"users,omitempty" query:"users"` - HasMore *bool `thrift:"has_more,127,optional" frugal:"127,optional,bool" form:"has_more" json:"has_more,omitempty" query:"has_more"` - NextPageToken *string `thrift:"next_page_token,128,optional" frugal:"128,optional,string" form:"next_page_token" json:"next_page_token,omitempty" query:"next_page_token"` - BaseResp *base.BaseResp `thrift:"BaseResp,255,optional" frugal:"255,optional,base.BaseResp" form:"BaseResp" json:"BaseResp,omitempty" query:"BaseResp"` + // key: version, value:被引用数 + ParentReferencesMapping map[string]int32 `thrift:"parent_references_mapping,3,optional" frugal:"3,optional,map" form:"parent_references_mapping" json:"parent_references_mapping,omitempty" query:"parent_references_mapping"` + // key:version, value:PromptDetail + PromptCommitDetailMapping map[string]*prompt.PromptDetail `thrift:"prompt_commit_detail_mapping,4,optional" frugal:"4,optional,map" form:"prompt_commit_detail_mapping" json:"prompt_commit_detail_mapping,omitempty" query:"prompt_commit_detail_mapping"` + Users []*user.UserInfoDetail `thrift:"users,11,optional" frugal:"11,optional,list" form:"users" json:"users,omitempty" query:"users"` + HasMore *bool `thrift:"has_more,127,optional" frugal:"127,optional,bool" form:"has_more" json:"has_more,omitempty" query:"has_more"` + NextPageToken *string `thrift:"next_page_token,128,optional" frugal:"128,optional,string" form:"next_page_token" json:"next_page_token,omitempty" query:"next_page_token"` + BaseResp *base.BaseResp `thrift:"BaseResp,255,optional" frugal:"255,optional,base.BaseResp" form:"BaseResp" json:"BaseResp,omitempty" query:"BaseResp"` } func NewListCommitResponse() *ListCommitResponse { @@ -7708,6 +8122,30 @@ func (p *ListCommitResponse) GetCommitVersionLabelMapping() (v map[string][]*pro return p.CommitVersionLabelMapping } +var ListCommitResponse_ParentReferencesMapping_DEFAULT map[string]int32 + +func (p *ListCommitResponse) GetParentReferencesMapping() (v map[string]int32) { + if p == nil { + return + } + if !p.IsSetParentReferencesMapping() { + return ListCommitResponse_ParentReferencesMapping_DEFAULT + } + return p.ParentReferencesMapping +} + +var ListCommitResponse_PromptCommitDetailMapping_DEFAULT map[string]*prompt.PromptDetail + +func (p *ListCommitResponse) GetPromptCommitDetailMapping() (v map[string]*prompt.PromptDetail) { + if p == nil { + return + } + if !p.IsSetPromptCommitDetailMapping() { + return ListCommitResponse_PromptCommitDetailMapping_DEFAULT + } + return p.PromptCommitDetailMapping +} + var ListCommitResponse_Users_DEFAULT []*user.UserInfoDetail func (p *ListCommitResponse) GetUsers() (v []*user.UserInfoDetail) { @@ -7761,6 +8199,12 @@ func (p *ListCommitResponse) SetPromptCommitInfos(val []*prompt.CommitInfo) { func (p *ListCommitResponse) SetCommitVersionLabelMapping(val map[string][]*prompt.Label) { p.CommitVersionLabelMapping = val } +func (p *ListCommitResponse) SetParentReferencesMapping(val map[string]int32) { + p.ParentReferencesMapping = val +} +func (p *ListCommitResponse) SetPromptCommitDetailMapping(val map[string]*prompt.PromptDetail) { + p.PromptCommitDetailMapping = val +} func (p *ListCommitResponse) SetUsers(val []*user.UserInfoDetail) { p.Users = val } @@ -7777,6 +8221,8 @@ func (p *ListCommitResponse) SetBaseResp(val *base.BaseResp) { var fieldIDToName_ListCommitResponse = map[int16]string{ 1: "prompt_commit_infos", 2: "commit_version_label_mapping", + 3: "parent_references_mapping", + 4: "prompt_commit_detail_mapping", 11: "users", 127: "has_more", 128: "next_page_token", @@ -7791,6 +8237,14 @@ func (p *ListCommitResponse) IsSetCommitVersionLabelMapping() bool { return p.CommitVersionLabelMapping != nil } +func (p *ListCommitResponse) IsSetParentReferencesMapping() bool { + return p.ParentReferencesMapping != nil +} + +func (p *ListCommitResponse) IsSetPromptCommitDetailMapping() bool { + return p.PromptCommitDetailMapping != nil +} + func (p *ListCommitResponse) IsSetUsers() bool { return p.Users != nil } @@ -7841,6 +8295,22 @@ func (p *ListCommitResponse) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 3: + if fieldTypeId == thrift.MAP { + if err = p.ReadField3(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 4: + if fieldTypeId == thrift.MAP { + if err = p.ReadField4(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } case 11: if fieldTypeId == thrift.LIST { if err = p.ReadField11(iprot); err != nil { @@ -7966,6 +8436,64 @@ func (p *ListCommitResponse) ReadField2(iprot thrift.TProtocol) error { p.CommitVersionLabelMapping = _field return nil } +func (p *ListCommitResponse) ReadField3(iprot thrift.TProtocol) error { + _, _, size, err := iprot.ReadMapBegin() + if err != nil { + return err + } + _field := make(map[string]int32, size) + for i := 0; i < size; i++ { + var _key string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _key = v + } + + var _val int32 + if v, err := iprot.ReadI32(); err != nil { + return err + } else { + _val = v + } + + _field[_key] = _val + } + if err := iprot.ReadMapEnd(); err != nil { + return err + } + p.ParentReferencesMapping = _field + return nil +} +func (p *ListCommitResponse) ReadField4(iprot thrift.TProtocol) error { + _, _, size, err := iprot.ReadMapBegin() + if err != nil { + return err + } + _field := make(map[string]*prompt.PromptDetail, size) + values := make([]prompt.PromptDetail, size) + for i := 0; i < size; i++ { + var _key string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _key = v + } + + _val := &values[i] + _val.InitDefault() + if err := _val.Read(iprot); err != nil { + return err + } + + _field[_key] = _val + } + if err := iprot.ReadMapEnd(); err != nil { + return err + } + p.PromptCommitDetailMapping = _field + return nil +} func (p *ListCommitResponse) ReadField11(iprot thrift.TProtocol) error { _, size, err := iprot.ReadListBegin() if err != nil { @@ -8034,6 +8562,14 @@ func (p *ListCommitResponse) Write(oprot thrift.TProtocol) (err error) { fieldId = 2 goto WriteFieldError } + if err = p.writeField3(oprot); err != nil { + fieldId = 3 + goto WriteFieldError + } + if err = p.writeField4(oprot); err != nil { + fieldId = 4 + goto WriteFieldError + } if err = p.writeField11(oprot); err != nil { fieldId = 11 goto WriteFieldError @@ -8131,6 +8667,64 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) } +func (p *ListCommitResponse) writeField3(oprot thrift.TProtocol) (err error) { + if p.IsSetParentReferencesMapping() { + if err = oprot.WriteFieldBegin("parent_references_mapping", thrift.MAP, 3); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteMapBegin(thrift.STRING, thrift.I32, len(p.ParentReferencesMapping)); err != nil { + return err + } + for k, v := range p.ParentReferencesMapping { + if err := oprot.WriteString(k); err != nil { + return err + } + if err := oprot.WriteI32(v); err != nil { + return err + } + } + if err := oprot.WriteMapEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) +} +func (p *ListCommitResponse) writeField4(oprot thrift.TProtocol) (err error) { + if p.IsSetPromptCommitDetailMapping() { + if err = oprot.WriteFieldBegin("prompt_commit_detail_mapping", thrift.MAP, 4); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteMapBegin(thrift.STRING, thrift.STRUCT, len(p.PromptCommitDetailMapping)); err != nil { + return err + } + for k, v := range p.PromptCommitDetailMapping { + if err := oprot.WriteString(k); err != nil { + return err + } + if err := v.Write(oprot); err != nil { + return err + } + } + if err := oprot.WriteMapEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 end error: ", p), err) +} func (p *ListCommitResponse) writeField11(oprot thrift.TProtocol) (err error) { if p.IsSetUsers() { if err = oprot.WriteFieldBegin("users", thrift.LIST, 11); err != nil { @@ -8232,6 +8826,12 @@ func (p *ListCommitResponse) DeepEqual(ano *ListCommitResponse) bool { if !p.Field2DeepEqual(ano.CommitVersionLabelMapping) { return false } + if !p.Field3DeepEqual(ano.ParentReferencesMapping) { + return false + } + if !p.Field4DeepEqual(ano.PromptCommitDetailMapping) { + return false + } if !p.Field11DeepEqual(ano.Users) { return false } @@ -8279,6 +8879,32 @@ func (p *ListCommitResponse) Field2DeepEqual(src map[string][]*prompt.Label) boo } return true } +func (p *ListCommitResponse) Field3DeepEqual(src map[string]int32) bool { + + if len(p.ParentReferencesMapping) != len(src) { + return false + } + for k, v := range p.ParentReferencesMapping { + _src := src[k] + if v != _src { + return false + } + } + return true +} +func (p *ListCommitResponse) Field4DeepEqual(src map[string]*prompt.PromptDetail) bool { + + if len(p.PromptCommitDetailMapping) != len(src) { + return false + } + for k, v := range p.PromptCommitDetailMapping { + _src := src[k] + if !v.DeepEqual(_src) { + return false + } + } + return true +} func (p *ListCommitResponse) Field11DeepEqual(src []*user.UserInfoDetail) bool { if len(p.Users) != len(src) { @@ -11779,72 +12405,807 @@ func (p *UpdateCommitLabelsResponse) Field255DeepEqual(src *base.BaseResp) bool return true } -type PromptManageService interface { - // --------------- Prompt管理 --------------- // - // 增 - CreatePrompt(ctx context.Context, request *CreatePromptRequest) (r *CreatePromptResponse, err error) - - ClonePrompt(ctx context.Context, request *ClonePromptRequest) (r *ClonePromptResponse, err error) - // 删 - DeletePrompt(ctx context.Context, request *DeletePromptRequest) (r *DeletePromptResponse, err error) - // 查 - GetPrompt(ctx context.Context, request *GetPromptRequest) (r *GetPromptResponse, err error) - - BatchGetPrompt(ctx context.Context, request *BatchGetPromptRequest) (r *BatchGetPromptResponse, err error) - - ListPrompt(ctx context.Context, request *ListPromptRequest) (r *ListPromptResponse, err error) - // 改 - UpdatePrompt(ctx context.Context, request *UpdatePromptRequest) (r *UpdatePromptResponse, err error) - - SaveDraft(ctx context.Context, request *SaveDraftRequest) (r *SaveDraftResponse, err error) - // --------------- Label管理 --------------- // - // Label管理 - CreateLabel(ctx context.Context, request *CreateLabelRequest) (r *CreateLabelResponse, err error) - - ListLabel(ctx context.Context, request *ListLabelRequest) (r *ListLabelResponse, err error) - - BatchGetLabel(ctx context.Context, request *BatchGetLabelRequest) (r *BatchGetLabelResponse, err error) - // --------------- Prompt版本管理 --------------- // - ListCommit(ctx context.Context, request *ListCommitRequest) (r *ListCommitResponse, err error) - - CommitDraft(ctx context.Context, request *CommitDraftRequest) (r *CommitDraftResponse, err error) - - RevertDraftFromCommit(ctx context.Context, request *RevertDraftFromCommitRequest) (r *RevertDraftFromCommitResponse, err error) +type ListParentPromptRequest struct { + WorkspaceID *int64 `thrift:"workspace_id,1,optional" frugal:"1,optional,i64" json:"workspace_id" form:"workspace_id" query:"workspace_id"` + PromptID *int64 `thrift:"prompt_id,2,optional" frugal:"2,optional,i64" json:"prompt_id" form:"prompt_id" query:"prompt_id"` + // 片段版本,不传则表示查询所有版本的引用记录 + CommitVersions []string `thrift:"commit_versions,3,optional" frugal:"3,optional,list" form:"commit_versions" json:"commit_versions,omitempty" query:"commit_versions"` + Base *base.Base `thrift:"Base,255,optional" frugal:"255,optional,base.Base" form:"Base" json:"Base,omitempty" query:"Base"` +} - UpdateCommitLabels(ctx context.Context, request *UpdateCommitLabelsRequest) (r *UpdateCommitLabelsResponse, err error) +func NewListParentPromptRequest() *ListParentPromptRequest { + return &ListParentPromptRequest{} } -type PromptManageServiceClient struct { - c thrift.TClient +func (p *ListParentPromptRequest) InitDefault() { } -func NewPromptManageServiceClientFactory(t thrift.TTransport, f thrift.TProtocolFactory) *PromptManageServiceClient { - return &PromptManageServiceClient{ - c: thrift.NewTStandardClient(f.GetProtocol(t), f.GetProtocol(t)), +var ListParentPromptRequest_WorkspaceID_DEFAULT int64 + +func (p *ListParentPromptRequest) GetWorkspaceID() (v int64) { + if p == nil { + return } + if !p.IsSetWorkspaceID() { + return ListParentPromptRequest_WorkspaceID_DEFAULT + } + return *p.WorkspaceID } -func NewPromptManageServiceClientProtocol(t thrift.TTransport, iprot thrift.TProtocol, oprot thrift.TProtocol) *PromptManageServiceClient { - return &PromptManageServiceClient{ - c: thrift.NewTStandardClient(iprot, oprot), +var ListParentPromptRequest_PromptID_DEFAULT int64 + +func (p *ListParentPromptRequest) GetPromptID() (v int64) { + if p == nil { + return + } + if !p.IsSetPromptID() { + return ListParentPromptRequest_PromptID_DEFAULT } + return *p.PromptID } -func NewPromptManageServiceClient(c thrift.TClient) *PromptManageServiceClient { - return &PromptManageServiceClient{ - c: c, +var ListParentPromptRequest_CommitVersions_DEFAULT []string + +func (p *ListParentPromptRequest) GetCommitVersions() (v []string) { + if p == nil { + return } + if !p.IsSetCommitVersions() { + return ListParentPromptRequest_CommitVersions_DEFAULT + } + return p.CommitVersions } -func (p *PromptManageServiceClient) Client_() thrift.TClient { - return p.c -} +var ListParentPromptRequest_Base_DEFAULT *base.Base -func (p *PromptManageServiceClient) CreatePrompt(ctx context.Context, request *CreatePromptRequest) (r *CreatePromptResponse, err error) { - var _args PromptManageServiceCreatePromptArgs - _args.Request = request - var _result PromptManageServiceCreatePromptResult - if err = p.Client_().Call(ctx, "CreatePrompt", &_args, &_result); err != nil { +func (p *ListParentPromptRequest) GetBase() (v *base.Base) { + if p == nil { + return + } + if !p.IsSetBase() { + return ListParentPromptRequest_Base_DEFAULT + } + return p.Base +} +func (p *ListParentPromptRequest) SetWorkspaceID(val *int64) { + p.WorkspaceID = val +} +func (p *ListParentPromptRequest) SetPromptID(val *int64) { + p.PromptID = val +} +func (p *ListParentPromptRequest) SetCommitVersions(val []string) { + p.CommitVersions = val +} +func (p *ListParentPromptRequest) SetBase(val *base.Base) { + p.Base = val +} + +var fieldIDToName_ListParentPromptRequest = map[int16]string{ + 1: "workspace_id", + 2: "prompt_id", + 3: "commit_versions", + 255: "Base", +} + +func (p *ListParentPromptRequest) IsSetWorkspaceID() bool { + return p.WorkspaceID != nil +} + +func (p *ListParentPromptRequest) IsSetPromptID() bool { + return p.PromptID != nil +} + +func (p *ListParentPromptRequest) IsSetCommitVersions() bool { + return p.CommitVersions != nil +} + +func (p *ListParentPromptRequest) IsSetBase() bool { + return p.Base != nil +} + +func (p *ListParentPromptRequest) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.I64 { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 2: + if fieldTypeId == thrift.I64 { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 3: + if fieldTypeId == thrift.LIST { + if err = p.ReadField3(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 255: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField255(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ListParentPromptRequest[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *ListParentPromptRequest) ReadField1(iprot thrift.TProtocol) error { + + var _field *int64 + if v, err := iprot.ReadI64(); err != nil { + return err + } else { + _field = &v + } + p.WorkspaceID = _field + return nil +} +func (p *ListParentPromptRequest) ReadField2(iprot thrift.TProtocol) error { + + var _field *int64 + if v, err := iprot.ReadI64(); err != nil { + return err + } else { + _field = &v + } + p.PromptID = _field + return nil +} +func (p *ListParentPromptRequest) ReadField3(iprot thrift.TProtocol) error { + _, size, err := iprot.ReadListBegin() + if err != nil { + return err + } + _field := make([]string, 0, size) + for i := 0; i < size; i++ { + + var _elem string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _elem = v + } + + _field = append(_field, _elem) + } + if err := iprot.ReadListEnd(); err != nil { + return err + } + p.CommitVersions = _field + return nil +} +func (p *ListParentPromptRequest) ReadField255(iprot thrift.TProtocol) error { + _field := base.NewBase() + if err := _field.Read(iprot); err != nil { + return err + } + p.Base = _field + return nil +} + +func (p *ListParentPromptRequest) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("ListParentPromptRequest"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } + if err = p.writeField3(oprot); err != nil { + fieldId = 3 + goto WriteFieldError + } + if err = p.writeField255(oprot); err != nil { + fieldId = 255 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *ListParentPromptRequest) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetWorkspaceID() { + if err = oprot.WriteFieldBegin("workspace_id", thrift.I64, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI64(*p.WorkspaceID); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} +func (p *ListParentPromptRequest) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetPromptID() { + if err = oprot.WriteFieldBegin("prompt_id", thrift.I64, 2); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI64(*p.PromptID); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} +func (p *ListParentPromptRequest) writeField3(oprot thrift.TProtocol) (err error) { + if p.IsSetCommitVersions() { + if err = oprot.WriteFieldBegin("commit_versions", thrift.LIST, 3); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteListBegin(thrift.STRING, len(p.CommitVersions)); err != nil { + return err + } + for _, v := range p.CommitVersions { + if err := oprot.WriteString(v); err != nil { + return err + } + } + if err := oprot.WriteListEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) +} +func (p *ListParentPromptRequest) writeField255(oprot thrift.TProtocol) (err error) { + if p.IsSetBase() { + if err = oprot.WriteFieldBegin("Base", thrift.STRUCT, 255); err != nil { + goto WriteFieldBeginError + } + if err := p.Base.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 255 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 255 end error: ", p), err) +} + +func (p *ListParentPromptRequest) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("ListParentPromptRequest(%+v)", *p) + +} + +func (p *ListParentPromptRequest) DeepEqual(ano *ListParentPromptRequest) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.WorkspaceID) { + return false + } + if !p.Field2DeepEqual(ano.PromptID) { + return false + } + if !p.Field3DeepEqual(ano.CommitVersions) { + return false + } + if !p.Field255DeepEqual(ano.Base) { + return false + } + return true +} + +func (p *ListParentPromptRequest) Field1DeepEqual(src *int64) bool { + + if p.WorkspaceID == src { + return true + } else if p.WorkspaceID == nil || src == nil { + return false + } + if *p.WorkspaceID != *src { + return false + } + return true +} +func (p *ListParentPromptRequest) Field2DeepEqual(src *int64) bool { + + if p.PromptID == src { + return true + } else if p.PromptID == nil || src == nil { + return false + } + if *p.PromptID != *src { + return false + } + return true +} +func (p *ListParentPromptRequest) Field3DeepEqual(src []string) bool { + + if len(p.CommitVersions) != len(src) { + return false + } + for i, v := range p.CommitVersions { + _src := src[i] + if strings.Compare(v, _src) != 0 { + return false + } + } + return true +} +func (p *ListParentPromptRequest) Field255DeepEqual(src *base.Base) bool { + + if !p.Base.DeepEqual(src) { + return false + } + return true +} + +type ListParentPromptResponse struct { + // 不同片段版本被引用的父prompt记录 + ParentPrompts map[string][]*prompt.PromptCommitVersions `thrift:"parent_prompts,1,optional" frugal:"1,optional,map>" form:"parent_prompts" json:"parent_prompts,omitempty" query:"parent_prompts"` + BaseResp *base.BaseResp `thrift:"BaseResp,255,optional" frugal:"255,optional,base.BaseResp" form:"BaseResp" json:"BaseResp,omitempty" query:"BaseResp"` +} + +func NewListParentPromptResponse() *ListParentPromptResponse { + return &ListParentPromptResponse{} +} + +func (p *ListParentPromptResponse) InitDefault() { +} + +var ListParentPromptResponse_ParentPrompts_DEFAULT map[string][]*prompt.PromptCommitVersions + +func (p *ListParentPromptResponse) GetParentPrompts() (v map[string][]*prompt.PromptCommitVersions) { + if p == nil { + return + } + if !p.IsSetParentPrompts() { + return ListParentPromptResponse_ParentPrompts_DEFAULT + } + return p.ParentPrompts +} + +var ListParentPromptResponse_BaseResp_DEFAULT *base.BaseResp + +func (p *ListParentPromptResponse) GetBaseResp() (v *base.BaseResp) { + if p == nil { + return + } + if !p.IsSetBaseResp() { + return ListParentPromptResponse_BaseResp_DEFAULT + } + return p.BaseResp +} +func (p *ListParentPromptResponse) SetParentPrompts(val map[string][]*prompt.PromptCommitVersions) { + p.ParentPrompts = val +} +func (p *ListParentPromptResponse) SetBaseResp(val *base.BaseResp) { + p.BaseResp = val +} + +var fieldIDToName_ListParentPromptResponse = map[int16]string{ + 1: "parent_prompts", + 255: "BaseResp", +} + +func (p *ListParentPromptResponse) IsSetParentPrompts() bool { + return p.ParentPrompts != nil +} + +func (p *ListParentPromptResponse) IsSetBaseResp() bool { + return p.BaseResp != nil +} + +func (p *ListParentPromptResponse) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.MAP { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 255: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField255(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ListParentPromptResponse[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *ListParentPromptResponse) ReadField1(iprot thrift.TProtocol) error { + _, _, size, err := iprot.ReadMapBegin() + if err != nil { + return err + } + _field := make(map[string][]*prompt.PromptCommitVersions, size) + for i := 0; i < size; i++ { + var _key string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _key = v + } + _, size, err := iprot.ReadListBegin() + if err != nil { + return err + } + _val := make([]*prompt.PromptCommitVersions, 0, size) + values := make([]prompt.PromptCommitVersions, size) + for i := 0; i < size; i++ { + _elem := &values[i] + _elem.InitDefault() + + if err := _elem.Read(iprot); err != nil { + return err + } + + _val = append(_val, _elem) + } + if err := iprot.ReadListEnd(); err != nil { + return err + } + + _field[_key] = _val + } + if err := iprot.ReadMapEnd(); err != nil { + return err + } + p.ParentPrompts = _field + return nil +} +func (p *ListParentPromptResponse) ReadField255(iprot thrift.TProtocol) error { + _field := base.NewBaseResp() + if err := _field.Read(iprot); err != nil { + return err + } + p.BaseResp = _field + return nil +} + +func (p *ListParentPromptResponse) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("ListParentPromptResponse"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + if err = p.writeField255(oprot); err != nil { + fieldId = 255 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *ListParentPromptResponse) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetParentPrompts() { + if err = oprot.WriteFieldBegin("parent_prompts", thrift.MAP, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteMapBegin(thrift.STRING, thrift.LIST, len(p.ParentPrompts)); err != nil { + return err + } + for k, v := range p.ParentPrompts { + if err := oprot.WriteString(k); err != nil { + return err + } + if err := oprot.WriteListBegin(thrift.STRUCT, len(v)); err != nil { + return err + } + for _, v := range v { + if err := v.Write(oprot); err != nil { + return err + } + } + if err := oprot.WriteListEnd(); err != nil { + return err + } + } + if err := oprot.WriteMapEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} +func (p *ListParentPromptResponse) writeField255(oprot thrift.TProtocol) (err error) { + if p.IsSetBaseResp() { + if err = oprot.WriteFieldBegin("BaseResp", thrift.STRUCT, 255); err != nil { + goto WriteFieldBeginError + } + if err := p.BaseResp.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 255 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 255 end error: ", p), err) +} + +func (p *ListParentPromptResponse) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("ListParentPromptResponse(%+v)", *p) + +} + +func (p *ListParentPromptResponse) DeepEqual(ano *ListParentPromptResponse) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.ParentPrompts) { + return false + } + if !p.Field255DeepEqual(ano.BaseResp) { + return false + } + return true +} + +func (p *ListParentPromptResponse) Field1DeepEqual(src map[string][]*prompt.PromptCommitVersions) bool { + + if len(p.ParentPrompts) != len(src) { + return false + } + for k, v := range p.ParentPrompts { + _src := src[k] + if len(v) != len(_src) { + return false + } + for i, v := range v { + _src1 := _src[i] + if !v.DeepEqual(_src1) { + return false + } + } + } + return true +} +func (p *ListParentPromptResponse) Field255DeepEqual(src *base.BaseResp) bool { + + if !p.BaseResp.DeepEqual(src) { + return false + } + return true +} + +type PromptManageService interface { + // --------------- Prompt管理 --------------- // + // 增 + CreatePrompt(ctx context.Context, request *CreatePromptRequest) (r *CreatePromptResponse, err error) + + ClonePrompt(ctx context.Context, request *ClonePromptRequest) (r *ClonePromptResponse, err error) + // 删 + DeletePrompt(ctx context.Context, request *DeletePromptRequest) (r *DeletePromptResponse, err error) + // 查 + GetPrompt(ctx context.Context, request *GetPromptRequest) (r *GetPromptResponse, err error) + + BatchGetPrompt(ctx context.Context, request *BatchGetPromptRequest) (r *BatchGetPromptResponse, err error) + + ListPrompt(ctx context.Context, request *ListPromptRequest) (r *ListPromptResponse, err error) + // 查询片段的引用记录 + ListParentPrompt(ctx context.Context, request *ListParentPromptRequest) (r *ListParentPromptResponse, err error) + // 改 + UpdatePrompt(ctx context.Context, request *UpdatePromptRequest) (r *UpdatePromptResponse, err error) + + SaveDraft(ctx context.Context, request *SaveDraftRequest) (r *SaveDraftResponse, err error) + // --------------- Label管理 --------------- // + // Label管理 + CreateLabel(ctx context.Context, request *CreateLabelRequest) (r *CreateLabelResponse, err error) + + ListLabel(ctx context.Context, request *ListLabelRequest) (r *ListLabelResponse, err error) + + BatchGetLabel(ctx context.Context, request *BatchGetLabelRequest) (r *BatchGetLabelResponse, err error) + // --------------- Prompt版本管理 --------------- // + ListCommit(ctx context.Context, request *ListCommitRequest) (r *ListCommitResponse, err error) + + CommitDraft(ctx context.Context, request *CommitDraftRequest) (r *CommitDraftResponse, err error) + + RevertDraftFromCommit(ctx context.Context, request *RevertDraftFromCommitRequest) (r *RevertDraftFromCommitResponse, err error) + + UpdateCommitLabels(ctx context.Context, request *UpdateCommitLabelsRequest) (r *UpdateCommitLabelsResponse, err error) +} + +type PromptManageServiceClient struct { + c thrift.TClient +} + +func NewPromptManageServiceClientFactory(t thrift.TTransport, f thrift.TProtocolFactory) *PromptManageServiceClient { + return &PromptManageServiceClient{ + c: thrift.NewTStandardClient(f.GetProtocol(t), f.GetProtocol(t)), + } +} + +func NewPromptManageServiceClientProtocol(t thrift.TTransport, iprot thrift.TProtocol, oprot thrift.TProtocol) *PromptManageServiceClient { + return &PromptManageServiceClient{ + c: thrift.NewTStandardClient(iprot, oprot), + } +} + +func NewPromptManageServiceClient(c thrift.TClient) *PromptManageServiceClient { + return &PromptManageServiceClient{ + c: c, + } +} + +func (p *PromptManageServiceClient) Client_() thrift.TClient { + return p.c +} + +func (p *PromptManageServiceClient) CreatePrompt(ctx context.Context, request *CreatePromptRequest) (r *CreatePromptResponse, err error) { + var _args PromptManageServiceCreatePromptArgs + _args.Request = request + var _result PromptManageServiceCreatePromptResult + if err = p.Client_().Call(ctx, "CreatePrompt", &_args, &_result); err != nil { return } return _result.GetSuccess(), nil @@ -11894,6 +13255,15 @@ func (p *PromptManageServiceClient) ListPrompt(ctx context.Context, request *Lis } return _result.GetSuccess(), nil } +func (p *PromptManageServiceClient) ListParentPrompt(ctx context.Context, request *ListParentPromptRequest) (r *ListParentPromptResponse, err error) { + var _args PromptManageServiceListParentPromptArgs + _args.Request = request + var _result PromptManageServiceListParentPromptResult + if err = p.Client_().Call(ctx, "ListParentPrompt", &_args, &_result); err != nil { + return + } + return _result.GetSuccess(), nil +} func (p *PromptManageServiceClient) UpdatePrompt(ctx context.Context, request *UpdatePromptRequest) (r *UpdatePromptResponse, err error) { var _args PromptManageServiceUpdatePromptArgs _args.Request = request @@ -12002,6 +13372,7 @@ func NewPromptManageServiceProcessor(handler PromptManageService) *PromptManageS self.AddToProcessorMap("GetPrompt", &promptManageServiceProcessorGetPrompt{handler: handler}) self.AddToProcessorMap("BatchGetPrompt", &promptManageServiceProcessorBatchGetPrompt{handler: handler}) self.AddToProcessorMap("ListPrompt", &promptManageServiceProcessorListPrompt{handler: handler}) + self.AddToProcessorMap("ListParentPrompt", &promptManageServiceProcessorListParentPrompt{handler: handler}) self.AddToProcessorMap("UpdatePrompt", &promptManageServiceProcessorUpdatePrompt{handler: handler}) self.AddToProcessorMap("SaveDraft", &promptManageServiceProcessorSaveDraft{handler: handler}) self.AddToProcessorMap("CreateLabel", &promptManageServiceProcessorCreateLabel{handler: handler}) @@ -12319,6 +13690,54 @@ func (p *promptManageServiceProcessorListPrompt) Process(ctx context.Context, se return true, err } +type promptManageServiceProcessorListParentPrompt struct { + handler PromptManageService +} + +func (p *promptManageServiceProcessorListParentPrompt) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { + args := PromptManageServiceListParentPromptArgs{} + if err = args.Read(iprot); err != nil { + iprot.ReadMessageEnd() + x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) + oprot.WriteMessageBegin("ListParentPrompt", thrift.EXCEPTION, seqId) + x.Write(oprot) + oprot.WriteMessageEnd() + oprot.Flush(ctx) + return false, err + } + + iprot.ReadMessageEnd() + var err2 error + result := PromptManageServiceListParentPromptResult{} + var retval *ListParentPromptResponse + if retval, err2 = p.handler.ListParentPrompt(ctx, args.Request); err2 != nil { + x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing ListParentPrompt: "+err2.Error()) + oprot.WriteMessageBegin("ListParentPrompt", thrift.EXCEPTION, seqId) + x.Write(oprot) + oprot.WriteMessageEnd() + oprot.Flush(ctx) + return true, err2 + } else { + result.Success = retval + } + if err2 = oprot.WriteMessageBegin("ListParentPrompt", thrift.REPLY, seqId); err2 != nil { + err = err2 + } + if err2 = result.Write(oprot); err == nil && err2 != nil { + err = err2 + } + if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { + err = err2 + } + if err2 = oprot.Flush(ctx); err == nil && err2 != nil { + err = err2 + } + if err != nil { + return + } + return true, err +} + type promptManageServiceProcessorUpdatePrompt struct { handler PromptManageService } @@ -12397,7 +13816,151 @@ func (p *promptManageServiceProcessorSaveDraft) Process(ctx context.Context, seq } else { result.Success = retval } - if err2 = oprot.WriteMessageBegin("SaveDraft", thrift.REPLY, seqId); err2 != nil { + if err2 = oprot.WriteMessageBegin("SaveDraft", thrift.REPLY, seqId); err2 != nil { + err = err2 + } + if err2 = result.Write(oprot); err == nil && err2 != nil { + err = err2 + } + if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { + err = err2 + } + if err2 = oprot.Flush(ctx); err == nil && err2 != nil { + err = err2 + } + if err != nil { + return + } + return true, err +} + +type promptManageServiceProcessorCreateLabel struct { + handler PromptManageService +} + +func (p *promptManageServiceProcessorCreateLabel) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { + args := PromptManageServiceCreateLabelArgs{} + if err = args.Read(iprot); err != nil { + iprot.ReadMessageEnd() + x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) + oprot.WriteMessageBegin("CreateLabel", thrift.EXCEPTION, seqId) + x.Write(oprot) + oprot.WriteMessageEnd() + oprot.Flush(ctx) + return false, err + } + + iprot.ReadMessageEnd() + var err2 error + result := PromptManageServiceCreateLabelResult{} + var retval *CreateLabelResponse + if retval, err2 = p.handler.CreateLabel(ctx, args.Request); err2 != nil { + x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing CreateLabel: "+err2.Error()) + oprot.WriteMessageBegin("CreateLabel", thrift.EXCEPTION, seqId) + x.Write(oprot) + oprot.WriteMessageEnd() + oprot.Flush(ctx) + return true, err2 + } else { + result.Success = retval + } + if err2 = oprot.WriteMessageBegin("CreateLabel", thrift.REPLY, seqId); err2 != nil { + err = err2 + } + if err2 = result.Write(oprot); err == nil && err2 != nil { + err = err2 + } + if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { + err = err2 + } + if err2 = oprot.Flush(ctx); err == nil && err2 != nil { + err = err2 + } + if err != nil { + return + } + return true, err +} + +type promptManageServiceProcessorListLabel struct { + handler PromptManageService +} + +func (p *promptManageServiceProcessorListLabel) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { + args := PromptManageServiceListLabelArgs{} + if err = args.Read(iprot); err != nil { + iprot.ReadMessageEnd() + x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) + oprot.WriteMessageBegin("ListLabel", thrift.EXCEPTION, seqId) + x.Write(oprot) + oprot.WriteMessageEnd() + oprot.Flush(ctx) + return false, err + } + + iprot.ReadMessageEnd() + var err2 error + result := PromptManageServiceListLabelResult{} + var retval *ListLabelResponse + if retval, err2 = p.handler.ListLabel(ctx, args.Request); err2 != nil { + x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing ListLabel: "+err2.Error()) + oprot.WriteMessageBegin("ListLabel", thrift.EXCEPTION, seqId) + x.Write(oprot) + oprot.WriteMessageEnd() + oprot.Flush(ctx) + return true, err2 + } else { + result.Success = retval + } + if err2 = oprot.WriteMessageBegin("ListLabel", thrift.REPLY, seqId); err2 != nil { + err = err2 + } + if err2 = result.Write(oprot); err == nil && err2 != nil { + err = err2 + } + if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { + err = err2 + } + if err2 = oprot.Flush(ctx); err == nil && err2 != nil { + err = err2 + } + if err != nil { + return + } + return true, err +} + +type promptManageServiceProcessorBatchGetLabel struct { + handler PromptManageService +} + +func (p *promptManageServiceProcessorBatchGetLabel) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { + args := PromptManageServiceBatchGetLabelArgs{} + if err = args.Read(iprot); err != nil { + iprot.ReadMessageEnd() + x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) + oprot.WriteMessageBegin("BatchGetLabel", thrift.EXCEPTION, seqId) + x.Write(oprot) + oprot.WriteMessageEnd() + oprot.Flush(ctx) + return false, err + } + + iprot.ReadMessageEnd() + var err2 error + result := PromptManageServiceBatchGetLabelResult{} + var retval *BatchGetLabelResponse + if retval, err2 = p.handler.BatchGetLabel(ctx, args.Request); err2 != nil { + x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing BatchGetLabel: "+err2.Error()) + oprot.WriteMessageBegin("BatchGetLabel", thrift.EXCEPTION, seqId) + x.Write(oprot) + oprot.WriteMessageEnd() + oprot.Flush(ctx) + return true, err2 + } else { + result.Success = retval + } + if err2 = oprot.WriteMessageBegin("BatchGetLabel", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { @@ -12415,16 +13978,16 @@ func (p *promptManageServiceProcessorSaveDraft) Process(ctx context.Context, seq return true, err } -type promptManageServiceProcessorCreateLabel struct { +type promptManageServiceProcessorListCommit struct { handler PromptManageService } -func (p *promptManageServiceProcessorCreateLabel) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { - args := PromptManageServiceCreateLabelArgs{} +func (p *promptManageServiceProcessorListCommit) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { + args := PromptManageServiceListCommitArgs{} if err = args.Read(iprot); err != nil { iprot.ReadMessageEnd() x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) - oprot.WriteMessageBegin("CreateLabel", thrift.EXCEPTION, seqId) + oprot.WriteMessageBegin("ListCommit", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush(ctx) @@ -12433,11 +13996,11 @@ func (p *promptManageServiceProcessorCreateLabel) Process(ctx context.Context, s iprot.ReadMessageEnd() var err2 error - result := PromptManageServiceCreateLabelResult{} - var retval *CreateLabelResponse - if retval, err2 = p.handler.CreateLabel(ctx, args.Request); err2 != nil { - x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing CreateLabel: "+err2.Error()) - oprot.WriteMessageBegin("CreateLabel", thrift.EXCEPTION, seqId) + result := PromptManageServiceListCommitResult{} + var retval *ListCommitResponse + if retval, err2 = p.handler.ListCommit(ctx, args.Request); err2 != nil { + x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing ListCommit: "+err2.Error()) + oprot.WriteMessageBegin("ListCommit", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush(ctx) @@ -12445,7 +14008,7 @@ func (p *promptManageServiceProcessorCreateLabel) Process(ctx context.Context, s } else { result.Success = retval } - if err2 = oprot.WriteMessageBegin("CreateLabel", thrift.REPLY, seqId); err2 != nil { + if err2 = oprot.WriteMessageBegin("ListCommit", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { @@ -12463,16 +14026,16 @@ func (p *promptManageServiceProcessorCreateLabel) Process(ctx context.Context, s return true, err } -type promptManageServiceProcessorListLabel struct { +type promptManageServiceProcessorCommitDraft struct { handler PromptManageService } -func (p *promptManageServiceProcessorListLabel) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { - args := PromptManageServiceListLabelArgs{} +func (p *promptManageServiceProcessorCommitDraft) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { + args := PromptManageServiceCommitDraftArgs{} if err = args.Read(iprot); err != nil { iprot.ReadMessageEnd() x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) - oprot.WriteMessageBegin("ListLabel", thrift.EXCEPTION, seqId) + oprot.WriteMessageBegin("CommitDraft", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush(ctx) @@ -12481,11 +14044,11 @@ func (p *promptManageServiceProcessorListLabel) Process(ctx context.Context, seq iprot.ReadMessageEnd() var err2 error - result := PromptManageServiceListLabelResult{} - var retval *ListLabelResponse - if retval, err2 = p.handler.ListLabel(ctx, args.Request); err2 != nil { - x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing ListLabel: "+err2.Error()) - oprot.WriteMessageBegin("ListLabel", thrift.EXCEPTION, seqId) + result := PromptManageServiceCommitDraftResult{} + var retval *CommitDraftResponse + if retval, err2 = p.handler.CommitDraft(ctx, args.Request); err2 != nil { + x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing CommitDraft: "+err2.Error()) + oprot.WriteMessageBegin("CommitDraft", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush(ctx) @@ -12493,7 +14056,7 @@ func (p *promptManageServiceProcessorListLabel) Process(ctx context.Context, seq } else { result.Success = retval } - if err2 = oprot.WriteMessageBegin("ListLabel", thrift.REPLY, seqId); err2 != nil { + if err2 = oprot.WriteMessageBegin("CommitDraft", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { @@ -12511,16 +14074,16 @@ func (p *promptManageServiceProcessorListLabel) Process(ctx context.Context, seq return true, err } -type promptManageServiceProcessorBatchGetLabel struct { +type promptManageServiceProcessorRevertDraftFromCommit struct { handler PromptManageService } -func (p *promptManageServiceProcessorBatchGetLabel) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { - args := PromptManageServiceBatchGetLabelArgs{} +func (p *promptManageServiceProcessorRevertDraftFromCommit) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { + args := PromptManageServiceRevertDraftFromCommitArgs{} if err = args.Read(iprot); err != nil { iprot.ReadMessageEnd() x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) - oprot.WriteMessageBegin("BatchGetLabel", thrift.EXCEPTION, seqId) + oprot.WriteMessageBegin("RevertDraftFromCommit", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush(ctx) @@ -12529,11 +14092,11 @@ func (p *promptManageServiceProcessorBatchGetLabel) Process(ctx context.Context, iprot.ReadMessageEnd() var err2 error - result := PromptManageServiceBatchGetLabelResult{} - var retval *BatchGetLabelResponse - if retval, err2 = p.handler.BatchGetLabel(ctx, args.Request); err2 != nil { - x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing BatchGetLabel: "+err2.Error()) - oprot.WriteMessageBegin("BatchGetLabel", thrift.EXCEPTION, seqId) + result := PromptManageServiceRevertDraftFromCommitResult{} + var retval *RevertDraftFromCommitResponse + if retval, err2 = p.handler.RevertDraftFromCommit(ctx, args.Request); err2 != nil { + x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing RevertDraftFromCommit: "+err2.Error()) + oprot.WriteMessageBegin("RevertDraftFromCommit", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush(ctx) @@ -12541,7 +14104,7 @@ func (p *promptManageServiceProcessorBatchGetLabel) Process(ctx context.Context, } else { result.Success = retval } - if err2 = oprot.WriteMessageBegin("BatchGetLabel", thrift.REPLY, seqId); err2 != nil { + if err2 = oprot.WriteMessageBegin("RevertDraftFromCommit", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { @@ -12559,16 +14122,16 @@ func (p *promptManageServiceProcessorBatchGetLabel) Process(ctx context.Context, return true, err } -type promptManageServiceProcessorListCommit struct { +type promptManageServiceProcessorUpdateCommitLabels struct { handler PromptManageService } -func (p *promptManageServiceProcessorListCommit) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { - args := PromptManageServiceListCommitArgs{} +func (p *promptManageServiceProcessorUpdateCommitLabels) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { + args := PromptManageServiceUpdateCommitLabelsArgs{} if err = args.Read(iprot); err != nil { iprot.ReadMessageEnd() x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) - oprot.WriteMessageBegin("ListCommit", thrift.EXCEPTION, seqId) + oprot.WriteMessageBegin("UpdateCommitLabels", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush(ctx) @@ -12577,11 +14140,11 @@ func (p *promptManageServiceProcessorListCommit) Process(ctx context.Context, se iprot.ReadMessageEnd() var err2 error - result := PromptManageServiceListCommitResult{} - var retval *ListCommitResponse - if retval, err2 = p.handler.ListCommit(ctx, args.Request); err2 != nil { - x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing ListCommit: "+err2.Error()) - oprot.WriteMessageBegin("ListCommit", thrift.EXCEPTION, seqId) + result := PromptManageServiceUpdateCommitLabelsResult{} + var retval *UpdateCommitLabelsResponse + if retval, err2 = p.handler.UpdateCommitLabels(ctx, args.Request); err2 != nil { + x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing UpdateCommitLabels: "+err2.Error()) + oprot.WriteMessageBegin("UpdateCommitLabels", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush(ctx) @@ -12589,7 +14152,7 @@ func (p *promptManageServiceProcessorListCommit) Process(ctx context.Context, se } else { result.Success = retval } - if err2 = oprot.WriteMessageBegin("ListCommit", thrift.REPLY, seqId); err2 != nil { + if err2 = oprot.WriteMessageBegin("UpdateCommitLabels", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { @@ -12601,191 +14164,391 @@ func (p *promptManageServiceProcessorListCommit) Process(ctx context.Context, se if err2 = oprot.Flush(ctx); err == nil && err2 != nil { err = err2 } - if err != nil { - return + if err != nil { + return + } + return true, err +} + +type PromptManageServiceCreatePromptArgs struct { + Request *CreatePromptRequest `thrift:"request,1" frugal:"1,default,CreatePromptRequest"` +} + +func NewPromptManageServiceCreatePromptArgs() *PromptManageServiceCreatePromptArgs { + return &PromptManageServiceCreatePromptArgs{} +} + +func (p *PromptManageServiceCreatePromptArgs) InitDefault() { +} + +var PromptManageServiceCreatePromptArgs_Request_DEFAULT *CreatePromptRequest + +func (p *PromptManageServiceCreatePromptArgs) GetRequest() (v *CreatePromptRequest) { + if p == nil { + return + } + if !p.IsSetRequest() { + return PromptManageServiceCreatePromptArgs_Request_DEFAULT + } + return p.Request +} +func (p *PromptManageServiceCreatePromptArgs) SetRequest(val *CreatePromptRequest) { + p.Request = val +} + +var fieldIDToName_PromptManageServiceCreatePromptArgs = map[int16]string{ + 1: "request", +} + +func (p *PromptManageServiceCreatePromptArgs) IsSetRequest() bool { + return p.Request != nil +} + +func (p *PromptManageServiceCreatePromptArgs) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceCreatePromptArgs[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *PromptManageServiceCreatePromptArgs) ReadField1(iprot thrift.TProtocol) error { + _field := NewCreatePromptRequest() + if err := _field.Read(iprot); err != nil { + return err + } + p.Request = _field + return nil +} + +func (p *PromptManageServiceCreatePromptArgs) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("CreatePrompt_args"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *PromptManageServiceCreatePromptArgs) writeField1(oprot thrift.TProtocol) (err error) { + if err = oprot.WriteFieldBegin("request", thrift.STRUCT, 1); err != nil { + goto WriteFieldBeginError + } + if err := p.Request.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} + +func (p *PromptManageServiceCreatePromptArgs) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("PromptManageServiceCreatePromptArgs(%+v)", *p) + +} + +func (p *PromptManageServiceCreatePromptArgs) DeepEqual(ano *PromptManageServiceCreatePromptArgs) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.Request) { + return false + } + return true +} + +func (p *PromptManageServiceCreatePromptArgs) Field1DeepEqual(src *CreatePromptRequest) bool { + + if !p.Request.DeepEqual(src) { + return false } - return true, err + return true } -type promptManageServiceProcessorCommitDraft struct { - handler PromptManageService +type PromptManageServiceCreatePromptResult struct { + Success *CreatePromptResponse `thrift:"success,0,optional" frugal:"0,optional,CreatePromptResponse"` } -func (p *promptManageServiceProcessorCommitDraft) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { - args := PromptManageServiceCommitDraftArgs{} - if err = args.Read(iprot); err != nil { - iprot.ReadMessageEnd() - x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) - oprot.WriteMessageBegin("CommitDraft", thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return false, err - } +func NewPromptManageServiceCreatePromptResult() *PromptManageServiceCreatePromptResult { + return &PromptManageServiceCreatePromptResult{} +} - iprot.ReadMessageEnd() - var err2 error - result := PromptManageServiceCommitDraftResult{} - var retval *CommitDraftResponse - if retval, err2 = p.handler.CommitDraft(ctx, args.Request); err2 != nil { - x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing CommitDraft: "+err2.Error()) - oprot.WriteMessageBegin("CommitDraft", thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return true, err2 - } else { - result.Success = retval - } - if err2 = oprot.WriteMessageBegin("CommitDraft", thrift.REPLY, seqId); err2 != nil { - err = err2 - } - if err2 = result.Write(oprot); err == nil && err2 != nil { - err = err2 - } - if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { - err = err2 - } - if err2 = oprot.Flush(ctx); err == nil && err2 != nil { - err = err2 - } - if err != nil { +func (p *PromptManageServiceCreatePromptResult) InitDefault() { +} + +var PromptManageServiceCreatePromptResult_Success_DEFAULT *CreatePromptResponse + +func (p *PromptManageServiceCreatePromptResult) GetSuccess() (v *CreatePromptResponse) { + if p == nil { return } - return true, err + if !p.IsSetSuccess() { + return PromptManageServiceCreatePromptResult_Success_DEFAULT + } + return p.Success +} +func (p *PromptManageServiceCreatePromptResult) SetSuccess(x interface{}) { + p.Success = x.(*CreatePromptResponse) } -type promptManageServiceProcessorRevertDraftFromCommit struct { - handler PromptManageService +var fieldIDToName_PromptManageServiceCreatePromptResult = map[int16]string{ + 0: "success", } -func (p *promptManageServiceProcessorRevertDraftFromCommit) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { - args := PromptManageServiceRevertDraftFromCommitArgs{} - if err = args.Read(iprot); err != nil { - iprot.ReadMessageEnd() - x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) - oprot.WriteMessageBegin("RevertDraftFromCommit", thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return false, err +func (p *PromptManageServiceCreatePromptResult) IsSetSuccess() bool { + return p.Success != nil +} + +func (p *PromptManageServiceCreatePromptResult) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError } - iprot.ReadMessageEnd() - var err2 error - result := PromptManageServiceRevertDraftFromCommitResult{} - var retval *RevertDraftFromCommitResponse - if retval, err2 = p.handler.RevertDraftFromCommit(ctx, args.Request); err2 != nil { - x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing RevertDraftFromCommit: "+err2.Error()) - oprot.WriteMessageBegin("RevertDraftFromCommit", thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return true, err2 - } else { - result.Success = retval + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 0: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField0(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } } - if err2 = oprot.WriteMessageBegin("RevertDraftFromCommit", thrift.REPLY, seqId); err2 != nil { - err = err2 + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError } - if err2 = result.Write(oprot); err == nil && err2 != nil { - err = err2 + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceCreatePromptResult[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *PromptManageServiceCreatePromptResult) ReadField0(iprot thrift.TProtocol) error { + _field := NewCreatePromptResponse() + if err := _field.Read(iprot); err != nil { + return err } - if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { - err = err2 + p.Success = _field + return nil +} + +func (p *PromptManageServiceCreatePromptResult) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("CreatePrompt_result"); err != nil { + goto WriteStructBeginError } - if err2 = oprot.Flush(ctx); err == nil && err2 != nil { - err = err2 + if p != nil { + if err = p.writeField0(oprot); err != nil { + fieldId = 0 + goto WriteFieldError + } } - if err != nil { - return + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError } - return true, err + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } -type promptManageServiceProcessorUpdateCommitLabels struct { - handler PromptManageService +func (p *PromptManageServiceCreatePromptResult) writeField0(oprot thrift.TProtocol) (err error) { + if p.IsSetSuccess() { + if err = oprot.WriteFieldBegin("success", thrift.STRUCT, 0); err != nil { + goto WriteFieldBeginError + } + if err := p.Success.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 0 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 0 end error: ", p), err) } -func (p *promptManageServiceProcessorUpdateCommitLabels) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { - args := PromptManageServiceUpdateCommitLabelsArgs{} - if err = args.Read(iprot); err != nil { - iprot.ReadMessageEnd() - x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) - oprot.WriteMessageBegin("UpdateCommitLabels", thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return false, err +func (p *PromptManageServiceCreatePromptResult) String() string { + if p == nil { + return "" } + return fmt.Sprintf("PromptManageServiceCreatePromptResult(%+v)", *p) - iprot.ReadMessageEnd() - var err2 error - result := PromptManageServiceUpdateCommitLabelsResult{} - var retval *UpdateCommitLabelsResponse - if retval, err2 = p.handler.UpdateCommitLabels(ctx, args.Request); err2 != nil { - x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing UpdateCommitLabels: "+err2.Error()) - oprot.WriteMessageBegin("UpdateCommitLabels", thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return true, err2 - } else { - result.Success = retval - } - if err2 = oprot.WriteMessageBegin("UpdateCommitLabels", thrift.REPLY, seqId); err2 != nil { - err = err2 - } - if err2 = result.Write(oprot); err == nil && err2 != nil { - err = err2 - } - if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { - err = err2 +} + +func (p *PromptManageServiceCreatePromptResult) DeepEqual(ano *PromptManageServiceCreatePromptResult) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false } - if err2 = oprot.Flush(ctx); err == nil && err2 != nil { - err = err2 + if !p.Field0DeepEqual(ano.Success) { + return false } - if err != nil { - return + return true +} + +func (p *PromptManageServiceCreatePromptResult) Field0DeepEqual(src *CreatePromptResponse) bool { + + if !p.Success.DeepEqual(src) { + return false } - return true, err + return true } -type PromptManageServiceCreatePromptArgs struct { - Request *CreatePromptRequest `thrift:"request,1" frugal:"1,default,CreatePromptRequest"` +type PromptManageServiceClonePromptArgs struct { + Request *ClonePromptRequest `thrift:"request,1" frugal:"1,default,ClonePromptRequest"` } -func NewPromptManageServiceCreatePromptArgs() *PromptManageServiceCreatePromptArgs { - return &PromptManageServiceCreatePromptArgs{} +func NewPromptManageServiceClonePromptArgs() *PromptManageServiceClonePromptArgs { + return &PromptManageServiceClonePromptArgs{} } -func (p *PromptManageServiceCreatePromptArgs) InitDefault() { +func (p *PromptManageServiceClonePromptArgs) InitDefault() { } -var PromptManageServiceCreatePromptArgs_Request_DEFAULT *CreatePromptRequest +var PromptManageServiceClonePromptArgs_Request_DEFAULT *ClonePromptRequest -func (p *PromptManageServiceCreatePromptArgs) GetRequest() (v *CreatePromptRequest) { +func (p *PromptManageServiceClonePromptArgs) GetRequest() (v *ClonePromptRequest) { if p == nil { return } if !p.IsSetRequest() { - return PromptManageServiceCreatePromptArgs_Request_DEFAULT + return PromptManageServiceClonePromptArgs_Request_DEFAULT } return p.Request } -func (p *PromptManageServiceCreatePromptArgs) SetRequest(val *CreatePromptRequest) { +func (p *PromptManageServiceClonePromptArgs) SetRequest(val *ClonePromptRequest) { p.Request = val } -var fieldIDToName_PromptManageServiceCreatePromptArgs = map[int16]string{ +var fieldIDToName_PromptManageServiceClonePromptArgs = map[int16]string{ 1: "request", } -func (p *PromptManageServiceCreatePromptArgs) IsSetRequest() bool { +func (p *PromptManageServiceClonePromptArgs) IsSetRequest() bool { return p.Request != nil } -func (p *PromptManageServiceCreatePromptArgs) Read(iprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceClonePromptArgs) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -12830,7 +14593,7 @@ ReadStructBeginError: ReadFieldBeginError: return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceCreatePromptArgs[fieldId]), err) + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceClonePromptArgs[fieldId]), err) SkipFieldError: return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) @@ -12840,8 +14603,8 @@ ReadStructEndError: return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } -func (p *PromptManageServiceCreatePromptArgs) ReadField1(iprot thrift.TProtocol) error { - _field := NewCreatePromptRequest() +func (p *PromptManageServiceClonePromptArgs) ReadField1(iprot thrift.TProtocol) error { + _field := NewClonePromptRequest() if err := _field.Read(iprot); err != nil { return err } @@ -12849,9 +14612,9 @@ func (p *PromptManageServiceCreatePromptArgs) ReadField1(iprot thrift.TProtocol) return nil } -func (p *PromptManageServiceCreatePromptArgs) Write(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceClonePromptArgs) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 - if err = oprot.WriteStructBegin("CreatePrompt_args"); err != nil { + if err = oprot.WriteStructBegin("ClonePrompt_args"); err != nil { goto WriteStructBeginError } if p != nil { @@ -12877,7 +14640,7 @@ WriteStructEndError: return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } -func (p *PromptManageServiceCreatePromptArgs) writeField1(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceClonePromptArgs) writeField1(oprot thrift.TProtocol) (err error) { if err = oprot.WriteFieldBegin("request", thrift.STRUCT, 1); err != nil { goto WriteFieldBeginError } @@ -12894,15 +14657,15 @@ WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) } -func (p *PromptManageServiceCreatePromptArgs) String() string { +func (p *PromptManageServiceClonePromptArgs) String() string { if p == nil { return "" } - return fmt.Sprintf("PromptManageServiceCreatePromptArgs(%+v)", *p) + return fmt.Sprintf("PromptManageServiceClonePromptArgs(%+v)", *p) } -func (p *PromptManageServiceCreatePromptArgs) DeepEqual(ano *PromptManageServiceCreatePromptArgs) bool { +func (p *PromptManageServiceClonePromptArgs) DeepEqual(ano *PromptManageServiceClonePromptArgs) bool { if p == ano { return true } else if p == nil || ano == nil { @@ -12914,7 +14677,7 @@ func (p *PromptManageServiceCreatePromptArgs) DeepEqual(ano *PromptManageService return true } -func (p *PromptManageServiceCreatePromptArgs) Field1DeepEqual(src *CreatePromptRequest) bool { +func (p *PromptManageServiceClonePromptArgs) Field1DeepEqual(src *ClonePromptRequest) bool { if !p.Request.DeepEqual(src) { return false @@ -12922,41 +14685,41 @@ func (p *PromptManageServiceCreatePromptArgs) Field1DeepEqual(src *CreatePromptR return true } -type PromptManageServiceCreatePromptResult struct { - Success *CreatePromptResponse `thrift:"success,0,optional" frugal:"0,optional,CreatePromptResponse"` +type PromptManageServiceClonePromptResult struct { + Success *ClonePromptResponse `thrift:"success,0,optional" frugal:"0,optional,ClonePromptResponse"` } -func NewPromptManageServiceCreatePromptResult() *PromptManageServiceCreatePromptResult { - return &PromptManageServiceCreatePromptResult{} +func NewPromptManageServiceClonePromptResult() *PromptManageServiceClonePromptResult { + return &PromptManageServiceClonePromptResult{} } -func (p *PromptManageServiceCreatePromptResult) InitDefault() { +func (p *PromptManageServiceClonePromptResult) InitDefault() { } -var PromptManageServiceCreatePromptResult_Success_DEFAULT *CreatePromptResponse +var PromptManageServiceClonePromptResult_Success_DEFAULT *ClonePromptResponse -func (p *PromptManageServiceCreatePromptResult) GetSuccess() (v *CreatePromptResponse) { +func (p *PromptManageServiceClonePromptResult) GetSuccess() (v *ClonePromptResponse) { if p == nil { return } if !p.IsSetSuccess() { - return PromptManageServiceCreatePromptResult_Success_DEFAULT + return PromptManageServiceClonePromptResult_Success_DEFAULT } return p.Success } -func (p *PromptManageServiceCreatePromptResult) SetSuccess(x interface{}) { - p.Success = x.(*CreatePromptResponse) +func (p *PromptManageServiceClonePromptResult) SetSuccess(x interface{}) { + p.Success = x.(*ClonePromptResponse) } -var fieldIDToName_PromptManageServiceCreatePromptResult = map[int16]string{ +var fieldIDToName_PromptManageServiceClonePromptResult = map[int16]string{ 0: "success", } -func (p *PromptManageServiceCreatePromptResult) IsSetSuccess() bool { +func (p *PromptManageServiceClonePromptResult) IsSetSuccess() bool { return p.Success != nil } -func (p *PromptManageServiceCreatePromptResult) Read(iprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceClonePromptResult) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -13001,7 +14764,7 @@ ReadStructBeginError: ReadFieldBeginError: return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceCreatePromptResult[fieldId]), err) + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceClonePromptResult[fieldId]), err) SkipFieldError: return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) @@ -13011,8 +14774,8 @@ ReadStructEndError: return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } -func (p *PromptManageServiceCreatePromptResult) ReadField0(iprot thrift.TProtocol) error { - _field := NewCreatePromptResponse() +func (p *PromptManageServiceClonePromptResult) ReadField0(iprot thrift.TProtocol) error { + _field := NewClonePromptResponse() if err := _field.Read(iprot); err != nil { return err } @@ -13020,9 +14783,9 @@ func (p *PromptManageServiceCreatePromptResult) ReadField0(iprot thrift.TProtoco return nil } -func (p *PromptManageServiceCreatePromptResult) Write(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceClonePromptResult) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 - if err = oprot.WriteStructBegin("CreatePrompt_result"); err != nil { + if err = oprot.WriteStructBegin("ClonePrompt_result"); err != nil { goto WriteStructBeginError } if p != nil { @@ -13048,7 +14811,7 @@ WriteStructEndError: return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } -func (p *PromptManageServiceCreatePromptResult) writeField0(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceClonePromptResult) writeField0(oprot thrift.TProtocol) (err error) { if p.IsSetSuccess() { if err = oprot.WriteFieldBegin("success", thrift.STRUCT, 0); err != nil { goto WriteFieldBeginError @@ -13067,15 +14830,15 @@ WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 0 end error: ", p), err) } -func (p *PromptManageServiceCreatePromptResult) String() string { +func (p *PromptManageServiceClonePromptResult) String() string { if p == nil { return "" } - return fmt.Sprintf("PromptManageServiceCreatePromptResult(%+v)", *p) + return fmt.Sprintf("PromptManageServiceClonePromptResult(%+v)", *p) } -func (p *PromptManageServiceCreatePromptResult) DeepEqual(ano *PromptManageServiceCreatePromptResult) bool { +func (p *PromptManageServiceClonePromptResult) DeepEqual(ano *PromptManageServiceClonePromptResult) bool { if p == ano { return true } else if p == nil || ano == nil { @@ -13087,7 +14850,7 @@ func (p *PromptManageServiceCreatePromptResult) DeepEqual(ano *PromptManageServi return true } -func (p *PromptManageServiceCreatePromptResult) Field0DeepEqual(src *CreatePromptResponse) bool { +func (p *PromptManageServiceClonePromptResult) Field0DeepEqual(src *ClonePromptResponse) bool { if !p.Success.DeepEqual(src) { return false @@ -13095,41 +14858,41 @@ func (p *PromptManageServiceCreatePromptResult) Field0DeepEqual(src *CreatePromp return true } -type PromptManageServiceClonePromptArgs struct { - Request *ClonePromptRequest `thrift:"request,1" frugal:"1,default,ClonePromptRequest"` +type PromptManageServiceDeletePromptArgs struct { + Request *DeletePromptRequest `thrift:"request,1" frugal:"1,default,DeletePromptRequest"` } -func NewPromptManageServiceClonePromptArgs() *PromptManageServiceClonePromptArgs { - return &PromptManageServiceClonePromptArgs{} +func NewPromptManageServiceDeletePromptArgs() *PromptManageServiceDeletePromptArgs { + return &PromptManageServiceDeletePromptArgs{} } -func (p *PromptManageServiceClonePromptArgs) InitDefault() { +func (p *PromptManageServiceDeletePromptArgs) InitDefault() { } -var PromptManageServiceClonePromptArgs_Request_DEFAULT *ClonePromptRequest +var PromptManageServiceDeletePromptArgs_Request_DEFAULT *DeletePromptRequest -func (p *PromptManageServiceClonePromptArgs) GetRequest() (v *ClonePromptRequest) { +func (p *PromptManageServiceDeletePromptArgs) GetRequest() (v *DeletePromptRequest) { if p == nil { return } if !p.IsSetRequest() { - return PromptManageServiceClonePromptArgs_Request_DEFAULT + return PromptManageServiceDeletePromptArgs_Request_DEFAULT } return p.Request } -func (p *PromptManageServiceClonePromptArgs) SetRequest(val *ClonePromptRequest) { +func (p *PromptManageServiceDeletePromptArgs) SetRequest(val *DeletePromptRequest) { p.Request = val } -var fieldIDToName_PromptManageServiceClonePromptArgs = map[int16]string{ +var fieldIDToName_PromptManageServiceDeletePromptArgs = map[int16]string{ 1: "request", } -func (p *PromptManageServiceClonePromptArgs) IsSetRequest() bool { +func (p *PromptManageServiceDeletePromptArgs) IsSetRequest() bool { return p.Request != nil } -func (p *PromptManageServiceClonePromptArgs) Read(iprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceDeletePromptArgs) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -13174,7 +14937,7 @@ ReadStructBeginError: ReadFieldBeginError: return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceClonePromptArgs[fieldId]), err) + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceDeletePromptArgs[fieldId]), err) SkipFieldError: return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) @@ -13184,8 +14947,8 @@ ReadStructEndError: return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } -func (p *PromptManageServiceClonePromptArgs) ReadField1(iprot thrift.TProtocol) error { - _field := NewClonePromptRequest() +func (p *PromptManageServiceDeletePromptArgs) ReadField1(iprot thrift.TProtocol) error { + _field := NewDeletePromptRequest() if err := _field.Read(iprot); err != nil { return err } @@ -13193,9 +14956,9 @@ func (p *PromptManageServiceClonePromptArgs) ReadField1(iprot thrift.TProtocol) return nil } -func (p *PromptManageServiceClonePromptArgs) Write(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceDeletePromptArgs) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 - if err = oprot.WriteStructBegin("ClonePrompt_args"); err != nil { + if err = oprot.WriteStructBegin("DeletePrompt_args"); err != nil { goto WriteStructBeginError } if p != nil { @@ -13221,7 +14984,7 @@ WriteStructEndError: return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } -func (p *PromptManageServiceClonePromptArgs) writeField1(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceDeletePromptArgs) writeField1(oprot thrift.TProtocol) (err error) { if err = oprot.WriteFieldBegin("request", thrift.STRUCT, 1); err != nil { goto WriteFieldBeginError } @@ -13238,15 +15001,15 @@ WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) } -func (p *PromptManageServiceClonePromptArgs) String() string { +func (p *PromptManageServiceDeletePromptArgs) String() string { if p == nil { return "" } - return fmt.Sprintf("PromptManageServiceClonePromptArgs(%+v)", *p) + return fmt.Sprintf("PromptManageServiceDeletePromptArgs(%+v)", *p) } -func (p *PromptManageServiceClonePromptArgs) DeepEqual(ano *PromptManageServiceClonePromptArgs) bool { +func (p *PromptManageServiceDeletePromptArgs) DeepEqual(ano *PromptManageServiceDeletePromptArgs) bool { if p == ano { return true } else if p == nil || ano == nil { @@ -13258,7 +15021,7 @@ func (p *PromptManageServiceClonePromptArgs) DeepEqual(ano *PromptManageServiceC return true } -func (p *PromptManageServiceClonePromptArgs) Field1DeepEqual(src *ClonePromptRequest) bool { +func (p *PromptManageServiceDeletePromptArgs) Field1DeepEqual(src *DeletePromptRequest) bool { if !p.Request.DeepEqual(src) { return false @@ -13266,41 +15029,41 @@ func (p *PromptManageServiceClonePromptArgs) Field1DeepEqual(src *ClonePromptReq return true } -type PromptManageServiceClonePromptResult struct { - Success *ClonePromptResponse `thrift:"success,0,optional" frugal:"0,optional,ClonePromptResponse"` +type PromptManageServiceDeletePromptResult struct { + Success *DeletePromptResponse `thrift:"success,0,optional" frugal:"0,optional,DeletePromptResponse"` } -func NewPromptManageServiceClonePromptResult() *PromptManageServiceClonePromptResult { - return &PromptManageServiceClonePromptResult{} +func NewPromptManageServiceDeletePromptResult() *PromptManageServiceDeletePromptResult { + return &PromptManageServiceDeletePromptResult{} } -func (p *PromptManageServiceClonePromptResult) InitDefault() { +func (p *PromptManageServiceDeletePromptResult) InitDefault() { } -var PromptManageServiceClonePromptResult_Success_DEFAULT *ClonePromptResponse +var PromptManageServiceDeletePromptResult_Success_DEFAULT *DeletePromptResponse -func (p *PromptManageServiceClonePromptResult) GetSuccess() (v *ClonePromptResponse) { +func (p *PromptManageServiceDeletePromptResult) GetSuccess() (v *DeletePromptResponse) { if p == nil { return } if !p.IsSetSuccess() { - return PromptManageServiceClonePromptResult_Success_DEFAULT + return PromptManageServiceDeletePromptResult_Success_DEFAULT } return p.Success } -func (p *PromptManageServiceClonePromptResult) SetSuccess(x interface{}) { - p.Success = x.(*ClonePromptResponse) +func (p *PromptManageServiceDeletePromptResult) SetSuccess(x interface{}) { + p.Success = x.(*DeletePromptResponse) } -var fieldIDToName_PromptManageServiceClonePromptResult = map[int16]string{ +var fieldIDToName_PromptManageServiceDeletePromptResult = map[int16]string{ 0: "success", } -func (p *PromptManageServiceClonePromptResult) IsSetSuccess() bool { +func (p *PromptManageServiceDeletePromptResult) IsSetSuccess() bool { return p.Success != nil } -func (p *PromptManageServiceClonePromptResult) Read(iprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceDeletePromptResult) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -13345,7 +15108,7 @@ ReadStructBeginError: ReadFieldBeginError: return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceClonePromptResult[fieldId]), err) + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceDeletePromptResult[fieldId]), err) SkipFieldError: return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) @@ -13355,8 +15118,8 @@ ReadStructEndError: return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } -func (p *PromptManageServiceClonePromptResult) ReadField0(iprot thrift.TProtocol) error { - _field := NewClonePromptResponse() +func (p *PromptManageServiceDeletePromptResult) ReadField0(iprot thrift.TProtocol) error { + _field := NewDeletePromptResponse() if err := _field.Read(iprot); err != nil { return err } @@ -13364,9 +15127,9 @@ func (p *PromptManageServiceClonePromptResult) ReadField0(iprot thrift.TProtocol return nil } -func (p *PromptManageServiceClonePromptResult) Write(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceDeletePromptResult) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 - if err = oprot.WriteStructBegin("ClonePrompt_result"); err != nil { + if err = oprot.WriteStructBegin("DeletePrompt_result"); err != nil { goto WriteStructBeginError } if p != nil { @@ -13392,7 +15155,7 @@ WriteStructEndError: return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } -func (p *PromptManageServiceClonePromptResult) writeField0(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceDeletePromptResult) writeField0(oprot thrift.TProtocol) (err error) { if p.IsSetSuccess() { if err = oprot.WriteFieldBegin("success", thrift.STRUCT, 0); err != nil { goto WriteFieldBeginError @@ -13411,15 +15174,15 @@ WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 0 end error: ", p), err) } -func (p *PromptManageServiceClonePromptResult) String() string { +func (p *PromptManageServiceDeletePromptResult) String() string { if p == nil { return "" } - return fmt.Sprintf("PromptManageServiceClonePromptResult(%+v)", *p) + return fmt.Sprintf("PromptManageServiceDeletePromptResult(%+v)", *p) } -func (p *PromptManageServiceClonePromptResult) DeepEqual(ano *PromptManageServiceClonePromptResult) bool { +func (p *PromptManageServiceDeletePromptResult) DeepEqual(ano *PromptManageServiceDeletePromptResult) bool { if p == ano { return true } else if p == nil || ano == nil { @@ -13431,7 +15194,7 @@ func (p *PromptManageServiceClonePromptResult) DeepEqual(ano *PromptManageServic return true } -func (p *PromptManageServiceClonePromptResult) Field0DeepEqual(src *ClonePromptResponse) bool { +func (p *PromptManageServiceDeletePromptResult) Field0DeepEqual(src *DeletePromptResponse) bool { if !p.Success.DeepEqual(src) { return false @@ -13439,41 +15202,41 @@ func (p *PromptManageServiceClonePromptResult) Field0DeepEqual(src *ClonePromptR return true } -type PromptManageServiceDeletePromptArgs struct { - Request *DeletePromptRequest `thrift:"request,1" frugal:"1,default,DeletePromptRequest"` +type PromptManageServiceGetPromptArgs struct { + Request *GetPromptRequest `thrift:"request,1" frugal:"1,default,GetPromptRequest"` } -func NewPromptManageServiceDeletePromptArgs() *PromptManageServiceDeletePromptArgs { - return &PromptManageServiceDeletePromptArgs{} +func NewPromptManageServiceGetPromptArgs() *PromptManageServiceGetPromptArgs { + return &PromptManageServiceGetPromptArgs{} } -func (p *PromptManageServiceDeletePromptArgs) InitDefault() { +func (p *PromptManageServiceGetPromptArgs) InitDefault() { } -var PromptManageServiceDeletePromptArgs_Request_DEFAULT *DeletePromptRequest +var PromptManageServiceGetPromptArgs_Request_DEFAULT *GetPromptRequest -func (p *PromptManageServiceDeletePromptArgs) GetRequest() (v *DeletePromptRequest) { +func (p *PromptManageServiceGetPromptArgs) GetRequest() (v *GetPromptRequest) { if p == nil { return } if !p.IsSetRequest() { - return PromptManageServiceDeletePromptArgs_Request_DEFAULT + return PromptManageServiceGetPromptArgs_Request_DEFAULT } return p.Request } -func (p *PromptManageServiceDeletePromptArgs) SetRequest(val *DeletePromptRequest) { +func (p *PromptManageServiceGetPromptArgs) SetRequest(val *GetPromptRequest) { p.Request = val } -var fieldIDToName_PromptManageServiceDeletePromptArgs = map[int16]string{ +var fieldIDToName_PromptManageServiceGetPromptArgs = map[int16]string{ 1: "request", } -func (p *PromptManageServiceDeletePromptArgs) IsSetRequest() bool { +func (p *PromptManageServiceGetPromptArgs) IsSetRequest() bool { return p.Request != nil } -func (p *PromptManageServiceDeletePromptArgs) Read(iprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceGetPromptArgs) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -13518,7 +15281,7 @@ ReadStructBeginError: ReadFieldBeginError: return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceDeletePromptArgs[fieldId]), err) + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceGetPromptArgs[fieldId]), err) SkipFieldError: return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) @@ -13528,8 +15291,8 @@ ReadStructEndError: return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } -func (p *PromptManageServiceDeletePromptArgs) ReadField1(iprot thrift.TProtocol) error { - _field := NewDeletePromptRequest() +func (p *PromptManageServiceGetPromptArgs) ReadField1(iprot thrift.TProtocol) error { + _field := NewGetPromptRequest() if err := _field.Read(iprot); err != nil { return err } @@ -13537,9 +15300,9 @@ func (p *PromptManageServiceDeletePromptArgs) ReadField1(iprot thrift.TProtocol) return nil } -func (p *PromptManageServiceDeletePromptArgs) Write(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceGetPromptArgs) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 - if err = oprot.WriteStructBegin("DeletePrompt_args"); err != nil { + if err = oprot.WriteStructBegin("GetPrompt_args"); err != nil { goto WriteStructBeginError } if p != nil { @@ -13565,7 +15328,7 @@ WriteStructEndError: return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } -func (p *PromptManageServiceDeletePromptArgs) writeField1(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceGetPromptArgs) writeField1(oprot thrift.TProtocol) (err error) { if err = oprot.WriteFieldBegin("request", thrift.STRUCT, 1); err != nil { goto WriteFieldBeginError } @@ -13582,15 +15345,15 @@ WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) } -func (p *PromptManageServiceDeletePromptArgs) String() string { +func (p *PromptManageServiceGetPromptArgs) String() string { if p == nil { return "" } - return fmt.Sprintf("PromptManageServiceDeletePromptArgs(%+v)", *p) + return fmt.Sprintf("PromptManageServiceGetPromptArgs(%+v)", *p) } -func (p *PromptManageServiceDeletePromptArgs) DeepEqual(ano *PromptManageServiceDeletePromptArgs) bool { +func (p *PromptManageServiceGetPromptArgs) DeepEqual(ano *PromptManageServiceGetPromptArgs) bool { if p == ano { return true } else if p == nil || ano == nil { @@ -13602,7 +15365,7 @@ func (p *PromptManageServiceDeletePromptArgs) DeepEqual(ano *PromptManageService return true } -func (p *PromptManageServiceDeletePromptArgs) Field1DeepEqual(src *DeletePromptRequest) bool { +func (p *PromptManageServiceGetPromptArgs) Field1DeepEqual(src *GetPromptRequest) bool { if !p.Request.DeepEqual(src) { return false @@ -13610,41 +15373,41 @@ func (p *PromptManageServiceDeletePromptArgs) Field1DeepEqual(src *DeletePromptR return true } -type PromptManageServiceDeletePromptResult struct { - Success *DeletePromptResponse `thrift:"success,0,optional" frugal:"0,optional,DeletePromptResponse"` +type PromptManageServiceGetPromptResult struct { + Success *GetPromptResponse `thrift:"success,0,optional" frugal:"0,optional,GetPromptResponse"` } -func NewPromptManageServiceDeletePromptResult() *PromptManageServiceDeletePromptResult { - return &PromptManageServiceDeletePromptResult{} +func NewPromptManageServiceGetPromptResult() *PromptManageServiceGetPromptResult { + return &PromptManageServiceGetPromptResult{} } -func (p *PromptManageServiceDeletePromptResult) InitDefault() { +func (p *PromptManageServiceGetPromptResult) InitDefault() { } -var PromptManageServiceDeletePromptResult_Success_DEFAULT *DeletePromptResponse +var PromptManageServiceGetPromptResult_Success_DEFAULT *GetPromptResponse -func (p *PromptManageServiceDeletePromptResult) GetSuccess() (v *DeletePromptResponse) { +func (p *PromptManageServiceGetPromptResult) GetSuccess() (v *GetPromptResponse) { if p == nil { return } if !p.IsSetSuccess() { - return PromptManageServiceDeletePromptResult_Success_DEFAULT + return PromptManageServiceGetPromptResult_Success_DEFAULT } return p.Success } -func (p *PromptManageServiceDeletePromptResult) SetSuccess(x interface{}) { - p.Success = x.(*DeletePromptResponse) +func (p *PromptManageServiceGetPromptResult) SetSuccess(x interface{}) { + p.Success = x.(*GetPromptResponse) } -var fieldIDToName_PromptManageServiceDeletePromptResult = map[int16]string{ +var fieldIDToName_PromptManageServiceGetPromptResult = map[int16]string{ 0: "success", } -func (p *PromptManageServiceDeletePromptResult) IsSetSuccess() bool { +func (p *PromptManageServiceGetPromptResult) IsSetSuccess() bool { return p.Success != nil } -func (p *PromptManageServiceDeletePromptResult) Read(iprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceGetPromptResult) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -13689,7 +15452,7 @@ ReadStructBeginError: ReadFieldBeginError: return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceDeletePromptResult[fieldId]), err) + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceGetPromptResult[fieldId]), err) SkipFieldError: return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) @@ -13699,8 +15462,8 @@ ReadStructEndError: return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } -func (p *PromptManageServiceDeletePromptResult) ReadField0(iprot thrift.TProtocol) error { - _field := NewDeletePromptResponse() +func (p *PromptManageServiceGetPromptResult) ReadField0(iprot thrift.TProtocol) error { + _field := NewGetPromptResponse() if err := _field.Read(iprot); err != nil { return err } @@ -13708,9 +15471,9 @@ func (p *PromptManageServiceDeletePromptResult) ReadField0(iprot thrift.TProtoco return nil } -func (p *PromptManageServiceDeletePromptResult) Write(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceGetPromptResult) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 - if err = oprot.WriteStructBegin("DeletePrompt_result"); err != nil { + if err = oprot.WriteStructBegin("GetPrompt_result"); err != nil { goto WriteStructBeginError } if p != nil { @@ -13736,7 +15499,7 @@ WriteStructEndError: return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } -func (p *PromptManageServiceDeletePromptResult) writeField0(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceGetPromptResult) writeField0(oprot thrift.TProtocol) (err error) { if p.IsSetSuccess() { if err = oprot.WriteFieldBegin("success", thrift.STRUCT, 0); err != nil { goto WriteFieldBeginError @@ -13755,15 +15518,15 @@ WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 0 end error: ", p), err) } -func (p *PromptManageServiceDeletePromptResult) String() string { +func (p *PromptManageServiceGetPromptResult) String() string { if p == nil { return "" } - return fmt.Sprintf("PromptManageServiceDeletePromptResult(%+v)", *p) + return fmt.Sprintf("PromptManageServiceGetPromptResult(%+v)", *p) } -func (p *PromptManageServiceDeletePromptResult) DeepEqual(ano *PromptManageServiceDeletePromptResult) bool { +func (p *PromptManageServiceGetPromptResult) DeepEqual(ano *PromptManageServiceGetPromptResult) bool { if p == ano { return true } else if p == nil || ano == nil { @@ -13775,7 +15538,7 @@ func (p *PromptManageServiceDeletePromptResult) DeepEqual(ano *PromptManageServi return true } -func (p *PromptManageServiceDeletePromptResult) Field0DeepEqual(src *DeletePromptResponse) bool { +func (p *PromptManageServiceGetPromptResult) Field0DeepEqual(src *GetPromptResponse) bool { if !p.Success.DeepEqual(src) { return false @@ -13783,41 +15546,41 @@ func (p *PromptManageServiceDeletePromptResult) Field0DeepEqual(src *DeletePromp return true } -type PromptManageServiceGetPromptArgs struct { - Request *GetPromptRequest `thrift:"request,1" frugal:"1,default,GetPromptRequest"` +type PromptManageServiceBatchGetPromptArgs struct { + Request *BatchGetPromptRequest `thrift:"request,1" frugal:"1,default,BatchGetPromptRequest"` } -func NewPromptManageServiceGetPromptArgs() *PromptManageServiceGetPromptArgs { - return &PromptManageServiceGetPromptArgs{} +func NewPromptManageServiceBatchGetPromptArgs() *PromptManageServiceBatchGetPromptArgs { + return &PromptManageServiceBatchGetPromptArgs{} } -func (p *PromptManageServiceGetPromptArgs) InitDefault() { +func (p *PromptManageServiceBatchGetPromptArgs) InitDefault() { } -var PromptManageServiceGetPromptArgs_Request_DEFAULT *GetPromptRequest +var PromptManageServiceBatchGetPromptArgs_Request_DEFAULT *BatchGetPromptRequest -func (p *PromptManageServiceGetPromptArgs) GetRequest() (v *GetPromptRequest) { +func (p *PromptManageServiceBatchGetPromptArgs) GetRequest() (v *BatchGetPromptRequest) { if p == nil { return } if !p.IsSetRequest() { - return PromptManageServiceGetPromptArgs_Request_DEFAULT + return PromptManageServiceBatchGetPromptArgs_Request_DEFAULT } return p.Request } -func (p *PromptManageServiceGetPromptArgs) SetRequest(val *GetPromptRequest) { +func (p *PromptManageServiceBatchGetPromptArgs) SetRequest(val *BatchGetPromptRequest) { p.Request = val } -var fieldIDToName_PromptManageServiceGetPromptArgs = map[int16]string{ +var fieldIDToName_PromptManageServiceBatchGetPromptArgs = map[int16]string{ 1: "request", } -func (p *PromptManageServiceGetPromptArgs) IsSetRequest() bool { +func (p *PromptManageServiceBatchGetPromptArgs) IsSetRequest() bool { return p.Request != nil } -func (p *PromptManageServiceGetPromptArgs) Read(iprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceBatchGetPromptArgs) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -13862,7 +15625,7 @@ ReadStructBeginError: ReadFieldBeginError: return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceGetPromptArgs[fieldId]), err) + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceBatchGetPromptArgs[fieldId]), err) SkipFieldError: return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) @@ -13872,8 +15635,8 @@ ReadStructEndError: return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } -func (p *PromptManageServiceGetPromptArgs) ReadField1(iprot thrift.TProtocol) error { - _field := NewGetPromptRequest() +func (p *PromptManageServiceBatchGetPromptArgs) ReadField1(iprot thrift.TProtocol) error { + _field := NewBatchGetPromptRequest() if err := _field.Read(iprot); err != nil { return err } @@ -13881,9 +15644,9 @@ func (p *PromptManageServiceGetPromptArgs) ReadField1(iprot thrift.TProtocol) er return nil } -func (p *PromptManageServiceGetPromptArgs) Write(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceBatchGetPromptArgs) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 - if err = oprot.WriteStructBegin("GetPrompt_args"); err != nil { + if err = oprot.WriteStructBegin("BatchGetPrompt_args"); err != nil { goto WriteStructBeginError } if p != nil { @@ -13909,7 +15672,7 @@ WriteStructEndError: return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } -func (p *PromptManageServiceGetPromptArgs) writeField1(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceBatchGetPromptArgs) writeField1(oprot thrift.TProtocol) (err error) { if err = oprot.WriteFieldBegin("request", thrift.STRUCT, 1); err != nil { goto WriteFieldBeginError } @@ -13926,15 +15689,15 @@ WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) } -func (p *PromptManageServiceGetPromptArgs) String() string { +func (p *PromptManageServiceBatchGetPromptArgs) String() string { if p == nil { return "" } - return fmt.Sprintf("PromptManageServiceGetPromptArgs(%+v)", *p) + return fmt.Sprintf("PromptManageServiceBatchGetPromptArgs(%+v)", *p) } -func (p *PromptManageServiceGetPromptArgs) DeepEqual(ano *PromptManageServiceGetPromptArgs) bool { +func (p *PromptManageServiceBatchGetPromptArgs) DeepEqual(ano *PromptManageServiceBatchGetPromptArgs) bool { if p == ano { return true } else if p == nil || ano == nil { @@ -13946,7 +15709,7 @@ func (p *PromptManageServiceGetPromptArgs) DeepEqual(ano *PromptManageServiceGet return true } -func (p *PromptManageServiceGetPromptArgs) Field1DeepEqual(src *GetPromptRequest) bool { +func (p *PromptManageServiceBatchGetPromptArgs) Field1DeepEqual(src *BatchGetPromptRequest) bool { if !p.Request.DeepEqual(src) { return false @@ -13954,41 +15717,41 @@ func (p *PromptManageServiceGetPromptArgs) Field1DeepEqual(src *GetPromptRequest return true } -type PromptManageServiceGetPromptResult struct { - Success *GetPromptResponse `thrift:"success,0,optional" frugal:"0,optional,GetPromptResponse"` +type PromptManageServiceBatchGetPromptResult struct { + Success *BatchGetPromptResponse `thrift:"success,0,optional" frugal:"0,optional,BatchGetPromptResponse"` } -func NewPromptManageServiceGetPromptResult() *PromptManageServiceGetPromptResult { - return &PromptManageServiceGetPromptResult{} +func NewPromptManageServiceBatchGetPromptResult() *PromptManageServiceBatchGetPromptResult { + return &PromptManageServiceBatchGetPromptResult{} } -func (p *PromptManageServiceGetPromptResult) InitDefault() { +func (p *PromptManageServiceBatchGetPromptResult) InitDefault() { } -var PromptManageServiceGetPromptResult_Success_DEFAULT *GetPromptResponse +var PromptManageServiceBatchGetPromptResult_Success_DEFAULT *BatchGetPromptResponse -func (p *PromptManageServiceGetPromptResult) GetSuccess() (v *GetPromptResponse) { +func (p *PromptManageServiceBatchGetPromptResult) GetSuccess() (v *BatchGetPromptResponse) { if p == nil { return } if !p.IsSetSuccess() { - return PromptManageServiceGetPromptResult_Success_DEFAULT + return PromptManageServiceBatchGetPromptResult_Success_DEFAULT } return p.Success } -func (p *PromptManageServiceGetPromptResult) SetSuccess(x interface{}) { - p.Success = x.(*GetPromptResponse) +func (p *PromptManageServiceBatchGetPromptResult) SetSuccess(x interface{}) { + p.Success = x.(*BatchGetPromptResponse) } -var fieldIDToName_PromptManageServiceGetPromptResult = map[int16]string{ +var fieldIDToName_PromptManageServiceBatchGetPromptResult = map[int16]string{ 0: "success", } -func (p *PromptManageServiceGetPromptResult) IsSetSuccess() bool { +func (p *PromptManageServiceBatchGetPromptResult) IsSetSuccess() bool { return p.Success != nil } -func (p *PromptManageServiceGetPromptResult) Read(iprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceBatchGetPromptResult) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -14033,7 +15796,7 @@ ReadStructBeginError: ReadFieldBeginError: return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceGetPromptResult[fieldId]), err) + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceBatchGetPromptResult[fieldId]), err) SkipFieldError: return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) @@ -14043,8 +15806,8 @@ ReadStructEndError: return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } -func (p *PromptManageServiceGetPromptResult) ReadField0(iprot thrift.TProtocol) error { - _field := NewGetPromptResponse() +func (p *PromptManageServiceBatchGetPromptResult) ReadField0(iprot thrift.TProtocol) error { + _field := NewBatchGetPromptResponse() if err := _field.Read(iprot); err != nil { return err } @@ -14052,9 +15815,9 @@ func (p *PromptManageServiceGetPromptResult) ReadField0(iprot thrift.TProtocol) return nil } -func (p *PromptManageServiceGetPromptResult) Write(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceBatchGetPromptResult) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 - if err = oprot.WriteStructBegin("GetPrompt_result"); err != nil { + if err = oprot.WriteStructBegin("BatchGetPrompt_result"); err != nil { goto WriteStructBeginError } if p != nil { @@ -14080,7 +15843,7 @@ WriteStructEndError: return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } -func (p *PromptManageServiceGetPromptResult) writeField0(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceBatchGetPromptResult) writeField0(oprot thrift.TProtocol) (err error) { if p.IsSetSuccess() { if err = oprot.WriteFieldBegin("success", thrift.STRUCT, 0); err != nil { goto WriteFieldBeginError @@ -14099,15 +15862,15 @@ WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 0 end error: ", p), err) } -func (p *PromptManageServiceGetPromptResult) String() string { +func (p *PromptManageServiceBatchGetPromptResult) String() string { if p == nil { return "" } - return fmt.Sprintf("PromptManageServiceGetPromptResult(%+v)", *p) + return fmt.Sprintf("PromptManageServiceBatchGetPromptResult(%+v)", *p) } -func (p *PromptManageServiceGetPromptResult) DeepEqual(ano *PromptManageServiceGetPromptResult) bool { +func (p *PromptManageServiceBatchGetPromptResult) DeepEqual(ano *PromptManageServiceBatchGetPromptResult) bool { if p == ano { return true } else if p == nil || ano == nil { @@ -14119,7 +15882,7 @@ func (p *PromptManageServiceGetPromptResult) DeepEqual(ano *PromptManageServiceG return true } -func (p *PromptManageServiceGetPromptResult) Field0DeepEqual(src *GetPromptResponse) bool { +func (p *PromptManageServiceBatchGetPromptResult) Field0DeepEqual(src *BatchGetPromptResponse) bool { if !p.Success.DeepEqual(src) { return false @@ -14127,41 +15890,41 @@ func (p *PromptManageServiceGetPromptResult) Field0DeepEqual(src *GetPromptRespo return true } -type PromptManageServiceBatchGetPromptArgs struct { - Request *BatchGetPromptRequest `thrift:"request,1" frugal:"1,default,BatchGetPromptRequest"` +type PromptManageServiceListPromptArgs struct { + Request *ListPromptRequest `thrift:"request,1" frugal:"1,default,ListPromptRequest"` } -func NewPromptManageServiceBatchGetPromptArgs() *PromptManageServiceBatchGetPromptArgs { - return &PromptManageServiceBatchGetPromptArgs{} +func NewPromptManageServiceListPromptArgs() *PromptManageServiceListPromptArgs { + return &PromptManageServiceListPromptArgs{} } -func (p *PromptManageServiceBatchGetPromptArgs) InitDefault() { +func (p *PromptManageServiceListPromptArgs) InitDefault() { } -var PromptManageServiceBatchGetPromptArgs_Request_DEFAULT *BatchGetPromptRequest +var PromptManageServiceListPromptArgs_Request_DEFAULT *ListPromptRequest -func (p *PromptManageServiceBatchGetPromptArgs) GetRequest() (v *BatchGetPromptRequest) { +func (p *PromptManageServiceListPromptArgs) GetRequest() (v *ListPromptRequest) { if p == nil { return } if !p.IsSetRequest() { - return PromptManageServiceBatchGetPromptArgs_Request_DEFAULT + return PromptManageServiceListPromptArgs_Request_DEFAULT } return p.Request } -func (p *PromptManageServiceBatchGetPromptArgs) SetRequest(val *BatchGetPromptRequest) { +func (p *PromptManageServiceListPromptArgs) SetRequest(val *ListPromptRequest) { p.Request = val } -var fieldIDToName_PromptManageServiceBatchGetPromptArgs = map[int16]string{ +var fieldIDToName_PromptManageServiceListPromptArgs = map[int16]string{ 1: "request", } -func (p *PromptManageServiceBatchGetPromptArgs) IsSetRequest() bool { +func (p *PromptManageServiceListPromptArgs) IsSetRequest() bool { return p.Request != nil } -func (p *PromptManageServiceBatchGetPromptArgs) Read(iprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceListPromptArgs) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -14206,7 +15969,7 @@ ReadStructBeginError: ReadFieldBeginError: return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceBatchGetPromptArgs[fieldId]), err) + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceListPromptArgs[fieldId]), err) SkipFieldError: return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) @@ -14216,8 +15979,8 @@ ReadStructEndError: return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } -func (p *PromptManageServiceBatchGetPromptArgs) ReadField1(iprot thrift.TProtocol) error { - _field := NewBatchGetPromptRequest() +func (p *PromptManageServiceListPromptArgs) ReadField1(iprot thrift.TProtocol) error { + _field := NewListPromptRequest() if err := _field.Read(iprot); err != nil { return err } @@ -14225,9 +15988,9 @@ func (p *PromptManageServiceBatchGetPromptArgs) ReadField1(iprot thrift.TProtoco return nil } -func (p *PromptManageServiceBatchGetPromptArgs) Write(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceListPromptArgs) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 - if err = oprot.WriteStructBegin("BatchGetPrompt_args"); err != nil { + if err = oprot.WriteStructBegin("ListPrompt_args"); err != nil { goto WriteStructBeginError } if p != nil { @@ -14253,7 +16016,7 @@ WriteStructEndError: return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } -func (p *PromptManageServiceBatchGetPromptArgs) writeField1(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceListPromptArgs) writeField1(oprot thrift.TProtocol) (err error) { if err = oprot.WriteFieldBegin("request", thrift.STRUCT, 1); err != nil { goto WriteFieldBeginError } @@ -14270,15 +16033,15 @@ WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) } -func (p *PromptManageServiceBatchGetPromptArgs) String() string { +func (p *PromptManageServiceListPromptArgs) String() string { if p == nil { return "" } - return fmt.Sprintf("PromptManageServiceBatchGetPromptArgs(%+v)", *p) + return fmt.Sprintf("PromptManageServiceListPromptArgs(%+v)", *p) } -func (p *PromptManageServiceBatchGetPromptArgs) DeepEqual(ano *PromptManageServiceBatchGetPromptArgs) bool { +func (p *PromptManageServiceListPromptArgs) DeepEqual(ano *PromptManageServiceListPromptArgs) bool { if p == ano { return true } else if p == nil || ano == nil { @@ -14290,7 +16053,7 @@ func (p *PromptManageServiceBatchGetPromptArgs) DeepEqual(ano *PromptManageServi return true } -func (p *PromptManageServiceBatchGetPromptArgs) Field1DeepEqual(src *BatchGetPromptRequest) bool { +func (p *PromptManageServiceListPromptArgs) Field1DeepEqual(src *ListPromptRequest) bool { if !p.Request.DeepEqual(src) { return false @@ -14298,41 +16061,41 @@ func (p *PromptManageServiceBatchGetPromptArgs) Field1DeepEqual(src *BatchGetPro return true } -type PromptManageServiceBatchGetPromptResult struct { - Success *BatchGetPromptResponse `thrift:"success,0,optional" frugal:"0,optional,BatchGetPromptResponse"` +type PromptManageServiceListPromptResult struct { + Success *ListPromptResponse `thrift:"success,0,optional" frugal:"0,optional,ListPromptResponse"` } -func NewPromptManageServiceBatchGetPromptResult() *PromptManageServiceBatchGetPromptResult { - return &PromptManageServiceBatchGetPromptResult{} +func NewPromptManageServiceListPromptResult() *PromptManageServiceListPromptResult { + return &PromptManageServiceListPromptResult{} } -func (p *PromptManageServiceBatchGetPromptResult) InitDefault() { +func (p *PromptManageServiceListPromptResult) InitDefault() { } -var PromptManageServiceBatchGetPromptResult_Success_DEFAULT *BatchGetPromptResponse +var PromptManageServiceListPromptResult_Success_DEFAULT *ListPromptResponse -func (p *PromptManageServiceBatchGetPromptResult) GetSuccess() (v *BatchGetPromptResponse) { +func (p *PromptManageServiceListPromptResult) GetSuccess() (v *ListPromptResponse) { if p == nil { return } if !p.IsSetSuccess() { - return PromptManageServiceBatchGetPromptResult_Success_DEFAULT + return PromptManageServiceListPromptResult_Success_DEFAULT } return p.Success } -func (p *PromptManageServiceBatchGetPromptResult) SetSuccess(x interface{}) { - p.Success = x.(*BatchGetPromptResponse) +func (p *PromptManageServiceListPromptResult) SetSuccess(x interface{}) { + p.Success = x.(*ListPromptResponse) } -var fieldIDToName_PromptManageServiceBatchGetPromptResult = map[int16]string{ +var fieldIDToName_PromptManageServiceListPromptResult = map[int16]string{ 0: "success", } -func (p *PromptManageServiceBatchGetPromptResult) IsSetSuccess() bool { +func (p *PromptManageServiceListPromptResult) IsSetSuccess() bool { return p.Success != nil } -func (p *PromptManageServiceBatchGetPromptResult) Read(iprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceListPromptResult) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -14377,7 +16140,7 @@ ReadStructBeginError: ReadFieldBeginError: return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceBatchGetPromptResult[fieldId]), err) + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceListPromptResult[fieldId]), err) SkipFieldError: return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) @@ -14387,8 +16150,8 @@ ReadStructEndError: return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } -func (p *PromptManageServiceBatchGetPromptResult) ReadField0(iprot thrift.TProtocol) error { - _field := NewBatchGetPromptResponse() +func (p *PromptManageServiceListPromptResult) ReadField0(iprot thrift.TProtocol) error { + _field := NewListPromptResponse() if err := _field.Read(iprot); err != nil { return err } @@ -14396,9 +16159,9 @@ func (p *PromptManageServiceBatchGetPromptResult) ReadField0(iprot thrift.TProto return nil } -func (p *PromptManageServiceBatchGetPromptResult) Write(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceListPromptResult) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 - if err = oprot.WriteStructBegin("BatchGetPrompt_result"); err != nil { + if err = oprot.WriteStructBegin("ListPrompt_result"); err != nil { goto WriteStructBeginError } if p != nil { @@ -14424,7 +16187,7 @@ WriteStructEndError: return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } -func (p *PromptManageServiceBatchGetPromptResult) writeField0(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceListPromptResult) writeField0(oprot thrift.TProtocol) (err error) { if p.IsSetSuccess() { if err = oprot.WriteFieldBegin("success", thrift.STRUCT, 0); err != nil { goto WriteFieldBeginError @@ -14443,15 +16206,15 @@ WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 0 end error: ", p), err) } -func (p *PromptManageServiceBatchGetPromptResult) String() string { +func (p *PromptManageServiceListPromptResult) String() string { if p == nil { return "" } - return fmt.Sprintf("PromptManageServiceBatchGetPromptResult(%+v)", *p) + return fmt.Sprintf("PromptManageServiceListPromptResult(%+v)", *p) } -func (p *PromptManageServiceBatchGetPromptResult) DeepEqual(ano *PromptManageServiceBatchGetPromptResult) bool { +func (p *PromptManageServiceListPromptResult) DeepEqual(ano *PromptManageServiceListPromptResult) bool { if p == ano { return true } else if p == nil || ano == nil { @@ -14463,7 +16226,7 @@ func (p *PromptManageServiceBatchGetPromptResult) DeepEqual(ano *PromptManageSer return true } -func (p *PromptManageServiceBatchGetPromptResult) Field0DeepEqual(src *BatchGetPromptResponse) bool { +func (p *PromptManageServiceListPromptResult) Field0DeepEqual(src *ListPromptResponse) bool { if !p.Success.DeepEqual(src) { return false @@ -14471,41 +16234,41 @@ func (p *PromptManageServiceBatchGetPromptResult) Field0DeepEqual(src *BatchGetP return true } -type PromptManageServiceListPromptArgs struct { - Request *ListPromptRequest `thrift:"request,1" frugal:"1,default,ListPromptRequest"` +type PromptManageServiceListParentPromptArgs struct { + Request *ListParentPromptRequest `thrift:"request,1" frugal:"1,default,ListParentPromptRequest"` } -func NewPromptManageServiceListPromptArgs() *PromptManageServiceListPromptArgs { - return &PromptManageServiceListPromptArgs{} +func NewPromptManageServiceListParentPromptArgs() *PromptManageServiceListParentPromptArgs { + return &PromptManageServiceListParentPromptArgs{} } -func (p *PromptManageServiceListPromptArgs) InitDefault() { +func (p *PromptManageServiceListParentPromptArgs) InitDefault() { } -var PromptManageServiceListPromptArgs_Request_DEFAULT *ListPromptRequest +var PromptManageServiceListParentPromptArgs_Request_DEFAULT *ListParentPromptRequest -func (p *PromptManageServiceListPromptArgs) GetRequest() (v *ListPromptRequest) { +func (p *PromptManageServiceListParentPromptArgs) GetRequest() (v *ListParentPromptRequest) { if p == nil { return } if !p.IsSetRequest() { - return PromptManageServiceListPromptArgs_Request_DEFAULT + return PromptManageServiceListParentPromptArgs_Request_DEFAULT } return p.Request } -func (p *PromptManageServiceListPromptArgs) SetRequest(val *ListPromptRequest) { +func (p *PromptManageServiceListParentPromptArgs) SetRequest(val *ListParentPromptRequest) { p.Request = val } -var fieldIDToName_PromptManageServiceListPromptArgs = map[int16]string{ +var fieldIDToName_PromptManageServiceListParentPromptArgs = map[int16]string{ 1: "request", } -func (p *PromptManageServiceListPromptArgs) IsSetRequest() bool { +func (p *PromptManageServiceListParentPromptArgs) IsSetRequest() bool { return p.Request != nil } -func (p *PromptManageServiceListPromptArgs) Read(iprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceListParentPromptArgs) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -14550,7 +16313,7 @@ ReadStructBeginError: ReadFieldBeginError: return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceListPromptArgs[fieldId]), err) + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceListParentPromptArgs[fieldId]), err) SkipFieldError: return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) @@ -14560,8 +16323,8 @@ ReadStructEndError: return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } -func (p *PromptManageServiceListPromptArgs) ReadField1(iprot thrift.TProtocol) error { - _field := NewListPromptRequest() +func (p *PromptManageServiceListParentPromptArgs) ReadField1(iprot thrift.TProtocol) error { + _field := NewListParentPromptRequest() if err := _field.Read(iprot); err != nil { return err } @@ -14569,9 +16332,9 @@ func (p *PromptManageServiceListPromptArgs) ReadField1(iprot thrift.TProtocol) e return nil } -func (p *PromptManageServiceListPromptArgs) Write(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceListParentPromptArgs) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 - if err = oprot.WriteStructBegin("ListPrompt_args"); err != nil { + if err = oprot.WriteStructBegin("ListParentPrompt_args"); err != nil { goto WriteStructBeginError } if p != nil { @@ -14597,7 +16360,7 @@ WriteStructEndError: return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } -func (p *PromptManageServiceListPromptArgs) writeField1(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceListParentPromptArgs) writeField1(oprot thrift.TProtocol) (err error) { if err = oprot.WriteFieldBegin("request", thrift.STRUCT, 1); err != nil { goto WriteFieldBeginError } @@ -14614,15 +16377,15 @@ WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) } -func (p *PromptManageServiceListPromptArgs) String() string { +func (p *PromptManageServiceListParentPromptArgs) String() string { if p == nil { return "" } - return fmt.Sprintf("PromptManageServiceListPromptArgs(%+v)", *p) + return fmt.Sprintf("PromptManageServiceListParentPromptArgs(%+v)", *p) } -func (p *PromptManageServiceListPromptArgs) DeepEqual(ano *PromptManageServiceListPromptArgs) bool { +func (p *PromptManageServiceListParentPromptArgs) DeepEqual(ano *PromptManageServiceListParentPromptArgs) bool { if p == ano { return true } else if p == nil || ano == nil { @@ -14634,7 +16397,7 @@ func (p *PromptManageServiceListPromptArgs) DeepEqual(ano *PromptManageServiceLi return true } -func (p *PromptManageServiceListPromptArgs) Field1DeepEqual(src *ListPromptRequest) bool { +func (p *PromptManageServiceListParentPromptArgs) Field1DeepEqual(src *ListParentPromptRequest) bool { if !p.Request.DeepEqual(src) { return false @@ -14642,41 +16405,41 @@ func (p *PromptManageServiceListPromptArgs) Field1DeepEqual(src *ListPromptReque return true } -type PromptManageServiceListPromptResult struct { - Success *ListPromptResponse `thrift:"success,0,optional" frugal:"0,optional,ListPromptResponse"` +type PromptManageServiceListParentPromptResult struct { + Success *ListParentPromptResponse `thrift:"success,0,optional" frugal:"0,optional,ListParentPromptResponse"` } -func NewPromptManageServiceListPromptResult() *PromptManageServiceListPromptResult { - return &PromptManageServiceListPromptResult{} +func NewPromptManageServiceListParentPromptResult() *PromptManageServiceListParentPromptResult { + return &PromptManageServiceListParentPromptResult{} } -func (p *PromptManageServiceListPromptResult) InitDefault() { +func (p *PromptManageServiceListParentPromptResult) InitDefault() { } -var PromptManageServiceListPromptResult_Success_DEFAULT *ListPromptResponse +var PromptManageServiceListParentPromptResult_Success_DEFAULT *ListParentPromptResponse -func (p *PromptManageServiceListPromptResult) GetSuccess() (v *ListPromptResponse) { +func (p *PromptManageServiceListParentPromptResult) GetSuccess() (v *ListParentPromptResponse) { if p == nil { return } if !p.IsSetSuccess() { - return PromptManageServiceListPromptResult_Success_DEFAULT + return PromptManageServiceListParentPromptResult_Success_DEFAULT } return p.Success } -func (p *PromptManageServiceListPromptResult) SetSuccess(x interface{}) { - p.Success = x.(*ListPromptResponse) +func (p *PromptManageServiceListParentPromptResult) SetSuccess(x interface{}) { + p.Success = x.(*ListParentPromptResponse) } -var fieldIDToName_PromptManageServiceListPromptResult = map[int16]string{ +var fieldIDToName_PromptManageServiceListParentPromptResult = map[int16]string{ 0: "success", } -func (p *PromptManageServiceListPromptResult) IsSetSuccess() bool { +func (p *PromptManageServiceListParentPromptResult) IsSetSuccess() bool { return p.Success != nil } -func (p *PromptManageServiceListPromptResult) Read(iprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceListParentPromptResult) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -14721,7 +16484,7 @@ ReadStructBeginError: ReadFieldBeginError: return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceListPromptResult[fieldId]), err) + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceListParentPromptResult[fieldId]), err) SkipFieldError: return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) @@ -14731,8 +16494,8 @@ ReadStructEndError: return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } -func (p *PromptManageServiceListPromptResult) ReadField0(iprot thrift.TProtocol) error { - _field := NewListPromptResponse() +func (p *PromptManageServiceListParentPromptResult) ReadField0(iprot thrift.TProtocol) error { + _field := NewListParentPromptResponse() if err := _field.Read(iprot); err != nil { return err } @@ -14740,9 +16503,9 @@ func (p *PromptManageServiceListPromptResult) ReadField0(iprot thrift.TProtocol) return nil } -func (p *PromptManageServiceListPromptResult) Write(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceListParentPromptResult) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 - if err = oprot.WriteStructBegin("ListPrompt_result"); err != nil { + if err = oprot.WriteStructBegin("ListParentPrompt_result"); err != nil { goto WriteStructBeginError } if p != nil { @@ -14768,7 +16531,7 @@ WriteStructEndError: return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } -func (p *PromptManageServiceListPromptResult) writeField0(oprot thrift.TProtocol) (err error) { +func (p *PromptManageServiceListParentPromptResult) writeField0(oprot thrift.TProtocol) (err error) { if p.IsSetSuccess() { if err = oprot.WriteFieldBegin("success", thrift.STRUCT, 0); err != nil { goto WriteFieldBeginError @@ -14787,15 +16550,15 @@ WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 0 end error: ", p), err) } -func (p *PromptManageServiceListPromptResult) String() string { +func (p *PromptManageServiceListParentPromptResult) String() string { if p == nil { return "" } - return fmt.Sprintf("PromptManageServiceListPromptResult(%+v)", *p) + return fmt.Sprintf("PromptManageServiceListParentPromptResult(%+v)", *p) } -func (p *PromptManageServiceListPromptResult) DeepEqual(ano *PromptManageServiceListPromptResult) bool { +func (p *PromptManageServiceListParentPromptResult) DeepEqual(ano *PromptManageServiceListParentPromptResult) bool { if p == ano { return true } else if p == nil || ano == nil { @@ -14807,7 +16570,7 @@ func (p *PromptManageServiceListPromptResult) DeepEqual(ano *PromptManageService return true } -func (p *PromptManageServiceListPromptResult) Field0DeepEqual(src *ListPromptResponse) bool { +func (p *PromptManageServiceListParentPromptResult) Field0DeepEqual(src *ListParentPromptResponse) bool { if !p.Success.DeepEqual(src) { return false diff --git a/backend/kitex_gen/coze/loop/prompt/manage/coze.loop.prompt.manage_validator.go b/backend/kitex_gen/coze/loop/prompt/manage/coze.loop.prompt.manage_validator.go index a76c59770..be8d6c11b 100644 --- a/backend/kitex_gen/coze/loop/prompt/manage/coze.loop.prompt.manage_validator.go +++ b/backend/kitex_gen/coze/loop/prompt/manage/coze.loop.prompt.manage_validator.go @@ -483,3 +483,31 @@ func (p *UpdateCommitLabelsResponse) IsValid() error { } return nil } +func (p *ListParentPromptRequest) IsValid() error { + if p.WorkspaceID == nil { + return fmt.Errorf("field WorkspaceID not_nil rule failed") + } + if *p.WorkspaceID <= int64(0) { + return fmt.Errorf("field WorkspaceID gt rule failed, current value: %v", *p.WorkspaceID) + } + if p.PromptID == nil { + return fmt.Errorf("field PromptID not_nil rule failed") + } + if *p.PromptID <= int64(0) { + return fmt.Errorf("field PromptID gt rule failed, current value: %v", *p.PromptID) + } + if p.Base != nil { + if err := p.Base.IsValid(); err != nil { + return fmt.Errorf("field Base not valid, %w", err) + } + } + return nil +} +func (p *ListParentPromptResponse) IsValid() error { + if p.BaseResp != nil { + if err := p.BaseResp.IsValid(); err != nil { + return fmt.Errorf("field BaseResp not valid, %w", err) + } + } + return nil +} diff --git a/backend/kitex_gen/coze/loop/prompt/manage/k-coze.loop.prompt.manage.go b/backend/kitex_gen/coze/loop/prompt/manage/k-coze.loop.prompt.manage.go index 9b73ad111..b66c724f8 100644 --- a/backend/kitex_gen/coze/loop/prompt/manage/k-coze.loop.prompt.manage.go +++ b/backend/kitex_gen/coze/loop/prompt/manage/k-coze.loop.prompt.manage.go @@ -104,6 +104,20 @@ func (p *CreatePromptRequest) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 14: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField14(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } case 21: if fieldTypeId == thrift.STRUCT { l, err = p.FastReadField21(buf[offset:]) @@ -206,6 +220,20 @@ func (p *CreatePromptRequest) FastReadField13(buf []byte) (int, error) { return offset, nil } +func (p *CreatePromptRequest) FastReadField14(buf []byte) (int, error) { + offset := 0 + + var _field *prompt.PromptType + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.PromptType = _field + return offset, nil +} + func (p *CreatePromptRequest) FastReadField21(buf []byte) (int, error) { offset := 0 _field := prompt.NewPromptDetail() @@ -241,6 +269,7 @@ func (p *CreatePromptRequest) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) offset += p.fastWriteField11(buf[offset:], w) offset += p.fastWriteField12(buf[offset:], w) offset += p.fastWriteField13(buf[offset:], w) + offset += p.fastWriteField14(buf[offset:], w) offset += p.fastWriteField21(buf[offset:], w) offset += p.fastWriteField255(buf[offset:], w) } @@ -255,6 +284,7 @@ func (p *CreatePromptRequest) BLength() int { l += p.field11Length() l += p.field12Length() l += p.field13Length() + l += p.field14Length() l += p.field21Length() l += p.field255Length() } @@ -298,6 +328,15 @@ func (p *CreatePromptRequest) fastWriteField13(buf []byte, w thrift.NocopyWriter return offset } +func (p *CreatePromptRequest) fastWriteField14(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetPromptType() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 14) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.PromptType) + } + return offset +} + func (p *CreatePromptRequest) fastWriteField21(buf []byte, w thrift.NocopyWriter) int { offset := 0 if p.IsSetDraftDetail() { @@ -352,6 +391,15 @@ func (p *CreatePromptRequest) field13Length() int { return l } +func (p *CreatePromptRequest) field14Length() int { + l := 0 + if p.IsSetPromptType() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.PromptType) + } + return l +} + func (p *CreatePromptRequest) field21Length() int { l := 0 if p.IsSetDraftDetail() { @@ -405,6 +453,11 @@ func (p *CreatePromptRequest) DeepCopy(s interface{}) error { p.PromptDescription = &tmp } + if src.PromptType != nil { + tmp := *src.PromptType + p.PromptType = &tmp + } + var _draftDetail *prompt.PromptDetail if src.DraftDetail != nil { _draftDetail = &prompt.PromptDetail{} @@ -1558,6 +1611,20 @@ func (p *GetPromptRequest) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 32: + if fieldTypeId == thrift.BOOL { + l, err = p.FastReadField32(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } case 255: if fieldTypeId == thrift.STRUCT { l, err = p.FastReadField255(buf[offset:]) @@ -1674,6 +1741,20 @@ func (p *GetPromptRequest) FastReadField31(buf []byte) (int, error) { return offset, nil } +func (p *GetPromptRequest) FastReadField32(buf []byte) (int, error) { + offset := 0 + + var _field *bool + if v, l, err := thrift.Binary.ReadBool(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.ExpandSnippet = _field + return offset, nil +} + func (p *GetPromptRequest) FastReadField255(buf []byte) (int, error) { offset := 0 _field := base.NewBase() @@ -1698,6 +1779,7 @@ func (p *GetPromptRequest) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) in offset += p.fastWriteField11(buf[offset:], w) offset += p.fastWriteField21(buf[offset:], w) offset += p.fastWriteField31(buf[offset:], w) + offset += p.fastWriteField32(buf[offset:], w) offset += p.fastWriteField12(buf[offset:], w) offset += p.fastWriteField255(buf[offset:], w) } @@ -1714,6 +1796,7 @@ func (p *GetPromptRequest) BLength() int { l += p.field12Length() l += p.field21Length() l += p.field31Length() + l += p.field32Length() l += p.field255Length() } l += thrift.Binary.FieldStopLength() @@ -1774,6 +1857,15 @@ func (p *GetPromptRequest) fastWriteField31(buf []byte, w thrift.NocopyWriter) i return offset } +func (p *GetPromptRequest) fastWriteField32(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetExpandSnippet() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.BOOL, 32) + offset += thrift.Binary.WriteBool(buf[offset:], *p.ExpandSnippet) + } + return offset +} + func (p *GetPromptRequest) fastWriteField255(buf []byte, w thrift.NocopyWriter) int { offset := 0 if p.IsSetBase() { @@ -1837,6 +1929,15 @@ func (p *GetPromptRequest) field31Length() int { return l } +func (p *GetPromptRequest) field32Length() int { + l := 0 + if p.IsSetExpandSnippet() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.BoolLength() + } + return l +} + func (p *GetPromptRequest) field255Length() int { l := 0 if p.IsSetBase() { @@ -1885,6 +1986,11 @@ func (p *GetPromptRequest) DeepCopy(s interface{}) error { p.WithDefaultConfig = &tmp } + if src.ExpandSnippet != nil { + tmp := *src.ExpandSnippet + p.ExpandSnippet = &tmp + } + var _base *base.Base if src.Base != nil { _base = &base.Base{} @@ -1942,6 +2048,20 @@ func (p *GetPromptResponse) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 12: + if fieldTypeId == thrift.I32 { + l, err = p.FastReadField12(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } case 255: if fieldTypeId == thrift.STRUCT { l, err = p.FastReadField255(buf[offset:]) @@ -1998,6 +2118,20 @@ func (p *GetPromptResponse) FastReadField11(buf []byte) (int, error) { return offset, nil } +func (p *GetPromptResponse) FastReadField12(buf []byte) (int, error) { + offset := 0 + + var _field *int32 + if v, l, err := thrift.Binary.ReadI32(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.TotalParentReferences = _field + return offset, nil +} + func (p *GetPromptResponse) FastReadField255(buf []byte) (int, error) { offset := 0 _field := base.NewBaseResp() @@ -2017,6 +2151,7 @@ func (p *GetPromptResponse) FastWrite(buf []byte) int { func (p *GetPromptResponse) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset := 0 if p != nil { + offset += p.fastWriteField12(buf[offset:], w) offset += p.fastWriteField1(buf[offset:], w) offset += p.fastWriteField11(buf[offset:], w) offset += p.fastWriteField255(buf[offset:], w) @@ -2030,6 +2165,7 @@ func (p *GetPromptResponse) BLength() int { if p != nil { l += p.field1Length() l += p.field11Length() + l += p.field12Length() l += p.field255Length() } l += thrift.Binary.FieldStopLength() @@ -2054,6 +2190,15 @@ func (p *GetPromptResponse) fastWriteField11(buf []byte, w thrift.NocopyWriter) return offset } +func (p *GetPromptResponse) fastWriteField12(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetTotalParentReferences() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I32, 12) + offset += thrift.Binary.WriteI32(buf[offset:], *p.TotalParentReferences) + } + return offset +} + func (p *GetPromptResponse) fastWriteField255(buf []byte, w thrift.NocopyWriter) int { offset := 0 if p.IsSetBaseResp() { @@ -2081,6 +2226,15 @@ func (p *GetPromptResponse) field11Length() int { return l } +func (p *GetPromptResponse) field12Length() int { + l := 0 + if p.IsSetTotalParentReferences() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I32Length() + } + return l +} + func (p *GetPromptResponse) field255Length() int { l := 0 if p.IsSetBaseResp() { @@ -2114,6 +2268,11 @@ func (p *GetPromptResponse) DeepCopy(s interface{}) error { } p.DefaultConfig = _defaultConfig + if src.TotalParentReferences != nil { + tmp := *src.TotalParentReferences + p.TotalParentReferences = &tmp + } + var _baseResp *base.BaseResp if src.BaseResp != nil { _baseResp = &base.BaseResp{} @@ -3007,6 +3166,20 @@ func (p *ListPromptRequest) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 14: + if fieldTypeId == thrift.LIST { + l, err = p.FastReadField14(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } case 127: if fieldTypeId == thrift.I32 { l, err = p.FastReadField127(buf[offset:]) @@ -3161,6 +3334,30 @@ func (p *ListPromptRequest) FastReadField13(buf []byte) (int, error) { return offset, nil } +func (p *ListPromptRequest) FastReadField14(buf []byte) (int, error) { + offset := 0 + + _, size, l, err := thrift.Binary.ReadListBegin(buf[offset:]) + offset += l + if err != nil { + return offset, err + } + _field := make([]prompt.PromptType, 0, size) + for i := 0; i < size; i++ { + var _elem prompt.PromptType + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _elem = v + } + + _field = append(_field, _elem) + } + p.FilterPromptTypes = _field + return offset, nil +} + func (p *ListPromptRequest) FastReadField127(buf []byte) (int, error) { offset := 0 @@ -3243,6 +3440,7 @@ func (p *ListPromptRequest) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) i offset += p.fastWriteField130(buf[offset:], w) offset += p.fastWriteField11(buf[offset:], w) offset += p.fastWriteField12(buf[offset:], w) + offset += p.fastWriteField14(buf[offset:], w) offset += p.fastWriteField129(buf[offset:], w) offset += p.fastWriteField255(buf[offset:], w) } @@ -3257,6 +3455,7 @@ func (p *ListPromptRequest) BLength() int { l += p.field11Length() l += p.field12Length() l += p.field13Length() + l += p.field14Length() l += p.field127Length() l += p.field128Length() l += p.field129Length() @@ -3310,6 +3509,22 @@ func (p *ListPromptRequest) fastWriteField13(buf []byte, w thrift.NocopyWriter) return offset } +func (p *ListPromptRequest) fastWriteField14(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetFilterPromptTypes() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.LIST, 14) + listBeginOffset := offset + offset += thrift.Binary.ListBeginLength() + var length int + for _, v := range p.FilterPromptTypes { + length++ + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, v) + } + thrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRING, length) + } + return offset +} + func (p *ListPromptRequest) fastWriteField127(buf []byte, w thrift.NocopyWriter) int { offset := 0 if p.IsSetPageNum() { @@ -3395,6 +3610,19 @@ func (p *ListPromptRequest) field13Length() int { return l } +func (p *ListPromptRequest) field14Length() int { + l := 0 + if p.IsSetFilterPromptTypes() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.ListBeginLength() + for _, v := range p.FilterPromptTypes { + _ = v + l += thrift.Binary.StringLengthNocopy(v) + } + } + return l +} + func (p *ListPromptRequest) field127Length() int { l := 0 if p.IsSetPageNum() { @@ -3475,6 +3703,15 @@ func (p *ListPromptRequest) DeepCopy(s interface{}) error { p.CommittedOnly = &tmp } + if src.FilterPromptTypes != nil { + p.FilterPromptTypes = make([]prompt.PromptType, 0, len(src.FilterPromptTypes)) + for _, elem := range src.FilterPromptTypes { + var _elem prompt.PromptType + _elem = elem + p.FilterPromptTypes = append(p.FilterPromptTypes, _elem) + } + } + if src.PageNum != nil { tmp := *src.PageNum p.PageNum = &tmp @@ -5167,6 +5404,20 @@ func (p *ListCommitRequest) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 2: + if fieldTypeId == thrift.BOOL { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } case 127: if fieldTypeId == thrift.I32 { l, err = p.FastReadField127(buf[offset:]) @@ -5255,6 +5506,20 @@ func (p *ListCommitRequest) FastReadField1(buf []byte) (int, error) { return offset, nil } +func (p *ListCommitRequest) FastReadField2(buf []byte) (int, error) { + offset := 0 + + var _field *bool + if v, l, err := thrift.Binary.ReadBool(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.WithCommitDetail = _field + return offset, nil +} + func (p *ListCommitRequest) FastReadField127(buf []byte) (int, error) { offset := 0 @@ -5317,6 +5582,7 @@ func (p *ListCommitRequest) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) i offset := 0 if p != nil { offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) offset += p.fastWriteField127(buf[offset:], w) offset += p.fastWriteField129(buf[offset:], w) offset += p.fastWriteField128(buf[offset:], w) @@ -5330,6 +5596,7 @@ func (p *ListCommitRequest) BLength() int { l := 0 if p != nil { l += p.field1Length() + l += p.field2Length() l += p.field127Length() l += p.field128Length() l += p.field129Length() @@ -5348,6 +5615,15 @@ func (p *ListCommitRequest) fastWriteField1(buf []byte, w thrift.NocopyWriter) i return offset } +func (p *ListCommitRequest) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetWithCommitDetail() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.BOOL, 2) + offset += thrift.Binary.WriteBool(buf[offset:], *p.WithCommitDetail) + } + return offset +} + func (p *ListCommitRequest) fastWriteField127(buf []byte, w thrift.NocopyWriter) int { offset := 0 if p.IsSetPageSize() { @@ -5393,6 +5669,15 @@ func (p *ListCommitRequest) field1Length() int { return l } +func (p *ListCommitRequest) field2Length() int { + l := 0 + if p.IsSetWithCommitDetail() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.BoolLength() + } + return l +} + func (p *ListCommitRequest) field127Length() int { l := 0 if p.IsSetPageSize() { @@ -5440,6 +5725,11 @@ func (p *ListCommitRequest) DeepCopy(s interface{}) error { p.PromptID = &tmp } + if src.WithCommitDetail != nil { + tmp := *src.WithCommitDetail + p.WithCommitDetail = &tmp + } + if src.PageSize != nil { tmp := *src.PageSize p.PageSize = &tmp @@ -5515,6 +5805,34 @@ func (p *ListCommitResponse) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 3: + if fieldTypeId == thrift.MAP { + l, err = p.FastReadField3(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 4: + if fieldTypeId == thrift.MAP { + l, err = p.FastReadField4(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } case 11: if fieldTypeId == thrift.LIST { l, err = p.FastReadField11(buf[offset:]) @@ -5657,27 +5975,92 @@ func (p *ListCommitResponse) FastReadField2(buf []byte) (int, error) { return offset, nil } -func (p *ListCommitResponse) FastReadField11(buf []byte) (int, error) { +func (p *ListCommitResponse) FastReadField3(buf []byte) (int, error) { offset := 0 - _, size, l, err := thrift.Binary.ReadListBegin(buf[offset:]) + _, _, size, l, err := thrift.Binary.ReadMapBegin(buf[offset:]) offset += l if err != nil { return offset, err } - _field := make([]*user.UserInfoDetail, 0, size) - values := make([]user.UserInfoDetail, size) + _field := make(map[string]int32, size) for i := 0; i < size; i++ { - _elem := &values[i] - _elem.InitDefault() - if l, err := _elem.FastRead(buf[offset:]); err != nil { + var _key string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { return offset, err } else { offset += l + _key = v } - _field = append(_field, _elem) - } + var _val int32 + if v, l, err := thrift.Binary.ReadI32(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _val = v + } + + _field[_key] = _val + } + p.ParentReferencesMapping = _field + return offset, nil +} + +func (p *ListCommitResponse) FastReadField4(buf []byte) (int, error) { + offset := 0 + + _, _, size, l, err := thrift.Binary.ReadMapBegin(buf[offset:]) + offset += l + if err != nil { + return offset, err + } + _field := make(map[string]*prompt.PromptDetail, size) + values := make([]prompt.PromptDetail, size) + for i := 0; i < size; i++ { + var _key string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _key = v + } + + _val := &values[i] + _val.InitDefault() + if l, err := _val.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + + _field[_key] = _val + } + p.PromptCommitDetailMapping = _field + return offset, nil +} + +func (p *ListCommitResponse) FastReadField11(buf []byte) (int, error) { + offset := 0 + + _, size, l, err := thrift.Binary.ReadListBegin(buf[offset:]) + offset += l + if err != nil { + return offset, err + } + _field := make([]*user.UserInfoDetail, 0, size) + values := make([]user.UserInfoDetail, size) + for i := 0; i < size; i++ { + _elem := &values[i] + _elem.InitDefault() + if l, err := _elem.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + + _field = append(_field, _elem) + } p.Users = _field return offset, nil } @@ -5732,6 +6115,8 @@ func (p *ListCommitResponse) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) offset += p.fastWriteField127(buf[offset:], w) offset += p.fastWriteField1(buf[offset:], w) offset += p.fastWriteField2(buf[offset:], w) + offset += p.fastWriteField3(buf[offset:], w) + offset += p.fastWriteField4(buf[offset:], w) offset += p.fastWriteField11(buf[offset:], w) offset += p.fastWriteField128(buf[offset:], w) offset += p.fastWriteField255(buf[offset:], w) @@ -5745,6 +6130,8 @@ func (p *ListCommitResponse) BLength() int { if p != nil { l += p.field1Length() l += p.field2Length() + l += p.field3Length() + l += p.field4Length() l += p.field11Length() l += p.field127Length() l += p.field128Length() @@ -5794,6 +6181,40 @@ func (p *ListCommitResponse) fastWriteField2(buf []byte, w thrift.NocopyWriter) return offset } +func (p *ListCommitResponse) fastWriteField3(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetParentReferencesMapping() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.MAP, 3) + mapBeginOffset := offset + offset += thrift.Binary.MapBeginLength() + var length int + for k, v := range p.ParentReferencesMapping { + length++ + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, k) + offset += thrift.Binary.WriteI32(buf[offset:], v) + } + thrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.STRING, thrift.I32, length) + } + return offset +} + +func (p *ListCommitResponse) fastWriteField4(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetPromptCommitDetailMapping() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.MAP, 4) + mapBeginOffset := offset + offset += thrift.Binary.MapBeginLength() + var length int + for k, v := range p.PromptCommitDetailMapping { + length++ + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, k) + offset += v.FastWriteNocopy(buf[offset:], w) + } + thrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.STRING, thrift.STRUCT, length) + } + return offset +} + func (p *ListCommitResponse) fastWriteField11(buf []byte, w thrift.NocopyWriter) int { offset := 0 if p.IsSetUsers() { @@ -5869,6 +6290,36 @@ func (p *ListCommitResponse) field2Length() int { return l } +func (p *ListCommitResponse) field3Length() int { + l := 0 + if p.IsSetParentReferencesMapping() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.MapBeginLength() + for k, v := range p.ParentReferencesMapping { + _, _ = k, v + + l += thrift.Binary.StringLengthNocopy(k) + l += thrift.Binary.I32Length() + } + } + return l +} + +func (p *ListCommitResponse) field4Length() int { + l := 0 + if p.IsSetPromptCommitDetailMapping() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.MapBeginLength() + for k, v := range p.PromptCommitDetailMapping { + _, _ = k, v + + l += thrift.Binary.StringLengthNocopy(k) + l += v.BLength() + } + } + return l +} + func (p *ListCommitResponse) field11Length() int { l := 0 if p.IsSetUsers() { @@ -5958,6 +6409,41 @@ func (p *ListCommitResponse) DeepCopy(s interface{}) error { } } + if src.ParentReferencesMapping != nil { + p.ParentReferencesMapping = make(map[string]int32, len(src.ParentReferencesMapping)) + for key, val := range src.ParentReferencesMapping { + var _key string + if key != "" { + _key = kutils.StringDeepCopy(key) + } + + var _val int32 + _val = val + + p.ParentReferencesMapping[_key] = _val + } + } + + if src.PromptCommitDetailMapping != nil { + p.PromptCommitDetailMapping = make(map[string]*prompt.PromptDetail, len(src.PromptCommitDetailMapping)) + for key, val := range src.PromptCommitDetailMapping { + var _key string + if key != "" { + _key = kutils.StringDeepCopy(key) + } + + var _val *prompt.PromptDetail + if val != nil { + _val = &prompt.PromptDetail{} + if err := _val.DeepCopy(val); err != nil { + return err + } + } + + p.PromptCommitDetailMapping[_key] = _val + } + } + if src.Users != nil { p.Users = make([]*user.UserInfoDetail, 0, len(src.Users)) for _, elem := range src.Users { @@ -8444,7 +8930,514 @@ func (p *UpdateCommitLabelsResponse) BLength() int { return l } -func (p *UpdateCommitLabelsResponse) fastWriteField255(buf []byte, w thrift.NocopyWriter) int { +func (p *UpdateCommitLabelsResponse) fastWriteField255(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetBaseResp() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 255) + offset += p.BaseResp.FastWriteNocopy(buf[offset:], w) + } + return offset +} + +func (p *UpdateCommitLabelsResponse) field255Length() int { + l := 0 + if p.IsSetBaseResp() { + l += thrift.Binary.FieldBeginLength() + l += p.BaseResp.BLength() + } + return l +} + +func (p *UpdateCommitLabelsResponse) DeepCopy(s interface{}) error { + src, ok := s.(*UpdateCommitLabelsResponse) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + var _baseResp *base.BaseResp + if src.BaseResp != nil { + _baseResp = &base.BaseResp{} + if err := _baseResp.DeepCopy(src.BaseResp); err != nil { + return err + } + } + p.BaseResp = _baseResp + + return nil +} + +func (p *ListParentPromptRequest) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.I64 { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.I64 { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 3: + if fieldTypeId == thrift.LIST { + l, err = p.FastReadField3(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 255: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField255(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ListParentPromptRequest[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *ListParentPromptRequest) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field *int64 + if v, l, err := thrift.Binary.ReadI64(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.WorkspaceID = _field + return offset, nil +} + +func (p *ListParentPromptRequest) FastReadField2(buf []byte) (int, error) { + offset := 0 + + var _field *int64 + if v, l, err := thrift.Binary.ReadI64(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.PromptID = _field + return offset, nil +} + +func (p *ListParentPromptRequest) FastReadField3(buf []byte) (int, error) { + offset := 0 + + _, size, l, err := thrift.Binary.ReadListBegin(buf[offset:]) + offset += l + if err != nil { + return offset, err + } + _field := make([]string, 0, size) + for i := 0; i < size; i++ { + var _elem string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _elem = v + } + + _field = append(_field, _elem) + } + p.CommitVersions = _field + return offset, nil +} + +func (p *ListParentPromptRequest) FastReadField255(buf []byte) (int, error) { + offset := 0 + _field := base.NewBase() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.Base = _field + return offset, nil +} + +func (p *ListParentPromptRequest) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *ListParentPromptRequest) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) + offset += p.fastWriteField3(buf[offset:], w) + offset += p.fastWriteField255(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *ListParentPromptRequest) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + l += p.field2Length() + l += p.field3Length() + l += p.field255Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *ListParentPromptRequest) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetWorkspaceID() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I64, 1) + offset += thrift.Binary.WriteI64(buf[offset:], *p.WorkspaceID) + } + return offset +} + +func (p *ListParentPromptRequest) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetPromptID() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I64, 2) + offset += thrift.Binary.WriteI64(buf[offset:], *p.PromptID) + } + return offset +} + +func (p *ListParentPromptRequest) fastWriteField3(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetCommitVersions() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.LIST, 3) + listBeginOffset := offset + offset += thrift.Binary.ListBeginLength() + var length int + for _, v := range p.CommitVersions { + length++ + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, v) + } + thrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRING, length) + } + return offset +} + +func (p *ListParentPromptRequest) fastWriteField255(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetBase() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 255) + offset += p.Base.FastWriteNocopy(buf[offset:], w) + } + return offset +} + +func (p *ListParentPromptRequest) field1Length() int { + l := 0 + if p.IsSetWorkspaceID() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I64Length() + } + return l +} + +func (p *ListParentPromptRequest) field2Length() int { + l := 0 + if p.IsSetPromptID() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I64Length() + } + return l +} + +func (p *ListParentPromptRequest) field3Length() int { + l := 0 + if p.IsSetCommitVersions() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.ListBeginLength() + for _, v := range p.CommitVersions { + _ = v + l += thrift.Binary.StringLengthNocopy(v) + } + } + return l +} + +func (p *ListParentPromptRequest) field255Length() int { + l := 0 + if p.IsSetBase() { + l += thrift.Binary.FieldBeginLength() + l += p.Base.BLength() + } + return l +} + +func (p *ListParentPromptRequest) DeepCopy(s interface{}) error { + src, ok := s.(*ListParentPromptRequest) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.WorkspaceID != nil { + tmp := *src.WorkspaceID + p.WorkspaceID = &tmp + } + + if src.PromptID != nil { + tmp := *src.PromptID + p.PromptID = &tmp + } + + if src.CommitVersions != nil { + p.CommitVersions = make([]string, 0, len(src.CommitVersions)) + for _, elem := range src.CommitVersions { + var _elem string + if elem != "" { + _elem = kutils.StringDeepCopy(elem) + } + p.CommitVersions = append(p.CommitVersions, _elem) + } + } + + var _base *base.Base + if src.Base != nil { + _base = &base.Base{} + if err := _base.DeepCopy(src.Base); err != nil { + return err + } + } + p.Base = _base + + return nil +} + +func (p *ListParentPromptResponse) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.MAP { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 255: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField255(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ListParentPromptResponse[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *ListParentPromptResponse) FastReadField1(buf []byte) (int, error) { + offset := 0 + + _, _, size, l, err := thrift.Binary.ReadMapBegin(buf[offset:]) + offset += l + if err != nil { + return offset, err + } + _field := make(map[string][]*prompt.PromptCommitVersions, size) + for i := 0; i < size; i++ { + var _key string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _key = v + } + + _, size, l, err := thrift.Binary.ReadListBegin(buf[offset:]) + offset += l + if err != nil { + return offset, err + } + _val := make([]*prompt.PromptCommitVersions, 0, size) + values := make([]prompt.PromptCommitVersions, size) + for i := 0; i < size; i++ { + _elem := &values[i] + _elem.InitDefault() + if l, err := _elem.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + + _val = append(_val, _elem) + } + + _field[_key] = _val + } + p.ParentPrompts = _field + return offset, nil +} + +func (p *ListParentPromptResponse) FastReadField255(buf []byte) (int, error) { + offset := 0 + _field := base.NewBaseResp() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.BaseResp = _field + return offset, nil +} + +func (p *ListParentPromptResponse) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *ListParentPromptResponse) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField255(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *ListParentPromptResponse) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + l += p.field255Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *ListParentPromptResponse) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetParentPrompts() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.MAP, 1) + mapBeginOffset := offset + offset += thrift.Binary.MapBeginLength() + var length int + for k, v := range p.ParentPrompts { + length++ + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, k) + listBeginOffset := offset + offset += thrift.Binary.ListBeginLength() + var length int + for _, v := range v { + length++ + offset += v.FastWriteNocopy(buf[offset:], w) + } + thrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRUCT, length) + } + thrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.STRING, thrift.LIST, length) + } + return offset +} + +func (p *ListParentPromptResponse) fastWriteField255(buf []byte, w thrift.NocopyWriter) int { offset := 0 if p.IsSetBaseResp() { offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 255) @@ -8453,7 +9446,26 @@ func (p *UpdateCommitLabelsResponse) fastWriteField255(buf []byte, w thrift.Noco return offset } -func (p *UpdateCommitLabelsResponse) field255Length() int { +func (p *ListParentPromptResponse) field1Length() int { + l := 0 + if p.IsSetParentPrompts() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.MapBeginLength() + for k, v := range p.ParentPrompts { + _, _ = k, v + + l += thrift.Binary.StringLengthNocopy(k) + l += thrift.Binary.ListBeginLength() + for _, v := range v { + _ = v + l += v.BLength() + } + } + } + return l +} + +func (p *ListParentPromptResponse) field255Length() int { l := 0 if p.IsSetBaseResp() { l += thrift.Binary.FieldBeginLength() @@ -8462,12 +9474,40 @@ func (p *UpdateCommitLabelsResponse) field255Length() int { return l } -func (p *UpdateCommitLabelsResponse) DeepCopy(s interface{}) error { - src, ok := s.(*UpdateCommitLabelsResponse) +func (p *ListParentPromptResponse) DeepCopy(s interface{}) error { + src, ok := s.(*ListParentPromptResponse) if !ok { return fmt.Errorf("%T's type not matched %T", s, p) } + if src.ParentPrompts != nil { + p.ParentPrompts = make(map[string][]*prompt.PromptCommitVersions, len(src.ParentPrompts)) + for key, val := range src.ParentPrompts { + var _key string + if key != "" { + _key = kutils.StringDeepCopy(key) + } + + var _val []*prompt.PromptCommitVersions + if val != nil { + _val = make([]*prompt.PromptCommitVersions, 0, len(val)) + for _, elem := range val { + var _elem *prompt.PromptCommitVersions + if elem != nil { + _elem = &prompt.PromptCommitVersions{} + if err := _elem.DeepCopy(elem); err != nil { + return err + } + } + + _val = append(_val, _elem) + } + } + + p.ParentPrompts[_key] = _val + } + } + var _baseResp *base.BaseResp if src.BaseResp != nil { _baseResp = &base.BaseResp{} @@ -9884,6 +10924,240 @@ func (p *PromptManageServiceListPromptResult) DeepCopy(s interface{}) error { return nil } +func (p *PromptManageServiceListParentPromptArgs) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceListParentPromptArgs[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *PromptManageServiceListParentPromptArgs) FastReadField1(buf []byte) (int, error) { + offset := 0 + _field := NewListParentPromptRequest() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.Request = _field + return offset, nil +} + +func (p *PromptManageServiceListParentPromptArgs) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *PromptManageServiceListParentPromptArgs) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *PromptManageServiceListParentPromptArgs) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *PromptManageServiceListParentPromptArgs) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 1) + offset += p.Request.FastWriteNocopy(buf[offset:], w) + return offset +} + +func (p *PromptManageServiceListParentPromptArgs) field1Length() int { + l := 0 + l += thrift.Binary.FieldBeginLength() + l += p.Request.BLength() + return l +} + +func (p *PromptManageServiceListParentPromptArgs) DeepCopy(s interface{}) error { + src, ok := s.(*PromptManageServiceListParentPromptArgs) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + var _request *ListParentPromptRequest + if src.Request != nil { + _request = &ListParentPromptRequest{} + if err := _request.DeepCopy(src.Request); err != nil { + return err + } + } + p.Request = _request + + return nil +} + +func (p *PromptManageServiceListParentPromptResult) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 0: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField0(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_PromptManageServiceListParentPromptResult[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *PromptManageServiceListParentPromptResult) FastReadField0(buf []byte) (int, error) { + offset := 0 + _field := NewListParentPromptResponse() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.Success = _field + return offset, nil +} + +func (p *PromptManageServiceListParentPromptResult) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *PromptManageServiceListParentPromptResult) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField0(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *PromptManageServiceListParentPromptResult) BLength() int { + l := 0 + if p != nil { + l += p.field0Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *PromptManageServiceListParentPromptResult) fastWriteField0(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetSuccess() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 0) + offset += p.Success.FastWriteNocopy(buf[offset:], w) + } + return offset +} + +func (p *PromptManageServiceListParentPromptResult) field0Length() int { + l := 0 + if p.IsSetSuccess() { + l += thrift.Binary.FieldBeginLength() + l += p.Success.BLength() + } + return l +} + +func (p *PromptManageServiceListParentPromptResult) DeepCopy(s interface{}) error { + src, ok := s.(*PromptManageServiceListParentPromptResult) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + var _success *ListParentPromptResponse + if src.Success != nil { + _success = &ListParentPromptResponse{} + if err := _success.DeepCopy(src.Success); err != nil { + return err + } + } + p.Success = _success + + return nil +} + func (p *PromptManageServiceUpdatePromptArgs) FastRead(buf []byte) (int, error) { var err error @@ -12038,6 +13312,14 @@ func (p *PromptManageServiceListPromptResult) GetResult() interface{} { return p.Success } +func (p *PromptManageServiceListParentPromptArgs) GetFirstArgument() interface{} { + return p.Request +} + +func (p *PromptManageServiceListParentPromptResult) GetResult() interface{} { + return p.Success +} + func (p *PromptManageServiceUpdatePromptArgs) GetFirstArgument() interface{} { return p.Request } diff --git a/backend/kitex_gen/coze/loop/prompt/manage/promptmanageservice/client.go b/backend/kitex_gen/coze/loop/prompt/manage/promptmanageservice/client.go index bc2ea53f6..3e7b1c158 100644 --- a/backend/kitex_gen/coze/loop/prompt/manage/promptmanageservice/client.go +++ b/backend/kitex_gen/coze/loop/prompt/manage/promptmanageservice/client.go @@ -17,6 +17,7 @@ type Client interface { GetPrompt(ctx context.Context, request *manage.GetPromptRequest, callOptions ...callopt.Option) (r *manage.GetPromptResponse, err error) BatchGetPrompt(ctx context.Context, request *manage.BatchGetPromptRequest, callOptions ...callopt.Option) (r *manage.BatchGetPromptResponse, err error) ListPrompt(ctx context.Context, request *manage.ListPromptRequest, callOptions ...callopt.Option) (r *manage.ListPromptResponse, err error) + ListParentPrompt(ctx context.Context, request *manage.ListParentPromptRequest, callOptions ...callopt.Option) (r *manage.ListParentPromptResponse, err error) UpdatePrompt(ctx context.Context, request *manage.UpdatePromptRequest, callOptions ...callopt.Option) (r *manage.UpdatePromptResponse, err error) SaveDraft(ctx context.Context, request *manage.SaveDraftRequest, callOptions ...callopt.Option) (r *manage.SaveDraftResponse, err error) CreateLabel(ctx context.Context, request *manage.CreateLabelRequest, callOptions ...callopt.Option) (r *manage.CreateLabelResponse, err error) @@ -87,6 +88,11 @@ func (p *kPromptManageServiceClient) ListPrompt(ctx context.Context, request *ma return p.kClient.ListPrompt(ctx, request) } +func (p *kPromptManageServiceClient) ListParentPrompt(ctx context.Context, request *manage.ListParentPromptRequest, callOptions ...callopt.Option) (r *manage.ListParentPromptResponse, err error) { + ctx = client.NewCtxWithCallOptions(ctx, callOptions) + return p.kClient.ListParentPrompt(ctx, request) +} + func (p *kPromptManageServiceClient) UpdatePrompt(ctx context.Context, request *manage.UpdatePromptRequest, callOptions ...callopt.Option) (r *manage.UpdatePromptResponse, err error) { ctx = client.NewCtxWithCallOptions(ctx, callOptions) return p.kClient.UpdatePrompt(ctx, request) diff --git a/backend/kitex_gen/coze/loop/prompt/manage/promptmanageservice/promptmanageservice.go b/backend/kitex_gen/coze/loop/prompt/manage/promptmanageservice/promptmanageservice.go index dac877965..d8900b7be 100644 --- a/backend/kitex_gen/coze/loop/prompt/manage/promptmanageservice/promptmanageservice.go +++ b/backend/kitex_gen/coze/loop/prompt/manage/promptmanageservice/promptmanageservice.go @@ -55,6 +55,13 @@ var serviceMethods = map[string]kitex.MethodInfo{ false, kitex.WithStreamingMode(kitex.StreamingNone), ), + "ListParentPrompt": kitex.NewMethodInfo( + listParentPromptHandler, + newPromptManageServiceListParentPromptArgs, + newPromptManageServiceListParentPromptResult, + false, + kitex.WithStreamingMode(kitex.StreamingNone), + ), "UpdatePrompt": kitex.NewMethodInfo( updatePromptHandler, newPromptManageServiceUpdatePromptArgs, @@ -265,6 +272,25 @@ func newPromptManageServiceListPromptResult() interface{} { return manage.NewPromptManageServiceListPromptResult() } +func listParentPromptHandler(ctx context.Context, handler interface{}, arg, result interface{}) error { + realArg := arg.(*manage.PromptManageServiceListParentPromptArgs) + realResult := result.(*manage.PromptManageServiceListParentPromptResult) + success, err := handler.(manage.PromptManageService).ListParentPrompt(ctx, realArg.Request) + if err != nil { + return err + } + realResult.Success = success + return nil +} + +func newPromptManageServiceListParentPromptArgs() interface{} { + return manage.NewPromptManageServiceListParentPromptArgs() +} + +func newPromptManageServiceListParentPromptResult() interface{} { + return manage.NewPromptManageServiceListParentPromptResult() +} + func updatePromptHandler(ctx context.Context, handler interface{}, arg, result interface{}) error { realArg := arg.(*manage.PromptManageServiceUpdatePromptArgs) realResult := result.(*manage.PromptManageServiceUpdatePromptResult) @@ -508,6 +534,16 @@ func (p *kClient) ListPrompt(ctx context.Context, request *manage.ListPromptRequ return _result.GetSuccess(), nil } +func (p *kClient) ListParentPrompt(ctx context.Context, request *manage.ListParentPromptRequest) (r *manage.ListParentPromptResponse, err error) { + var _args manage.PromptManageServiceListParentPromptArgs + _args.Request = request + var _result manage.PromptManageServiceListParentPromptResult + if err = p.c.Call(ctx, "ListParentPrompt", &_args, &_result); err != nil { + return + } + return _result.GetSuccess(), nil +} + func (p *kClient) UpdatePrompt(ctx context.Context, request *manage.UpdatePromptRequest) (r *manage.UpdatePromptResponse, err error) { var _args manage.PromptManageServiceUpdatePromptArgs _args.Request = request diff --git a/backend/kitex_gen/coze/loop/prompt/promptmanageservice/client.go b/backend/kitex_gen/coze/loop/prompt/promptmanageservice/client.go index bc2ea53f6..3e7b1c158 100644 --- a/backend/kitex_gen/coze/loop/prompt/promptmanageservice/client.go +++ b/backend/kitex_gen/coze/loop/prompt/promptmanageservice/client.go @@ -17,6 +17,7 @@ type Client interface { GetPrompt(ctx context.Context, request *manage.GetPromptRequest, callOptions ...callopt.Option) (r *manage.GetPromptResponse, err error) BatchGetPrompt(ctx context.Context, request *manage.BatchGetPromptRequest, callOptions ...callopt.Option) (r *manage.BatchGetPromptResponse, err error) ListPrompt(ctx context.Context, request *manage.ListPromptRequest, callOptions ...callopt.Option) (r *manage.ListPromptResponse, err error) + ListParentPrompt(ctx context.Context, request *manage.ListParentPromptRequest, callOptions ...callopt.Option) (r *manage.ListParentPromptResponse, err error) UpdatePrompt(ctx context.Context, request *manage.UpdatePromptRequest, callOptions ...callopt.Option) (r *manage.UpdatePromptResponse, err error) SaveDraft(ctx context.Context, request *manage.SaveDraftRequest, callOptions ...callopt.Option) (r *manage.SaveDraftResponse, err error) CreateLabel(ctx context.Context, request *manage.CreateLabelRequest, callOptions ...callopt.Option) (r *manage.CreateLabelResponse, err error) @@ -87,6 +88,11 @@ func (p *kPromptManageServiceClient) ListPrompt(ctx context.Context, request *ma return p.kClient.ListPrompt(ctx, request) } +func (p *kPromptManageServiceClient) ListParentPrompt(ctx context.Context, request *manage.ListParentPromptRequest, callOptions ...callopt.Option) (r *manage.ListParentPromptResponse, err error) { + ctx = client.NewCtxWithCallOptions(ctx, callOptions) + return p.kClient.ListParentPrompt(ctx, request) +} + func (p *kPromptManageServiceClient) UpdatePrompt(ctx context.Context, request *manage.UpdatePromptRequest, callOptions ...callopt.Option) (r *manage.UpdatePromptResponse, err error) { ctx = client.NewCtxWithCallOptions(ctx, callOptions) return p.kClient.UpdatePrompt(ctx, request) diff --git a/backend/kitex_gen/coze/loop/prompt/promptmanageservice/promptmanageservice.go b/backend/kitex_gen/coze/loop/prompt/promptmanageservice/promptmanageservice.go index 82ceac44f..f2c8b070e 100644 --- a/backend/kitex_gen/coze/loop/prompt/promptmanageservice/promptmanageservice.go +++ b/backend/kitex_gen/coze/loop/prompt/promptmanageservice/promptmanageservice.go @@ -56,6 +56,13 @@ var serviceMethods = map[string]kitex.MethodInfo{ false, kitex.WithStreamingMode(kitex.StreamingNone), ), + "ListParentPrompt": kitex.NewMethodInfo( + listParentPromptHandler, + newPromptManageServiceListParentPromptArgs, + newPromptManageServiceListParentPromptResult, + false, + kitex.WithStreamingMode(kitex.StreamingNone), + ), "UpdatePrompt": kitex.NewMethodInfo( updatePromptHandler, newPromptManageServiceUpdatePromptArgs, @@ -266,6 +273,25 @@ func newPromptManageServiceListPromptResult() interface{} { return manage.NewPromptManageServiceListPromptResult() } +func listParentPromptHandler(ctx context.Context, handler interface{}, arg, result interface{}) error { + realArg := arg.(*manage.PromptManageServiceListParentPromptArgs) + realResult := result.(*manage.PromptManageServiceListParentPromptResult) + success, err := handler.(manage.PromptManageService).ListParentPrompt(ctx, realArg.Request) + if err != nil { + return err + } + realResult.Success = success + return nil +} + +func newPromptManageServiceListParentPromptArgs() interface{} { + return manage.NewPromptManageServiceListParentPromptArgs() +} + +func newPromptManageServiceListParentPromptResult() interface{} { + return manage.NewPromptManageServiceListParentPromptResult() +} + func updatePromptHandler(ctx context.Context, handler interface{}, arg, result interface{}) error { realArg := arg.(*manage.PromptManageServiceUpdatePromptArgs) realResult := result.(*manage.PromptManageServiceUpdatePromptResult) @@ -509,6 +535,16 @@ func (p *kClient) ListPrompt(ctx context.Context, request *manage.ListPromptRequ return _result.GetSuccess(), nil } +func (p *kClient) ListParentPrompt(ctx context.Context, request *manage.ListParentPromptRequest) (r *manage.ListParentPromptResponse, err error) { + var _args manage.PromptManageServiceListParentPromptArgs + _args.Request = request + var _result manage.PromptManageServiceListParentPromptResult + if err = p.c.Call(ctx, "ListParentPrompt", &_args, &_result); err != nil { + return + } + return _result.GetSuccess(), nil +} + func (p *kClient) UpdatePrompt(ctx context.Context, request *manage.UpdatePromptRequest) (r *manage.UpdatePromptResponse, err error) { var _args manage.PromptManageServiceUpdatePromptArgs _args.Request = request diff --git a/backend/loop_gen/coze/loop/prompt/lomanage/local_promptmanageservice.go b/backend/loop_gen/coze/loop/prompt/lomanage/local_promptmanageservice.go index e5bfc95d6..74d233021 100644 --- a/backend/loop_gen/coze/loop/prompt/lomanage/local_promptmanageservice.go +++ b/backend/loop_gen/coze/loop/prompt/lomanage/local_promptmanageservice.go @@ -155,6 +155,29 @@ func (l *LocalPromptManageService) ListPrompt(ctx context.Context, request *mana return result.GetSuccess(), nil } +// ListParentPrompt +// 查询片段的引用记录 +func (l *LocalPromptManageService) ListParentPrompt(ctx context.Context, request *manage.ListParentPromptRequest, callOptions ...callopt.Option) (*manage.ListParentPromptResponse, error) { + chain := l.mds(func(ctx context.Context, in, out interface{}) error { + arg := in.(*manage.PromptManageServiceListParentPromptArgs) + result := out.(*manage.PromptManageServiceListParentPromptResult) + resp, err := l.impl.ListParentPrompt(ctx, arg.Request) + if err != nil { + return err + } + result.SetSuccess(resp) + return nil + }) + + arg := &manage.PromptManageServiceListParentPromptArgs{Request: request} + result := &manage.PromptManageServiceListParentPromptResult{} + ctx = l.injectRPCInfo(ctx, "ListParentPrompt") + if err := chain(ctx, arg, result); err != nil { + return nil, err + } + return result.GetSuccess(), nil +} + // UpdatePrompt // 改 func (l *LocalPromptManageService) UpdatePrompt(ctx context.Context, request *manage.UpdatePromptRequest, callOptions ...callopt.Option) (*manage.UpdatePromptResponse, error) { diff --git a/backend/modules/data/domain/component/conf/mocks/conf.go b/backend/modules/data/domain/component/conf/mocks/conf.go index f450af452..dbbd91059 100644 --- a/backend/modules/data/domain/component/conf/mocks/conf.go +++ b/backend/modules/data/domain/component/conf/mocks/conf.go @@ -20,6 +20,7 @@ import ( type MockIConfig struct { ctrl *gomock.Controller recorder *MockIConfigMockRecorder + isgomock struct{} } // MockIConfigMockRecorder is the mock recorder for MockIConfig. diff --git a/backend/modules/data/domain/component/rpc/mocks/user_provider.go b/backend/modules/data/domain/component/rpc/mocks/user_provider.go index 699798a92..34e5f3ebd 100644 --- a/backend/modules/data/domain/component/rpc/mocks/user_provider.go +++ b/backend/modules/data/domain/component/rpc/mocks/user_provider.go @@ -21,6 +21,7 @@ import ( type MockIUserProvider struct { ctrl *gomock.Controller recorder *MockIUserProviderMockRecorder + isgomock struct{} } // MockIUserProviderMockRecorder is the mock recorder for MockIUserProvider. @@ -41,16 +42,16 @@ func (m *MockIUserProvider) EXPECT() *MockIUserProviderMockRecorder { } // MGetUserInfo mocks base method. -func (m *MockIUserProvider) MGetUserInfo(arg0 context.Context, arg1 []string) ([]*entity.UserInfo, error) { +func (m *MockIUserProvider) MGetUserInfo(ctx context.Context, userIDs []string) ([]*entity.UserInfo, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MGetUserInfo", arg0, arg1) + ret := m.ctrl.Call(m, "MGetUserInfo", ctx, userIDs) ret0, _ := ret[0].([]*entity.UserInfo) ret1, _ := ret[1].(error) return ret0, ret1 } // MGetUserInfo indicates an expected call of MGetUserInfo. -func (mr *MockIUserProviderMockRecorder) MGetUserInfo(arg0, arg1 any) *gomock.Call { +func (mr *MockIUserProviderMockRecorder) MGetUserInfo(ctx, userIDs any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetUserInfo", reflect.TypeOf((*MockIUserProvider)(nil).MGetUserInfo), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetUserInfo", reflect.TypeOf((*MockIUserProvider)(nil).MGetUserInfo), ctx, userIDs) } diff --git a/backend/modules/data/domain/component/userinfo/mocks/userinfo.go b/backend/modules/data/domain/component/userinfo/mocks/userinfo.go index 0b41e6b54..cb7ed53c8 100644 --- a/backend/modules/data/domain/component/userinfo/mocks/userinfo.go +++ b/backend/modules/data/domain/component/userinfo/mocks/userinfo.go @@ -20,6 +20,7 @@ import ( type MockUserInfoService struct { ctrl *gomock.Controller recorder *MockUserInfoServiceMockRecorder + isgomock struct{} } // MockUserInfoServiceMockRecorder is the mock recorder for MockUserInfoService. @@ -40,13 +41,13 @@ func (m *MockUserInfoService) EXPECT() *MockUserInfoServiceMockRecorder { } // PackUserInfo mocks base method. -func (m *MockUserInfoService) PackUserInfo(arg0 context.Context, arg1 any) { +func (m *MockUserInfoService) PackUserInfo(ctx context.Context, userInfoCarrier any) { m.ctrl.T.Helper() - m.ctrl.Call(m, "PackUserInfo", arg0, arg1) + m.ctrl.Call(m, "PackUserInfo", ctx, userInfoCarrier) } // PackUserInfo indicates an expected call of PackUserInfo. -func (mr *MockUserInfoServiceMockRecorder) PackUserInfo(arg0, arg1 any) *gomock.Call { +func (mr *MockUserInfoServiceMockRecorder) PackUserInfo(ctx, userInfoCarrier any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackUserInfo", reflect.TypeOf((*MockUserInfoService)(nil).PackUserInfo), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackUserInfo", reflect.TypeOf((*MockUserInfoService)(nil).PackUserInfo), ctx, userInfoCarrier) } diff --git a/backend/modules/data/domain/tag/repo/mocks/tag_mock.go b/backend/modules/data/domain/tag/repo/mocks/tag_mock.go index 813cbccdf..c4460d262 100644 --- a/backend/modules/data/domain/tag/repo/mocks/tag_mock.go +++ b/backend/modules/data/domain/tag/repo/mocks/tag_mock.go @@ -23,6 +23,7 @@ import ( type MockITagAPI struct { ctrl *gomock.Controller recorder *MockITagAPIMockRecorder + isgomock struct{} } // MockITagAPIMockRecorder is the mock recorder for MockITagAPI. @@ -43,10 +44,10 @@ func (m *MockITagAPI) EXPECT() *MockITagAPIMockRecorder { } // CountTagKeys mocks base method. -func (m *MockITagAPI) CountTagKeys(arg0 context.Context, arg1 *entity.MGetTagKeyParam, arg2 ...db.Option) (int64, error) { +func (m *MockITagAPI) CountTagKeys(ctx context.Context, param *entity.MGetTagKeyParam, opts ...db.Option) (int64, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, param} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "CountTagKeys", varargs...) @@ -56,17 +57,17 @@ func (m *MockITagAPI) CountTagKeys(arg0 context.Context, arg1 *entity.MGetTagKey } // CountTagKeys indicates an expected call of CountTagKeys. -func (mr *MockITagAPIMockRecorder) CountTagKeys(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockITagAPIMockRecorder) CountTagKeys(ctx, param any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, param}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountTagKeys", reflect.TypeOf((*MockITagAPI)(nil).CountTagKeys), varargs...) } // DeleteTagKey mocks base method. -func (m *MockITagAPI) DeleteTagKey(arg0 context.Context, arg1, arg2 int64, arg3 ...db.Option) error { +func (m *MockITagAPI) DeleteTagKey(ctx context.Context, spaceID, id int64, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []any{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, spaceID, id} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "DeleteTagKey", varargs...) @@ -75,17 +76,17 @@ func (m *MockITagAPI) DeleteTagKey(arg0 context.Context, arg1, arg2 int64, arg3 } // DeleteTagKey indicates an expected call of DeleteTagKey. -func (mr *MockITagAPIMockRecorder) DeleteTagKey(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { +func (mr *MockITagAPIMockRecorder) DeleteTagKey(ctx, spaceID, id any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, spaceID, id}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTagKey", reflect.TypeOf((*MockITagAPI)(nil).DeleteTagKey), varargs...) } // DeleteTagValue mocks base method. -func (m *MockITagAPI) DeleteTagValue(arg0 context.Context, arg1, arg2 int64, arg3 ...db.Option) error { +func (m *MockITagAPI) DeleteTagValue(ctx context.Context, spaceID, id int64, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []any{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, spaceID, id} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "DeleteTagValue", varargs...) @@ -94,17 +95,17 @@ func (m *MockITagAPI) DeleteTagValue(arg0 context.Context, arg1, arg2 int64, arg } // DeleteTagValue indicates an expected call of DeleteTagValue. -func (mr *MockITagAPIMockRecorder) DeleteTagValue(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { +func (mr *MockITagAPIMockRecorder) DeleteTagValue(ctx, spaceID, id any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, spaceID, id}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTagValue", reflect.TypeOf((*MockITagAPI)(nil).DeleteTagValue), varargs...) } // GetTagKey mocks base method. -func (m *MockITagAPI) GetTagKey(arg0 context.Context, arg1, arg2 int64, arg3 ...db.Option) (*entity.TagKey, error) { +func (m *MockITagAPI) GetTagKey(ctx context.Context, spaceID, id int64, opts ...db.Option) (*entity.TagKey, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, spaceID, id} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetTagKey", varargs...) @@ -114,17 +115,17 @@ func (m *MockITagAPI) GetTagKey(arg0 context.Context, arg1, arg2 int64, arg3 ... } // GetTagKey indicates an expected call of GetTagKey. -func (mr *MockITagAPIMockRecorder) GetTagKey(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { +func (mr *MockITagAPIMockRecorder) GetTagKey(ctx, spaceID, id any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, spaceID, id}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTagKey", reflect.TypeOf((*MockITagAPI)(nil).GetTagKey), varargs...) } // GetTagValue mocks base method. -func (m *MockITagAPI) GetTagValue(arg0 context.Context, arg1, arg2 int64, arg3 ...db.Option) (*entity.TagValue, error) { +func (m *MockITagAPI) GetTagValue(ctx context.Context, spaceID, id int64, opts ...db.Option) (*entity.TagValue, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, spaceID, id} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetTagValue", varargs...) @@ -134,17 +135,17 @@ func (m *MockITagAPI) GetTagValue(arg0 context.Context, arg1, arg2 int64, arg3 . } // GetTagValue indicates an expected call of GetTagValue. -func (mr *MockITagAPIMockRecorder) GetTagValue(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { +func (mr *MockITagAPIMockRecorder) GetTagValue(ctx, spaceID, id any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, spaceID, id}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTagValue", reflect.TypeOf((*MockITagAPI)(nil).GetTagValue), varargs...) } // MCreateTagKeys mocks base method. -func (m *MockITagAPI) MCreateTagKeys(arg0 context.Context, arg1 []*entity.TagKey, arg2 ...db.Option) error { +func (m *MockITagAPI) MCreateTagKeys(ctx context.Context, val []*entity.TagKey, opt ...db.Option) error { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, val} + for _, a := range opt { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "MCreateTagKeys", varargs...) @@ -153,17 +154,17 @@ func (m *MockITagAPI) MCreateTagKeys(arg0 context.Context, arg1 []*entity.TagKey } // MCreateTagKeys indicates an expected call of MCreateTagKeys. -func (mr *MockITagAPIMockRecorder) MCreateTagKeys(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockITagAPIMockRecorder) MCreateTagKeys(ctx, val any, opt ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, val}, opt...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MCreateTagKeys", reflect.TypeOf((*MockITagAPI)(nil).MCreateTagKeys), varargs...) } // MCreateTagValues mocks base method. -func (m *MockITagAPI) MCreateTagValues(arg0 context.Context, arg1 []*entity.TagValue, arg2 ...db.Option) error { +func (m *MockITagAPI) MCreateTagValues(ctx context.Context, val []*entity.TagValue, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, val} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "MCreateTagValues", varargs...) @@ -172,17 +173,17 @@ func (m *MockITagAPI) MCreateTagValues(arg0 context.Context, arg1 []*entity.TagV } // MCreateTagValues indicates an expected call of MCreateTagValues. -func (mr *MockITagAPIMockRecorder) MCreateTagValues(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockITagAPIMockRecorder) MCreateTagValues(ctx, val any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, val}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MCreateTagValues", reflect.TypeOf((*MockITagAPI)(nil).MCreateTagValues), varargs...) } // MGetTagKeys mocks base method. -func (m *MockITagAPI) MGetTagKeys(arg0 context.Context, arg1 *entity.MGetTagKeyParam, arg2 ...db.Option) ([]*entity.TagKey, *pagination.PageResult, error) { +func (m *MockITagAPI) MGetTagKeys(ctx context.Context, param *entity.MGetTagKeyParam, opts ...db.Option) ([]*entity.TagKey, *pagination.PageResult, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, param} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "MGetTagKeys", varargs...) @@ -193,17 +194,17 @@ func (m *MockITagAPI) MGetTagKeys(arg0 context.Context, arg1 *entity.MGetTagKeyP } // MGetTagKeys indicates an expected call of MGetTagKeys. -func (mr *MockITagAPIMockRecorder) MGetTagKeys(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockITagAPIMockRecorder) MGetTagKeys(ctx, param any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, param}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetTagKeys", reflect.TypeOf((*MockITagAPI)(nil).MGetTagKeys), varargs...) } // MGetTagValue mocks base method. -func (m *MockITagAPI) MGetTagValue(arg0 context.Context, arg1 *entity.MGetTagValueParam, arg2 ...db.Option) ([]*entity.TagValue, *pagination.PageResult, error) { +func (m *MockITagAPI) MGetTagValue(ctx context.Context, param *entity.MGetTagValueParam, opts ...db.Option) ([]*entity.TagValue, *pagination.PageResult, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, param} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "MGetTagValue", varargs...) @@ -214,17 +215,17 @@ func (m *MockITagAPI) MGetTagValue(arg0 context.Context, arg1 *entity.MGetTagVal } // MGetTagValue indicates an expected call of MGetTagValue. -func (mr *MockITagAPIMockRecorder) MGetTagValue(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockITagAPIMockRecorder) MGetTagValue(ctx, param any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, param}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetTagValue", reflect.TypeOf((*MockITagAPI)(nil).MGetTagValue), varargs...) } // PatchTagKey mocks base method. -func (m *MockITagAPI) PatchTagKey(arg0 context.Context, arg1, arg2 int64, arg3 *entity.TagKey, arg4 ...db.Option) error { +func (m *MockITagAPI) PatchTagKey(ctx context.Context, spaceID, id int64, patch *entity.TagKey, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []any{arg0, arg1, arg2, arg3} - for _, a := range arg4 { + varargs := []any{ctx, spaceID, id, patch} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "PatchTagKey", varargs...) @@ -233,17 +234,17 @@ func (m *MockITagAPI) PatchTagKey(arg0 context.Context, arg1, arg2 int64, arg3 * } // PatchTagKey indicates an expected call of PatchTagKey. -func (mr *MockITagAPIMockRecorder) PatchTagKey(arg0, arg1, arg2, arg3 any, arg4 ...any) *gomock.Call { +func (mr *MockITagAPIMockRecorder) PatchTagKey(ctx, spaceID, id, patch any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1, arg2, arg3}, arg4...) + varargs := append([]any{ctx, spaceID, id, patch}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PatchTagKey", reflect.TypeOf((*MockITagAPI)(nil).PatchTagKey), varargs...) } // PatchTagValue mocks base method. -func (m *MockITagAPI) PatchTagValue(arg0 context.Context, arg1, arg2 int64, arg3 *entity.TagValue, arg4 ...db.Option) error { +func (m *MockITagAPI) PatchTagValue(ctx context.Context, spaceID, id int64, patch *entity.TagValue, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []any{arg0, arg1, arg2, arg3} - for _, a := range arg4 { + varargs := []any{ctx, spaceID, id, patch} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "PatchTagValue", varargs...) @@ -252,17 +253,17 @@ func (m *MockITagAPI) PatchTagValue(arg0 context.Context, arg1, arg2 int64, arg3 } // PatchTagValue indicates an expected call of PatchTagValue. -func (mr *MockITagAPIMockRecorder) PatchTagValue(arg0, arg1, arg2, arg3 any, arg4 ...any) *gomock.Call { +func (mr *MockITagAPIMockRecorder) PatchTagValue(ctx, spaceID, id, patch any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1, arg2, arg3}, arg4...) + varargs := append([]any{ctx, spaceID, id, patch}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PatchTagValue", reflect.TypeOf((*MockITagAPI)(nil).PatchTagValue), varargs...) } // UpdateTagKeysStatus mocks base method. -func (m *MockITagAPI) UpdateTagKeysStatus(arg0 context.Context, arg1, arg2 int64, arg3 int32, arg4 entity.TagStatus, arg5 bool, arg6 ...db.Option) error { +func (m *MockITagAPI) UpdateTagKeysStatus(ctx context.Context, spaceID, tagKeyID int64, versionNum int32, toStatus entity.TagStatus, updateInfo bool, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []any{arg0, arg1, arg2, arg3, arg4, arg5} - for _, a := range arg6 { + varargs := []any{ctx, spaceID, tagKeyID, versionNum, toStatus, updateInfo} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "UpdateTagKeysStatus", varargs...) @@ -271,17 +272,17 @@ func (m *MockITagAPI) UpdateTagKeysStatus(arg0 context.Context, arg1, arg2 int64 } // UpdateTagKeysStatus indicates an expected call of UpdateTagKeysStatus. -func (mr *MockITagAPIMockRecorder) UpdateTagKeysStatus(arg0, arg1, arg2, arg3, arg4, arg5 any, arg6 ...any) *gomock.Call { +func (mr *MockITagAPIMockRecorder) UpdateTagKeysStatus(ctx, spaceID, tagKeyID, versionNum, toStatus, updateInfo any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1, arg2, arg3, arg4, arg5}, arg6...) + varargs := append([]any{ctx, spaceID, tagKeyID, versionNum, toStatus, updateInfo}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTagKeysStatus", reflect.TypeOf((*MockITagAPI)(nil).UpdateTagKeysStatus), varargs...) } // UpdateTagValuesStatus mocks base method. -func (m *MockITagAPI) UpdateTagValuesStatus(arg0 context.Context, arg1, arg2 int64, arg3 int32, arg4 entity.TagStatus, arg5 bool, arg6 ...db.Option) error { +func (m *MockITagAPI) UpdateTagValuesStatus(ctx context.Context, spaceID, tagKeyID int64, versionNum int32, toStatus entity.TagStatus, updateInfo bool, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []any{arg0, arg1, arg2, arg3, arg4, arg5} - for _, a := range arg6 { + varargs := []any{ctx, spaceID, tagKeyID, versionNum, toStatus, updateInfo} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "UpdateTagValuesStatus", varargs...) @@ -290,8 +291,8 @@ func (m *MockITagAPI) UpdateTagValuesStatus(arg0 context.Context, arg1, arg2 int } // UpdateTagValuesStatus indicates an expected call of UpdateTagValuesStatus. -func (mr *MockITagAPIMockRecorder) UpdateTagValuesStatus(arg0, arg1, arg2, arg3, arg4, arg5 any, arg6 ...any) *gomock.Call { +func (mr *MockITagAPIMockRecorder) UpdateTagValuesStatus(ctx, spaceID, tagKeyID, versionNum, toStatus, updateInfo any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1, arg2, arg3, arg4, arg5}, arg6...) + varargs := append([]any{ctx, spaceID, tagKeyID, versionNum, toStatus, updateInfo}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTagValuesStatus", reflect.TypeOf((*MockITagAPI)(nil).UpdateTagValuesStatus), varargs...) } diff --git a/backend/modules/data/domain/tag/service/mocks/tag_service_mock.go b/backend/modules/data/domain/tag/service/mocks/tag_service_mock.go index dcb71ea55..d7128a8cf 100644 --- a/backend/modules/data/domain/tag/service/mocks/tag_service_mock.go +++ b/backend/modules/data/domain/tag/service/mocks/tag_service_mock.go @@ -23,6 +23,7 @@ import ( type MockITagService struct { ctrl *gomock.Controller recorder *MockITagServiceMockRecorder + isgomock struct{} } // MockITagServiceMockRecorder is the mock recorder for MockITagService. @@ -43,10 +44,10 @@ func (m *MockITagService) EXPECT() *MockITagServiceMockRecorder { } // ArchiveOptionTag mocks base method. -func (m *MockITagService) ArchiveOptionTag(arg0 context.Context, arg1, arg2 int64, arg3 *entity.TagKey, arg4 ...db.Option) error { +func (m *MockITagService) ArchiveOptionTag(ctx context.Context, spaceID, tagKeyID int64, val *entity.TagKey, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []any{arg0, arg1, arg2, arg3} - for _, a := range arg4 { + varargs := []any{ctx, spaceID, tagKeyID, val} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "ArchiveOptionTag", varargs...) @@ -55,47 +56,47 @@ func (m *MockITagService) ArchiveOptionTag(arg0 context.Context, arg1, arg2 int6 } // ArchiveOptionTag indicates an expected call of ArchiveOptionTag. -func (mr *MockITagServiceMockRecorder) ArchiveOptionTag(arg0, arg1, arg2, arg3 any, arg4 ...any) *gomock.Call { +func (mr *MockITagServiceMockRecorder) ArchiveOptionTag(ctx, spaceID, tagKeyID, val any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1, arg2, arg3}, arg4...) + varargs := append([]any{ctx, spaceID, tagKeyID, val}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ArchiveOptionTag", reflect.TypeOf((*MockITagService)(nil).ArchiveOptionTag), varargs...) } // BatchGetTagsByTagKeyIDs mocks base method. -func (m *MockITagService) BatchGetTagsByTagKeyIDs(arg0 context.Context, arg1 int64, arg2 []int64) ([]*entity.TagKey, error) { +func (m *MockITagService) BatchGetTagsByTagKeyIDs(ctx context.Context, spaceID int64, tagKeyIDs []int64) ([]*entity.TagKey, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BatchGetTagsByTagKeyIDs", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "BatchGetTagsByTagKeyIDs", ctx, spaceID, tagKeyIDs) ret0, _ := ret[0].([]*entity.TagKey) ret1, _ := ret[1].(error) return ret0, ret1 } // BatchGetTagsByTagKeyIDs indicates an expected call of BatchGetTagsByTagKeyIDs. -func (mr *MockITagServiceMockRecorder) BatchGetTagsByTagKeyIDs(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockITagServiceMockRecorder) BatchGetTagsByTagKeyIDs(ctx, spaceID, tagKeyIDs any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetTagsByTagKeyIDs", reflect.TypeOf((*MockITagService)(nil).BatchGetTagsByTagKeyIDs), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetTagsByTagKeyIDs", reflect.TypeOf((*MockITagService)(nil).BatchGetTagsByTagKeyIDs), ctx, spaceID, tagKeyIDs) } // BatchUpdateTagStatus mocks base method. -func (m *MockITagService) BatchUpdateTagStatus(arg0 context.Context, arg1 int64, arg2 []int64, arg3 entity.TagStatus) (map[int64]string, error) { +func (m *MockITagService) BatchUpdateTagStatus(ctx context.Context, spaceID int64, tagKeyIDs []int64, toStatus entity.TagStatus) (map[int64]string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BatchUpdateTagStatus", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "BatchUpdateTagStatus", ctx, spaceID, tagKeyIDs, toStatus) ret0, _ := ret[0].(map[int64]string) ret1, _ := ret[1].(error) return ret0, ret1 } // BatchUpdateTagStatus indicates an expected call of BatchUpdateTagStatus. -func (mr *MockITagServiceMockRecorder) BatchUpdateTagStatus(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockITagServiceMockRecorder) BatchUpdateTagStatus(ctx, spaceID, tagKeyIDs, toStatus any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpdateTagStatus", reflect.TypeOf((*MockITagService)(nil).BatchUpdateTagStatus), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpdateTagStatus", reflect.TypeOf((*MockITagService)(nil).BatchUpdateTagStatus), ctx, spaceID, tagKeyIDs, toStatus) } // CreateTag mocks base method. -func (m *MockITagService) CreateTag(arg0 context.Context, arg1 int64, arg2 *entity.TagKey, arg3 ...db.Option) (int64, error) { +func (m *MockITagService) CreateTag(ctx context.Context, spaceID int64, val *entity.TagKey, opts ...db.Option) (int64, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, spaceID, val} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "CreateTag", varargs...) @@ -105,17 +106,17 @@ func (m *MockITagService) CreateTag(arg0 context.Context, arg1 int64, arg2 *enti } // CreateTag indicates an expected call of CreateTag. -func (mr *MockITagServiceMockRecorder) CreateTag(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { +func (mr *MockITagServiceMockRecorder) CreateTag(ctx, spaceID, val any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, spaceID, val}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTag", reflect.TypeOf((*MockITagService)(nil).CreateTag), varargs...) } // GetAllTagKeyVersionsByKeyID mocks base method. -func (m *MockITagService) GetAllTagKeyVersionsByKeyID(arg0 context.Context, arg1, arg2 int64, arg3 ...db.Option) ([]*entity.TagKey, error) { +func (m *MockITagService) GetAllTagKeyVersionsByKeyID(ctx context.Context, spaceID, tagKeyID int64, opts ...db.Option) ([]*entity.TagKey, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, spaceID, tagKeyID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetAllTagKeyVersionsByKeyID", varargs...) @@ -125,17 +126,17 @@ func (m *MockITagService) GetAllTagKeyVersionsByKeyID(arg0 context.Context, arg1 } // GetAllTagKeyVersionsByKeyID indicates an expected call of GetAllTagKeyVersionsByKeyID. -func (mr *MockITagServiceMockRecorder) GetAllTagKeyVersionsByKeyID(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { +func (mr *MockITagServiceMockRecorder) GetAllTagKeyVersionsByKeyID(ctx, spaceID, tagKeyID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, spaceID, tagKeyID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllTagKeyVersionsByKeyID", reflect.TypeOf((*MockITagService)(nil).GetAllTagKeyVersionsByKeyID), varargs...) } // GetAndBuildTagValues mocks base method. -func (m *MockITagService) GetAndBuildTagValues(arg0 context.Context, arg1, arg2 int64, arg3 int32, arg4 ...db.Option) ([]*entity.TagValue, error) { +func (m *MockITagService) GetAndBuildTagValues(ctx context.Context, spaceID, tagKeyID int64, versionNum int32, opts ...db.Option) ([]*entity.TagValue, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1, arg2, arg3} - for _, a := range arg4 { + varargs := []any{ctx, spaceID, tagKeyID, versionNum} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetAndBuildTagValues", varargs...) @@ -145,17 +146,17 @@ func (m *MockITagService) GetAndBuildTagValues(arg0 context.Context, arg1, arg2 } // GetAndBuildTagValues indicates an expected call of GetAndBuildTagValues. -func (mr *MockITagServiceMockRecorder) GetAndBuildTagValues(arg0, arg1, arg2, arg3 any, arg4 ...any) *gomock.Call { +func (mr *MockITagServiceMockRecorder) GetAndBuildTagValues(ctx, spaceID, tagKeyID, versionNum any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1, arg2, arg3}, arg4...) + varargs := append([]any{ctx, spaceID, tagKeyID, versionNum}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAndBuildTagValues", reflect.TypeOf((*MockITagService)(nil).GetAndBuildTagValues), varargs...) } // GetLatestTag mocks base method. -func (m *MockITagService) GetLatestTag(arg0 context.Context, arg1, arg2 int64, arg3 ...db.Option) (*entity.TagKey, error) { +func (m *MockITagService) GetLatestTag(ctx context.Context, spaceID, tagKeyID int64, opts ...db.Option) (*entity.TagKey, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, spaceID, tagKeyID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetLatestTag", varargs...) @@ -165,31 +166,31 @@ func (m *MockITagService) GetLatestTag(arg0 context.Context, arg1, arg2 int64, a } // GetLatestTag indicates an expected call of GetLatestTag. -func (mr *MockITagServiceMockRecorder) GetLatestTag(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { +func (mr *MockITagServiceMockRecorder) GetLatestTag(ctx, spaceID, tagKeyID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, spaceID, tagKeyID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLatestTag", reflect.TypeOf((*MockITagService)(nil).GetLatestTag), varargs...) } // GetTagDetail mocks base method. -func (m *MockITagService) GetTagDetail(arg0 context.Context, arg1 int64, arg2 *entity.GetTagDetailReq) (*entity.GetTagDetailResp, error) { +func (m *MockITagService) GetTagDetail(ctx context.Context, spaceID int64, param *entity.GetTagDetailReq) (*entity.GetTagDetailResp, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTagDetail", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetTagDetail", ctx, spaceID, param) ret0, _ := ret[0].(*entity.GetTagDetailResp) ret1, _ := ret[1].(error) return ret0, ret1 } // GetTagDetail indicates an expected call of GetTagDetail. -func (mr *MockITagServiceMockRecorder) GetTagDetail(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockITagServiceMockRecorder) GetTagDetail(ctx, spaceID, param any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTagDetail", reflect.TypeOf((*MockITagService)(nil).GetTagDetail), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTagDetail", reflect.TypeOf((*MockITagService)(nil).GetTagDetail), ctx, spaceID, param) } // GetTagSpec mocks base method. -func (m *MockITagService) GetTagSpec(arg0 context.Context, arg1 int64) (int64, int64, int64, error) { +func (m *MockITagService) GetTagSpec(ctx context.Context, spaceID int64) (int64, int64, int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTagSpec", arg0, arg1) + ret := m.ctrl.Call(m, "GetTagSpec", ctx, spaceID) ret0, _ := ret[0].(int64) ret1, _ := ret[1].(int64) ret2, _ := ret[2].(int64) @@ -198,15 +199,15 @@ func (m *MockITagService) GetTagSpec(arg0 context.Context, arg1 int64) (int64, i } // GetTagSpec indicates an expected call of GetTagSpec. -func (mr *MockITagServiceMockRecorder) GetTagSpec(arg0, arg1 any) *gomock.Call { +func (mr *MockITagServiceMockRecorder) GetTagSpec(ctx, spaceID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTagSpec", reflect.TypeOf((*MockITagService)(nil).GetTagSpec), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTagSpec", reflect.TypeOf((*MockITagService)(nil).GetTagSpec), ctx, spaceID) } // SearchTags mocks base method. -func (m *MockITagService) SearchTags(arg0 context.Context, arg1 int64, arg2 *entity.MGetTagKeyParam) ([]*entity.TagKey, *pagination.PageResult, error) { +func (m *MockITagService) SearchTags(ctx context.Context, spaceID int64, param *entity.MGetTagKeyParam) ([]*entity.TagKey, *pagination.PageResult, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SearchTags", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "SearchTags", ctx, spaceID, param) ret0, _ := ret[0].([]*entity.TagKey) ret1, _ := ret[1].(*pagination.PageResult) ret2, _ := ret[2].(error) @@ -214,16 +215,16 @@ func (m *MockITagService) SearchTags(arg0 context.Context, arg1 int64, arg2 *ent } // SearchTags indicates an expected call of SearchTags. -func (mr *MockITagServiceMockRecorder) SearchTags(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockITagServiceMockRecorder) SearchTags(ctx, spaceID, param any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchTags", reflect.TypeOf((*MockITagService)(nil).SearchTags), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchTags", reflect.TypeOf((*MockITagService)(nil).SearchTags), ctx, spaceID, param) } // UpdateOptionTag mocks base method. -func (m *MockITagService) UpdateOptionTag(arg0 context.Context, arg1, arg2 int64, arg3 *entity.TagKey, arg4 ...db.Option) error { +func (m *MockITagService) UpdateOptionTag(ctx context.Context, spaceID, tagKeyID int64, val *entity.TagKey, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []any{arg0, arg1, arg2, arg3} - for _, a := range arg4 { + varargs := []any{ctx, spaceID, tagKeyID, val} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "UpdateOptionTag", varargs...) @@ -232,17 +233,17 @@ func (m *MockITagService) UpdateOptionTag(arg0 context.Context, arg1, arg2 int64 } // UpdateOptionTag indicates an expected call of UpdateOptionTag. -func (mr *MockITagServiceMockRecorder) UpdateOptionTag(arg0, arg1, arg2, arg3 any, arg4 ...any) *gomock.Call { +func (mr *MockITagServiceMockRecorder) UpdateOptionTag(ctx, spaceID, tagKeyID, val any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1, arg2, arg3}, arg4...) + varargs := append([]any{ctx, spaceID, tagKeyID, val}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateOptionTag", reflect.TypeOf((*MockITagService)(nil).UpdateOptionTag), varargs...) } // UpdateTag mocks base method. -func (m *MockITagService) UpdateTag(arg0 context.Context, arg1, arg2 int64, arg3 *entity.TagKey, arg4 ...db.Option) error { +func (m *MockITagService) UpdateTag(ctx context.Context, spaceID, tagKeyID int64, val *entity.TagKey, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []any{arg0, arg1, arg2, arg3} - for _, a := range arg4 { + varargs := []any{ctx, spaceID, tagKeyID, val} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "UpdateTag", varargs...) @@ -251,17 +252,17 @@ func (m *MockITagService) UpdateTag(arg0 context.Context, arg1, arg2 int64, arg3 } // UpdateTag indicates an expected call of UpdateTag. -func (mr *MockITagServiceMockRecorder) UpdateTag(arg0, arg1, arg2, arg3 any, arg4 ...any) *gomock.Call { +func (mr *MockITagServiceMockRecorder) UpdateTag(ctx, spaceID, tagKeyID, val any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1, arg2, arg3}, arg4...) + varargs := append([]any{ctx, spaceID, tagKeyID, val}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTag", reflect.TypeOf((*MockITagService)(nil).UpdateTag), varargs...) } // UpdateTagStatus mocks base method. -func (m *MockITagService) UpdateTagStatus(arg0 context.Context, arg1, arg2 int64, arg3 int32, arg4 entity.TagStatus, arg5, arg6 bool, arg7 ...db.Option) error { +func (m *MockITagService) UpdateTagStatus(ctx context.Context, spaceID, tagKeyID int64, versionNum int32, status entity.TagStatus, needLock, updatedInfo bool, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []any{arg0, arg1, arg2, arg3, arg4, arg5, arg6} - for _, a := range arg7 { + varargs := []any{ctx, spaceID, tagKeyID, versionNum, status, needLock, updatedInfo} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "UpdateTagStatus", varargs...) @@ -270,22 +271,22 @@ func (m *MockITagService) UpdateTagStatus(arg0 context.Context, arg1, arg2 int64 } // UpdateTagStatus indicates an expected call of UpdateTagStatus. -func (mr *MockITagServiceMockRecorder) UpdateTagStatus(arg0, arg1, arg2, arg3, arg4, arg5, arg6 any, arg7 ...any) *gomock.Call { +func (mr *MockITagServiceMockRecorder) UpdateTagStatus(ctx, spaceID, tagKeyID, versionNum, status, needLock, updatedInfo any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1, arg2, arg3, arg4, arg5, arg6}, arg7...) + varargs := append([]any{ctx, spaceID, tagKeyID, versionNum, status, needLock, updatedInfo}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTagStatus", reflect.TypeOf((*MockITagService)(nil).UpdateTagStatus), varargs...) } // UpdateTagStatusWithNewVersion mocks base method. -func (m *MockITagService) UpdateTagStatusWithNewVersion(arg0 context.Context, arg1, arg2 int64, arg3 entity.TagStatus) error { +func (m *MockITagService) UpdateTagStatusWithNewVersion(ctx context.Context, spaceID, tagKeyID int64, status entity.TagStatus) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateTagStatusWithNewVersion", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "UpdateTagStatusWithNewVersion", ctx, spaceID, tagKeyID, status) ret0, _ := ret[0].(error) return ret0 } // UpdateTagStatusWithNewVersion indicates an expected call of UpdateTagStatusWithNewVersion. -func (mr *MockITagServiceMockRecorder) UpdateTagStatusWithNewVersion(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockITagServiceMockRecorder) UpdateTagStatusWithNewVersion(ctx, spaceID, tagKeyID, status any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTagStatusWithNewVersion", reflect.TypeOf((*MockITagService)(nil).UpdateTagStatusWithNewVersion), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTagStatusWithNewVersion", reflect.TypeOf((*MockITagService)(nil).UpdateTagStatusWithNewVersion), ctx, spaceID, tagKeyID, status) } diff --git a/backend/modules/evaluation/domain/entity/mocks/expt_scheduler_mock.go b/backend/modules/evaluation/domain/entity/mocks/expt_scheduler_mock.go index 9e2eb4a87..9d4b6e419 100644 --- a/backend/modules/evaluation/domain/entity/mocks/expt_scheduler_mock.go +++ b/backend/modules/evaluation/domain/entity/mocks/expt_scheduler_mock.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/domain/entity (interfaces: ExptSchedulerMode) +// +// Generated by this command: +// +// mockgen -destination ./mocks/expt_scheduler_mock.go --package mocks . ExptSchedulerMode +// // Package mocks is a generated GoMock package. package mocks @@ -16,6 +21,7 @@ import ( type MockExptSchedulerMode struct { ctrl *gomock.Controller recorder *MockExptSchedulerModeMockRecorder + isgomock struct{} } // MockExptSchedulerModeMockRecorder is the mock recorder for MockExptSchedulerMode. @@ -36,32 +42,32 @@ func (m *MockExptSchedulerMode) EXPECT() *MockExptSchedulerModeMockRecorder { } // ExptEnd mocks base method. -func (m *MockExptSchedulerMode) ExptEnd(arg0 context.Context, arg1 *entity.ExptScheduleEvent, arg2 *entity.Experiment, arg3, arg4 int) (bool, error) { +func (m *MockExptSchedulerMode) ExptEnd(ctx context.Context, event *entity.ExptScheduleEvent, expt *entity.Experiment, toSubmit, incomplete int) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ExptEnd", arg0, arg1, arg2, arg3, arg4) + ret := m.ctrl.Call(m, "ExptEnd", ctx, event, expt, toSubmit, incomplete) ret0, _ := ret[0].(bool) ret1, _ := ret[1].(error) return ret0, ret1 } // ExptEnd indicates an expected call of ExptEnd. -func (mr *MockExptSchedulerModeMockRecorder) ExptEnd(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { +func (mr *MockExptSchedulerModeMockRecorder) ExptEnd(ctx, event, expt, toSubmit, incomplete any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExptEnd", reflect.TypeOf((*MockExptSchedulerMode)(nil).ExptEnd), arg0, arg1, arg2, arg3, arg4) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExptEnd", reflect.TypeOf((*MockExptSchedulerMode)(nil).ExptEnd), ctx, event, expt, toSubmit, incomplete) } // ExptStart mocks base method. -func (m *MockExptSchedulerMode) ExptStart(arg0 context.Context, arg1 *entity.ExptScheduleEvent, arg2 *entity.Experiment) error { +func (m *MockExptSchedulerMode) ExptStart(ctx context.Context, event *entity.ExptScheduleEvent, expt *entity.Experiment) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ExptStart", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "ExptStart", ctx, event, expt) ret0, _ := ret[0].(error) return ret0 } // ExptStart indicates an expected call of ExptStart. -func (mr *MockExptSchedulerModeMockRecorder) ExptStart(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockExptSchedulerModeMockRecorder) ExptStart(ctx, event, expt any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExptStart", reflect.TypeOf((*MockExptSchedulerMode)(nil).ExptStart), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExptStart", reflect.TypeOf((*MockExptSchedulerMode)(nil).ExptStart), ctx, event, expt) } // Mode mocks base method. @@ -79,37 +85,37 @@ func (mr *MockExptSchedulerModeMockRecorder) Mode() *gomock.Call { } // NextTick mocks base method. -func (m *MockExptSchedulerMode) NextTick(arg0 context.Context, arg1 *entity.ExptScheduleEvent, arg2 bool) error { +func (m *MockExptSchedulerMode) NextTick(ctx context.Context, event *entity.ExptScheduleEvent, nextTick bool) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NextTick", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "NextTick", ctx, event, nextTick) ret0, _ := ret[0].(error) return ret0 } // NextTick indicates an expected call of NextTick. -func (mr *MockExptSchedulerModeMockRecorder) NextTick(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockExptSchedulerModeMockRecorder) NextTick(ctx, event, nextTick any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextTick", reflect.TypeOf((*MockExptSchedulerMode)(nil).NextTick), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextTick", reflect.TypeOf((*MockExptSchedulerMode)(nil).NextTick), ctx, event, nextTick) } // PublishResult mocks base method. -func (m *MockExptSchedulerMode) PublishResult(arg0 context.Context, arg1 []*entity.ExptTurnEvaluatorResultRef, arg2 *entity.ExptScheduleEvent) error { +func (m *MockExptSchedulerMode) PublishResult(ctx context.Context, turnEvaluatorRefs []*entity.ExptTurnEvaluatorResultRef, event *entity.ExptScheduleEvent) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PublishResult", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "PublishResult", ctx, turnEvaluatorRefs, event) ret0, _ := ret[0].(error) return ret0 } // PublishResult indicates an expected call of PublishResult. -func (mr *MockExptSchedulerModeMockRecorder) PublishResult(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockExptSchedulerModeMockRecorder) PublishResult(ctx, turnEvaluatorRefs, event any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishResult", reflect.TypeOf((*MockExptSchedulerMode)(nil).PublishResult), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishResult", reflect.TypeOf((*MockExptSchedulerMode)(nil).PublishResult), ctx, turnEvaluatorRefs, event) } // ScanEvalItems mocks base method. -func (m *MockExptSchedulerMode) ScanEvalItems(arg0 context.Context, arg1 *entity.ExptScheduleEvent, arg2 *entity.Experiment) ([]*entity.ExptEvalItem, []*entity.ExptEvalItem, []*entity.ExptEvalItem, error) { +func (m *MockExptSchedulerMode) ScanEvalItems(ctx context.Context, event *entity.ExptScheduleEvent, expt *entity.Experiment) ([]*entity.ExptEvalItem, []*entity.ExptEvalItem, []*entity.ExptEvalItem, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ScanEvalItems", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "ScanEvalItems", ctx, event, expt) ret0, _ := ret[0].([]*entity.ExptEvalItem) ret1, _ := ret[1].([]*entity.ExptEvalItem) ret2, _ := ret[2].([]*entity.ExptEvalItem) @@ -118,35 +124,35 @@ func (m *MockExptSchedulerMode) ScanEvalItems(arg0 context.Context, arg1 *entity } // ScanEvalItems indicates an expected call of ScanEvalItems. -func (mr *MockExptSchedulerModeMockRecorder) ScanEvalItems(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockExptSchedulerModeMockRecorder) ScanEvalItems(ctx, event, expt any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ScanEvalItems", reflect.TypeOf((*MockExptSchedulerMode)(nil).ScanEvalItems), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ScanEvalItems", reflect.TypeOf((*MockExptSchedulerMode)(nil).ScanEvalItems), ctx, event, expt) } // ScheduleEnd mocks base method. -func (m *MockExptSchedulerMode) ScheduleEnd(arg0 context.Context, arg1 *entity.ExptScheduleEvent, arg2 *entity.Experiment, arg3, arg4 int) error { +func (m *MockExptSchedulerMode) ScheduleEnd(ctx context.Context, event *entity.ExptScheduleEvent, expt *entity.Experiment, toSubmit, incomplete int) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ScheduleEnd", arg0, arg1, arg2, arg3, arg4) + ret := m.ctrl.Call(m, "ScheduleEnd", ctx, event, expt, toSubmit, incomplete) ret0, _ := ret[0].(error) return ret0 } // ScheduleEnd indicates an expected call of ScheduleEnd. -func (mr *MockExptSchedulerModeMockRecorder) ScheduleEnd(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { +func (mr *MockExptSchedulerModeMockRecorder) ScheduleEnd(ctx, event, expt, toSubmit, incomplete any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ScheduleEnd", reflect.TypeOf((*MockExptSchedulerMode)(nil).ScheduleEnd), arg0, arg1, arg2, arg3, arg4) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ScheduleEnd", reflect.TypeOf((*MockExptSchedulerMode)(nil).ScheduleEnd), ctx, event, expt, toSubmit, incomplete) } // ScheduleStart mocks base method. -func (m *MockExptSchedulerMode) ScheduleStart(arg0 context.Context, arg1 *entity.ExptScheduleEvent, arg2 *entity.Experiment) error { +func (m *MockExptSchedulerMode) ScheduleStart(ctx context.Context, event *entity.ExptScheduleEvent, expt *entity.Experiment) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ScheduleStart", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "ScheduleStart", ctx, event, expt) ret0, _ := ret[0].(error) return ret0 } // ScheduleStart indicates an expected call of ScheduleStart. -func (mr *MockExptSchedulerModeMockRecorder) ScheduleStart(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockExptSchedulerModeMockRecorder) ScheduleStart(ctx, event, expt any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ScheduleStart", reflect.TypeOf((*MockExptSchedulerMode)(nil).ScheduleStart), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ScheduleStart", reflect.TypeOf((*MockExptSchedulerMode)(nil).ScheduleStart), ctx, event, expt) } diff --git a/backend/modules/evaluation/domain/events/mocks/evaluator_event_publisher_mock.go b/backend/modules/evaluation/domain/events/mocks/evaluator_event_publisher_mock.go index 6d15b6e2b..e97b66442 100644 --- a/backend/modules/evaluation/domain/events/mocks/evaluator_event_publisher_mock.go +++ b/backend/modules/evaluation/domain/events/mocks/evaluator_event_publisher_mock.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/domain/events (interfaces: EvaluatorEventPublisher) +// +// Generated by this command: +// +// mockgen -destination mocks/evaluator_event_publisher_mock.go -package mocks . EvaluatorEventPublisher +// // Package mocks is a generated GoMock package. package mocks @@ -17,6 +22,7 @@ import ( type MockEvaluatorEventPublisher struct { ctrl *gomock.Controller recorder *MockEvaluatorEventPublisherMockRecorder + isgomock struct{} } // MockEvaluatorEventPublisherMockRecorder is the mock recorder for MockEvaluatorEventPublisher. @@ -37,15 +43,15 @@ func (m *MockEvaluatorEventPublisher) EXPECT() *MockEvaluatorEventPublisherMockR } // PublishEvaluatorRecordCorrection mocks base method. -func (m *MockEvaluatorEventPublisher) PublishEvaluatorRecordCorrection(arg0 context.Context, arg1 *entity.EvaluatorRecordCorrectionEvent, arg2 *time.Duration) error { +func (m *MockEvaluatorEventPublisher) PublishEvaluatorRecordCorrection(ctx context.Context, events *entity.EvaluatorRecordCorrectionEvent, duration *time.Duration) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PublishEvaluatorRecordCorrection", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "PublishEvaluatorRecordCorrection", ctx, events, duration) ret0, _ := ret[0].(error) return ret0 } // PublishEvaluatorRecordCorrection indicates an expected call of PublishEvaluatorRecordCorrection. -func (mr *MockEvaluatorEventPublisherMockRecorder) PublishEvaluatorRecordCorrection(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockEvaluatorEventPublisherMockRecorder) PublishEvaluatorRecordCorrection(ctx, events, duration any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishEvaluatorRecordCorrection", reflect.TypeOf((*MockEvaluatorEventPublisher)(nil).PublishEvaluatorRecordCorrection), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishEvaluatorRecordCorrection", reflect.TypeOf((*MockEvaluatorEventPublisher)(nil).PublishEvaluatorRecordCorrection), ctx, events, duration) } diff --git a/backend/modules/evaluation/domain/events/mocks/expt_event_publisher_mock.go b/backend/modules/evaluation/domain/events/mocks/expt_event_publisher_mock.go index d14d4577c..86cc51fde 100644 --- a/backend/modules/evaluation/domain/events/mocks/expt_event_publisher_mock.go +++ b/backend/modules/evaluation/domain/events/mocks/expt_event_publisher_mock.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/domain/events (interfaces: ExptEventPublisher) +// +// Generated by this command: +// +// mockgen -destination mocks/expt_event_publisher_mock.go -package mocks . ExptEventPublisher +// // Package mocks is a generated GoMock package. package mocks @@ -9,15 +14,15 @@ import ( reflect "reflect" time "time" - "go.uber.org/mock/gomock" - entity "github.com/coze-dev/coze-loop/backend/modules/evaluation/domain/entity" + gomock "go.uber.org/mock/gomock" ) // MockExptEventPublisher is a mock of ExptEventPublisher interface. type MockExptEventPublisher struct { ctrl *gomock.Controller recorder *MockExptEventPublisherMockRecorder + isgomock struct{} } // MockExptEventPublisherMockRecorder is the mock recorder for MockExptEventPublisher. @@ -38,99 +43,99 @@ func (m *MockExptEventPublisher) EXPECT() *MockExptEventPublisherMockRecorder { } // BatchPublishExptRecordEvalEvent mocks base method. -func (m *MockExptEventPublisher) BatchPublishExptRecordEvalEvent(arg0 context.Context, arg1 []*entity.ExptItemEvalEvent, arg2 *time.Duration) error { +func (m *MockExptEventPublisher) BatchPublishExptRecordEvalEvent(ctx context.Context, events []*entity.ExptItemEvalEvent, duration *time.Duration) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BatchPublishExptRecordEvalEvent", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "BatchPublishExptRecordEvalEvent", ctx, events, duration) ret0, _ := ret[0].(error) return ret0 } // BatchPublishExptRecordEvalEvent indicates an expected call of BatchPublishExptRecordEvalEvent. -func (mr *MockExptEventPublisherMockRecorder) BatchPublishExptRecordEvalEvent(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockExptEventPublisherMockRecorder) BatchPublishExptRecordEvalEvent(ctx, events, duration any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchPublishExptRecordEvalEvent", reflect.TypeOf((*MockExptEventPublisher)(nil).BatchPublishExptRecordEvalEvent), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchPublishExptRecordEvalEvent", reflect.TypeOf((*MockExptEventPublisher)(nil).BatchPublishExptRecordEvalEvent), ctx, events, duration) } // PublishExptAggrCalculateEvent mocks base method. -func (m *MockExptEventPublisher) PublishExptAggrCalculateEvent(arg0 context.Context, arg1 []*entity.AggrCalculateEvent, arg2 *time.Duration) error { +func (m *MockExptEventPublisher) PublishExptAggrCalculateEvent(ctx context.Context, events []*entity.AggrCalculateEvent, duration *time.Duration) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PublishExptAggrCalculateEvent", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "PublishExptAggrCalculateEvent", ctx, events, duration) ret0, _ := ret[0].(error) return ret0 } // PublishExptAggrCalculateEvent indicates an expected call of PublishExptAggrCalculateEvent. -func (mr *MockExptEventPublisherMockRecorder) PublishExptAggrCalculateEvent(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockExptEventPublisherMockRecorder) PublishExptAggrCalculateEvent(ctx, events, duration any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishExptAggrCalculateEvent", reflect.TypeOf((*MockExptEventPublisher)(nil).PublishExptAggrCalculateEvent), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishExptAggrCalculateEvent", reflect.TypeOf((*MockExptEventPublisher)(nil).PublishExptAggrCalculateEvent), ctx, events, duration) } // PublishExptExportCSVEvent mocks base method. -func (m *MockExptEventPublisher) PublishExptExportCSVEvent(arg0 context.Context, arg1 *entity.ExportCSVEvent, arg2 *time.Duration) error { +func (m *MockExptEventPublisher) PublishExptExportCSVEvent(ctx context.Context, events *entity.ExportCSVEvent, duration *time.Duration) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PublishExptExportCSVEvent", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "PublishExptExportCSVEvent", ctx, events, duration) ret0, _ := ret[0].(error) return ret0 } // PublishExptExportCSVEvent indicates an expected call of PublishExptExportCSVEvent. -func (mr *MockExptEventPublisherMockRecorder) PublishExptExportCSVEvent(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockExptEventPublisherMockRecorder) PublishExptExportCSVEvent(ctx, events, duration any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishExptExportCSVEvent", reflect.TypeOf((*MockExptEventPublisher)(nil).PublishExptExportCSVEvent), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishExptExportCSVEvent", reflect.TypeOf((*MockExptEventPublisher)(nil).PublishExptExportCSVEvent), ctx, events, duration) } // PublishExptOnlineEvalResult mocks base method. -func (m *MockExptEventPublisher) PublishExptOnlineEvalResult(arg0 context.Context, arg1 *entity.OnlineExptEvalResultEvent, arg2 *time.Duration) error { +func (m *MockExptEventPublisher) PublishExptOnlineEvalResult(ctx context.Context, events *entity.OnlineExptEvalResultEvent, duration *time.Duration) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PublishExptOnlineEvalResult", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "PublishExptOnlineEvalResult", ctx, events, duration) ret0, _ := ret[0].(error) return ret0 } // PublishExptOnlineEvalResult indicates an expected call of PublishExptOnlineEvalResult. -func (mr *MockExptEventPublisherMockRecorder) PublishExptOnlineEvalResult(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockExptEventPublisherMockRecorder) PublishExptOnlineEvalResult(ctx, events, duration any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishExptOnlineEvalResult", reflect.TypeOf((*MockExptEventPublisher)(nil).PublishExptOnlineEvalResult), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishExptOnlineEvalResult", reflect.TypeOf((*MockExptEventPublisher)(nil).PublishExptOnlineEvalResult), ctx, events, duration) } // PublishExptRecordEvalEvent mocks base method. -func (m *MockExptEventPublisher) PublishExptRecordEvalEvent(arg0 context.Context, arg1 *entity.ExptItemEvalEvent, arg2 *time.Duration) error { +func (m *MockExptEventPublisher) PublishExptRecordEvalEvent(ctx context.Context, event *entity.ExptItemEvalEvent, duration *time.Duration) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PublishExptRecordEvalEvent", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "PublishExptRecordEvalEvent", ctx, event, duration) ret0, _ := ret[0].(error) return ret0 } // PublishExptRecordEvalEvent indicates an expected call of PublishExptRecordEvalEvent. -func (mr *MockExptEventPublisherMockRecorder) PublishExptRecordEvalEvent(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockExptEventPublisherMockRecorder) PublishExptRecordEvalEvent(ctx, event, duration any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishExptRecordEvalEvent", reflect.TypeOf((*MockExptEventPublisher)(nil).PublishExptRecordEvalEvent), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishExptRecordEvalEvent", reflect.TypeOf((*MockExptEventPublisher)(nil).PublishExptRecordEvalEvent), ctx, event, duration) } // PublishExptScheduleEvent mocks base method. -func (m *MockExptEventPublisher) PublishExptScheduleEvent(arg0 context.Context, arg1 *entity.ExptScheduleEvent, arg2 *time.Duration) error { +func (m *MockExptEventPublisher) PublishExptScheduleEvent(ctx context.Context, event *entity.ExptScheduleEvent, duration *time.Duration) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PublishExptScheduleEvent", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "PublishExptScheduleEvent", ctx, event, duration) ret0, _ := ret[0].(error) return ret0 } // PublishExptScheduleEvent indicates an expected call of PublishExptScheduleEvent. -func (mr *MockExptEventPublisherMockRecorder) PublishExptScheduleEvent(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockExptEventPublisherMockRecorder) PublishExptScheduleEvent(ctx, event, duration any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishExptScheduleEvent", reflect.TypeOf((*MockExptEventPublisher)(nil).PublishExptScheduleEvent), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishExptScheduleEvent", reflect.TypeOf((*MockExptEventPublisher)(nil).PublishExptScheduleEvent), ctx, event, duration) } // PublishExptTurnResultFilterEvent mocks base method. -func (m *MockExptEventPublisher) PublishExptTurnResultFilterEvent(arg0 context.Context, arg1 *entity.ExptTurnResultFilterEvent, arg2 *time.Duration) error { +func (m *MockExptEventPublisher) PublishExptTurnResultFilterEvent(ctx context.Context, event *entity.ExptTurnResultFilterEvent, duration *time.Duration) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PublishExptTurnResultFilterEvent", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "PublishExptTurnResultFilterEvent", ctx, event, duration) ret0, _ := ret[0].(error) return ret0 } // PublishExptTurnResultFilterEvent indicates an expected call of PublishExptTurnResultFilterEvent. -func (mr *MockExptEventPublisherMockRecorder) PublishExptTurnResultFilterEvent(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockExptEventPublisherMockRecorder) PublishExptTurnResultFilterEvent(ctx, event, duration any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishExptTurnResultFilterEvent", reflect.TypeOf((*MockExptEventPublisher)(nil).PublishExptTurnResultFilterEvent), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishExptTurnResultFilterEvent", reflect.TypeOf((*MockExptEventPublisher)(nil).PublishExptTurnResultFilterEvent), ctx, event, duration) } diff --git a/backend/modules/evaluation/domain/repo/mocks/evaluator_mock.go b/backend/modules/evaluation/domain/repo/mocks/evaluator_mock.go index 29a9ac313..5177a5010 100644 --- a/backend/modules/evaluation/domain/repo/mocks/evaluator_mock.go +++ b/backend/modules/evaluation/domain/repo/mocks/evaluator_mock.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/domain/repo (interfaces: IEvaluatorRepo) +// +// Generated by this command: +// +// mockgen -destination mocks/evaluator_mock.go -package mocks . IEvaluatorRepo +// // Package mocks is a generated GoMock package. package mocks @@ -17,6 +22,7 @@ import ( type MockIEvaluatorRepo struct { ctrl *gomock.Controller recorder *MockIEvaluatorRepoMockRecorder + isgomock struct{} } // MockIEvaluatorRepoMockRecorder is the mock recorder for MockIEvaluatorRepo. @@ -37,192 +43,192 @@ func (m *MockIEvaluatorRepo) EXPECT() *MockIEvaluatorRepoMockRecorder { } // BatchDeleteEvaluator mocks base method. -func (m *MockIEvaluatorRepo) BatchDeleteEvaluator(arg0 context.Context, arg1 []int64, arg2 string) error { +func (m *MockIEvaluatorRepo) BatchDeleteEvaluator(ctx context.Context, ids []int64, userID string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BatchDeleteEvaluator", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "BatchDeleteEvaluator", ctx, ids, userID) ret0, _ := ret[0].(error) return ret0 } // BatchDeleteEvaluator indicates an expected call of BatchDeleteEvaluator. -func (mr *MockIEvaluatorRepoMockRecorder) BatchDeleteEvaluator(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockIEvaluatorRepoMockRecorder) BatchDeleteEvaluator(ctx, ids, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchDeleteEvaluator", reflect.TypeOf((*MockIEvaluatorRepo)(nil).BatchDeleteEvaluator), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchDeleteEvaluator", reflect.TypeOf((*MockIEvaluatorRepo)(nil).BatchDeleteEvaluator), ctx, ids, userID) } // BatchGetEvaluatorByVersionID mocks base method. -func (m *MockIEvaluatorRepo) BatchGetEvaluatorByVersionID(arg0 context.Context, arg1 *int64, arg2 []int64, arg3 bool) ([]*entity.Evaluator, error) { +func (m *MockIEvaluatorRepo) BatchGetEvaluatorByVersionID(ctx context.Context, spaceID *int64, ids []int64, includeDeleted bool) ([]*entity.Evaluator, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BatchGetEvaluatorByVersionID", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "BatchGetEvaluatorByVersionID", ctx, spaceID, ids, includeDeleted) ret0, _ := ret[0].([]*entity.Evaluator) ret1, _ := ret[1].(error) return ret0, ret1 } // BatchGetEvaluatorByVersionID indicates an expected call of BatchGetEvaluatorByVersionID. -func (mr *MockIEvaluatorRepoMockRecorder) BatchGetEvaluatorByVersionID(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockIEvaluatorRepoMockRecorder) BatchGetEvaluatorByVersionID(ctx, spaceID, ids, includeDeleted any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetEvaluatorByVersionID", reflect.TypeOf((*MockIEvaluatorRepo)(nil).BatchGetEvaluatorByVersionID), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetEvaluatorByVersionID", reflect.TypeOf((*MockIEvaluatorRepo)(nil).BatchGetEvaluatorByVersionID), ctx, spaceID, ids, includeDeleted) } // BatchGetEvaluatorDraftByEvaluatorID mocks base method. -func (m *MockIEvaluatorRepo) BatchGetEvaluatorDraftByEvaluatorID(arg0 context.Context, arg1 int64, arg2 []int64, arg3 bool) ([]*entity.Evaluator, error) { +func (m *MockIEvaluatorRepo) BatchGetEvaluatorDraftByEvaluatorID(ctx context.Context, spaceID int64, ids []int64, includeDeleted bool) ([]*entity.Evaluator, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BatchGetEvaluatorDraftByEvaluatorID", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "BatchGetEvaluatorDraftByEvaluatorID", ctx, spaceID, ids, includeDeleted) ret0, _ := ret[0].([]*entity.Evaluator) ret1, _ := ret[1].(error) return ret0, ret1 } // BatchGetEvaluatorDraftByEvaluatorID indicates an expected call of BatchGetEvaluatorDraftByEvaluatorID. -func (mr *MockIEvaluatorRepoMockRecorder) BatchGetEvaluatorDraftByEvaluatorID(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockIEvaluatorRepoMockRecorder) BatchGetEvaluatorDraftByEvaluatorID(ctx, spaceID, ids, includeDeleted any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetEvaluatorDraftByEvaluatorID", reflect.TypeOf((*MockIEvaluatorRepo)(nil).BatchGetEvaluatorDraftByEvaluatorID), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetEvaluatorDraftByEvaluatorID", reflect.TypeOf((*MockIEvaluatorRepo)(nil).BatchGetEvaluatorDraftByEvaluatorID), ctx, spaceID, ids, includeDeleted) } // BatchGetEvaluatorMetaByID mocks base method. -func (m *MockIEvaluatorRepo) BatchGetEvaluatorMetaByID(arg0 context.Context, arg1 []int64, arg2 bool) ([]*entity.Evaluator, error) { +func (m *MockIEvaluatorRepo) BatchGetEvaluatorMetaByID(ctx context.Context, ids []int64, includeDeleted bool) ([]*entity.Evaluator, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BatchGetEvaluatorMetaByID", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "BatchGetEvaluatorMetaByID", ctx, ids, includeDeleted) ret0, _ := ret[0].([]*entity.Evaluator) ret1, _ := ret[1].(error) return ret0, ret1 } // BatchGetEvaluatorMetaByID indicates an expected call of BatchGetEvaluatorMetaByID. -func (mr *MockIEvaluatorRepoMockRecorder) BatchGetEvaluatorMetaByID(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockIEvaluatorRepoMockRecorder) BatchGetEvaluatorMetaByID(ctx, ids, includeDeleted any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetEvaluatorMetaByID", reflect.TypeOf((*MockIEvaluatorRepo)(nil).BatchGetEvaluatorMetaByID), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetEvaluatorMetaByID", reflect.TypeOf((*MockIEvaluatorRepo)(nil).BatchGetEvaluatorMetaByID), ctx, ids, includeDeleted) } // BatchGetEvaluatorVersionsByEvaluatorIDs mocks base method. -func (m *MockIEvaluatorRepo) BatchGetEvaluatorVersionsByEvaluatorIDs(arg0 context.Context, arg1 []int64, arg2 bool) ([]*entity.Evaluator, error) { +func (m *MockIEvaluatorRepo) BatchGetEvaluatorVersionsByEvaluatorIDs(ctx context.Context, evaluatorIDs []int64, includeDeleted bool) ([]*entity.Evaluator, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BatchGetEvaluatorVersionsByEvaluatorIDs", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "BatchGetEvaluatorVersionsByEvaluatorIDs", ctx, evaluatorIDs, includeDeleted) ret0, _ := ret[0].([]*entity.Evaluator) ret1, _ := ret[1].(error) return ret0, ret1 } // BatchGetEvaluatorVersionsByEvaluatorIDs indicates an expected call of BatchGetEvaluatorVersionsByEvaluatorIDs. -func (mr *MockIEvaluatorRepoMockRecorder) BatchGetEvaluatorVersionsByEvaluatorIDs(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockIEvaluatorRepoMockRecorder) BatchGetEvaluatorVersionsByEvaluatorIDs(ctx, evaluatorIDs, includeDeleted any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetEvaluatorVersionsByEvaluatorIDs", reflect.TypeOf((*MockIEvaluatorRepo)(nil).BatchGetEvaluatorVersionsByEvaluatorIDs), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetEvaluatorVersionsByEvaluatorIDs", reflect.TypeOf((*MockIEvaluatorRepo)(nil).BatchGetEvaluatorVersionsByEvaluatorIDs), ctx, evaluatorIDs, includeDeleted) } // CheckNameExist mocks base method. -func (m *MockIEvaluatorRepo) CheckNameExist(arg0 context.Context, arg1, arg2 int64, arg3 string) (bool, error) { +func (m *MockIEvaluatorRepo) CheckNameExist(ctx context.Context, spaceID, evaluatorID int64, name string) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CheckNameExist", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "CheckNameExist", ctx, spaceID, evaluatorID, name) ret0, _ := ret[0].(bool) ret1, _ := ret[1].(error) return ret0, ret1 } // CheckNameExist indicates an expected call of CheckNameExist. -func (mr *MockIEvaluatorRepoMockRecorder) CheckNameExist(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockIEvaluatorRepoMockRecorder) CheckNameExist(ctx, spaceID, evaluatorID, name any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckNameExist", reflect.TypeOf((*MockIEvaluatorRepo)(nil).CheckNameExist), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckNameExist", reflect.TypeOf((*MockIEvaluatorRepo)(nil).CheckNameExist), ctx, spaceID, evaluatorID, name) } // CheckVersionExist mocks base method. -func (m *MockIEvaluatorRepo) CheckVersionExist(arg0 context.Context, arg1 int64, arg2 string) (bool, error) { +func (m *MockIEvaluatorRepo) CheckVersionExist(ctx context.Context, evaluatorID int64, version string) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CheckVersionExist", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "CheckVersionExist", ctx, evaluatorID, version) ret0, _ := ret[0].(bool) ret1, _ := ret[1].(error) return ret0, ret1 } // CheckVersionExist indicates an expected call of CheckVersionExist. -func (mr *MockIEvaluatorRepoMockRecorder) CheckVersionExist(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockIEvaluatorRepoMockRecorder) CheckVersionExist(ctx, evaluatorID, version any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckVersionExist", reflect.TypeOf((*MockIEvaluatorRepo)(nil).CheckVersionExist), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckVersionExist", reflect.TypeOf((*MockIEvaluatorRepo)(nil).CheckVersionExist), ctx, evaluatorID, version) } // CreateEvaluator mocks base method. -func (m *MockIEvaluatorRepo) CreateEvaluator(arg0 context.Context, arg1 *entity.Evaluator) (int64, error) { +func (m *MockIEvaluatorRepo) CreateEvaluator(ctx context.Context, evaluator *entity.Evaluator) (int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateEvaluator", arg0, arg1) + ret := m.ctrl.Call(m, "CreateEvaluator", ctx, evaluator) ret0, _ := ret[0].(int64) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateEvaluator indicates an expected call of CreateEvaluator. -func (mr *MockIEvaluatorRepoMockRecorder) CreateEvaluator(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIEvaluatorRepoMockRecorder) CreateEvaluator(ctx, evaluator any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateEvaluator", reflect.TypeOf((*MockIEvaluatorRepo)(nil).CreateEvaluator), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateEvaluator", reflect.TypeOf((*MockIEvaluatorRepo)(nil).CreateEvaluator), ctx, evaluator) } // ListEvaluator mocks base method. -func (m *MockIEvaluatorRepo) ListEvaluator(arg0 context.Context, arg1 *repo.ListEvaluatorRequest) (*repo.ListEvaluatorResponse, error) { +func (m *MockIEvaluatorRepo) ListEvaluator(ctx context.Context, req *repo.ListEvaluatorRequest) (*repo.ListEvaluatorResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListEvaluator", arg0, arg1) + ret := m.ctrl.Call(m, "ListEvaluator", ctx, req) ret0, _ := ret[0].(*repo.ListEvaluatorResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // ListEvaluator indicates an expected call of ListEvaluator. -func (mr *MockIEvaluatorRepoMockRecorder) ListEvaluator(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIEvaluatorRepoMockRecorder) ListEvaluator(ctx, req any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListEvaluator", reflect.TypeOf((*MockIEvaluatorRepo)(nil).ListEvaluator), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListEvaluator", reflect.TypeOf((*MockIEvaluatorRepo)(nil).ListEvaluator), ctx, req) } // ListEvaluatorVersion mocks base method. -func (m *MockIEvaluatorRepo) ListEvaluatorVersion(arg0 context.Context, arg1 *repo.ListEvaluatorVersionRequest) (*repo.ListEvaluatorVersionResponse, error) { +func (m *MockIEvaluatorRepo) ListEvaluatorVersion(ctx context.Context, req *repo.ListEvaluatorVersionRequest) (*repo.ListEvaluatorVersionResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListEvaluatorVersion", arg0, arg1) + ret := m.ctrl.Call(m, "ListEvaluatorVersion", ctx, req) ret0, _ := ret[0].(*repo.ListEvaluatorVersionResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // ListEvaluatorVersion indicates an expected call of ListEvaluatorVersion. -func (mr *MockIEvaluatorRepoMockRecorder) ListEvaluatorVersion(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIEvaluatorRepoMockRecorder) ListEvaluatorVersion(ctx, req any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListEvaluatorVersion", reflect.TypeOf((*MockIEvaluatorRepo)(nil).ListEvaluatorVersion), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListEvaluatorVersion", reflect.TypeOf((*MockIEvaluatorRepo)(nil).ListEvaluatorVersion), ctx, req) } // SubmitEvaluatorVersion mocks base method. -func (m *MockIEvaluatorRepo) SubmitEvaluatorVersion(arg0 context.Context, arg1 *entity.Evaluator) error { +func (m *MockIEvaluatorRepo) SubmitEvaluatorVersion(ctx context.Context, evaluatorVersionDO *entity.Evaluator) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SubmitEvaluatorVersion", arg0, arg1) + ret := m.ctrl.Call(m, "SubmitEvaluatorVersion", ctx, evaluatorVersionDO) ret0, _ := ret[0].(error) return ret0 } // SubmitEvaluatorVersion indicates an expected call of SubmitEvaluatorVersion. -func (mr *MockIEvaluatorRepoMockRecorder) SubmitEvaluatorVersion(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIEvaluatorRepoMockRecorder) SubmitEvaluatorVersion(ctx, evaluatorVersionDO any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubmitEvaluatorVersion", reflect.TypeOf((*MockIEvaluatorRepo)(nil).SubmitEvaluatorVersion), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubmitEvaluatorVersion", reflect.TypeOf((*MockIEvaluatorRepo)(nil).SubmitEvaluatorVersion), ctx, evaluatorVersionDO) } // UpdateEvaluatorDraft mocks base method. -func (m *MockIEvaluatorRepo) UpdateEvaluatorDraft(arg0 context.Context, arg1 *entity.Evaluator) error { +func (m *MockIEvaluatorRepo) UpdateEvaluatorDraft(ctx context.Context, version *entity.Evaluator) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateEvaluatorDraft", arg0, arg1) + ret := m.ctrl.Call(m, "UpdateEvaluatorDraft", ctx, version) ret0, _ := ret[0].(error) return ret0 } // UpdateEvaluatorDraft indicates an expected call of UpdateEvaluatorDraft. -func (mr *MockIEvaluatorRepoMockRecorder) UpdateEvaluatorDraft(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIEvaluatorRepoMockRecorder) UpdateEvaluatorDraft(ctx, version any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEvaluatorDraft", reflect.TypeOf((*MockIEvaluatorRepo)(nil).UpdateEvaluatorDraft), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEvaluatorDraft", reflect.TypeOf((*MockIEvaluatorRepo)(nil).UpdateEvaluatorDraft), ctx, version) } // UpdateEvaluatorMeta mocks base method. -func (m *MockIEvaluatorRepo) UpdateEvaluatorMeta(arg0 context.Context, arg1 int64, arg2, arg3, arg4 string) error { +func (m *MockIEvaluatorRepo) UpdateEvaluatorMeta(ctx context.Context, id int64, name, description, userID string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateEvaluatorMeta", arg0, arg1, arg2, arg3, arg4) + ret := m.ctrl.Call(m, "UpdateEvaluatorMeta", ctx, id, name, description, userID) ret0, _ := ret[0].(error) return ret0 } // UpdateEvaluatorMeta indicates an expected call of UpdateEvaluatorMeta. -func (mr *MockIEvaluatorRepoMockRecorder) UpdateEvaluatorMeta(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { +func (mr *MockIEvaluatorRepoMockRecorder) UpdateEvaluatorMeta(ctx, id, name, description, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEvaluatorMeta", reflect.TypeOf((*MockIEvaluatorRepo)(nil).UpdateEvaluatorMeta), arg0, arg1, arg2, arg3, arg4) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEvaluatorMeta", reflect.TypeOf((*MockIEvaluatorRepo)(nil).UpdateEvaluatorMeta), ctx, id, name, description, userID) } diff --git a/backend/modules/evaluation/domain/repo/mocks/evaluator_record_mock.go b/backend/modules/evaluation/domain/repo/mocks/evaluator_record_mock.go index 30bd9b3a2..e0ca5d9e1 100644 --- a/backend/modules/evaluation/domain/repo/mocks/evaluator_record_mock.go +++ b/backend/modules/evaluation/domain/repo/mocks/evaluator_record_mock.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/domain/repo (interfaces: IEvaluatorRecordRepo) +// +// Generated by this command: +// +// mockgen -destination mocks/evaluator_record_mock.go -package mocks . IEvaluatorRecordRepo +// // Package mocks is a generated GoMock package. package mocks @@ -16,6 +21,7 @@ import ( type MockIEvaluatorRecordRepo struct { ctrl *gomock.Controller recorder *MockIEvaluatorRecordRepoMockRecorder + isgomock struct{} } // MockIEvaluatorRecordRepoMockRecorder is the mock recorder for MockIEvaluatorRecordRepo. @@ -36,59 +42,59 @@ func (m *MockIEvaluatorRecordRepo) EXPECT() *MockIEvaluatorRecordRepoMockRecorde } // BatchGetEvaluatorRecord mocks base method. -func (m *MockIEvaluatorRecordRepo) BatchGetEvaluatorRecord(arg0 context.Context, arg1 []int64, arg2 bool) ([]*entity.EvaluatorRecord, error) { +func (m *MockIEvaluatorRecordRepo) BatchGetEvaluatorRecord(ctx context.Context, evaluatorRecordIDs []int64, includeDeleted bool) ([]*entity.EvaluatorRecord, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BatchGetEvaluatorRecord", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "BatchGetEvaluatorRecord", ctx, evaluatorRecordIDs, includeDeleted) ret0, _ := ret[0].([]*entity.EvaluatorRecord) ret1, _ := ret[1].(error) return ret0, ret1 } // BatchGetEvaluatorRecord indicates an expected call of BatchGetEvaluatorRecord. -func (mr *MockIEvaluatorRecordRepoMockRecorder) BatchGetEvaluatorRecord(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockIEvaluatorRecordRepoMockRecorder) BatchGetEvaluatorRecord(ctx, evaluatorRecordIDs, includeDeleted any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetEvaluatorRecord", reflect.TypeOf((*MockIEvaluatorRecordRepo)(nil).BatchGetEvaluatorRecord), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetEvaluatorRecord", reflect.TypeOf((*MockIEvaluatorRecordRepo)(nil).BatchGetEvaluatorRecord), ctx, evaluatorRecordIDs, includeDeleted) } // CorrectEvaluatorRecord mocks base method. -func (m *MockIEvaluatorRecordRepo) CorrectEvaluatorRecord(arg0 context.Context, arg1 *entity.EvaluatorRecord) error { +func (m *MockIEvaluatorRecordRepo) CorrectEvaluatorRecord(ctx context.Context, evaluatorRecordDO *entity.EvaluatorRecord) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CorrectEvaluatorRecord", arg0, arg1) + ret := m.ctrl.Call(m, "CorrectEvaluatorRecord", ctx, evaluatorRecordDO) ret0, _ := ret[0].(error) return ret0 } // CorrectEvaluatorRecord indicates an expected call of CorrectEvaluatorRecord. -func (mr *MockIEvaluatorRecordRepoMockRecorder) CorrectEvaluatorRecord(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIEvaluatorRecordRepoMockRecorder) CorrectEvaluatorRecord(ctx, evaluatorRecordDO any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CorrectEvaluatorRecord", reflect.TypeOf((*MockIEvaluatorRecordRepo)(nil).CorrectEvaluatorRecord), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CorrectEvaluatorRecord", reflect.TypeOf((*MockIEvaluatorRecordRepo)(nil).CorrectEvaluatorRecord), ctx, evaluatorRecordDO) } // CreateEvaluatorRecord mocks base method. -func (m *MockIEvaluatorRecordRepo) CreateEvaluatorRecord(arg0 context.Context, arg1 *entity.EvaluatorRecord) error { +func (m *MockIEvaluatorRecordRepo) CreateEvaluatorRecord(ctx context.Context, evaluatorRecord *entity.EvaluatorRecord) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateEvaluatorRecord", arg0, arg1) + ret := m.ctrl.Call(m, "CreateEvaluatorRecord", ctx, evaluatorRecord) ret0, _ := ret[0].(error) return ret0 } // CreateEvaluatorRecord indicates an expected call of CreateEvaluatorRecord. -func (mr *MockIEvaluatorRecordRepoMockRecorder) CreateEvaluatorRecord(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIEvaluatorRecordRepoMockRecorder) CreateEvaluatorRecord(ctx, evaluatorRecord any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateEvaluatorRecord", reflect.TypeOf((*MockIEvaluatorRecordRepo)(nil).CreateEvaluatorRecord), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateEvaluatorRecord", reflect.TypeOf((*MockIEvaluatorRecordRepo)(nil).CreateEvaluatorRecord), ctx, evaluatorRecord) } // GetEvaluatorRecord mocks base method. -func (m *MockIEvaluatorRecordRepo) GetEvaluatorRecord(arg0 context.Context, arg1 int64, arg2 bool) (*entity.EvaluatorRecord, error) { +func (m *MockIEvaluatorRecordRepo) GetEvaluatorRecord(ctx context.Context, evaluatorRecordID int64, includeDeleted bool) (*entity.EvaluatorRecord, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetEvaluatorRecord", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetEvaluatorRecord", ctx, evaluatorRecordID, includeDeleted) ret0, _ := ret[0].(*entity.EvaluatorRecord) ret1, _ := ret[1].(error) return ret0, ret1 } // GetEvaluatorRecord indicates an expected call of GetEvaluatorRecord. -func (mr *MockIEvaluatorRecordRepoMockRecorder) GetEvaluatorRecord(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockIEvaluatorRecordRepoMockRecorder) GetEvaluatorRecord(ctx, evaluatorRecordID, includeDeleted any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEvaluatorRecord", reflect.TypeOf((*MockIEvaluatorRecordRepo)(nil).GetEvaluatorRecord), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEvaluatorRecord", reflect.TypeOf((*MockIEvaluatorRecordRepo)(nil).GetEvaluatorRecord), ctx, evaluatorRecordID, includeDeleted) } diff --git a/backend/modules/evaluation/domain/repo/mocks/ratelimiter_mock.go b/backend/modules/evaluation/domain/repo/mocks/ratelimiter_mock.go index 4303f9745..55d4b8839 100644 --- a/backend/modules/evaluation/domain/repo/mocks/ratelimiter_mock.go +++ b/backend/modules/evaluation/domain/repo/mocks/ratelimiter_mock.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/domain/repo (interfaces: RateLimiter) +// +// Generated by this command: +// +// mockgen -destination mocks/ratelimiter_mock.go -package mocks . RateLimiter +// // Package mocks is a generated GoMock package. package mocks @@ -15,6 +20,7 @@ import ( type MockRateLimiter struct { ctrl *gomock.Controller recorder *MockRateLimiterMockRecorder + isgomock struct{} } // MockRateLimiterMockRecorder is the mock recorder for MockRateLimiter. @@ -35,15 +41,15 @@ func (m *MockRateLimiter) EXPECT() *MockRateLimiterMockRecorder { } // AllowInvoke mocks base method. -func (m *MockRateLimiter) AllowInvoke(arg0 context.Context, arg1 int64) bool { +func (m *MockRateLimiter) AllowInvoke(ctx context.Context, spaceID int64) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AllowInvoke", arg0, arg1) + ret := m.ctrl.Call(m, "AllowInvoke", ctx, spaceID) ret0, _ := ret[0].(bool) return ret0 } // AllowInvoke indicates an expected call of AllowInvoke. -func (mr *MockRateLimiterMockRecorder) AllowInvoke(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockRateLimiterMockRecorder) AllowInvoke(ctx, spaceID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowInvoke", reflect.TypeOf((*MockRateLimiter)(nil).AllowInvoke), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowInvoke", reflect.TypeOf((*MockRateLimiter)(nil).AllowInvoke), ctx, spaceID) } diff --git a/backend/modules/evaluation/infra/repo/evaluator/mysql/mocks/evaluator_mock.go b/backend/modules/evaluation/infra/repo/evaluator/mysql/mocks/evaluator_mock.go index 7214747f3..50db0136d 100644 --- a/backend/modules/evaluation/infra/repo/evaluator/mysql/mocks/evaluator_mock.go +++ b/backend/modules/evaluation/infra/repo/evaluator/mysql/mocks/evaluator_mock.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/evaluator/mysql (interfaces: EvaluatorDAO) +// +// Generated by this command: +// +// mockgen -destination mocks/evaluator_mock.go -package=mocks . EvaluatorDAO +// // Package mocks is a generated GoMock package. package mocks @@ -18,6 +23,7 @@ import ( type MockEvaluatorDAO struct { ctrl *gomock.Controller recorder *MockEvaluatorDAOMockRecorder + isgomock struct{} } // MockEvaluatorDAOMockRecorder is the mock recorder for MockEvaluatorDAO. @@ -38,10 +44,10 @@ func (m *MockEvaluatorDAO) EXPECT() *MockEvaluatorDAOMockRecorder { } // BatchDeleteEvaluator mocks base method. -func (m *MockEvaluatorDAO) BatchDeleteEvaluator(arg0 context.Context, arg1 []int64, arg2 string, arg3 ...db.Option) error { +func (m *MockEvaluatorDAO) BatchDeleteEvaluator(ctx context.Context, ids []int64, userID string, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, ids, userID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchDeleteEvaluator", varargs...) @@ -50,17 +56,17 @@ func (m *MockEvaluatorDAO) BatchDeleteEvaluator(arg0 context.Context, arg1 []int } // BatchDeleteEvaluator indicates an expected call of BatchDeleteEvaluator. -func (mr *MockEvaluatorDAOMockRecorder) BatchDeleteEvaluator(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockEvaluatorDAOMockRecorder) BatchDeleteEvaluator(ctx, ids, userID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, ids, userID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchDeleteEvaluator", reflect.TypeOf((*MockEvaluatorDAO)(nil).BatchDeleteEvaluator), varargs...) } // BatchGetEvaluatorByID mocks base method. -func (m *MockEvaluatorDAO) BatchGetEvaluatorByID(arg0 context.Context, arg1 []int64, arg2 bool, arg3 ...db.Option) ([]*model.Evaluator, error) { +func (m *MockEvaluatorDAO) BatchGetEvaluatorByID(ctx context.Context, ids []int64, includeDeleted bool, opts ...db.Option) ([]*model.Evaluator, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, ids, includeDeleted} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchGetEvaluatorByID", varargs...) @@ -70,17 +76,17 @@ func (m *MockEvaluatorDAO) BatchGetEvaluatorByID(arg0 context.Context, arg1 []in } // BatchGetEvaluatorByID indicates an expected call of BatchGetEvaluatorByID. -func (mr *MockEvaluatorDAOMockRecorder) BatchGetEvaluatorByID(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockEvaluatorDAOMockRecorder) BatchGetEvaluatorByID(ctx, ids, includeDeleted any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, ids, includeDeleted}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetEvaluatorByID", reflect.TypeOf((*MockEvaluatorDAO)(nil).BatchGetEvaluatorByID), varargs...) } // CheckNameExist mocks base method. -func (m *MockEvaluatorDAO) CheckNameExist(arg0 context.Context, arg1, arg2 int64, arg3 string, arg4 ...db.Option) (bool, error) { +func (m *MockEvaluatorDAO) CheckNameExist(ctx context.Context, spaceID, evaluatorID int64, name string, opts ...db.Option) (bool, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3} - for _, a := range arg4 { + varargs := []any{ctx, spaceID, evaluatorID, name} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "CheckNameExist", varargs...) @@ -90,17 +96,17 @@ func (m *MockEvaluatorDAO) CheckNameExist(arg0 context.Context, arg1, arg2 int64 } // CheckNameExist indicates an expected call of CheckNameExist. -func (mr *MockEvaluatorDAOMockRecorder) CheckNameExist(arg0, arg1, arg2, arg3 interface{}, arg4 ...interface{}) *gomock.Call { +func (mr *MockEvaluatorDAOMockRecorder) CheckNameExist(ctx, spaceID, evaluatorID, name any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3}, arg4...) + varargs := append([]any{ctx, spaceID, evaluatorID, name}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckNameExist", reflect.TypeOf((*MockEvaluatorDAO)(nil).CheckNameExist), varargs...) } // CreateEvaluator mocks base method. -func (m *MockEvaluatorDAO) CreateEvaluator(arg0 context.Context, arg1 *model.Evaluator, arg2 ...db.Option) error { +func (m *MockEvaluatorDAO) CreateEvaluator(ctx context.Context, evaluator *model.Evaluator, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, evaluator} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "CreateEvaluator", varargs...) @@ -109,17 +115,17 @@ func (m *MockEvaluatorDAO) CreateEvaluator(arg0 context.Context, arg1 *model.Eva } // CreateEvaluator indicates an expected call of CreateEvaluator. -func (mr *MockEvaluatorDAOMockRecorder) CreateEvaluator(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockEvaluatorDAOMockRecorder) CreateEvaluator(ctx, evaluator any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, evaluator}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateEvaluator", reflect.TypeOf((*MockEvaluatorDAO)(nil).CreateEvaluator), varargs...) } // GetEvaluatorByID mocks base method. -func (m *MockEvaluatorDAO) GetEvaluatorByID(arg0 context.Context, arg1 int64, arg2 bool, arg3 ...db.Option) (*model.Evaluator, error) { +func (m *MockEvaluatorDAO) GetEvaluatorByID(ctx context.Context, id int64, includeDeleted bool, opts ...db.Option) (*model.Evaluator, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, id, includeDeleted} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetEvaluatorByID", varargs...) @@ -129,17 +135,17 @@ func (m *MockEvaluatorDAO) GetEvaluatorByID(arg0 context.Context, arg1 int64, ar } // GetEvaluatorByID indicates an expected call of GetEvaluatorByID. -func (mr *MockEvaluatorDAOMockRecorder) GetEvaluatorByID(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockEvaluatorDAOMockRecorder) GetEvaluatorByID(ctx, id, includeDeleted any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, id, includeDeleted}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEvaluatorByID", reflect.TypeOf((*MockEvaluatorDAO)(nil).GetEvaluatorByID), varargs...) } // ListEvaluator mocks base method. -func (m *MockEvaluatorDAO) ListEvaluator(arg0 context.Context, arg1 *mysql.ListEvaluatorRequest, arg2 ...db.Option) (*mysql.ListEvaluatorResponse, error) { +func (m *MockEvaluatorDAO) ListEvaluator(ctx context.Context, req *mysql.ListEvaluatorRequest, opts ...db.Option) (*mysql.ListEvaluatorResponse, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "ListEvaluator", varargs...) @@ -149,17 +155,17 @@ func (m *MockEvaluatorDAO) ListEvaluator(arg0 context.Context, arg1 *mysql.ListE } // ListEvaluator indicates an expected call of ListEvaluator. -func (mr *MockEvaluatorDAOMockRecorder) ListEvaluator(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockEvaluatorDAOMockRecorder) ListEvaluator(ctx, req any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListEvaluator", reflect.TypeOf((*MockEvaluatorDAO)(nil).ListEvaluator), varargs...) } // UpdateEvaluatorDraftSubmitted mocks base method. -func (m *MockEvaluatorDAO) UpdateEvaluatorDraftSubmitted(arg0 context.Context, arg1 int64, arg2 bool, arg3 string, arg4 ...db.Option) error { +func (m *MockEvaluatorDAO) UpdateEvaluatorDraftSubmitted(ctx context.Context, evaluatorID int64, draftSubmitted bool, userID string, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3} - for _, a := range arg4 { + varargs := []any{ctx, evaluatorID, draftSubmitted, userID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "UpdateEvaluatorDraftSubmitted", varargs...) @@ -168,17 +174,17 @@ func (m *MockEvaluatorDAO) UpdateEvaluatorDraftSubmitted(arg0 context.Context, a } // UpdateEvaluatorDraftSubmitted indicates an expected call of UpdateEvaluatorDraftSubmitted. -func (mr *MockEvaluatorDAOMockRecorder) UpdateEvaluatorDraftSubmitted(arg0, arg1, arg2, arg3 interface{}, arg4 ...interface{}) *gomock.Call { +func (mr *MockEvaluatorDAOMockRecorder) UpdateEvaluatorDraftSubmitted(ctx, evaluatorID, draftSubmitted, userID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3}, arg4...) + varargs := append([]any{ctx, evaluatorID, draftSubmitted, userID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEvaluatorDraftSubmitted", reflect.TypeOf((*MockEvaluatorDAO)(nil).UpdateEvaluatorDraftSubmitted), varargs...) } // UpdateEvaluatorLatestVersion mocks base method. -func (m *MockEvaluatorDAO) UpdateEvaluatorLatestVersion(arg0 context.Context, arg1 int64, arg2, arg3 string, arg4 ...db.Option) error { +func (m *MockEvaluatorDAO) UpdateEvaluatorLatestVersion(ctx context.Context, evaluatorID int64, version, userID string, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3} - for _, a := range arg4 { + varargs := []any{ctx, evaluatorID, version, userID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "UpdateEvaluatorLatestVersion", varargs...) @@ -187,17 +193,17 @@ func (m *MockEvaluatorDAO) UpdateEvaluatorLatestVersion(arg0 context.Context, ar } // UpdateEvaluatorLatestVersion indicates an expected call of UpdateEvaluatorLatestVersion. -func (mr *MockEvaluatorDAOMockRecorder) UpdateEvaluatorLatestVersion(arg0, arg1, arg2, arg3 interface{}, arg4 ...interface{}) *gomock.Call { +func (mr *MockEvaluatorDAOMockRecorder) UpdateEvaluatorLatestVersion(ctx, evaluatorID, version, userID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3}, arg4...) + varargs := append([]any{ctx, evaluatorID, version, userID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEvaluatorLatestVersion", reflect.TypeOf((*MockEvaluatorDAO)(nil).UpdateEvaluatorLatestVersion), varargs...) } // UpdateEvaluatorMeta mocks base method. -func (m *MockEvaluatorDAO) UpdateEvaluatorMeta(arg0 context.Context, arg1 *model.Evaluator, arg2 ...db.Option) error { +func (m *MockEvaluatorDAO) UpdateEvaluatorMeta(ctx context.Context, do *model.Evaluator, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, do} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "UpdateEvaluatorMeta", varargs...) @@ -206,8 +212,8 @@ func (m *MockEvaluatorDAO) UpdateEvaluatorMeta(arg0 context.Context, arg1 *model } // UpdateEvaluatorMeta indicates an expected call of UpdateEvaluatorMeta. -func (mr *MockEvaluatorDAOMockRecorder) UpdateEvaluatorMeta(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockEvaluatorDAOMockRecorder) UpdateEvaluatorMeta(ctx, do any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, do}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEvaluatorMeta", reflect.TypeOf((*MockEvaluatorDAO)(nil).UpdateEvaluatorMeta), varargs...) } diff --git a/backend/modules/evaluation/infra/repo/evaluator/mysql/mocks/evaluator_record_mock.go b/backend/modules/evaluation/infra/repo/evaluator/mysql/mocks/evaluator_record_mock.go index bc48d275e..90fda9a1b 100644 --- a/backend/modules/evaluation/infra/repo/evaluator/mysql/mocks/evaluator_record_mock.go +++ b/backend/modules/evaluation/infra/repo/evaluator/mysql/mocks/evaluator_record_mock.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/evaluator/mysql (interfaces: EvaluatorRecordDAO) +// +// Generated by this command: +// +// mockgen -destination mocks/evaluator_record_mock.go -package=mocks . EvaluatorRecordDAO +// // Package mocks is a generated GoMock package. package mocks @@ -17,6 +22,7 @@ import ( type MockEvaluatorRecordDAO struct { ctrl *gomock.Controller recorder *MockEvaluatorRecordDAOMockRecorder + isgomock struct{} } // MockEvaluatorRecordDAOMockRecorder is the mock recorder for MockEvaluatorRecordDAO. @@ -37,10 +43,10 @@ func (m *MockEvaluatorRecordDAO) EXPECT() *MockEvaluatorRecordDAOMockRecorder { } // BatchGetEvaluatorRecord mocks base method. -func (m *MockEvaluatorRecordDAO) BatchGetEvaluatorRecord(arg0 context.Context, arg1 []int64, arg2 bool, arg3 ...db.Option) ([]*model.EvaluatorRecord, error) { +func (m *MockEvaluatorRecordDAO) BatchGetEvaluatorRecord(ctx context.Context, evaluatorRecordIDs []int64, includeDeleted bool, opts ...db.Option) ([]*model.EvaluatorRecord, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, evaluatorRecordIDs, includeDeleted} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchGetEvaluatorRecord", varargs...) @@ -50,17 +56,17 @@ func (m *MockEvaluatorRecordDAO) BatchGetEvaluatorRecord(arg0 context.Context, a } // BatchGetEvaluatorRecord indicates an expected call of BatchGetEvaluatorRecord. -func (mr *MockEvaluatorRecordDAOMockRecorder) BatchGetEvaluatorRecord(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockEvaluatorRecordDAOMockRecorder) BatchGetEvaluatorRecord(ctx, evaluatorRecordIDs, includeDeleted any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, evaluatorRecordIDs, includeDeleted}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetEvaluatorRecord", reflect.TypeOf((*MockEvaluatorRecordDAO)(nil).BatchGetEvaluatorRecord), varargs...) } // CreateEvaluatorRecord mocks base method. -func (m *MockEvaluatorRecordDAO) CreateEvaluatorRecord(arg0 context.Context, arg1 *model.EvaluatorRecord, arg2 ...db.Option) error { +func (m *MockEvaluatorRecordDAO) CreateEvaluatorRecord(ctx context.Context, evaluatorRecord *model.EvaluatorRecord, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, evaluatorRecord} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "CreateEvaluatorRecord", varargs...) @@ -69,17 +75,17 @@ func (m *MockEvaluatorRecordDAO) CreateEvaluatorRecord(arg0 context.Context, arg } // CreateEvaluatorRecord indicates an expected call of CreateEvaluatorRecord. -func (mr *MockEvaluatorRecordDAOMockRecorder) CreateEvaluatorRecord(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockEvaluatorRecordDAOMockRecorder) CreateEvaluatorRecord(ctx, evaluatorRecord any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, evaluatorRecord}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateEvaluatorRecord", reflect.TypeOf((*MockEvaluatorRecordDAO)(nil).CreateEvaluatorRecord), varargs...) } // GetEvaluatorRecord mocks base method. -func (m *MockEvaluatorRecordDAO) GetEvaluatorRecord(arg0 context.Context, arg1 int64, arg2 bool, arg3 ...db.Option) (*model.EvaluatorRecord, error) { +func (m *MockEvaluatorRecordDAO) GetEvaluatorRecord(ctx context.Context, evaluatorRecordID int64, includeDeleted bool, opts ...db.Option) (*model.EvaluatorRecord, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, evaluatorRecordID, includeDeleted} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetEvaluatorRecord", varargs...) @@ -89,17 +95,17 @@ func (m *MockEvaluatorRecordDAO) GetEvaluatorRecord(arg0 context.Context, arg1 i } // GetEvaluatorRecord indicates an expected call of GetEvaluatorRecord. -func (mr *MockEvaluatorRecordDAOMockRecorder) GetEvaluatorRecord(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockEvaluatorRecordDAOMockRecorder) GetEvaluatorRecord(ctx, evaluatorRecordID, includeDeleted any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, evaluatorRecordID, includeDeleted}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEvaluatorRecord", reflect.TypeOf((*MockEvaluatorRecordDAO)(nil).GetEvaluatorRecord), varargs...) } // UpdateEvaluatorRecord mocks base method. -func (m *MockEvaluatorRecordDAO) UpdateEvaluatorRecord(arg0 context.Context, arg1 *model.EvaluatorRecord, arg2 ...db.Option) error { +func (m *MockEvaluatorRecordDAO) UpdateEvaluatorRecord(ctx context.Context, evaluatorRecord *model.EvaluatorRecord, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, evaluatorRecord} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "UpdateEvaluatorRecord", varargs...) @@ -108,8 +114,8 @@ func (m *MockEvaluatorRecordDAO) UpdateEvaluatorRecord(arg0 context.Context, arg } // UpdateEvaluatorRecord indicates an expected call of UpdateEvaluatorRecord. -func (mr *MockEvaluatorRecordDAOMockRecorder) UpdateEvaluatorRecord(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockEvaluatorRecordDAOMockRecorder) UpdateEvaluatorRecord(ctx, evaluatorRecord any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, evaluatorRecord}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEvaluatorRecord", reflect.TypeOf((*MockEvaluatorRecordDAO)(nil).UpdateEvaluatorRecord), varargs...) } diff --git a/backend/modules/evaluation/infra/repo/evaluator/mysql/mocks/evaluator_version_mock.go b/backend/modules/evaluation/infra/repo/evaluator/mysql/mocks/evaluator_version_mock.go index 2922b3b2f..a0d7965eb 100644 --- a/backend/modules/evaluation/infra/repo/evaluator/mysql/mocks/evaluator_version_mock.go +++ b/backend/modules/evaluation/infra/repo/evaluator/mysql/mocks/evaluator_version_mock.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/evaluator/mysql (interfaces: EvaluatorVersionDAO) +// +// Generated by this command: +// +// mockgen -destination mocks/evaluator_version_mock.go -package=mocks . EvaluatorVersionDAO +// // Package mocks is a generated GoMock package. package mocks @@ -18,6 +23,7 @@ import ( type MockEvaluatorVersionDAO struct { ctrl *gomock.Controller recorder *MockEvaluatorVersionDAOMockRecorder + isgomock struct{} } // MockEvaluatorVersionDAOMockRecorder is the mock recorder for MockEvaluatorVersionDAO. @@ -38,10 +44,10 @@ func (m *MockEvaluatorVersionDAO) EXPECT() *MockEvaluatorVersionDAOMockRecorder } // BatchDeleteEvaluatorVersionByEvaluatorIDs mocks base method. -func (m *MockEvaluatorVersionDAO) BatchDeleteEvaluatorVersionByEvaluatorIDs(arg0 context.Context, arg1 []int64, arg2 string, arg3 ...db.Option) error { +func (m *MockEvaluatorVersionDAO) BatchDeleteEvaluatorVersionByEvaluatorIDs(ctx context.Context, evaluatorIDs []int64, userID string, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, evaluatorIDs, userID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchDeleteEvaluatorVersionByEvaluatorIDs", varargs...) @@ -50,17 +56,17 @@ func (m *MockEvaluatorVersionDAO) BatchDeleteEvaluatorVersionByEvaluatorIDs(arg0 } // BatchDeleteEvaluatorVersionByEvaluatorIDs indicates an expected call of BatchDeleteEvaluatorVersionByEvaluatorIDs. -func (mr *MockEvaluatorVersionDAOMockRecorder) BatchDeleteEvaluatorVersionByEvaluatorIDs(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockEvaluatorVersionDAOMockRecorder) BatchDeleteEvaluatorVersionByEvaluatorIDs(ctx, evaluatorIDs, userID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, evaluatorIDs, userID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchDeleteEvaluatorVersionByEvaluatorIDs", reflect.TypeOf((*MockEvaluatorVersionDAO)(nil).BatchDeleteEvaluatorVersionByEvaluatorIDs), varargs...) } // BatchGetEvaluatorDraftByEvaluatorID mocks base method. -func (m *MockEvaluatorVersionDAO) BatchGetEvaluatorDraftByEvaluatorID(arg0 context.Context, arg1 []int64, arg2 bool, arg3 ...db.Option) ([]*model.EvaluatorVersion, error) { +func (m *MockEvaluatorVersionDAO) BatchGetEvaluatorDraftByEvaluatorID(ctx context.Context, evaluatorIDs []int64, includeDeleted bool, opts ...db.Option) ([]*model.EvaluatorVersion, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, evaluatorIDs, includeDeleted} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchGetEvaluatorDraftByEvaluatorID", varargs...) @@ -70,17 +76,17 @@ func (m *MockEvaluatorVersionDAO) BatchGetEvaluatorDraftByEvaluatorID(arg0 conte } // BatchGetEvaluatorDraftByEvaluatorID indicates an expected call of BatchGetEvaluatorDraftByEvaluatorID. -func (mr *MockEvaluatorVersionDAOMockRecorder) BatchGetEvaluatorDraftByEvaluatorID(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockEvaluatorVersionDAOMockRecorder) BatchGetEvaluatorDraftByEvaluatorID(ctx, evaluatorIDs, includeDeleted any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, evaluatorIDs, includeDeleted}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetEvaluatorDraftByEvaluatorID", reflect.TypeOf((*MockEvaluatorVersionDAO)(nil).BatchGetEvaluatorDraftByEvaluatorID), varargs...) } // BatchGetEvaluatorVersionByID mocks base method. -func (m *MockEvaluatorVersionDAO) BatchGetEvaluatorVersionByID(arg0 context.Context, arg1 *int64, arg2 []int64, arg3 bool, arg4 ...db.Option) ([]*model.EvaluatorVersion, error) { +func (m *MockEvaluatorVersionDAO) BatchGetEvaluatorVersionByID(ctx context.Context, spaceID *int64, ids []int64, includeDeleted bool, opts ...db.Option) ([]*model.EvaluatorVersion, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3} - for _, a := range arg4 { + varargs := []any{ctx, spaceID, ids, includeDeleted} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchGetEvaluatorVersionByID", varargs...) @@ -90,17 +96,17 @@ func (m *MockEvaluatorVersionDAO) BatchGetEvaluatorVersionByID(arg0 context.Cont } // BatchGetEvaluatorVersionByID indicates an expected call of BatchGetEvaluatorVersionByID. -func (mr *MockEvaluatorVersionDAOMockRecorder) BatchGetEvaluatorVersionByID(arg0, arg1, arg2, arg3 interface{}, arg4 ...interface{}) *gomock.Call { +func (mr *MockEvaluatorVersionDAOMockRecorder) BatchGetEvaluatorVersionByID(ctx, spaceID, ids, includeDeleted any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3}, arg4...) + varargs := append([]any{ctx, spaceID, ids, includeDeleted}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetEvaluatorVersionByID", reflect.TypeOf((*MockEvaluatorVersionDAO)(nil).BatchGetEvaluatorVersionByID), varargs...) } // BatchGetEvaluatorVersionsByEvaluatorIDs mocks base method. -func (m *MockEvaluatorVersionDAO) BatchGetEvaluatorVersionsByEvaluatorIDs(arg0 context.Context, arg1 []int64, arg2 bool, arg3 ...db.Option) ([]*model.EvaluatorVersion, error) { +func (m *MockEvaluatorVersionDAO) BatchGetEvaluatorVersionsByEvaluatorIDs(ctx context.Context, evaluatorIDs []int64, includeDeleted bool, opts ...db.Option) ([]*model.EvaluatorVersion, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, evaluatorIDs, includeDeleted} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchGetEvaluatorVersionsByEvaluatorIDs", varargs...) @@ -110,17 +116,17 @@ func (m *MockEvaluatorVersionDAO) BatchGetEvaluatorVersionsByEvaluatorIDs(arg0 c } // BatchGetEvaluatorVersionsByEvaluatorIDs indicates an expected call of BatchGetEvaluatorVersionsByEvaluatorIDs. -func (mr *MockEvaluatorVersionDAOMockRecorder) BatchGetEvaluatorVersionsByEvaluatorIDs(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockEvaluatorVersionDAOMockRecorder) BatchGetEvaluatorVersionsByEvaluatorIDs(ctx, evaluatorIDs, includeDeleted any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, evaluatorIDs, includeDeleted}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetEvaluatorVersionsByEvaluatorIDs", reflect.TypeOf((*MockEvaluatorVersionDAO)(nil).BatchGetEvaluatorVersionsByEvaluatorIDs), varargs...) } // CheckVersionExist mocks base method. -func (m *MockEvaluatorVersionDAO) CheckVersionExist(arg0 context.Context, arg1 int64, arg2 string, arg3 ...db.Option) (bool, error) { +func (m *MockEvaluatorVersionDAO) CheckVersionExist(ctx context.Context, evaluatorID int64, version string, opts ...db.Option) (bool, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, evaluatorID, version} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "CheckVersionExist", varargs...) @@ -130,17 +136,17 @@ func (m *MockEvaluatorVersionDAO) CheckVersionExist(arg0 context.Context, arg1 i } // CheckVersionExist indicates an expected call of CheckVersionExist. -func (mr *MockEvaluatorVersionDAOMockRecorder) CheckVersionExist(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockEvaluatorVersionDAOMockRecorder) CheckVersionExist(ctx, evaluatorID, version any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, evaluatorID, version}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckVersionExist", reflect.TypeOf((*MockEvaluatorVersionDAO)(nil).CheckVersionExist), varargs...) } // CreateEvaluatorVersion mocks base method. -func (m *MockEvaluatorVersionDAO) CreateEvaluatorVersion(arg0 context.Context, arg1 *model.EvaluatorVersion, arg2 ...db.Option) error { +func (m *MockEvaluatorVersionDAO) CreateEvaluatorVersion(ctx context.Context, version *model.EvaluatorVersion, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, version} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "CreateEvaluatorVersion", varargs...) @@ -149,17 +155,17 @@ func (m *MockEvaluatorVersionDAO) CreateEvaluatorVersion(arg0 context.Context, a } // CreateEvaluatorVersion indicates an expected call of CreateEvaluatorVersion. -func (mr *MockEvaluatorVersionDAOMockRecorder) CreateEvaluatorVersion(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockEvaluatorVersionDAOMockRecorder) CreateEvaluatorVersion(ctx, version any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, version}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateEvaluatorVersion", reflect.TypeOf((*MockEvaluatorVersionDAO)(nil).CreateEvaluatorVersion), varargs...) } // DeleteEvaluatorVersion mocks base method. -func (m *MockEvaluatorVersionDAO) DeleteEvaluatorVersion(arg0 context.Context, arg1 int64, arg2 string, arg3 ...db.Option) error { +func (m *MockEvaluatorVersionDAO) DeleteEvaluatorVersion(ctx context.Context, id int64, userID string, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, id, userID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "DeleteEvaluatorVersion", varargs...) @@ -168,17 +174,17 @@ func (m *MockEvaluatorVersionDAO) DeleteEvaluatorVersion(arg0 context.Context, a } // DeleteEvaluatorVersion indicates an expected call of DeleteEvaluatorVersion. -func (mr *MockEvaluatorVersionDAOMockRecorder) DeleteEvaluatorVersion(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockEvaluatorVersionDAOMockRecorder) DeleteEvaluatorVersion(ctx, id, userID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, id, userID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteEvaluatorVersion", reflect.TypeOf((*MockEvaluatorVersionDAO)(nil).DeleteEvaluatorVersion), varargs...) } // ListEvaluatorVersion mocks base method. -func (m *MockEvaluatorVersionDAO) ListEvaluatorVersion(arg0 context.Context, arg1 *mysql.ListEvaluatorVersionRequest, arg2 ...db.Option) (*mysql.ListEvaluatorVersionResponse, error) { +func (m *MockEvaluatorVersionDAO) ListEvaluatorVersion(ctx context.Context, req *mysql.ListEvaluatorVersionRequest, opts ...db.Option) (*mysql.ListEvaluatorVersionResponse, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "ListEvaluatorVersion", varargs...) @@ -188,17 +194,17 @@ func (m *MockEvaluatorVersionDAO) ListEvaluatorVersion(arg0 context.Context, arg } // ListEvaluatorVersion indicates an expected call of ListEvaluatorVersion. -func (mr *MockEvaluatorVersionDAOMockRecorder) ListEvaluatorVersion(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockEvaluatorVersionDAOMockRecorder) ListEvaluatorVersion(ctx, req any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListEvaluatorVersion", reflect.TypeOf((*MockEvaluatorVersionDAO)(nil).ListEvaluatorVersion), varargs...) } // UpdateEvaluatorDraft mocks base method. -func (m *MockEvaluatorVersionDAO) UpdateEvaluatorDraft(arg0 context.Context, arg1 *model.EvaluatorVersion, arg2 ...db.Option) error { +func (m *MockEvaluatorVersionDAO) UpdateEvaluatorDraft(ctx context.Context, version *model.EvaluatorVersion, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, version} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "UpdateEvaluatorDraft", varargs...) @@ -207,8 +213,8 @@ func (m *MockEvaluatorVersionDAO) UpdateEvaluatorDraft(arg0 context.Context, arg } // UpdateEvaluatorDraft indicates an expected call of UpdateEvaluatorDraft. -func (mr *MockEvaluatorVersionDAOMockRecorder) UpdateEvaluatorDraft(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockEvaluatorVersionDAOMockRecorder) UpdateEvaluatorDraft(ctx, version any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, version}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEvaluatorDraft", reflect.TypeOf((*MockEvaluatorVersionDAO)(nil).UpdateEvaluatorDraft), varargs...) } diff --git a/backend/modules/evaluation/infra/repo/experiment/ck/mocks/expt.go b/backend/modules/evaluation/infra/repo/experiment/ck/mocks/expt.go index e11be126e..242e6ce26 100644 --- a/backend/modules/evaluation/infra/repo/experiment/ck/mocks/expt.go +++ b/backend/modules/evaluation/infra/repo/experiment/ck/mocks/expt.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/experiment/ck (interfaces: IExptTurnResultFilterDAO) +// +// Generated by this command: +// +// mockgen -destination=mocks/expt.go -package=mocks . IExptTurnResultFilterDAO +// // Package mocks is a generated GoMock package. package mocks @@ -18,6 +23,7 @@ import ( type MockIExptTurnResultFilterDAO struct { ctrl *gomock.Controller recorder *MockIExptTurnResultFilterDAOMockRecorder + isgomock struct{} } // MockIExptTurnResultFilterDAOMockRecorder is the mock recorder for MockIExptTurnResultFilterDAO. @@ -38,24 +44,24 @@ func (m *MockIExptTurnResultFilterDAO) EXPECT() *MockIExptTurnResultFilterDAOMoc } // GetByExptIDItemIDs mocks base method. -func (m *MockIExptTurnResultFilterDAO) GetByExptIDItemIDs(arg0 context.Context, arg1, arg2, arg3 string, arg4 []string) ([]*model0.ExptTurnResultFilter, error) { +func (m *MockIExptTurnResultFilterDAO) GetByExptIDItemIDs(ctx context.Context, spaceID, exptID, createdDate string, itemIDs []string) ([]*model0.ExptTurnResultFilter, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetByExptIDItemIDs", arg0, arg1, arg2, arg3, arg4) + ret := m.ctrl.Call(m, "GetByExptIDItemIDs", ctx, spaceID, exptID, createdDate, itemIDs) ret0, _ := ret[0].([]*model0.ExptTurnResultFilter) ret1, _ := ret[1].(error) return ret0, ret1 } // GetByExptIDItemIDs indicates an expected call of GetByExptIDItemIDs. -func (mr *MockIExptTurnResultFilterDAOMockRecorder) GetByExptIDItemIDs(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { +func (mr *MockIExptTurnResultFilterDAOMockRecorder) GetByExptIDItemIDs(ctx, spaceID, exptID, createdDate, itemIDs any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByExptIDItemIDs", reflect.TypeOf((*MockIExptTurnResultFilterDAO)(nil).GetByExptIDItemIDs), arg0, arg1, arg2, arg3, arg4) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByExptIDItemIDs", reflect.TypeOf((*MockIExptTurnResultFilterDAO)(nil).GetByExptIDItemIDs), ctx, spaceID, exptID, createdDate, itemIDs) } // QueryItemIDStates mocks base method. -func (m *MockIExptTurnResultFilterDAO) QueryItemIDStates(arg0 context.Context, arg1 *ck.ExptTurnResultFilterQueryCond) (map[string]int32, int64, error) { +func (m *MockIExptTurnResultFilterDAO) QueryItemIDStates(ctx context.Context, cond *ck.ExptTurnResultFilterQueryCond) (map[string]int32, int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "QueryItemIDStates", arg0, arg1) + ret := m.ctrl.Call(m, "QueryItemIDStates", ctx, cond) ret0, _ := ret[0].(map[string]int32) ret1, _ := ret[1].(int64) ret2, _ := ret[2].(error) @@ -63,21 +69,21 @@ func (m *MockIExptTurnResultFilterDAO) QueryItemIDStates(arg0 context.Context, a } // QueryItemIDStates indicates an expected call of QueryItemIDStates. -func (mr *MockIExptTurnResultFilterDAOMockRecorder) QueryItemIDStates(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIExptTurnResultFilterDAOMockRecorder) QueryItemIDStates(ctx, cond any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryItemIDStates", reflect.TypeOf((*MockIExptTurnResultFilterDAO)(nil).QueryItemIDStates), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryItemIDStates", reflect.TypeOf((*MockIExptTurnResultFilterDAO)(nil).QueryItemIDStates), ctx, cond) } // Save mocks base method. -func (m *MockIExptTurnResultFilterDAO) Save(arg0 context.Context, arg1 []*model.ExptTurnResultFilter) error { +func (m *MockIExptTurnResultFilterDAO) Save(ctx context.Context, filter []*model.ExptTurnResultFilter) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Save", arg0, arg1) + ret := m.ctrl.Call(m, "Save", ctx, filter) ret0, _ := ret[0].(error) return ret0 } // Save indicates an expected call of Save. -func (mr *MockIExptTurnResultFilterDAOMockRecorder) Save(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIExptTurnResultFilterDAOMockRecorder) Save(ctx, filter any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Save", reflect.TypeOf((*MockIExptTurnResultFilterDAO)(nil).Save), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Save", reflect.TypeOf((*MockIExptTurnResultFilterDAO)(nil).Save), ctx, filter) } diff --git a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/annotate_record.go b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/annotate_record.go index c1059477a..60c0fe23c 100644 --- a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/annotate_record.go +++ b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/annotate_record.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/experiment/mysql (interfaces: IAnnotateRecordDAO) +// +// Generated by this command: +// +// mockgen -destination=mocks/annotate_record.go -package=mocks . IAnnotateRecordDAO +// // Package mocks is a generated GoMock package. package mocks @@ -8,16 +13,16 @@ import ( context "context" reflect "reflect" - "go.uber.org/mock/gomock" - db "github.com/coze-dev/coze-loop/backend/infra/db" model "github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/experiment/mysql/gorm_gen/model" + gomock "go.uber.org/mock/gomock" ) // MockIAnnotateRecordDAO is a mock of IAnnotateRecordDAO interface. type MockIAnnotateRecordDAO struct { ctrl *gomock.Controller recorder *MockIAnnotateRecordDAOMockRecorder + isgomock struct{} } // MockIAnnotateRecordDAOMockRecorder is the mock recorder for MockIAnnotateRecordDAO. @@ -38,10 +43,10 @@ func (m *MockIAnnotateRecordDAO) EXPECT() *MockIAnnotateRecordDAOMockRecorder { } // BatchSave mocks base method. -func (m *MockIAnnotateRecordDAO) BatchSave(arg0 context.Context, arg1 []*model.AnnotateRecord, arg2 ...db.Option) error { +func (m *MockIAnnotateRecordDAO) BatchSave(ctx context.Context, annotateRecord []*model.AnnotateRecord, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, annotateRecord} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchSave", varargs...) @@ -50,32 +55,32 @@ func (m *MockIAnnotateRecordDAO) BatchSave(arg0 context.Context, arg1 []*model.A } // BatchSave indicates an expected call of BatchSave. -func (mr *MockIAnnotateRecordDAOMockRecorder) BatchSave(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockIAnnotateRecordDAOMockRecorder) BatchSave(ctx, annotateRecord any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, annotateRecord}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchSave", reflect.TypeOf((*MockIAnnotateRecordDAO)(nil).BatchSave), varargs...) } // MGetByID mocks base method. -func (m *MockIAnnotateRecordDAO) MGetByID(arg0 context.Context, arg1 []int64) ([]*model.AnnotateRecord, error) { +func (m *MockIAnnotateRecordDAO) MGetByID(ctx context.Context, ids []int64) ([]*model.AnnotateRecord, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MGetByID", arg0, arg1) + ret := m.ctrl.Call(m, "MGetByID", ctx, ids) ret0, _ := ret[0].([]*model.AnnotateRecord) ret1, _ := ret[1].(error) return ret0, ret1 } // MGetByID indicates an expected call of MGetByID. -func (mr *MockIAnnotateRecordDAOMockRecorder) MGetByID(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIAnnotateRecordDAOMockRecorder) MGetByID(ctx, ids any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetByID", reflect.TypeOf((*MockIAnnotateRecordDAO)(nil).MGetByID), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetByID", reflect.TypeOf((*MockIAnnotateRecordDAO)(nil).MGetByID), ctx, ids) } // Save mocks base method. -func (m *MockIAnnotateRecordDAO) Save(arg0 context.Context, arg1 *model.AnnotateRecord, arg2 ...db.Option) error { +func (m *MockIAnnotateRecordDAO) Save(ctx context.Context, annotateRecord *model.AnnotateRecord, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, annotateRecord} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Save", varargs...) @@ -84,17 +89,17 @@ func (m *MockIAnnotateRecordDAO) Save(arg0 context.Context, arg1 *model.Annotate } // Save indicates an expected call of Save. -func (mr *MockIAnnotateRecordDAOMockRecorder) Save(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockIAnnotateRecordDAOMockRecorder) Save(ctx, annotateRecord any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, annotateRecord}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Save", reflect.TypeOf((*MockIAnnotateRecordDAO)(nil).Save), varargs...) } // Update mocks base method. -func (m *MockIAnnotateRecordDAO) Update(arg0 context.Context, arg1 *model.AnnotateRecord, arg2 ...db.Option) error { +func (m *MockIAnnotateRecordDAO) Update(ctx context.Context, annotateRecord *model.AnnotateRecord, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, annotateRecord} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Update", varargs...) @@ -103,8 +108,8 @@ func (m *MockIAnnotateRecordDAO) Update(arg0 context.Context, arg1 *model.Annota } // Update indicates an expected call of Update. -func (mr *MockIAnnotateRecordDAOMockRecorder) Update(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockIAnnotateRecordDAOMockRecorder) Update(ctx, annotateRecord any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, annotateRecord}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockIAnnotateRecordDAO)(nil).Update), varargs...) } diff --git a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_aggr_result.go b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_aggr_result.go index bff63e345..e4d87aaca 100644 --- a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_aggr_result.go +++ b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_aggr_result.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/experiment/mysql (interfaces: ExptAggrResultDAO) +// +// Generated by this command: +// +// mockgen -destination=mocks/expt_aggr_result.go -package mocks . ExptAggrResultDAO +// // Package mocks is a generated GoMock package. package mocks @@ -8,16 +13,16 @@ import ( context "context" reflect "reflect" - "go.uber.org/mock/gomock" - db "github.com/coze-dev/coze-loop/backend/infra/db" model "github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/experiment/mysql/gorm_gen/model" + gomock "go.uber.org/mock/gomock" ) // MockExptAggrResultDAO is a mock of ExptAggrResultDAO interface. type MockExptAggrResultDAO struct { ctrl *gomock.Controller recorder *MockExptAggrResultDAOMockRecorder + isgomock struct{} } // MockExptAggrResultDAOMockRecorder is the mock recorder for MockExptAggrResultDAO. @@ -38,10 +43,10 @@ func (m *MockExptAggrResultDAO) EXPECT() *MockExptAggrResultDAOMockRecorder { } // BatchCreateExptAggrResult mocks base method. -func (m *MockExptAggrResultDAO) BatchCreateExptAggrResult(arg0 context.Context, arg1 []*model.ExptAggrResult, arg2 ...db.Option) error { +func (m *MockExptAggrResultDAO) BatchCreateExptAggrResult(ctx context.Context, exptAggrResults []*model.ExptAggrResult, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, exptAggrResults} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchCreateExptAggrResult", varargs...) @@ -50,17 +55,17 @@ func (m *MockExptAggrResultDAO) BatchCreateExptAggrResult(arg0 context.Context, } // BatchCreateExptAggrResult indicates an expected call of BatchCreateExptAggrResult. -func (mr *MockExptAggrResultDAOMockRecorder) BatchCreateExptAggrResult(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockExptAggrResultDAOMockRecorder) BatchCreateExptAggrResult(ctx, exptAggrResults any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, exptAggrResults}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchCreateExptAggrResult", reflect.TypeOf((*MockExptAggrResultDAO)(nil).BatchCreateExptAggrResult), varargs...) } // BatchGetExptAggrResultByExperimentIDs mocks base method. -func (m *MockExptAggrResultDAO) BatchGetExptAggrResultByExperimentIDs(arg0 context.Context, arg1 []int64, arg2 ...db.Option) ([]*model.ExptAggrResult, error) { +func (m *MockExptAggrResultDAO) BatchGetExptAggrResultByExperimentIDs(ctx context.Context, experimentIDs []int64, opts ...db.Option) ([]*model.ExptAggrResult, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, experimentIDs} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchGetExptAggrResultByExperimentIDs", varargs...) @@ -70,17 +75,17 @@ func (m *MockExptAggrResultDAO) BatchGetExptAggrResultByExperimentIDs(arg0 conte } // BatchGetExptAggrResultByExperimentIDs indicates an expected call of BatchGetExptAggrResultByExperimentIDs. -func (mr *MockExptAggrResultDAOMockRecorder) BatchGetExptAggrResultByExperimentIDs(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockExptAggrResultDAOMockRecorder) BatchGetExptAggrResultByExperimentIDs(ctx, experimentIDs any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, experimentIDs}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetExptAggrResultByExperimentIDs", reflect.TypeOf((*MockExptAggrResultDAO)(nil).BatchGetExptAggrResultByExperimentIDs), varargs...) } // CreateExptAggrResult mocks base method. -func (m *MockExptAggrResultDAO) CreateExptAggrResult(arg0 context.Context, arg1 *model.ExptAggrResult, arg2 ...db.Option) error { +func (m *MockExptAggrResultDAO) CreateExptAggrResult(ctx context.Context, exptAggrResult *model.ExptAggrResult, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, exptAggrResult} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "CreateExptAggrResult", varargs...) @@ -89,17 +94,17 @@ func (m *MockExptAggrResultDAO) CreateExptAggrResult(arg0 context.Context, arg1 } // CreateExptAggrResult indicates an expected call of CreateExptAggrResult. -func (mr *MockExptAggrResultDAOMockRecorder) CreateExptAggrResult(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockExptAggrResultDAOMockRecorder) CreateExptAggrResult(ctx, exptAggrResult any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, exptAggrResult}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateExptAggrResult", reflect.TypeOf((*MockExptAggrResultDAO)(nil).CreateExptAggrResult), varargs...) } // DeleteExptAggrResult mocks base method. -func (m *MockExptAggrResultDAO) DeleteExptAggrResult(arg0 context.Context, arg1 *model.ExptAggrResult, arg2 ...db.Option) error { +func (m *MockExptAggrResultDAO) DeleteExptAggrResult(ctx context.Context, exptAggrResult *model.ExptAggrResult, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, exptAggrResult} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "DeleteExptAggrResult", varargs...) @@ -108,17 +113,17 @@ func (m *MockExptAggrResultDAO) DeleteExptAggrResult(arg0 context.Context, arg1 } // DeleteExptAggrResult indicates an expected call of DeleteExptAggrResult. -func (mr *MockExptAggrResultDAOMockRecorder) DeleteExptAggrResult(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockExptAggrResultDAOMockRecorder) DeleteExptAggrResult(ctx, exptAggrResult any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, exptAggrResult}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteExptAggrResult", reflect.TypeOf((*MockExptAggrResultDAO)(nil).DeleteExptAggrResult), varargs...) } // GetExptAggrResult mocks base method. -func (m *MockExptAggrResultDAO) GetExptAggrResult(arg0 context.Context, arg1 int64, arg2 int32, arg3 string, arg4 ...db.Option) (*model.ExptAggrResult, error) { +func (m *MockExptAggrResultDAO) GetExptAggrResult(ctx context.Context, experimentID int64, fieldType int32, fieldKey string, opts ...db.Option) (*model.ExptAggrResult, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3} - for _, a := range arg4 { + varargs := []any{ctx, experimentID, fieldType, fieldKey} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetExptAggrResult", varargs...) @@ -128,17 +133,17 @@ func (m *MockExptAggrResultDAO) GetExptAggrResult(arg0 context.Context, arg1 int } // GetExptAggrResult indicates an expected call of GetExptAggrResult. -func (mr *MockExptAggrResultDAOMockRecorder) GetExptAggrResult(arg0, arg1, arg2, arg3 interface{}, arg4 ...interface{}) *gomock.Call { +func (mr *MockExptAggrResultDAOMockRecorder) GetExptAggrResult(ctx, experimentID, fieldType, fieldKey any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3}, arg4...) + varargs := append([]any{ctx, experimentID, fieldType, fieldKey}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExptAggrResult", reflect.TypeOf((*MockExptAggrResultDAO)(nil).GetExptAggrResult), varargs...) } // GetExptAggrResultByExperimentID mocks base method. -func (m *MockExptAggrResultDAO) GetExptAggrResultByExperimentID(arg0 context.Context, arg1 int64, arg2 ...db.Option) ([]*model.ExptAggrResult, error) { +func (m *MockExptAggrResultDAO) GetExptAggrResultByExperimentID(ctx context.Context, experimentID int64, opts ...db.Option) ([]*model.ExptAggrResult, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, experimentID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetExptAggrResultByExperimentID", varargs...) @@ -148,17 +153,17 @@ func (m *MockExptAggrResultDAO) GetExptAggrResultByExperimentID(arg0 context.Con } // GetExptAggrResultByExperimentID indicates an expected call of GetExptAggrResultByExperimentID. -func (mr *MockExptAggrResultDAOMockRecorder) GetExptAggrResultByExperimentID(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockExptAggrResultDAOMockRecorder) GetExptAggrResultByExperimentID(ctx, experimentID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, experimentID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExptAggrResultByExperimentID", reflect.TypeOf((*MockExptAggrResultDAO)(nil).GetExptAggrResultByExperimentID), varargs...) } // UpdateAndGetLatestVersion mocks base method. -func (m *MockExptAggrResultDAO) UpdateAndGetLatestVersion(arg0 context.Context, arg1 int64, arg2 int32, arg3 string, arg4 ...db.Option) (int64, error) { +func (m *MockExptAggrResultDAO) UpdateAndGetLatestVersion(ctx context.Context, experimentID int64, fieldType int32, fieldKey string, opts ...db.Option) (int64, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3} - for _, a := range arg4 { + varargs := []any{ctx, experimentID, fieldType, fieldKey} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "UpdateAndGetLatestVersion", varargs...) @@ -168,17 +173,17 @@ func (m *MockExptAggrResultDAO) UpdateAndGetLatestVersion(arg0 context.Context, } // UpdateAndGetLatestVersion indicates an expected call of UpdateAndGetLatestVersion. -func (mr *MockExptAggrResultDAOMockRecorder) UpdateAndGetLatestVersion(arg0, arg1, arg2, arg3 interface{}, arg4 ...interface{}) *gomock.Call { +func (mr *MockExptAggrResultDAOMockRecorder) UpdateAndGetLatestVersion(ctx, experimentID, fieldType, fieldKey any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3}, arg4...) + varargs := append([]any{ctx, experimentID, fieldType, fieldKey}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAndGetLatestVersion", reflect.TypeOf((*MockExptAggrResultDAO)(nil).UpdateAndGetLatestVersion), varargs...) } // UpdateExptAggrResultByVersion mocks base method. -func (m *MockExptAggrResultDAO) UpdateExptAggrResultByVersion(arg0 context.Context, arg1 *model.ExptAggrResult, arg2 int64, arg3 ...db.Option) error { +func (m *MockExptAggrResultDAO) UpdateExptAggrResultByVersion(ctx context.Context, exptAggrResult *model.ExptAggrResult, taskVersion int64, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, exptAggrResult, taskVersion} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "UpdateExptAggrResultByVersion", varargs...) @@ -187,8 +192,8 @@ func (m *MockExptAggrResultDAO) UpdateExptAggrResultByVersion(arg0 context.Conte } // UpdateExptAggrResultByVersion indicates an expected call of UpdateExptAggrResultByVersion. -func (mr *MockExptAggrResultDAOMockRecorder) UpdateExptAggrResultByVersion(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockExptAggrResultDAOMockRecorder) UpdateExptAggrResultByVersion(ctx, exptAggrResult, taskVersion any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, exptAggrResult, taskVersion}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateExptAggrResultByVersion", reflect.TypeOf((*MockExptAggrResultDAO)(nil).UpdateExptAggrResultByVersion), varargs...) } diff --git a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_evaluator_ref.go b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_evaluator_ref.go index a7a8b3e4d..ca7ae9367 100644 --- a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_evaluator_ref.go +++ b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_evaluator_ref.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/experiment/mysql (interfaces: IExptEvaluatorRefDAO) +// +// Generated by this command: +// +// mockgen -destination=mocks/expt_evaluator_ref.go -package mocks . IExptEvaluatorRefDAO +// // Package mocks is a generated GoMock package. package mocks @@ -16,6 +21,7 @@ import ( type MockIExptEvaluatorRefDAO struct { ctrl *gomock.Controller recorder *MockIExptEvaluatorRefDAOMockRecorder + isgomock struct{} } // MockIExptEvaluatorRefDAOMockRecorder is the mock recorder for MockIExptEvaluatorRefDAO. @@ -36,30 +42,30 @@ func (m *MockIExptEvaluatorRefDAO) EXPECT() *MockIExptEvaluatorRefDAOMockRecorde } // Create mocks base method. -func (m *MockIExptEvaluatorRefDAO) Create(arg0 context.Context, arg1 []*model.ExptEvaluatorRef) error { +func (m *MockIExptEvaluatorRefDAO) Create(ctx context.Context, exptEvaluatorRef []*model.ExptEvaluatorRef) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Create", arg0, arg1) + ret := m.ctrl.Call(m, "Create", ctx, exptEvaluatorRef) ret0, _ := ret[0].(error) return ret0 } // Create indicates an expected call of Create. -func (mr *MockIExptEvaluatorRefDAOMockRecorder) Create(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIExptEvaluatorRefDAOMockRecorder) Create(ctx, exptEvaluatorRef any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockIExptEvaluatorRefDAO)(nil).Create), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockIExptEvaluatorRefDAO)(nil).Create), ctx, exptEvaluatorRef) } // MGetByExptID mocks base method. -func (m *MockIExptEvaluatorRefDAO) MGetByExptID(arg0 context.Context, arg1 []int64, arg2 int64) ([]*model.ExptEvaluatorRef, error) { +func (m *MockIExptEvaluatorRefDAO) MGetByExptID(ctx context.Context, exptIDs []int64, spaceID int64) ([]*model.ExptEvaluatorRef, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MGetByExptID", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "MGetByExptID", ctx, exptIDs, spaceID) ret0, _ := ret[0].([]*model.ExptEvaluatorRef) ret1, _ := ret[1].(error) return ret0, ret1 } // MGetByExptID indicates an expected call of MGetByExptID. -func (mr *MockIExptEvaluatorRefDAOMockRecorder) MGetByExptID(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockIExptEvaluatorRefDAOMockRecorder) MGetByExptID(ctx, exptIDs, spaceID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetByExptID", reflect.TypeOf((*MockIExptEvaluatorRefDAO)(nil).MGetByExptID), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetByExptID", reflect.TypeOf((*MockIExptEvaluatorRefDAO)(nil).MGetByExptID), ctx, exptIDs, spaceID) } diff --git a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_insight_analysis_feedback_comment.go b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_insight_analysis_feedback_comment.go index faee0a332..691ec4388 100644 --- a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_insight_analysis_feedback_comment.go +++ b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_insight_analysis_feedback_comment.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/experiment/mysql (interfaces: IExptInsightAnalysisFeedbackCommentDAO) +// +// Generated by this command: +// +// mockgen -destination=mocks/expt_insight_analysis_feedback_comment.go -package mocks . IExptInsightAnalysisFeedbackCommentDAO +// // Package mocks is a generated GoMock package. package mocks @@ -18,6 +23,7 @@ import ( type MockIExptInsightAnalysisFeedbackCommentDAO struct { ctrl *gomock.Controller recorder *MockIExptInsightAnalysisFeedbackCommentDAOMockRecorder + isgomock struct{} } // MockIExptInsightAnalysisFeedbackCommentDAOMockRecorder is the mock recorder for MockIExptInsightAnalysisFeedbackCommentDAO. @@ -38,10 +44,10 @@ func (m *MockIExptInsightAnalysisFeedbackCommentDAO) EXPECT() *MockIExptInsightA } // Create mocks base method. -func (m *MockIExptInsightAnalysisFeedbackCommentDAO) Create(arg0 context.Context, arg1 *model.ExptInsightAnalysisFeedbackComment, arg2 ...db.Option) error { +func (m *MockIExptInsightAnalysisFeedbackCommentDAO) Create(ctx context.Context, feedbackComment *model.ExptInsightAnalysisFeedbackComment, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, feedbackComment} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Create", varargs...) @@ -50,31 +56,31 @@ func (m *MockIExptInsightAnalysisFeedbackCommentDAO) Create(arg0 context.Context } // Create indicates an expected call of Create. -func (mr *MockIExptInsightAnalysisFeedbackCommentDAOMockRecorder) Create(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockIExptInsightAnalysisFeedbackCommentDAOMockRecorder) Create(ctx, feedbackComment any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, feedbackComment}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockIExptInsightAnalysisFeedbackCommentDAO)(nil).Create), varargs...) } // Delete mocks base method. -func (m *MockIExptInsightAnalysisFeedbackCommentDAO) Delete(arg0 context.Context, arg1, arg2, arg3 int64) error { +func (m *MockIExptInsightAnalysisFeedbackCommentDAO) Delete(ctx context.Context, spaceID, exptID, commentID int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Delete", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "Delete", ctx, spaceID, exptID, commentID) ret0, _ := ret[0].(error) return ret0 } // Delete indicates an expected call of Delete. -func (mr *MockIExptInsightAnalysisFeedbackCommentDAOMockRecorder) Delete(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockIExptInsightAnalysisFeedbackCommentDAOMockRecorder) Delete(ctx, spaceID, exptID, commentID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockIExptInsightAnalysisFeedbackCommentDAO)(nil).Delete), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockIExptInsightAnalysisFeedbackCommentDAO)(nil).Delete), ctx, spaceID, exptID, commentID) } // GetByRecordID mocks base method. -func (m *MockIExptInsightAnalysisFeedbackCommentDAO) GetByRecordID(arg0 context.Context, arg1, arg2, arg3 int64, arg4 ...db.Option) (*model.ExptInsightAnalysisFeedbackComment, error) { +func (m *MockIExptInsightAnalysisFeedbackCommentDAO) GetByRecordID(ctx context.Context, spaceID, exptID, recordID int64, opts ...db.Option) (*model.ExptInsightAnalysisFeedbackComment, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3} - for _, a := range arg4 { + varargs := []any{ctx, spaceID, exptID, recordID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetByRecordID", varargs...) @@ -84,16 +90,16 @@ func (m *MockIExptInsightAnalysisFeedbackCommentDAO) GetByRecordID(arg0 context. } // GetByRecordID indicates an expected call of GetByRecordID. -func (mr *MockIExptInsightAnalysisFeedbackCommentDAOMockRecorder) GetByRecordID(arg0, arg1, arg2, arg3 interface{}, arg4 ...interface{}) *gomock.Call { +func (mr *MockIExptInsightAnalysisFeedbackCommentDAOMockRecorder) GetByRecordID(ctx, spaceID, exptID, recordID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3}, arg4...) + varargs := append([]any{ctx, spaceID, exptID, recordID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByRecordID", reflect.TypeOf((*MockIExptInsightAnalysisFeedbackCommentDAO)(nil).GetByRecordID), varargs...) } // List mocks base method. -func (m *MockIExptInsightAnalysisFeedbackCommentDAO) List(arg0 context.Context, arg1, arg2, arg3 int64, arg4 entity.Page) ([]*model.ExptInsightAnalysisFeedbackComment, int64, error) { +func (m *MockIExptInsightAnalysisFeedbackCommentDAO) List(ctx context.Context, spaceID, exptID, recordID int64, page entity.Page) ([]*model.ExptInsightAnalysisFeedbackComment, int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "List", arg0, arg1, arg2, arg3, arg4) + ret := m.ctrl.Call(m, "List", ctx, spaceID, exptID, recordID, page) ret0, _ := ret[0].([]*model.ExptInsightAnalysisFeedbackComment) ret1, _ := ret[1].(int64) ret2, _ := ret[2].(error) @@ -101,16 +107,16 @@ func (m *MockIExptInsightAnalysisFeedbackCommentDAO) List(arg0 context.Context, } // List indicates an expected call of List. -func (mr *MockIExptInsightAnalysisFeedbackCommentDAOMockRecorder) List(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { +func (mr *MockIExptInsightAnalysisFeedbackCommentDAOMockRecorder) List(ctx, spaceID, exptID, recordID, page any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockIExptInsightAnalysisFeedbackCommentDAO)(nil).List), arg0, arg1, arg2, arg3, arg4) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockIExptInsightAnalysisFeedbackCommentDAO)(nil).List), ctx, spaceID, exptID, recordID, page) } // Update mocks base method. -func (m *MockIExptInsightAnalysisFeedbackCommentDAO) Update(arg0 context.Context, arg1 *model.ExptInsightAnalysisFeedbackComment, arg2 ...db.Option) error { +func (m *MockIExptInsightAnalysisFeedbackCommentDAO) Update(ctx context.Context, feedbackComment *model.ExptInsightAnalysisFeedbackComment, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, feedbackComment} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Update", varargs...) @@ -119,8 +125,8 @@ func (m *MockIExptInsightAnalysisFeedbackCommentDAO) Update(arg0 context.Context } // Update indicates an expected call of Update. -func (mr *MockIExptInsightAnalysisFeedbackCommentDAOMockRecorder) Update(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockIExptInsightAnalysisFeedbackCommentDAOMockRecorder) Update(ctx, feedbackComment any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, feedbackComment}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockIExptInsightAnalysisFeedbackCommentDAO)(nil).Update), varargs...) } diff --git a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_insight_analysis_feedback_vote.go b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_insight_analysis_feedback_vote.go index addd54589..13964dd1b 100644 --- a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_insight_analysis_feedback_vote.go +++ b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_insight_analysis_feedback_vote.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/experiment/mysql (interfaces: IExptInsightAnalysisFeedbackVoteDAO) +// +// Generated by this command: +// +// mockgen -destination=mocks/expt_insight_analysis_feedback_vote.go -package mocks . IExptInsightAnalysisFeedbackVoteDAO +// // Package mocks is a generated GoMock package. package mocks @@ -17,6 +22,7 @@ import ( type MockIExptInsightAnalysisFeedbackVoteDAO struct { ctrl *gomock.Controller recorder *MockIExptInsightAnalysisFeedbackVoteDAOMockRecorder + isgomock struct{} } // MockIExptInsightAnalysisFeedbackVoteDAOMockRecorder is the mock recorder for MockIExptInsightAnalysisFeedbackVoteDAO. @@ -37,9 +43,9 @@ func (m *MockIExptInsightAnalysisFeedbackVoteDAO) EXPECT() *MockIExptInsightAnal } // Count mocks base method. -func (m *MockIExptInsightAnalysisFeedbackVoteDAO) Count(arg0 context.Context, arg1, arg2, arg3 int64) (int64, int64, error) { +func (m *MockIExptInsightAnalysisFeedbackVoteDAO) Count(ctx context.Context, spaceID, exptID, recordID int64) (int64, int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Count", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "Count", ctx, spaceID, exptID, recordID) ret0, _ := ret[0].(int64) ret1, _ := ret[1].(int64) ret2, _ := ret[2].(error) @@ -47,16 +53,16 @@ func (m *MockIExptInsightAnalysisFeedbackVoteDAO) Count(arg0 context.Context, ar } // Count indicates an expected call of Count. -func (mr *MockIExptInsightAnalysisFeedbackVoteDAOMockRecorder) Count(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockIExptInsightAnalysisFeedbackVoteDAOMockRecorder) Count(ctx, spaceID, exptID, recordID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Count", reflect.TypeOf((*MockIExptInsightAnalysisFeedbackVoteDAO)(nil).Count), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Count", reflect.TypeOf((*MockIExptInsightAnalysisFeedbackVoteDAO)(nil).Count), ctx, spaceID, exptID, recordID) } // Create mocks base method. -func (m *MockIExptInsightAnalysisFeedbackVoteDAO) Create(arg0 context.Context, arg1 *model.ExptInsightAnalysisFeedbackVote, arg2 ...db.Option) error { +func (m *MockIExptInsightAnalysisFeedbackVoteDAO) Create(ctx context.Context, feedbackVote *model.ExptInsightAnalysisFeedbackVote, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, feedbackVote} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Create", varargs...) @@ -65,17 +71,17 @@ func (m *MockIExptInsightAnalysisFeedbackVoteDAO) Create(arg0 context.Context, a } // Create indicates an expected call of Create. -func (mr *MockIExptInsightAnalysisFeedbackVoteDAOMockRecorder) Create(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockIExptInsightAnalysisFeedbackVoteDAOMockRecorder) Create(ctx, feedbackVote any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, feedbackVote}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockIExptInsightAnalysisFeedbackVoteDAO)(nil).Create), varargs...) } // GetByUser mocks base method. -func (m *MockIExptInsightAnalysisFeedbackVoteDAO) GetByUser(arg0 context.Context, arg1, arg2, arg3 int64, arg4 string, arg5 ...db.Option) (*model.ExptInsightAnalysisFeedbackVote, error) { +func (m *MockIExptInsightAnalysisFeedbackVoteDAO) GetByUser(ctx context.Context, spaceID, exptID, recordID int64, userID string, opts ...db.Option) (*model.ExptInsightAnalysisFeedbackVote, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3, arg4} - for _, a := range arg5 { + varargs := []any{ctx, spaceID, exptID, recordID, userID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetByUser", varargs...) @@ -85,17 +91,17 @@ func (m *MockIExptInsightAnalysisFeedbackVoteDAO) GetByUser(arg0 context.Context } // GetByUser indicates an expected call of GetByUser. -func (mr *MockIExptInsightAnalysisFeedbackVoteDAOMockRecorder) GetByUser(arg0, arg1, arg2, arg3, arg4 interface{}, arg5 ...interface{}) *gomock.Call { +func (mr *MockIExptInsightAnalysisFeedbackVoteDAOMockRecorder) GetByUser(ctx, spaceID, exptID, recordID, userID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3, arg4}, arg5...) + varargs := append([]any{ctx, spaceID, exptID, recordID, userID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByUser", reflect.TypeOf((*MockIExptInsightAnalysisFeedbackVoteDAO)(nil).GetByUser), varargs...) } // Update mocks base method. -func (m *MockIExptInsightAnalysisFeedbackVoteDAO) Update(arg0 context.Context, arg1 *model.ExptInsightAnalysisFeedbackVote, arg2 ...db.Option) error { +func (m *MockIExptInsightAnalysisFeedbackVoteDAO) Update(ctx context.Context, feedbackVote *model.ExptInsightAnalysisFeedbackVote, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, feedbackVote} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Update", varargs...) @@ -104,8 +110,8 @@ func (m *MockIExptInsightAnalysisFeedbackVoteDAO) Update(arg0 context.Context, a } // Update indicates an expected call of Update. -func (mr *MockIExptInsightAnalysisFeedbackVoteDAOMockRecorder) Update(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockIExptInsightAnalysisFeedbackVoteDAOMockRecorder) Update(ctx, feedbackVote any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, feedbackVote}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockIExptInsightAnalysisFeedbackVoteDAO)(nil).Update), varargs...) } diff --git a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_insight_analysis_record.go b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_insight_analysis_record.go index ccffbe7fc..63def3504 100644 --- a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_insight_analysis_record.go +++ b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_insight_analysis_record.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/experiment/mysql (interfaces: IExptInsightAnalysisRecordDAO) +// +// Generated by this command: +// +// mockgen -destination=mocks/expt_insight_analysis_record.go -package mocks . IExptInsightAnalysisRecordDAO +// // Package mocks is a generated GoMock package. package mocks @@ -18,6 +23,7 @@ import ( type MockIExptInsightAnalysisRecordDAO struct { ctrl *gomock.Controller recorder *MockIExptInsightAnalysisRecordDAOMockRecorder + isgomock struct{} } // MockIExptInsightAnalysisRecordDAOMockRecorder is the mock recorder for MockIExptInsightAnalysisRecordDAO. @@ -38,10 +44,10 @@ func (m *MockIExptInsightAnalysisRecordDAO) EXPECT() *MockIExptInsightAnalysisRe } // Create mocks base method. -func (m *MockIExptInsightAnalysisRecordDAO) Create(arg0 context.Context, arg1 *model.ExptInsightAnalysisRecord, arg2 ...db.Option) error { +func (m *MockIExptInsightAnalysisRecordDAO) Create(ctx context.Context, record *model.ExptInsightAnalysisRecord, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, record} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Create", varargs...) @@ -50,31 +56,31 @@ func (m *MockIExptInsightAnalysisRecordDAO) Create(arg0 context.Context, arg1 *m } // Create indicates an expected call of Create. -func (mr *MockIExptInsightAnalysisRecordDAOMockRecorder) Create(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockIExptInsightAnalysisRecordDAOMockRecorder) Create(ctx, record any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, record}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockIExptInsightAnalysisRecordDAO)(nil).Create), varargs...) } // Delete mocks base method. -func (m *MockIExptInsightAnalysisRecordDAO) Delete(arg0 context.Context, arg1, arg2, arg3 int64) error { +func (m *MockIExptInsightAnalysisRecordDAO) Delete(ctx context.Context, spaceID, exptID, recordID int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Delete", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "Delete", ctx, spaceID, exptID, recordID) ret0, _ := ret[0].(error) return ret0 } // Delete indicates an expected call of Delete. -func (mr *MockIExptInsightAnalysisRecordDAOMockRecorder) Delete(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockIExptInsightAnalysisRecordDAOMockRecorder) Delete(ctx, spaceID, exptID, recordID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockIExptInsightAnalysisRecordDAO)(nil).Delete), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockIExptInsightAnalysisRecordDAO)(nil).Delete), ctx, spaceID, exptID, recordID) } // GetByID mocks base method. -func (m *MockIExptInsightAnalysisRecordDAO) GetByID(arg0 context.Context, arg1, arg2, arg3 int64, arg4 ...db.Option) (*model.ExptInsightAnalysisRecord, error) { +func (m *MockIExptInsightAnalysisRecordDAO) GetByID(ctx context.Context, spaceID, exptID, recordID int64, opts ...db.Option) (*model.ExptInsightAnalysisRecord, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3} - for _, a := range arg4 { + varargs := []any{ctx, spaceID, exptID, recordID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetByID", varargs...) @@ -84,16 +90,16 @@ func (m *MockIExptInsightAnalysisRecordDAO) GetByID(arg0 context.Context, arg1, } // GetByID indicates an expected call of GetByID. -func (mr *MockIExptInsightAnalysisRecordDAOMockRecorder) GetByID(arg0, arg1, arg2, arg3 interface{}, arg4 ...interface{}) *gomock.Call { +func (mr *MockIExptInsightAnalysisRecordDAOMockRecorder) GetByID(ctx, spaceID, exptID, recordID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3}, arg4...) + varargs := append([]any{ctx, spaceID, exptID, recordID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByID", reflect.TypeOf((*MockIExptInsightAnalysisRecordDAO)(nil).GetByID), varargs...) } // List mocks base method. -func (m *MockIExptInsightAnalysisRecordDAO) List(arg0 context.Context, arg1, arg2 int64, arg3 entity.Page) ([]*model.ExptInsightAnalysisRecord, int64, error) { +func (m *MockIExptInsightAnalysisRecordDAO) List(ctx context.Context, spaceID, exptID int64, page entity.Page) ([]*model.ExptInsightAnalysisRecord, int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "List", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "List", ctx, spaceID, exptID, page) ret0, _ := ret[0].([]*model.ExptInsightAnalysisRecord) ret1, _ := ret[1].(int64) ret2, _ := ret[2].(error) @@ -101,16 +107,16 @@ func (m *MockIExptInsightAnalysisRecordDAO) List(arg0 context.Context, arg1, arg } // List indicates an expected call of List. -func (mr *MockIExptInsightAnalysisRecordDAOMockRecorder) List(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockIExptInsightAnalysisRecordDAOMockRecorder) List(ctx, spaceID, exptID, page any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockIExptInsightAnalysisRecordDAO)(nil).List), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockIExptInsightAnalysisRecordDAO)(nil).List), ctx, spaceID, exptID, page) } // Update mocks base method. -func (m *MockIExptInsightAnalysisRecordDAO) Update(arg0 context.Context, arg1 *model.ExptInsightAnalysisRecord, arg2 ...db.Option) error { +func (m *MockIExptInsightAnalysisRecordDAO) Update(ctx context.Context, record *model.ExptInsightAnalysisRecord, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, record} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Update", varargs...) @@ -119,8 +125,8 @@ func (m *MockIExptInsightAnalysisRecordDAO) Update(arg0 context.Context, arg1 *m } // Update indicates an expected call of Update. -func (mr *MockIExptInsightAnalysisRecordDAOMockRecorder) Update(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockIExptInsightAnalysisRecordDAOMockRecorder) Update(ctx, record any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, record}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockIExptInsightAnalysisRecordDAO)(nil).Update), varargs...) } diff --git a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_item_result.go b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_item_result.go index 03d486f73..70685b05c 100644 --- a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_item_result.go +++ b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_item_result.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/experiment/mysql (interfaces: IExptItemResultDAO) +// +// Generated by this command: +// +// mockgen -destination=mocks/expt_item_result.go -package mocks . IExptItemResultDAO +// // Package mocks is a generated GoMock package. package mocks @@ -18,6 +23,7 @@ import ( type MockIExptItemResultDAO struct { ctrl *gomock.Controller recorder *MockIExptItemResultDAOMockRecorder + isgomock struct{} } // MockIExptItemResultDAOMockRecorder is the mock recorder for MockIExptItemResultDAO. @@ -38,10 +44,10 @@ func (m *MockIExptItemResultDAO) EXPECT() *MockIExptItemResultDAOMockRecorder { } // BatchCreateNX mocks base method. -func (m *MockIExptItemResultDAO) BatchCreateNX(arg0 context.Context, arg1 []*model.ExptItemResult, arg2 ...db.Option) error { +func (m *MockIExptItemResultDAO) BatchCreateNX(ctx context.Context, itemResults []*model.ExptItemResult, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, itemResults} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchCreateNX", varargs...) @@ -50,17 +56,17 @@ func (m *MockIExptItemResultDAO) BatchCreateNX(arg0 context.Context, arg1 []*mod } // BatchCreateNX indicates an expected call of BatchCreateNX. -func (mr *MockIExptItemResultDAOMockRecorder) BatchCreateNX(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockIExptItemResultDAOMockRecorder) BatchCreateNX(ctx, itemResults any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, itemResults}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchCreateNX", reflect.TypeOf((*MockIExptItemResultDAO)(nil).BatchCreateNX), varargs...) } // BatchCreateNXRunLogs mocks base method. -func (m *MockIExptItemResultDAO) BatchCreateNXRunLogs(arg0 context.Context, arg1 []*model.ExptItemResultRunLog, arg2 ...db.Option) error { +func (m *MockIExptItemResultDAO) BatchCreateNXRunLogs(ctx context.Context, itemRunLogs []*model.ExptItemResultRunLog, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, itemRunLogs} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchCreateNXRunLogs", varargs...) @@ -69,17 +75,17 @@ func (m *MockIExptItemResultDAO) BatchCreateNXRunLogs(arg0 context.Context, arg1 } // BatchCreateNXRunLogs indicates an expected call of BatchCreateNXRunLogs. -func (mr *MockIExptItemResultDAOMockRecorder) BatchCreateNXRunLogs(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockIExptItemResultDAOMockRecorder) BatchCreateNXRunLogs(ctx, itemRunLogs any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, itemRunLogs}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchCreateNXRunLogs", reflect.TypeOf((*MockIExptItemResultDAO)(nil).BatchCreateNXRunLogs), varargs...) } // BatchGet mocks base method. -func (m *MockIExptItemResultDAO) BatchGet(arg0 context.Context, arg1, arg2 int64, arg3 []int64, arg4 ...db.Option) ([]*model.ExptItemResult, error) { +func (m *MockIExptItemResultDAO) BatchGet(ctx context.Context, spaceID, exptID int64, itemIDs []int64, opts ...db.Option) ([]*model.ExptItemResult, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3} - for _, a := range arg4 { + varargs := []any{ctx, spaceID, exptID, itemIDs} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchGet", varargs...) @@ -89,32 +95,32 @@ func (m *MockIExptItemResultDAO) BatchGet(arg0 context.Context, arg1, arg2 int64 } // BatchGet indicates an expected call of BatchGet. -func (mr *MockIExptItemResultDAOMockRecorder) BatchGet(arg0, arg1, arg2, arg3 interface{}, arg4 ...interface{}) *gomock.Call { +func (mr *MockIExptItemResultDAOMockRecorder) BatchGet(ctx, spaceID, exptID, itemIDs any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3}, arg4...) + varargs := append([]any{ctx, spaceID, exptID, itemIDs}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGet", reflect.TypeOf((*MockIExptItemResultDAO)(nil).BatchGet), varargs...) } // GetItemIDListByExptID mocks base method. -func (m *MockIExptItemResultDAO) GetItemIDListByExptID(arg0 context.Context, arg1, arg2 int64) ([]int64, error) { +func (m *MockIExptItemResultDAO) GetItemIDListByExptID(ctx context.Context, exptID, spaceID int64) ([]int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetItemIDListByExptID", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetItemIDListByExptID", ctx, exptID, spaceID) ret0, _ := ret[0].([]int64) ret1, _ := ret[1].(error) return ret0, ret1 } // GetItemIDListByExptID indicates an expected call of GetItemIDListByExptID. -func (mr *MockIExptItemResultDAOMockRecorder) GetItemIDListByExptID(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockIExptItemResultDAOMockRecorder) GetItemIDListByExptID(ctx, exptID, spaceID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetItemIDListByExptID", reflect.TypeOf((*MockIExptItemResultDAO)(nil).GetItemIDListByExptID), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetItemIDListByExptID", reflect.TypeOf((*MockIExptItemResultDAO)(nil).GetItemIDListByExptID), ctx, exptID, spaceID) } // GetItemRunLog mocks base method. -func (m *MockIExptItemResultDAO) GetItemRunLog(arg0 context.Context, arg1, arg2, arg3, arg4 int64, arg5 ...db.Option) (*model.ExptItemResultRunLog, error) { +func (m *MockIExptItemResultDAO) GetItemRunLog(ctx context.Context, exptID, exptRunID, itemID, spaceID int64, opts ...db.Option) (*model.ExptItemResultRunLog, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3, arg4} - for _, a := range arg5 { + varargs := []any{ctx, exptID, exptRunID, itemID, spaceID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetItemRunLog", varargs...) @@ -124,17 +130,17 @@ func (m *MockIExptItemResultDAO) GetItemRunLog(arg0 context.Context, arg1, arg2, } // GetItemRunLog indicates an expected call of GetItemRunLog. -func (mr *MockIExptItemResultDAOMockRecorder) GetItemRunLog(arg0, arg1, arg2, arg3, arg4 interface{}, arg5 ...interface{}) *gomock.Call { +func (mr *MockIExptItemResultDAOMockRecorder) GetItemRunLog(ctx, exptID, exptRunID, itemID, spaceID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3, arg4}, arg5...) + varargs := append([]any{ctx, exptID, exptRunID, itemID, spaceID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetItemRunLog", reflect.TypeOf((*MockIExptItemResultDAO)(nil).GetItemRunLog), varargs...) } // GetItemTurnResults mocks base method. -func (m *MockIExptItemResultDAO) GetItemTurnResults(arg0 context.Context, arg1, arg2, arg3 int64, arg4 ...db.Option) ([]*model.ExptTurnResult, error) { +func (m *MockIExptItemResultDAO) GetItemTurnResults(ctx context.Context, spaceID, exptID, itemID int64, opts ...db.Option) ([]*model.ExptTurnResult, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3} - for _, a := range arg4 { + varargs := []any{ctx, spaceID, exptID, itemID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetItemTurnResults", varargs...) @@ -144,17 +150,17 @@ func (m *MockIExptItemResultDAO) GetItemTurnResults(arg0 context.Context, arg1, } // GetItemTurnResults indicates an expected call of GetItemTurnResults. -func (mr *MockIExptItemResultDAOMockRecorder) GetItemTurnResults(arg0, arg1, arg2, arg3 interface{}, arg4 ...interface{}) *gomock.Call { +func (mr *MockIExptItemResultDAOMockRecorder) GetItemTurnResults(ctx, spaceID, exptID, itemID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3}, arg4...) + varargs := append([]any{ctx, spaceID, exptID, itemID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetItemTurnResults", reflect.TypeOf((*MockIExptItemResultDAO)(nil).GetItemTurnResults), varargs...) } // GetMaxItemIdxByExptID mocks base method. -func (m *MockIExptItemResultDAO) GetMaxItemIdxByExptID(arg0 context.Context, arg1, arg2 int64, arg3 ...db.Option) (int32, error) { +func (m *MockIExptItemResultDAO) GetMaxItemIdxByExptID(ctx context.Context, exptID, spaceID int64, opts ...db.Option) (int32, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, exptID, spaceID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetMaxItemIdxByExptID", varargs...) @@ -164,16 +170,16 @@ func (m *MockIExptItemResultDAO) GetMaxItemIdxByExptID(arg0 context.Context, arg } // GetMaxItemIdxByExptID indicates an expected call of GetMaxItemIdxByExptID. -func (mr *MockIExptItemResultDAOMockRecorder) GetMaxItemIdxByExptID(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockIExptItemResultDAOMockRecorder) GetMaxItemIdxByExptID(ctx, exptID, spaceID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, exptID, spaceID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMaxItemIdxByExptID", reflect.TypeOf((*MockIExptItemResultDAO)(nil).GetMaxItemIdxByExptID), varargs...) } // ListItemResultsByExptID mocks base method. -func (m *MockIExptItemResultDAO) ListItemResultsByExptID(arg0 context.Context, arg1, arg2 int64, arg3 entity.Page, arg4 bool) ([]*model.ExptItemResult, int64, error) { +func (m *MockIExptItemResultDAO) ListItemResultsByExptID(ctx context.Context, exptID, spaceID int64, page entity.Page, desc bool) ([]*model.ExptItemResult, int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListItemResultsByExptID", arg0, arg1, arg2, arg3, arg4) + ret := m.ctrl.Call(m, "ListItemResultsByExptID", ctx, exptID, spaceID, page, desc) ret0, _ := ret[0].([]*model.ExptItemResult) ret1, _ := ret[1].(int64) ret2, _ := ret[2].(error) @@ -181,16 +187,16 @@ func (m *MockIExptItemResultDAO) ListItemResultsByExptID(arg0 context.Context, a } // ListItemResultsByExptID indicates an expected call of ListItemResultsByExptID. -func (mr *MockIExptItemResultDAOMockRecorder) ListItemResultsByExptID(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { +func (mr *MockIExptItemResultDAOMockRecorder) ListItemResultsByExptID(ctx, exptID, spaceID, page, desc any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListItemResultsByExptID", reflect.TypeOf((*MockIExptItemResultDAO)(nil).ListItemResultsByExptID), arg0, arg1, arg2, arg3, arg4) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListItemResultsByExptID", reflect.TypeOf((*MockIExptItemResultDAO)(nil).ListItemResultsByExptID), ctx, exptID, spaceID, page, desc) } // MGetItemRunLog mocks base method. -func (m *MockIExptItemResultDAO) MGetItemRunLog(arg0 context.Context, arg1, arg2 int64, arg3 []int64, arg4 int64, arg5 ...db.Option) ([]*model.ExptItemResultRunLog, error) { +func (m *MockIExptItemResultDAO) MGetItemRunLog(ctx context.Context, exptID, exptRunID int64, itemIDs []int64, spaceID int64, opts ...db.Option) ([]*model.ExptItemResultRunLog, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3, arg4} - for _, a := range arg5 { + varargs := []any{ctx, exptID, exptRunID, itemIDs, spaceID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "MGetItemRunLog", varargs...) @@ -200,17 +206,17 @@ func (m *MockIExptItemResultDAO) MGetItemRunLog(arg0 context.Context, arg1, arg2 } // MGetItemRunLog indicates an expected call of MGetItemRunLog. -func (mr *MockIExptItemResultDAOMockRecorder) MGetItemRunLog(arg0, arg1, arg2, arg3, arg4 interface{}, arg5 ...interface{}) *gomock.Call { +func (mr *MockIExptItemResultDAOMockRecorder) MGetItemRunLog(ctx, exptID, exptRunID, itemIDs, spaceID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3, arg4}, arg5...) + varargs := append([]any{ctx, exptID, exptRunID, itemIDs, spaceID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetItemRunLog", reflect.TypeOf((*MockIExptItemResultDAO)(nil).MGetItemRunLog), varargs...) } // SaveItemResults mocks base method. -func (m *MockIExptItemResultDAO) SaveItemResults(arg0 context.Context, arg1 []*model.ExptItemResult, arg2 ...db.Option) error { +func (m *MockIExptItemResultDAO) SaveItemResults(ctx context.Context, itemResults []*model.ExptItemResult, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, itemResults} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "SaveItemResults", varargs...) @@ -219,17 +225,17 @@ func (m *MockIExptItemResultDAO) SaveItemResults(arg0 context.Context, arg1 []*m } // SaveItemResults indicates an expected call of SaveItemResults. -func (mr *MockIExptItemResultDAOMockRecorder) SaveItemResults(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockIExptItemResultDAOMockRecorder) SaveItemResults(ctx, itemResults any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, itemResults}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveItemResults", reflect.TypeOf((*MockIExptItemResultDAO)(nil).SaveItemResults), varargs...) } // ScanItemResults mocks base method. -func (m *MockIExptItemResultDAO) ScanItemResults(arg0 context.Context, arg1, arg2, arg3 int64, arg4 []int32, arg5 int64, arg6 ...db.Option) ([]*model.ExptItemResult, int64, error) { +func (m *MockIExptItemResultDAO) ScanItemResults(ctx context.Context, exptID, cursor, limit int64, status []int32, spaceID int64, opts ...db.Option) ([]*model.ExptItemResult, int64, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3, arg4, arg5} - for _, a := range arg6 { + varargs := []any{ctx, exptID, cursor, limit, status, spaceID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "ScanItemResults", varargs...) @@ -240,17 +246,17 @@ func (m *MockIExptItemResultDAO) ScanItemResults(arg0 context.Context, arg1, arg } // ScanItemResults indicates an expected call of ScanItemResults. -func (mr *MockIExptItemResultDAOMockRecorder) ScanItemResults(arg0, arg1, arg2, arg3, arg4, arg5 interface{}, arg6 ...interface{}) *gomock.Call { +func (mr *MockIExptItemResultDAOMockRecorder) ScanItemResults(ctx, exptID, cursor, limit, status, spaceID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3, arg4, arg5}, arg6...) + varargs := append([]any{ctx, exptID, cursor, limit, status, spaceID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ScanItemResults", reflect.TypeOf((*MockIExptItemResultDAO)(nil).ScanItemResults), varargs...) } // ScanItemRunLogs mocks base method. -func (m *MockIExptItemResultDAO) ScanItemRunLogs(arg0 context.Context, arg1, arg2 int64, arg3 *entity.ExptItemRunLogFilter, arg4, arg5, arg6 int64, arg7 ...db.Option) ([]*model.ExptItemResultRunLog, int64, error) { +func (m *MockIExptItemResultDAO) ScanItemRunLogs(ctx context.Context, exptID, exptRunID int64, filter *entity.ExptItemRunLogFilter, cursor, limit, spaceID int64, opts ...db.Option) ([]*model.ExptItemResultRunLog, int64, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3, arg4, arg5, arg6} - for _, a := range arg7 { + varargs := []any{ctx, exptID, exptRunID, filter, cursor, limit, spaceID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "ScanItemRunLogs", varargs...) @@ -261,17 +267,17 @@ func (m *MockIExptItemResultDAO) ScanItemRunLogs(arg0 context.Context, arg1, arg } // ScanItemRunLogs indicates an expected call of ScanItemRunLogs. -func (mr *MockIExptItemResultDAOMockRecorder) ScanItemRunLogs(arg0, arg1, arg2, arg3, arg4, arg5, arg6 interface{}, arg7 ...interface{}) *gomock.Call { +func (mr *MockIExptItemResultDAOMockRecorder) ScanItemRunLogs(ctx, exptID, exptRunID, filter, cursor, limit, spaceID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3, arg4, arg5, arg6}, arg7...) + varargs := append([]any{ctx, exptID, exptRunID, filter, cursor, limit, spaceID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ScanItemRunLogs", reflect.TypeOf((*MockIExptItemResultDAO)(nil).ScanItemRunLogs), varargs...) } // UpdateItemRunLog mocks base method. -func (m *MockIExptItemResultDAO) UpdateItemRunLog(arg0 context.Context, arg1, arg2 int64, arg3 []int64, arg4 map[string]interface{}, arg5 int64, arg6 ...db.Option) error { +func (m *MockIExptItemResultDAO) UpdateItemRunLog(ctx context.Context, exptID, exptRunID int64, itemID []int64, ufields map[string]any, spaceID int64, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3, arg4, arg5} - for _, a := range arg6 { + varargs := []any{ctx, exptID, exptRunID, itemID, ufields, spaceID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "UpdateItemRunLog", varargs...) @@ -280,17 +286,17 @@ func (m *MockIExptItemResultDAO) UpdateItemRunLog(arg0 context.Context, arg1, ar } // UpdateItemRunLog indicates an expected call of UpdateItemRunLog. -func (mr *MockIExptItemResultDAOMockRecorder) UpdateItemRunLog(arg0, arg1, arg2, arg3, arg4, arg5 interface{}, arg6 ...interface{}) *gomock.Call { +func (mr *MockIExptItemResultDAOMockRecorder) UpdateItemRunLog(ctx, exptID, exptRunID, itemID, ufields, spaceID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3, arg4, arg5}, arg6...) + varargs := append([]any{ctx, exptID, exptRunID, itemID, ufields, spaceID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateItemRunLog", reflect.TypeOf((*MockIExptItemResultDAO)(nil).UpdateItemRunLog), varargs...) } // UpdateItemsResult mocks base method. -func (m *MockIExptItemResultDAO) UpdateItemsResult(arg0 context.Context, arg1, arg2 int64, arg3 []int64, arg4 map[string]interface{}, arg5 ...db.Option) error { +func (m *MockIExptItemResultDAO) UpdateItemsResult(ctx context.Context, spaceID, exptID int64, itemID []int64, ufields map[string]any, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3, arg4} - for _, a := range arg5 { + varargs := []any{ctx, spaceID, exptID, itemID, ufields} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "UpdateItemsResult", varargs...) @@ -299,8 +305,8 @@ func (m *MockIExptItemResultDAO) UpdateItemsResult(arg0 context.Context, arg1, a } // UpdateItemsResult indicates an expected call of UpdateItemsResult. -func (mr *MockIExptItemResultDAOMockRecorder) UpdateItemsResult(arg0, arg1, arg2, arg3, arg4 interface{}, arg5 ...interface{}) *gomock.Call { +func (mr *MockIExptItemResultDAOMockRecorder) UpdateItemsResult(ctx, spaceID, exptID, itemID, ufields any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3, arg4}, arg5...) + varargs := append([]any{ctx, spaceID, exptID, itemID, ufields}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateItemsResult", reflect.TypeOf((*MockIExptItemResultDAO)(nil).UpdateItemsResult), varargs...) } diff --git a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_result_export_record.go b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_result_export_record.go index 6385624a6..663371ba0 100644 --- a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_result_export_record.go +++ b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_result_export_record.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/experiment/mysql (interfaces: ExptResultExportRecordDAO) +// +// Generated by this command: +// +// mockgen -destination=mocks/expt_result_export_record.go -package=mocks . ExptResultExportRecordDAO +// // Package mocks is a generated GoMock package. package mocks @@ -18,6 +23,7 @@ import ( type MockExptResultExportRecordDAO struct { ctrl *gomock.Controller recorder *MockExptResultExportRecordDAOMockRecorder + isgomock struct{} } // MockExptResultExportRecordDAOMockRecorder is the mock recorder for MockExptResultExportRecordDAO. @@ -38,10 +44,10 @@ func (m *MockExptResultExportRecordDAO) EXPECT() *MockExptResultExportRecordDAOM } // Create mocks base method. -func (m *MockExptResultExportRecordDAO) Create(arg0 context.Context, arg1 *model.ExptResultExportRecord, arg2 ...db.Option) error { +func (m *MockExptResultExportRecordDAO) Create(ctx context.Context, exptResultExportRecord *model.ExptResultExportRecord, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, exptResultExportRecord} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Create", varargs...) @@ -50,17 +56,17 @@ func (m *MockExptResultExportRecordDAO) Create(arg0 context.Context, arg1 *model } // Create indicates an expected call of Create. -func (mr *MockExptResultExportRecordDAOMockRecorder) Create(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockExptResultExportRecordDAOMockRecorder) Create(ctx, exptResultExportRecord any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, exptResultExportRecord}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockExptResultExportRecordDAO)(nil).Create), varargs...) } // Get mocks base method. -func (m *MockExptResultExportRecordDAO) Get(arg0 context.Context, arg1, arg2 int64, arg3 ...db.Option) (*model.ExptResultExportRecord, error) { +func (m *MockExptResultExportRecordDAO) Get(ctx context.Context, spaceID, exportID int64, opts ...db.Option) (*model.ExptResultExportRecord, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, spaceID, exportID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Get", varargs...) @@ -70,16 +76,16 @@ func (m *MockExptResultExportRecordDAO) Get(arg0 context.Context, arg1, arg2 int } // Get indicates an expected call of Get. -func (mr *MockExptResultExportRecordDAOMockRecorder) Get(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockExptResultExportRecordDAOMockRecorder) Get(ctx, spaceID, exportID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, spaceID, exportID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockExptResultExportRecordDAO)(nil).Get), varargs...) } // List mocks base method. -func (m *MockExptResultExportRecordDAO) List(arg0 context.Context, arg1, arg2 int64, arg3 entity.Page, arg4 *int32) ([]*model.ExptResultExportRecord, int64, error) { +func (m *MockExptResultExportRecordDAO) List(ctx context.Context, spaceID, exptID int64, page entity.Page, csvExportStatus *int32) ([]*model.ExptResultExportRecord, int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "List", arg0, arg1, arg2, arg3, arg4) + ret := m.ctrl.Call(m, "List", ctx, spaceID, exptID, page, csvExportStatus) ret0, _ := ret[0].([]*model.ExptResultExportRecord) ret1, _ := ret[1].(int64) ret2, _ := ret[2].(error) @@ -87,16 +93,16 @@ func (m *MockExptResultExportRecordDAO) List(arg0 context.Context, arg1, arg2 in } // List indicates an expected call of List. -func (mr *MockExptResultExportRecordDAOMockRecorder) List(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { +func (mr *MockExptResultExportRecordDAOMockRecorder) List(ctx, spaceID, exptID, page, csvExportStatus any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockExptResultExportRecordDAO)(nil).List), arg0, arg1, arg2, arg3, arg4) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockExptResultExportRecordDAO)(nil).List), ctx, spaceID, exptID, page, csvExportStatus) } // Update mocks base method. -func (m *MockExptResultExportRecordDAO) Update(arg0 context.Context, arg1 *model.ExptResultExportRecord, arg2 ...db.Option) error { +func (m *MockExptResultExportRecordDAO) Update(ctx context.Context, exptResultExportRecord *model.ExptResultExportRecord, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, exptResultExportRecord} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Update", varargs...) @@ -105,8 +111,8 @@ func (m *MockExptResultExportRecordDAO) Update(arg0 context.Context, arg1 *model } // Update indicates an expected call of Update. -func (mr *MockExptResultExportRecordDAOMockRecorder) Update(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockExptResultExportRecordDAOMockRecorder) Update(ctx, exptResultExportRecord any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, exptResultExportRecord}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockExptResultExportRecordDAO)(nil).Update), varargs...) } diff --git a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_run_log.go b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_run_log.go index 156e67cc6..ef1c1d21b 100644 --- a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_run_log.go +++ b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_run_log.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/experiment/mysql (interfaces: IExptRunLogDAO) +// +// Generated by this command: +// +// mockgen -destination=mocks/expt_run_log.go -package mocks . IExptRunLogDAO +// // Package mocks is a generated GoMock package. package mocks @@ -17,6 +22,7 @@ import ( type MockIExptRunLogDAO struct { ctrl *gomock.Controller recorder *MockIExptRunLogDAOMockRecorder + isgomock struct{} } // MockIExptRunLogDAOMockRecorder is the mock recorder for MockIExptRunLogDAO. @@ -37,10 +43,10 @@ func (m *MockIExptRunLogDAO) EXPECT() *MockIExptRunLogDAOMockRecorder { } // Create mocks base method. -func (m *MockIExptRunLogDAO) Create(arg0 context.Context, arg1 *model.ExptRunLog, arg2 ...db.Option) error { +func (m *MockIExptRunLogDAO) Create(ctx context.Context, exptRunLog *model.ExptRunLog, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, exptRunLog} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Create", varargs...) @@ -49,17 +55,17 @@ func (m *MockIExptRunLogDAO) Create(arg0 context.Context, arg1 *model.ExptRunLog } // Create indicates an expected call of Create. -func (mr *MockIExptRunLogDAOMockRecorder) Create(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockIExptRunLogDAOMockRecorder) Create(ctx, exptRunLog any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, exptRunLog}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockIExptRunLogDAO)(nil).Create), varargs...) } // Get mocks base method. -func (m *MockIExptRunLogDAO) Get(arg0 context.Context, arg1, arg2 int64, arg3 ...db.Option) (*model.ExptRunLog, error) { +func (m *MockIExptRunLogDAO) Get(ctx context.Context, exptID, exptRunID int64, opts ...db.Option) (*model.ExptRunLog, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, exptID, exptRunID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Get", varargs...) @@ -69,17 +75,17 @@ func (m *MockIExptRunLogDAO) Get(arg0 context.Context, arg1, arg2 int64, arg3 .. } // Get indicates an expected call of Get. -func (mr *MockIExptRunLogDAOMockRecorder) Get(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockIExptRunLogDAOMockRecorder) Get(ctx, exptID, exptRunID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, exptID, exptRunID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockIExptRunLogDAO)(nil).Get), varargs...) } // Save mocks base method. -func (m *MockIExptRunLogDAO) Save(arg0 context.Context, arg1 *model.ExptRunLog, arg2 ...db.Option) error { +func (m *MockIExptRunLogDAO) Save(ctx context.Context, exptRunLog *model.ExptRunLog, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, exptRunLog} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Save", varargs...) @@ -88,17 +94,17 @@ func (m *MockIExptRunLogDAO) Save(arg0 context.Context, arg1 *model.ExptRunLog, } // Save indicates an expected call of Save. -func (mr *MockIExptRunLogDAOMockRecorder) Save(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockIExptRunLogDAOMockRecorder) Save(ctx, exptRunLog any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, exptRunLog}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Save", reflect.TypeOf((*MockIExptRunLogDAO)(nil).Save), varargs...) } // Update mocks base method. -func (m *MockIExptRunLogDAO) Update(arg0 context.Context, arg1, arg2 int64, arg3 map[string]interface{}, arg4 ...db.Option) error { +func (m *MockIExptRunLogDAO) Update(ctx context.Context, exptID, exptRunID int64, ufields map[string]any, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3} - for _, a := range arg4 { + varargs := []any{ctx, exptID, exptRunID, ufields} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Update", varargs...) @@ -107,8 +113,8 @@ func (m *MockIExptRunLogDAO) Update(arg0 context.Context, arg1, arg2 int64, arg3 } // Update indicates an expected call of Update. -func (mr *MockIExptRunLogDAOMockRecorder) Update(arg0, arg1, arg2, arg3 interface{}, arg4 ...interface{}) *gomock.Call { +func (mr *MockIExptRunLogDAOMockRecorder) Update(ctx, exptID, exptRunID, ufields any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3}, arg4...) + varargs := append([]any{ctx, exptID, exptRunID, ufields}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockIExptRunLogDAO)(nil).Update), varargs...) } diff --git a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_stats.go b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_stats.go index cb02bfedb..b2c61129c 100644 --- a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_stats.go +++ b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_stats.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/experiment/mysql (interfaces: IExptStatsDAO) +// +// Generated by this command: +// +// mockgen -destination=mocks/expt_stats.go -package mocks . IExptStatsDAO +// // Package mocks is a generated GoMock package. package mocks @@ -17,6 +22,7 @@ import ( type MockIExptStatsDAO struct { ctrl *gomock.Controller recorder *MockIExptStatsDAOMockRecorder + isgomock struct{} } // MockIExptStatsDAOMockRecorder is the mock recorder for MockIExptStatsDAO. @@ -37,87 +43,87 @@ func (m *MockIExptStatsDAO) EXPECT() *MockIExptStatsDAOMockRecorder { } // ArithOperateCount mocks base method. -func (m *MockIExptStatsDAO) ArithOperateCount(arg0 context.Context, arg1, arg2 int64, arg3 *entity.StatsCntArithOp) error { +func (m *MockIExptStatsDAO) ArithOperateCount(ctx context.Context, exptID, spaceID int64, cntArithOp *entity.StatsCntArithOp) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ArithOperateCount", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "ArithOperateCount", ctx, exptID, spaceID, cntArithOp) ret0, _ := ret[0].(error) return ret0 } // ArithOperateCount indicates an expected call of ArithOperateCount. -func (mr *MockIExptStatsDAOMockRecorder) ArithOperateCount(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockIExptStatsDAOMockRecorder) ArithOperateCount(ctx, exptID, spaceID, cntArithOp any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ArithOperateCount", reflect.TypeOf((*MockIExptStatsDAO)(nil).ArithOperateCount), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ArithOperateCount", reflect.TypeOf((*MockIExptStatsDAO)(nil).ArithOperateCount), ctx, exptID, spaceID, cntArithOp) } // Create mocks base method. -func (m *MockIExptStatsDAO) Create(arg0 context.Context, arg1 *model.ExptStats) error { +func (m *MockIExptStatsDAO) Create(ctx context.Context, stats *model.ExptStats) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Create", arg0, arg1) + ret := m.ctrl.Call(m, "Create", ctx, stats) ret0, _ := ret[0].(error) return ret0 } // Create indicates an expected call of Create. -func (mr *MockIExptStatsDAOMockRecorder) Create(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIExptStatsDAOMockRecorder) Create(ctx, stats any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockIExptStatsDAO)(nil).Create), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockIExptStatsDAO)(nil).Create), ctx, stats) } // Get mocks base method. -func (m *MockIExptStatsDAO) Get(arg0 context.Context, arg1, arg2 int64) (*model.ExptStats, error) { +func (m *MockIExptStatsDAO) Get(ctx context.Context, exptID, spaceID int64) (*model.ExptStats, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "Get", ctx, exptID, spaceID) ret0, _ := ret[0].(*model.ExptStats) ret1, _ := ret[1].(error) return ret0, ret1 } // Get indicates an expected call of Get. -func (mr *MockIExptStatsDAOMockRecorder) Get(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockIExptStatsDAOMockRecorder) Get(ctx, exptID, spaceID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockIExptStatsDAO)(nil).Get), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockIExptStatsDAO)(nil).Get), ctx, exptID, spaceID) } // MGet mocks base method. -func (m *MockIExptStatsDAO) MGet(arg0 context.Context, arg1 []int64, arg2 int64) ([]*model.ExptStats, error) { +func (m *MockIExptStatsDAO) MGet(ctx context.Context, exptIDs []int64, spaceID int64) ([]*model.ExptStats, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MGet", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "MGet", ctx, exptIDs, spaceID) ret0, _ := ret[0].([]*model.ExptStats) ret1, _ := ret[1].(error) return ret0, ret1 } // MGet indicates an expected call of MGet. -func (mr *MockIExptStatsDAOMockRecorder) MGet(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockIExptStatsDAOMockRecorder) MGet(ctx, exptIDs, spaceID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGet", reflect.TypeOf((*MockIExptStatsDAO)(nil).MGet), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGet", reflect.TypeOf((*MockIExptStatsDAO)(nil).MGet), ctx, exptIDs, spaceID) } // Save mocks base method. -func (m *MockIExptStatsDAO) Save(arg0 context.Context, arg1 *model.ExptStats) error { +func (m *MockIExptStatsDAO) Save(ctx context.Context, stats *model.ExptStats) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Save", arg0, arg1) + ret := m.ctrl.Call(m, "Save", ctx, stats) ret0, _ := ret[0].(error) return ret0 } // Save indicates an expected call of Save. -func (mr *MockIExptStatsDAOMockRecorder) Save(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIExptStatsDAOMockRecorder) Save(ctx, stats any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Save", reflect.TypeOf((*MockIExptStatsDAO)(nil).Save), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Save", reflect.TypeOf((*MockIExptStatsDAO)(nil).Save), ctx, stats) } // UpdateByExptID mocks base method. -func (m *MockIExptStatsDAO) UpdateByExptID(arg0 context.Context, arg1, arg2 int64, arg3 *model.ExptStats) error { +func (m *MockIExptStatsDAO) UpdateByExptID(ctx context.Context, exptID, spaceID int64, stats *model.ExptStats) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateByExptID", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "UpdateByExptID", ctx, exptID, spaceID, stats) ret0, _ := ret[0].(error) return ret0 } // UpdateByExptID indicates an expected call of UpdateByExptID. -func (mr *MockIExptStatsDAOMockRecorder) UpdateByExptID(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockIExptStatsDAOMockRecorder) UpdateByExptID(ctx, exptID, spaceID, stats any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateByExptID", reflect.TypeOf((*MockIExptStatsDAO)(nil).UpdateByExptID), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateByExptID", reflect.TypeOf((*MockIExptStatsDAO)(nil).UpdateByExptID), ctx, exptID, spaceID, stats) } diff --git a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_turn_annotate_record_ref.go b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_turn_annotate_record_ref.go index 178987666..f291d781e 100644 --- a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_turn_annotate_record_ref.go +++ b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_turn_annotate_record_ref.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/experiment/mysql (interfaces: IExptTurnAnnotateRecordRefDAO) +// +// Generated by this command: +// +// mockgen -destination=mocks/expt_turn_annotate_record_ref.go -package=mocks . IExptTurnAnnotateRecordRefDAO +// // Package mocks is a generated GoMock package. package mocks @@ -10,13 +15,14 @@ import ( db "github.com/coze-dev/coze-loop/backend/infra/db" model "github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/experiment/mysql/gorm_gen/model" - "go.uber.org/mock/gomock" + gomock "go.uber.org/mock/gomock" ) // MockIExptTurnAnnotateRecordRefDAO is a mock of IExptTurnAnnotateRecordRefDAO interface. type MockIExptTurnAnnotateRecordRefDAO struct { ctrl *gomock.Controller recorder *MockIExptTurnAnnotateRecordRefDAOMockRecorder + isgomock struct{} } // MockIExptTurnAnnotateRecordRefDAOMockRecorder is the mock recorder for MockIExptTurnAnnotateRecordRefDAO. @@ -37,10 +43,10 @@ func (m *MockIExptTurnAnnotateRecordRefDAO) EXPECT() *MockIExptTurnAnnotateRecor } // BatchGet mocks base method. -func (m *MockIExptTurnAnnotateRecordRefDAO) BatchGet(arg0 context.Context, arg1 int64, arg2 []int64, arg3 ...db.Option) ([]*model.ExptTurnAnnotateRecordRef, error) { +func (m *MockIExptTurnAnnotateRecordRefDAO) BatchGet(ctx context.Context, spaceID int64, exptTurnResultIDs []int64, opts ...db.Option) ([]*model.ExptTurnAnnotateRecordRef, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, spaceID, exptTurnResultIDs} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchGet", varargs...) @@ -50,17 +56,17 @@ func (m *MockIExptTurnAnnotateRecordRefDAO) BatchGet(arg0 context.Context, arg1 } // BatchGet indicates an expected call of BatchGet. -func (mr *MockIExptTurnAnnotateRecordRefDAOMockRecorder) BatchGet(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockIExptTurnAnnotateRecordRefDAOMockRecorder) BatchGet(ctx, spaceID, exptTurnResultIDs any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, spaceID, exptTurnResultIDs}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGet", reflect.TypeOf((*MockIExptTurnAnnotateRecordRefDAO)(nil).BatchGet), varargs...) } // BatchGetByExptIDs mocks base method. -func (m *MockIExptTurnAnnotateRecordRefDAO) BatchGetByExptIDs(arg0 context.Context, arg1 int64, arg2 []int64, arg3 ...db.Option) ([]*model.ExptTurnAnnotateRecordRef, error) { +func (m *MockIExptTurnAnnotateRecordRefDAO) BatchGetByExptIDs(ctx context.Context, spaceID int64, exptIDs []int64, opts ...db.Option) ([]*model.ExptTurnAnnotateRecordRef, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, spaceID, exptIDs} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchGetByExptIDs", varargs...) @@ -70,17 +76,17 @@ func (m *MockIExptTurnAnnotateRecordRefDAO) BatchGetByExptIDs(arg0 context.Conte } // BatchGetByExptIDs indicates an expected call of BatchGetByExptIDs. -func (mr *MockIExptTurnAnnotateRecordRefDAOMockRecorder) BatchGetByExptIDs(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockIExptTurnAnnotateRecordRefDAOMockRecorder) BatchGetByExptIDs(ctx, spaceID, exptIDs any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, spaceID, exptIDs}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetByExptIDs", reflect.TypeOf((*MockIExptTurnAnnotateRecordRefDAO)(nil).BatchGetByExptIDs), varargs...) } // DeleteByTagKeyID mocks base method. -func (m *MockIExptTurnAnnotateRecordRefDAO) DeleteByTagKeyID(arg0 context.Context, arg1, arg2, arg3 int64, arg4 ...db.Option) error { +func (m *MockIExptTurnAnnotateRecordRefDAO) DeleteByTagKeyID(ctx context.Context, spaceID, exptID, tagKeyID int64, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3} - for _, a := range arg4 { + varargs := []any{ctx, spaceID, exptID, tagKeyID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "DeleteByTagKeyID", varargs...) @@ -89,17 +95,17 @@ func (m *MockIExptTurnAnnotateRecordRefDAO) DeleteByTagKeyID(arg0 context.Contex } // DeleteByTagKeyID indicates an expected call of DeleteByTagKeyID. -func (mr *MockIExptTurnAnnotateRecordRefDAOMockRecorder) DeleteByTagKeyID(arg0, arg1, arg2, arg3 interface{}, arg4 ...interface{}) *gomock.Call { +func (mr *MockIExptTurnAnnotateRecordRefDAOMockRecorder) DeleteByTagKeyID(ctx, spaceID, exptID, tagKeyID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3}, arg4...) + varargs := append([]any{ctx, spaceID, exptID, tagKeyID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteByTagKeyID", reflect.TypeOf((*MockIExptTurnAnnotateRecordRefDAO)(nil).DeleteByTagKeyID), varargs...) } // GetByExptID mocks base method. -func (m *MockIExptTurnAnnotateRecordRefDAO) GetByExptID(arg0 context.Context, arg1, arg2 int64, arg3 ...db.Option) ([]*model.ExptTurnAnnotateRecordRef, error) { +func (m *MockIExptTurnAnnotateRecordRefDAO) GetByExptID(ctx context.Context, spaceID, exptID int64, opts ...db.Option) ([]*model.ExptTurnAnnotateRecordRef, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, spaceID, exptID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetByExptID", varargs...) @@ -109,17 +115,17 @@ func (m *MockIExptTurnAnnotateRecordRefDAO) GetByExptID(arg0 context.Context, ar } // GetByExptID indicates an expected call of GetByExptID. -func (mr *MockIExptTurnAnnotateRecordRefDAOMockRecorder) GetByExptID(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockIExptTurnAnnotateRecordRefDAOMockRecorder) GetByExptID(ctx, spaceID, exptID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, spaceID, exptID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByExptID", reflect.TypeOf((*MockIExptTurnAnnotateRecordRefDAO)(nil).GetByExptID), varargs...) } // GetByTagKeyID mocks base method. -func (m *MockIExptTurnAnnotateRecordRefDAO) GetByTagKeyID(arg0 context.Context, arg1, arg2, arg3 int64, arg4 ...db.Option) ([]*model.ExptTurnAnnotateRecordRef, error) { +func (m *MockIExptTurnAnnotateRecordRefDAO) GetByTagKeyID(ctx context.Context, spaceID, exptID, tagKeyID int64, opts ...db.Option) ([]*model.ExptTurnAnnotateRecordRef, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3} - for _, a := range arg4 { + varargs := []any{ctx, spaceID, exptID, tagKeyID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetByTagKeyID", varargs...) @@ -129,17 +135,17 @@ func (m *MockIExptTurnAnnotateRecordRefDAO) GetByTagKeyID(arg0 context.Context, } // GetByTagKeyID indicates an expected call of GetByTagKeyID. -func (mr *MockIExptTurnAnnotateRecordRefDAOMockRecorder) GetByTagKeyID(arg0, arg1, arg2, arg3 interface{}, arg4 ...interface{}) *gomock.Call { +func (mr *MockIExptTurnAnnotateRecordRefDAOMockRecorder) GetByTagKeyID(ctx, spaceID, exptID, tagKeyID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3}, arg4...) + varargs := append([]any{ctx, spaceID, exptID, tagKeyID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByTagKeyID", reflect.TypeOf((*MockIExptTurnAnnotateRecordRefDAO)(nil).GetByTagKeyID), varargs...) } // Save mocks base method. -func (m *MockIExptTurnAnnotateRecordRefDAO) Save(arg0 context.Context, arg1 *model.ExptTurnAnnotateRecordRef, arg2 ...db.Option) error { +func (m *MockIExptTurnAnnotateRecordRefDAO) Save(ctx context.Context, refs *model.ExptTurnAnnotateRecordRef, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, refs} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Save", varargs...) @@ -148,8 +154,8 @@ func (m *MockIExptTurnAnnotateRecordRefDAO) Save(arg0 context.Context, arg1 *mod } // Save indicates an expected call of Save. -func (mr *MockIExptTurnAnnotateRecordRefDAOMockRecorder) Save(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockIExptTurnAnnotateRecordRefDAOMockRecorder) Save(ctx, refs any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, refs}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Save", reflect.TypeOf((*MockIExptTurnAnnotateRecordRefDAO)(nil).Save), varargs...) } diff --git a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_turn_evaluator_result_ref.go b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_turn_evaluator_result_ref.go index 4341e56db..ff61bfd9f 100644 --- a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_turn_evaluator_result_ref.go +++ b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_turn_evaluator_result_ref.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/experiment/mysql (interfaces: IExptTurnEvaluatorResultRefDAO) +// +// Generated by this command: +// +// mockgen -destination=mocks/expt_turn_evaluator_result_ref.go -package mocks . IExptTurnEvaluatorResultRefDAO +// // Package mocks is a generated GoMock package. package mocks @@ -17,6 +22,7 @@ import ( type MockIExptTurnEvaluatorResultRefDAO struct { ctrl *gomock.Controller recorder *MockIExptTurnEvaluatorResultRefDAOMockRecorder + isgomock struct{} } // MockIExptTurnEvaluatorResultRefDAOMockRecorder is the mock recorder for MockIExptTurnEvaluatorResultRefDAO. @@ -37,10 +43,10 @@ func (m *MockIExptTurnEvaluatorResultRefDAO) EXPECT() *MockIExptTurnEvaluatorRes } // BatchGet mocks base method. -func (m *MockIExptTurnEvaluatorResultRefDAO) BatchGet(arg0 context.Context, arg1 int64, arg2 []int64, arg3 ...db.Option) ([]*model.ExptTurnEvaluatorResultRef, error) { +func (m *MockIExptTurnEvaluatorResultRefDAO) BatchGet(ctx context.Context, spaceID int64, exptTurnResultIDs []int64, opts ...db.Option) ([]*model.ExptTurnEvaluatorResultRef, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, spaceID, exptTurnResultIDs} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchGet", varargs...) @@ -50,17 +56,17 @@ func (m *MockIExptTurnEvaluatorResultRefDAO) BatchGet(arg0 context.Context, arg1 } // BatchGet indicates an expected call of BatchGet. -func (mr *MockIExptTurnEvaluatorResultRefDAOMockRecorder) BatchGet(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockIExptTurnEvaluatorResultRefDAOMockRecorder) BatchGet(ctx, spaceID, exptTurnResultIDs any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, spaceID, exptTurnResultIDs}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGet", reflect.TypeOf((*MockIExptTurnEvaluatorResultRefDAO)(nil).BatchGet), varargs...) } // GetByExptEvaluatorVersionID mocks base method. -func (m *MockIExptTurnEvaluatorResultRefDAO) GetByExptEvaluatorVersionID(arg0 context.Context, arg1, arg2, arg3 int64, arg4 ...db.Option) ([]*model.ExptTurnEvaluatorResultRef, error) { +func (m *MockIExptTurnEvaluatorResultRefDAO) GetByExptEvaluatorVersionID(ctx context.Context, spaceID, exptID, evaluatorVersionID int64, opts ...db.Option) ([]*model.ExptTurnEvaluatorResultRef, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3} - for _, a := range arg4 { + varargs := []any{ctx, spaceID, exptID, evaluatorVersionID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetByExptEvaluatorVersionID", varargs...) @@ -70,17 +76,17 @@ func (m *MockIExptTurnEvaluatorResultRefDAO) GetByExptEvaluatorVersionID(arg0 co } // GetByExptEvaluatorVersionID indicates an expected call of GetByExptEvaluatorVersionID. -func (mr *MockIExptTurnEvaluatorResultRefDAOMockRecorder) GetByExptEvaluatorVersionID(arg0, arg1, arg2, arg3 interface{}, arg4 ...interface{}) *gomock.Call { +func (mr *MockIExptTurnEvaluatorResultRefDAOMockRecorder) GetByExptEvaluatorVersionID(ctx, spaceID, exptID, evaluatorVersionID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3}, arg4...) + varargs := append([]any{ctx, spaceID, exptID, evaluatorVersionID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByExptEvaluatorVersionID", reflect.TypeOf((*MockIExptTurnEvaluatorResultRefDAO)(nil).GetByExptEvaluatorVersionID), varargs...) } // GetByExptID mocks base method. -func (m *MockIExptTurnEvaluatorResultRefDAO) GetByExptID(arg0 context.Context, arg1, arg2 int64, arg3 ...db.Option) ([]*model.ExptTurnEvaluatorResultRef, error) { +func (m *MockIExptTurnEvaluatorResultRefDAO) GetByExptID(ctx context.Context, spaceID, exptID int64, opts ...db.Option) ([]*model.ExptTurnEvaluatorResultRef, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, spaceID, exptID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetByExptID", varargs...) @@ -90,8 +96,8 @@ func (m *MockIExptTurnEvaluatorResultRefDAO) GetByExptID(arg0 context.Context, a } // GetByExptID indicates an expected call of GetByExptID. -func (mr *MockIExptTurnEvaluatorResultRefDAOMockRecorder) GetByExptID(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockIExptTurnEvaluatorResultRefDAOMockRecorder) GetByExptID(ctx, spaceID, exptID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, spaceID, exptID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByExptID", reflect.TypeOf((*MockIExptTurnEvaluatorResultRefDAO)(nil).GetByExptID), varargs...) } diff --git a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_turn_result.go b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_turn_result.go index f3d2e1e6b..ef34df51e 100644 --- a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_turn_result.go +++ b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_turn_result.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/experiment/mysql (interfaces: ExptTurnResultDAO) +// +// Generated by this command: +// +// mockgen -destination=mocks/expt_turn_result.go -package mocks . ExptTurnResultDAO +// // Package mocks is a generated GoMock package. package mocks @@ -8,17 +13,17 @@ import ( context "context" reflect "reflect" - "go.uber.org/mock/gomock" - db "github.com/coze-dev/coze-loop/backend/infra/db" entity "github.com/coze-dev/coze-loop/backend/modules/evaluation/domain/entity" model "github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/experiment/mysql/gorm_gen/model" + gomock "go.uber.org/mock/gomock" ) // MockExptTurnResultDAO is a mock of ExptTurnResultDAO interface. type MockExptTurnResultDAO struct { ctrl *gomock.Controller recorder *MockExptTurnResultDAOMockRecorder + isgomock struct{} } // MockExptTurnResultDAOMockRecorder is the mock recorder for MockExptTurnResultDAO. @@ -39,10 +44,10 @@ func (m *MockExptTurnResultDAO) EXPECT() *MockExptTurnResultDAOMockRecorder { } // BatchCreateNX mocks base method. -func (m *MockExptTurnResultDAO) BatchCreateNX(arg0 context.Context, arg1 []*model.ExptTurnResult, arg2 ...db.Option) error { +func (m *MockExptTurnResultDAO) BatchCreateNX(ctx context.Context, turnResults []*model.ExptTurnResult, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, turnResults} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchCreateNX", varargs...) @@ -51,17 +56,17 @@ func (m *MockExptTurnResultDAO) BatchCreateNX(arg0 context.Context, arg1 []*mode } // BatchCreateNX indicates an expected call of BatchCreateNX. -func (mr *MockExptTurnResultDAOMockRecorder) BatchCreateNX(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockExptTurnResultDAOMockRecorder) BatchCreateNX(ctx, turnResults any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, turnResults}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchCreateNX", reflect.TypeOf((*MockExptTurnResultDAO)(nil).BatchCreateNX), varargs...) } // BatchCreateNXRunLog mocks base method. -func (m *MockExptTurnResultDAO) BatchCreateNXRunLog(arg0 context.Context, arg1 []*model.ExptTurnResultRunLog, arg2 ...db.Option) error { +func (m *MockExptTurnResultDAO) BatchCreateNXRunLog(ctx context.Context, turnResults []*model.ExptTurnResultRunLog, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, turnResults} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchCreateNXRunLog", varargs...) @@ -70,17 +75,17 @@ func (m *MockExptTurnResultDAO) BatchCreateNXRunLog(arg0 context.Context, arg1 [ } // BatchCreateNXRunLog indicates an expected call of BatchCreateNXRunLog. -func (mr *MockExptTurnResultDAOMockRecorder) BatchCreateNXRunLog(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockExptTurnResultDAOMockRecorder) BatchCreateNXRunLog(ctx, turnResults any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, turnResults}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchCreateNXRunLog", reflect.TypeOf((*MockExptTurnResultDAO)(nil).BatchCreateNXRunLog), varargs...) } // BatchGet mocks base method. -func (m *MockExptTurnResultDAO) BatchGet(arg0 context.Context, arg1, arg2 int64, arg3 []int64, arg4 ...db.Option) ([]*model.ExptTurnResult, error) { +func (m *MockExptTurnResultDAO) BatchGet(ctx context.Context, spaceID, exptID int64, itemIDs []int64, opts ...db.Option) ([]*model.ExptTurnResult, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3} - for _, a := range arg4 { + varargs := []any{ctx, spaceID, exptID, itemIDs} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchGet", varargs...) @@ -90,17 +95,17 @@ func (m *MockExptTurnResultDAO) BatchGet(arg0 context.Context, arg1, arg2 int64, } // BatchGet indicates an expected call of BatchGet. -func (mr *MockExptTurnResultDAOMockRecorder) BatchGet(arg0, arg1, arg2, arg3 interface{}, arg4 ...interface{}) *gomock.Call { +func (mr *MockExptTurnResultDAOMockRecorder) BatchGet(ctx, spaceID, exptID, itemIDs any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3}, arg4...) + varargs := append([]any{ctx, spaceID, exptID, itemIDs}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGet", reflect.TypeOf((*MockExptTurnResultDAO)(nil).BatchGet), varargs...) } // CreateTurnEvaluatorRefs mocks base method. -func (m *MockExptTurnResultDAO) CreateTurnEvaluatorRefs(arg0 context.Context, arg1 []*model.ExptTurnEvaluatorResultRef, arg2 ...db.Option) error { +func (m *MockExptTurnResultDAO) CreateTurnEvaluatorRefs(ctx context.Context, turnResults []*model.ExptTurnEvaluatorResultRef, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, turnResults} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "CreateTurnEvaluatorRefs", varargs...) @@ -109,17 +114,17 @@ func (m *MockExptTurnResultDAO) CreateTurnEvaluatorRefs(arg0 context.Context, ar } // CreateTurnEvaluatorRefs indicates an expected call of CreateTurnEvaluatorRefs. -func (mr *MockExptTurnResultDAOMockRecorder) CreateTurnEvaluatorRefs(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockExptTurnResultDAOMockRecorder) CreateTurnEvaluatorRefs(ctx, turnResults any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, turnResults}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTurnEvaluatorRefs", reflect.TypeOf((*MockExptTurnResultDAO)(nil).CreateTurnEvaluatorRefs), varargs...) } // Get mocks base method. -func (m *MockExptTurnResultDAO) Get(arg0 context.Context, arg1, arg2, arg3, arg4 int64, arg5 ...db.Option) (*model.ExptTurnResult, error) { +func (m *MockExptTurnResultDAO) Get(ctx context.Context, spaceID, exptID, itemID, turnID int64, opts ...db.Option) (*model.ExptTurnResult, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3, arg4} - for _, a := range arg5 { + varargs := []any{ctx, spaceID, exptID, itemID, turnID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Get", varargs...) @@ -129,17 +134,17 @@ func (m *MockExptTurnResultDAO) Get(arg0 context.Context, arg1, arg2, arg3, arg4 } // Get indicates an expected call of Get. -func (mr *MockExptTurnResultDAOMockRecorder) Get(arg0, arg1, arg2, arg3, arg4 interface{}, arg5 ...interface{}) *gomock.Call { +func (mr *MockExptTurnResultDAOMockRecorder) Get(ctx, spaceID, exptID, itemID, turnID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3, arg4}, arg5...) + varargs := append([]any{ctx, spaceID, exptID, itemID, turnID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockExptTurnResultDAO)(nil).Get), varargs...) } // GetItemTurnResults mocks base method. -func (m *MockExptTurnResultDAO) GetItemTurnResults(arg0 context.Context, arg1, arg2, arg3 int64, arg4 ...db.Option) ([]*model.ExptTurnResult, error) { +func (m *MockExptTurnResultDAO) GetItemTurnResults(ctx context.Context, exptID, itemID, spaceID int64, opts ...db.Option) ([]*model.ExptTurnResult, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3} - for _, a := range arg4 { + varargs := []any{ctx, exptID, itemID, spaceID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetItemTurnResults", varargs...) @@ -149,17 +154,17 @@ func (m *MockExptTurnResultDAO) GetItemTurnResults(arg0 context.Context, arg1, a } // GetItemTurnResults indicates an expected call of GetItemTurnResults. -func (mr *MockExptTurnResultDAOMockRecorder) GetItemTurnResults(arg0, arg1, arg2, arg3 interface{}, arg4 ...interface{}) *gomock.Call { +func (mr *MockExptTurnResultDAOMockRecorder) GetItemTurnResults(ctx, exptID, itemID, spaceID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3}, arg4...) + varargs := append([]any{ctx, exptID, itemID, spaceID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetItemTurnResults", reflect.TypeOf((*MockExptTurnResultDAO)(nil).GetItemTurnResults), varargs...) } // GetItemTurnRunLogs mocks base method. -func (m *MockExptTurnResultDAO) GetItemTurnRunLogs(arg0 context.Context, arg1, arg2, arg3, arg4 int64, arg5 ...db.Option) ([]*model.ExptTurnResultRunLog, error) { +func (m *MockExptTurnResultDAO) GetItemTurnRunLogs(ctx context.Context, exptID, exptRunID, itemID, spaceID int64, opts ...db.Option) ([]*model.ExptTurnResultRunLog, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3, arg4} - for _, a := range arg5 { + varargs := []any{ctx, exptID, exptRunID, itemID, spaceID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetItemTurnRunLogs", varargs...) @@ -169,17 +174,17 @@ func (m *MockExptTurnResultDAO) GetItemTurnRunLogs(arg0 context.Context, arg1, a } // GetItemTurnRunLogs indicates an expected call of GetItemTurnRunLogs. -func (mr *MockExptTurnResultDAOMockRecorder) GetItemTurnRunLogs(arg0, arg1, arg2, arg3, arg4 interface{}, arg5 ...interface{}) *gomock.Call { +func (mr *MockExptTurnResultDAOMockRecorder) GetItemTurnRunLogs(ctx, exptID, exptRunID, itemID, spaceID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3, arg4}, arg5...) + varargs := append([]any{ctx, exptID, exptRunID, itemID, spaceID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetItemTurnRunLogs", reflect.TypeOf((*MockExptTurnResultDAO)(nil).GetItemTurnRunLogs), varargs...) } // ListTurnResult mocks base method. -func (m *MockExptTurnResultDAO) ListTurnResult(arg0 context.Context, arg1, arg2 int64, arg3 *entity.ExptTurnResultFilter, arg4 entity.Page, arg5 bool, arg6 ...db.Option) ([]*model.ExptTurnResult, int64, error) { +func (m *MockExptTurnResultDAO) ListTurnResult(ctx context.Context, spaceID, exptID int64, filter *entity.ExptTurnResultFilter, page entity.Page, desc bool, opts ...db.Option) ([]*model.ExptTurnResult, int64, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3, arg4, arg5} - for _, a := range arg6 { + varargs := []any{ctx, spaceID, exptID, filter, page, desc} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "ListTurnResult", varargs...) @@ -190,17 +195,17 @@ func (m *MockExptTurnResultDAO) ListTurnResult(arg0 context.Context, arg1, arg2 } // ListTurnResult indicates an expected call of ListTurnResult. -func (mr *MockExptTurnResultDAOMockRecorder) ListTurnResult(arg0, arg1, arg2, arg3, arg4, arg5 interface{}, arg6 ...interface{}) *gomock.Call { +func (mr *MockExptTurnResultDAOMockRecorder) ListTurnResult(ctx, spaceID, exptID, filter, page, desc any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3, arg4, arg5}, arg6...) + varargs := append([]any{ctx, spaceID, exptID, filter, page, desc}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListTurnResult", reflect.TypeOf((*MockExptTurnResultDAO)(nil).ListTurnResult), varargs...) } // ListTurnResultByItemIDs mocks base method. -func (m *MockExptTurnResultDAO) ListTurnResultByItemIDs(arg0 context.Context, arg1, arg2 int64, arg3 []int64, arg4 entity.Page, arg5 bool, arg6 ...db.Option) ([]*model.ExptTurnResult, int64, error) { +func (m *MockExptTurnResultDAO) ListTurnResultByItemIDs(ctx context.Context, spaceID, exptID int64, itemIDs []int64, page entity.Page, desc bool, opts ...db.Option) ([]*model.ExptTurnResult, int64, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3, arg4, arg5} - for _, a := range arg6 { + varargs := []any{ctx, spaceID, exptID, itemIDs, page, desc} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "ListTurnResultByItemIDs", varargs...) @@ -211,17 +216,17 @@ func (m *MockExptTurnResultDAO) ListTurnResultByItemIDs(arg0 context.Context, ar } // ListTurnResultByItemIDs indicates an expected call of ListTurnResultByItemIDs. -func (mr *MockExptTurnResultDAOMockRecorder) ListTurnResultByItemIDs(arg0, arg1, arg2, arg3, arg4, arg5 interface{}, arg6 ...interface{}) *gomock.Call { +func (mr *MockExptTurnResultDAOMockRecorder) ListTurnResultByItemIDs(ctx, spaceID, exptID, itemIDs, page, desc any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3, arg4, arg5}, arg6...) + varargs := append([]any{ctx, spaceID, exptID, itemIDs, page, desc}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListTurnResultByItemIDs", reflect.TypeOf((*MockExptTurnResultDAO)(nil).ListTurnResultByItemIDs), varargs...) } // MGetItemTurnRunLogs mocks base method. -func (m *MockExptTurnResultDAO) MGetItemTurnRunLogs(arg0 context.Context, arg1, arg2 int64, arg3 []int64, arg4 int64, arg5 ...db.Option) ([]*model.ExptTurnResultRunLog, error) { +func (m *MockExptTurnResultDAO) MGetItemTurnRunLogs(ctx context.Context, exptID, exptRunID int64, itemIDs []int64, spaceID int64, opts ...db.Option) ([]*model.ExptTurnResultRunLog, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3, arg4} - for _, a := range arg5 { + varargs := []any{ctx, exptID, exptRunID, itemIDs, spaceID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "MGetItemTurnRunLogs", varargs...) @@ -231,17 +236,17 @@ func (m *MockExptTurnResultDAO) MGetItemTurnRunLogs(arg0 context.Context, arg1, } // MGetItemTurnRunLogs indicates an expected call of MGetItemTurnRunLogs. -func (mr *MockExptTurnResultDAOMockRecorder) MGetItemTurnRunLogs(arg0, arg1, arg2, arg3, arg4 interface{}, arg5 ...interface{}) *gomock.Call { +func (mr *MockExptTurnResultDAOMockRecorder) MGetItemTurnRunLogs(ctx, exptID, exptRunID, itemIDs, spaceID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3, arg4}, arg5...) + varargs := append([]any{ctx, exptID, exptRunID, itemIDs, spaceID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetItemTurnRunLogs", reflect.TypeOf((*MockExptTurnResultDAO)(nil).MGetItemTurnRunLogs), varargs...) } // SaveTurnResults mocks base method. -func (m *MockExptTurnResultDAO) SaveTurnResults(arg0 context.Context, arg1 []*model.ExptTurnResult, arg2 ...db.Option) error { +func (m *MockExptTurnResultDAO) SaveTurnResults(ctx context.Context, turnResults []*model.ExptTurnResult, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, turnResults} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "SaveTurnResults", varargs...) @@ -250,17 +255,17 @@ func (m *MockExptTurnResultDAO) SaveTurnResults(arg0 context.Context, arg1 []*mo } // SaveTurnResults indicates an expected call of SaveTurnResults. -func (mr *MockExptTurnResultDAOMockRecorder) SaveTurnResults(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockExptTurnResultDAOMockRecorder) SaveTurnResults(ctx, turnResults any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, turnResults}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveTurnResults", reflect.TypeOf((*MockExptTurnResultDAO)(nil).SaveTurnResults), varargs...) } // SaveTurnRunLogs mocks base method. -func (m *MockExptTurnResultDAO) SaveTurnRunLogs(arg0 context.Context, arg1 []*model.ExptTurnResultRunLog, arg2 ...db.Option) error { +func (m *MockExptTurnResultDAO) SaveTurnRunLogs(ctx context.Context, turnResults []*model.ExptTurnResultRunLog, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, turnResults} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "SaveTurnRunLogs", varargs...) @@ -269,17 +274,17 @@ func (m *MockExptTurnResultDAO) SaveTurnRunLogs(arg0 context.Context, arg1 []*mo } // SaveTurnRunLogs indicates an expected call of SaveTurnRunLogs. -func (mr *MockExptTurnResultDAOMockRecorder) SaveTurnRunLogs(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockExptTurnResultDAOMockRecorder) SaveTurnRunLogs(ctx, turnResults any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, turnResults}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveTurnRunLogs", reflect.TypeOf((*MockExptTurnResultDAO)(nil).SaveTurnRunLogs), varargs...) } // ScanTurnResults mocks base method. -func (m *MockExptTurnResultDAO) ScanTurnResults(arg0 context.Context, arg1 int64, arg2 []int32, arg3, arg4, arg5 int64, arg6 ...db.Option) ([]*model.ExptTurnResult, int64, error) { +func (m *MockExptTurnResultDAO) ScanTurnResults(ctx context.Context, exptID int64, status []int32, cursor, limit, spaceID int64, opts ...db.Option) ([]*model.ExptTurnResult, int64, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3, arg4, arg5} - for _, a := range arg6 { + varargs := []any{ctx, exptID, status, cursor, limit, spaceID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "ScanTurnResults", varargs...) @@ -290,17 +295,17 @@ func (m *MockExptTurnResultDAO) ScanTurnResults(arg0 context.Context, arg1 int64 } // ScanTurnResults indicates an expected call of ScanTurnResults. -func (mr *MockExptTurnResultDAOMockRecorder) ScanTurnResults(arg0, arg1, arg2, arg3, arg4, arg5 interface{}, arg6 ...interface{}) *gomock.Call { +func (mr *MockExptTurnResultDAOMockRecorder) ScanTurnResults(ctx, exptID, status, cursor, limit, spaceID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3, arg4, arg5}, arg6...) + varargs := append([]any{ctx, exptID, status, cursor, limit, spaceID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ScanTurnResults", reflect.TypeOf((*MockExptTurnResultDAO)(nil).ScanTurnResults), varargs...) } // ScanTurnRunLogs mocks base method. -func (m *MockExptTurnResultDAO) ScanTurnRunLogs(arg0 context.Context, arg1, arg2, arg3, arg4 int64, arg5 ...db.Option) ([]*model.ExptTurnResultRunLog, int64, error) { +func (m *MockExptTurnResultDAO) ScanTurnRunLogs(ctx context.Context, exptID, cursor, limit, spaceID int64, opts ...db.Option) ([]*model.ExptTurnResultRunLog, int64, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3, arg4} - for _, a := range arg5 { + varargs := []any{ctx, exptID, cursor, limit, spaceID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "ScanTurnRunLogs", varargs...) @@ -311,17 +316,17 @@ func (m *MockExptTurnResultDAO) ScanTurnRunLogs(arg0 context.Context, arg1, arg2 } // ScanTurnRunLogs indicates an expected call of ScanTurnRunLogs. -func (mr *MockExptTurnResultDAOMockRecorder) ScanTurnRunLogs(arg0, arg1, arg2, arg3, arg4 interface{}, arg5 ...interface{}) *gomock.Call { +func (mr *MockExptTurnResultDAOMockRecorder) ScanTurnRunLogs(ctx, exptID, cursor, limit, spaceID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3, arg4}, arg5...) + varargs := append([]any{ctx, exptID, cursor, limit, spaceID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ScanTurnRunLogs", reflect.TypeOf((*MockExptTurnResultDAO)(nil).ScanTurnRunLogs), varargs...) } // UpdateTurnResults mocks base method. -func (m *MockExptTurnResultDAO) UpdateTurnResults(arg0 context.Context, arg1 int64, arg2 []*entity.ItemTurnID, arg3 int64, arg4 map[string]interface{}, arg5 ...db.Option) error { +func (m *MockExptTurnResultDAO) UpdateTurnResults(ctx context.Context, exptID int64, itemTurnIDs []*entity.ItemTurnID, spaceID int64, ufields map[string]any, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3, arg4} - for _, a := range arg5 { + varargs := []any{ctx, exptID, itemTurnIDs, spaceID, ufields} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "UpdateTurnResults", varargs...) @@ -330,17 +335,17 @@ func (m *MockExptTurnResultDAO) UpdateTurnResults(arg0 context.Context, arg1 int } // UpdateTurnResults indicates an expected call of UpdateTurnResults. -func (mr *MockExptTurnResultDAOMockRecorder) UpdateTurnResults(arg0, arg1, arg2, arg3, arg4 interface{}, arg5 ...interface{}) *gomock.Call { +func (mr *MockExptTurnResultDAOMockRecorder) UpdateTurnResults(ctx, exptID, itemTurnIDs, spaceID, ufields any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3, arg4}, arg5...) + varargs := append([]any{ctx, exptID, itemTurnIDs, spaceID, ufields}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTurnResults", reflect.TypeOf((*MockExptTurnResultDAO)(nil).UpdateTurnResults), varargs...) } // UpdateTurnResultsWithItemIDs mocks base method. -func (m *MockExptTurnResultDAO) UpdateTurnResultsWithItemIDs(arg0 context.Context, arg1 int64, arg2 []int64, arg3 int64, arg4 map[string]interface{}, arg5 ...db.Option) error { +func (m *MockExptTurnResultDAO) UpdateTurnResultsWithItemIDs(ctx context.Context, exptID int64, itemIDs []int64, spaceID int64, ufields map[string]any, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3, arg4} - for _, a := range arg5 { + varargs := []any{ctx, exptID, itemIDs, spaceID, ufields} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "UpdateTurnResultsWithItemIDs", varargs...) @@ -349,17 +354,17 @@ func (m *MockExptTurnResultDAO) UpdateTurnResultsWithItemIDs(arg0 context.Contex } // UpdateTurnResultsWithItemIDs indicates an expected call of UpdateTurnResultsWithItemIDs. -func (mr *MockExptTurnResultDAOMockRecorder) UpdateTurnResultsWithItemIDs(arg0, arg1, arg2, arg3, arg4 interface{}, arg5 ...interface{}) *gomock.Call { +func (mr *MockExptTurnResultDAOMockRecorder) UpdateTurnResultsWithItemIDs(ctx, exptID, itemIDs, spaceID, ufields any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3, arg4}, arg5...) + varargs := append([]any{ctx, exptID, itemIDs, spaceID, ufields}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTurnResultsWithItemIDs", reflect.TypeOf((*MockExptTurnResultDAO)(nil).UpdateTurnResultsWithItemIDs), varargs...) } // UpdateTurnRunLogWithItemIDs mocks base method. -func (m *MockExptTurnResultDAO) UpdateTurnRunLogWithItemIDs(arg0 context.Context, arg1, arg2, arg3 int64, arg4 []int64, arg5 map[string]interface{}, arg6 ...db.Option) error { +func (m *MockExptTurnResultDAO) UpdateTurnRunLogWithItemIDs(ctx context.Context, spaceID, exptID, exptRunID int64, itemIDs []int64, ufields map[string]any, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3, arg4, arg5} - for _, a := range arg6 { + varargs := []any{ctx, spaceID, exptID, exptRunID, itemIDs, ufields} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "UpdateTurnRunLogWithItemIDs", varargs...) @@ -368,8 +373,8 @@ func (m *MockExptTurnResultDAO) UpdateTurnRunLogWithItemIDs(arg0 context.Context } // UpdateTurnRunLogWithItemIDs indicates an expected call of UpdateTurnRunLogWithItemIDs. -func (mr *MockExptTurnResultDAOMockRecorder) UpdateTurnRunLogWithItemIDs(arg0, arg1, arg2, arg3, arg4, arg5 interface{}, arg6 ...interface{}) *gomock.Call { +func (mr *MockExptTurnResultDAOMockRecorder) UpdateTurnRunLogWithItemIDs(ctx, spaceID, exptID, exptRunID, itemIDs, ufields any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3, arg4, arg5}, arg6...) + varargs := append([]any{ctx, spaceID, exptID, exptRunID, itemIDs, ufields}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTurnRunLogWithItemIDs", reflect.TypeOf((*MockExptTurnResultDAO)(nil).UpdateTurnRunLogWithItemIDs), varargs...) } diff --git a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_turn_result_filter_key_mapping.go b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_turn_result_filter_key_mapping.go index 7207c8597..541a91ea4 100644 --- a/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_turn_result_filter_key_mapping.go +++ b/backend/modules/evaluation/infra/repo/experiment/mysql/mocks/expt_turn_result_filter_key_mapping.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/experiment/mysql (interfaces: IExptTurnResultFilterKeyMappingDAO) +// +// Generated by this command: +// +// mockgen -destination=mocks/expt_turn_result_filter_key_mapping.go -package mocks . IExptTurnResultFilterKeyMappingDAO +// // Package mocks is a generated GoMock package. package mocks @@ -10,13 +15,14 @@ import ( db "github.com/coze-dev/coze-loop/backend/infra/db" model "github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/experiment/mysql/gorm_gen/model" - "go.uber.org/mock/gomock" + gomock "go.uber.org/mock/gomock" ) // MockIExptTurnResultFilterKeyMappingDAO is a mock of IExptTurnResultFilterKeyMappingDAO interface. type MockIExptTurnResultFilterKeyMappingDAO struct { ctrl *gomock.Controller recorder *MockIExptTurnResultFilterKeyMappingDAOMockRecorder + isgomock struct{} } // MockIExptTurnResultFilterKeyMappingDAOMockRecorder is the mock recorder for MockIExptTurnResultFilterKeyMappingDAO. @@ -37,10 +43,10 @@ func (m *MockIExptTurnResultFilterKeyMappingDAO) EXPECT() *MockIExptTurnResultFi } // Delete mocks base method. -func (m *MockIExptTurnResultFilterKeyMappingDAO) Delete(arg0 context.Context, arg1 *model.ExptTurnResultFilterKeyMapping, arg2 ...db.Option) error { +func (m *MockIExptTurnResultFilterKeyMappingDAO) Delete(ctx context.Context, mapping *model.ExptTurnResultFilterKeyMapping, opts ...db.Option) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, mapping} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Delete", varargs...) @@ -49,17 +55,17 @@ func (m *MockIExptTurnResultFilterKeyMappingDAO) Delete(arg0 context.Context, ar } // Delete indicates an expected call of Delete. -func (mr *MockIExptTurnResultFilterKeyMappingDAOMockRecorder) Delete(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockIExptTurnResultFilterKeyMappingDAOMockRecorder) Delete(ctx, mapping any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]any{ctx, mapping}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockIExptTurnResultFilterKeyMappingDAO)(nil).Delete), varargs...) } // GetByExptID mocks base method. -func (m *MockIExptTurnResultFilterKeyMappingDAO) GetByExptID(arg0 context.Context, arg1, arg2 int64, arg3 ...db.Option) ([]*model.ExptTurnResultFilterKeyMapping, error) { +func (m *MockIExptTurnResultFilterKeyMappingDAO) GetByExptID(ctx context.Context, spaceID, exptID int64, opts ...db.Option) ([]*model.ExptTurnResultFilterKeyMapping, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{ctx, spaceID, exptID} + for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetByExptID", varargs...) @@ -69,22 +75,22 @@ func (m *MockIExptTurnResultFilterKeyMappingDAO) GetByExptID(arg0 context.Contex } // GetByExptID indicates an expected call of GetByExptID. -func (mr *MockIExptTurnResultFilterKeyMappingDAOMockRecorder) GetByExptID(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockIExptTurnResultFilterKeyMappingDAOMockRecorder) GetByExptID(ctx, spaceID, exptID any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + varargs := append([]any{ctx, spaceID, exptID}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByExptID", reflect.TypeOf((*MockIExptTurnResultFilterKeyMappingDAO)(nil).GetByExptID), varargs...) } // Insert mocks base method. -func (m *MockIExptTurnResultFilterKeyMappingDAO) Insert(arg0 context.Context, arg1 []*model.ExptTurnResultFilterKeyMapping) error { +func (m *MockIExptTurnResultFilterKeyMappingDAO) Insert(ctx context.Context, exptTurnResultFilterKeyMappings []*model.ExptTurnResultFilterKeyMapping) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Insert", arg0, arg1) + ret := m.ctrl.Call(m, "Insert", ctx, exptTurnResultFilterKeyMappings) ret0, _ := ret[0].(error) return ret0 } // Insert indicates an expected call of Insert. -func (mr *MockIExptTurnResultFilterKeyMappingDAOMockRecorder) Insert(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIExptTurnResultFilterKeyMappingDAOMockRecorder) Insert(ctx, exptTurnResultFilterKeyMappings any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockIExptTurnResultFilterKeyMappingDAO)(nil).Insert), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockIExptTurnResultFilterKeyMappingDAO)(nil).Insert), ctx, exptTurnResultFilterKeyMappings) } diff --git a/backend/modules/evaluation/infra/repo/experiment/redis/dao/mocks/quota.go b/backend/modules/evaluation/infra/repo/experiment/redis/dao/mocks/quota.go index a7e3a9faf..dd4916691 100644 --- a/backend/modules/evaluation/infra/repo/experiment/redis/dao/mocks/quota.go +++ b/backend/modules/evaluation/infra/repo/experiment/redis/dao/mocks/quota.go @@ -1,22 +1,27 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/experiment/redis/dao (interfaces: IQuotaDAO) +// Source: quota.go +// +// Generated by this command: +// +// mockgen -source=quota.go -destination=mocks/quota.go -package=mocks -mock_names=IQuotaDAO=MockIQuotaDAO +// // Package mocks is a generated GoMock package. package mocks import ( - "context" - "reflect" + context "context" + reflect "reflect" - "go.uber.org/mock/gomock" - - "github.com/coze-dev/coze-loop/backend/modules/evaluation/domain/entity" + entity "github.com/coze-dev/coze-loop/backend/modules/evaluation/domain/entity" + gomock "go.uber.org/mock/gomock" ) // MockIQuotaDAO is a mock of IQuotaDAO interface. type MockIQuotaDAO struct { ctrl *gomock.Controller recorder *MockIQuotaDAOMockRecorder + isgomock struct{} } // MockIQuotaDAOMockRecorder is the mock recorder for MockIQuotaDAO. @@ -37,30 +42,30 @@ func (m *MockIQuotaDAO) EXPECT() *MockIQuotaDAOMockRecorder { } // GetQuotaSpaceExpt mocks base method. -func (m *MockIQuotaDAO) GetQuotaSpaceExpt(arg0 context.Context, arg1 int64) (*entity.QuotaSpaceExpt, error) { +func (m *MockIQuotaDAO) GetQuotaSpaceExpt(ctx context.Context, spaceID int64) (*entity.QuotaSpaceExpt, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetQuotaSpaceExpt", arg0, arg1) + ret := m.ctrl.Call(m, "GetQuotaSpaceExpt", ctx, spaceID) ret0, _ := ret[0].(*entity.QuotaSpaceExpt) ret1, _ := ret[1].(error) return ret0, ret1 } // GetQuotaSpaceExpt indicates an expected call of GetQuotaSpaceExpt. -func (mr *MockIQuotaDAOMockRecorder) GetQuotaSpaceExpt(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIQuotaDAOMockRecorder) GetQuotaSpaceExpt(ctx, spaceID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetQuotaSpaceExpt", reflect.TypeOf((*MockIQuotaDAO)(nil).GetQuotaSpaceExpt), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetQuotaSpaceExpt", reflect.TypeOf((*MockIQuotaDAO)(nil).GetQuotaSpaceExpt), ctx, spaceID) } // SetQuotaSpaceExpt mocks base method. -func (m *MockIQuotaDAO) SetQuotaSpaceExpt(arg0 context.Context, arg1 int64, arg2 *entity.QuotaSpaceExpt) error { +func (m *MockIQuotaDAO) SetQuotaSpaceExpt(ctx context.Context, spaceID int64, qse *entity.QuotaSpaceExpt) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetQuotaSpaceExpt", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "SetQuotaSpaceExpt", ctx, spaceID, qse) ret0, _ := ret[0].(error) return ret0 } // SetQuotaSpaceExpt indicates an expected call of SetQuotaSpaceExpt. -func (mr *MockIQuotaDAOMockRecorder) SetQuotaSpaceExpt(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockIQuotaDAOMockRecorder) SetQuotaSpaceExpt(ctx, spaceID, qse any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetQuotaSpaceExpt", reflect.TypeOf((*MockIQuotaDAO)(nil).SetQuotaSpaceExpt), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetQuotaSpaceExpt", reflect.TypeOf((*MockIQuotaDAO)(nil).SetQuotaSpaceExpt), ctx, spaceID, qse) } diff --git a/backend/modules/foundation/domain/authn/repo/mocks/authn_repo.go b/backend/modules/foundation/domain/authn/repo/mocks/authn_repo.go index 1c8fe1cf5..83dd9ecc2 100644 --- a/backend/modules/foundation/domain/authn/repo/mocks/authn_repo.go +++ b/backend/modules/foundation/domain/authn/repo/mocks/authn_repo.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/coze-dev/coze-loop/backend/modules/foundation/domain/authn/repo (interfaces: IAuthNRepo) +// +// Generated by this command: +// +// mockgen -destination=mocks/authn_repo.go -package=mocks . IAuthNRepo +// // Package mocks is a generated GoMock package. package mocks @@ -16,6 +21,7 @@ import ( type MockIAuthNRepo struct { ctrl *gomock.Controller recorder *MockIAuthNRepoMockRecorder + isgomock struct{} } // MockIAuthNRepoMockRecorder is the mock recorder for MockIAuthNRepo. @@ -36,9 +42,9 @@ func (m *MockIAuthNRepo) EXPECT() *MockIAuthNRepoMockRecorder { } // CreateAPIKey mocks base method. -func (m *MockIAuthNRepo) CreateAPIKey(arg0 context.Context, arg1 *entity.APIKey) (int64, string, error) { +func (m *MockIAuthNRepo) CreateAPIKey(ctx context.Context, apiKeyEntity *entity.APIKey) (int64, string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateAPIKey", arg0, arg1) + ret := m.ctrl.Call(m, "CreateAPIKey", ctx, apiKeyEntity) ret0, _ := ret[0].(int64) ret1, _ := ret[1].(string) ret2, _ := ret[2].(error) @@ -46,94 +52,94 @@ func (m *MockIAuthNRepo) CreateAPIKey(arg0 context.Context, arg1 *entity.APIKey) } // CreateAPIKey indicates an expected call of CreateAPIKey. -func (mr *MockIAuthNRepoMockRecorder) CreateAPIKey(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIAuthNRepoMockRecorder) CreateAPIKey(ctx, apiKeyEntity any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAPIKey", reflect.TypeOf((*MockIAuthNRepo)(nil).CreateAPIKey), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAPIKey", reflect.TypeOf((*MockIAuthNRepo)(nil).CreateAPIKey), ctx, apiKeyEntity) } // DeleteAPIKey mocks base method. -func (m *MockIAuthNRepo) DeleteAPIKey(arg0 context.Context, arg1 int64) error { +func (m *MockIAuthNRepo) DeleteAPIKey(ctx context.Context, apiKeyID int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteAPIKey", arg0, arg1) + ret := m.ctrl.Call(m, "DeleteAPIKey", ctx, apiKeyID) ret0, _ := ret[0].(error) return ret0 } // DeleteAPIKey indicates an expected call of DeleteAPIKey. -func (mr *MockIAuthNRepoMockRecorder) DeleteAPIKey(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIAuthNRepoMockRecorder) DeleteAPIKey(ctx, apiKeyID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAPIKey", reflect.TypeOf((*MockIAuthNRepo)(nil).DeleteAPIKey), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAPIKey", reflect.TypeOf((*MockIAuthNRepo)(nil).DeleteAPIKey), ctx, apiKeyID) } // FlushAPIKeyUsedTime mocks base method. -func (m *MockIAuthNRepo) FlushAPIKeyUsedTime(arg0 context.Context, arg1 int64) error { +func (m *MockIAuthNRepo) FlushAPIKeyUsedTime(ctx context.Context, apiKeyID int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "FlushAPIKeyUsedTime", arg0, arg1) + ret := m.ctrl.Call(m, "FlushAPIKeyUsedTime", ctx, apiKeyID) ret0, _ := ret[0].(error) return ret0 } // FlushAPIKeyUsedTime indicates an expected call of FlushAPIKeyUsedTime. -func (mr *MockIAuthNRepoMockRecorder) FlushAPIKeyUsedTime(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIAuthNRepoMockRecorder) FlushAPIKeyUsedTime(ctx, apiKeyID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FlushAPIKeyUsedTime", reflect.TypeOf((*MockIAuthNRepo)(nil).FlushAPIKeyUsedTime), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FlushAPIKeyUsedTime", reflect.TypeOf((*MockIAuthNRepo)(nil).FlushAPIKeyUsedTime), ctx, apiKeyID) } // GetAPIKeyByIDs mocks base method. -func (m *MockIAuthNRepo) GetAPIKeyByIDs(arg0 context.Context, arg1 []int64) ([]*entity.APIKey, error) { +func (m *MockIAuthNRepo) GetAPIKeyByIDs(ctx context.Context, apiKeyIDs []int64) ([]*entity.APIKey, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAPIKeyByIDs", arg0, arg1) + ret := m.ctrl.Call(m, "GetAPIKeyByIDs", ctx, apiKeyIDs) ret0, _ := ret[0].([]*entity.APIKey) ret1, _ := ret[1].(error) return ret0, ret1 } // GetAPIKeyByIDs indicates an expected call of GetAPIKeyByIDs. -func (mr *MockIAuthNRepoMockRecorder) GetAPIKeyByIDs(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIAuthNRepoMockRecorder) GetAPIKeyByIDs(ctx, apiKeyIDs any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPIKeyByIDs", reflect.TypeOf((*MockIAuthNRepo)(nil).GetAPIKeyByIDs), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPIKeyByIDs", reflect.TypeOf((*MockIAuthNRepo)(nil).GetAPIKeyByIDs), ctx, apiKeyIDs) } // GetAPIKeyByKey mocks base method. -func (m *MockIAuthNRepo) GetAPIKeyByKey(arg0 context.Context, arg1 string) (*entity.APIKey, error) { +func (m *MockIAuthNRepo) GetAPIKeyByKey(ctx context.Context, key string) (*entity.APIKey, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAPIKeyByKey", arg0, arg1) + ret := m.ctrl.Call(m, "GetAPIKeyByKey", ctx, key) ret0, _ := ret[0].(*entity.APIKey) ret1, _ := ret[1].(error) return ret0, ret1 } // GetAPIKeyByKey indicates an expected call of GetAPIKeyByKey. -func (mr *MockIAuthNRepoMockRecorder) GetAPIKeyByKey(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIAuthNRepoMockRecorder) GetAPIKeyByKey(ctx, key any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPIKeyByKey", reflect.TypeOf((*MockIAuthNRepo)(nil).GetAPIKeyByKey), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPIKeyByKey", reflect.TypeOf((*MockIAuthNRepo)(nil).GetAPIKeyByKey), ctx, key) } // GetAPIKeyByUser mocks base method. -func (m *MockIAuthNRepo) GetAPIKeyByUser(arg0 context.Context, arg1 int64, arg2, arg3 int) ([]*entity.APIKey, error) { +func (m *MockIAuthNRepo) GetAPIKeyByUser(ctx context.Context, userID int64, pageNumber, pageSize int) ([]*entity.APIKey, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAPIKeyByUser", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "GetAPIKeyByUser", ctx, userID, pageNumber, pageSize) ret0, _ := ret[0].([]*entity.APIKey) ret1, _ := ret[1].(error) return ret0, ret1 } // GetAPIKeyByUser indicates an expected call of GetAPIKeyByUser. -func (mr *MockIAuthNRepoMockRecorder) GetAPIKeyByUser(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockIAuthNRepoMockRecorder) GetAPIKeyByUser(ctx, userID, pageNumber, pageSize any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPIKeyByUser", reflect.TypeOf((*MockIAuthNRepo)(nil).GetAPIKeyByUser), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPIKeyByUser", reflect.TypeOf((*MockIAuthNRepo)(nil).GetAPIKeyByUser), ctx, userID, pageNumber, pageSize) } // UpdateAPIKeyName mocks base method. -func (m *MockIAuthNRepo) UpdateAPIKeyName(arg0 context.Context, arg1 int64, arg2 string) error { +func (m *MockIAuthNRepo) UpdateAPIKeyName(ctx context.Context, apiKeyID int64, name string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateAPIKeyName", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "UpdateAPIKeyName", ctx, apiKeyID, name) ret0, _ := ret[0].(error) return ret0 } // UpdateAPIKeyName indicates an expected call of UpdateAPIKeyName. -func (mr *MockIAuthNRepoMockRecorder) UpdateAPIKeyName(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockIAuthNRepoMockRecorder) UpdateAPIKeyName(ctx, apiKeyID, name any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAPIKeyName", reflect.TypeOf((*MockIAuthNRepo)(nil).UpdateAPIKeyName), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAPIKeyName", reflect.TypeOf((*MockIAuthNRepo)(nil).UpdateAPIKeyName), ctx, apiKeyID, name) } diff --git a/backend/modules/foundation/domain/user/service/mocks/user_service.go b/backend/modules/foundation/domain/user/service/mocks/user_service.go index df25c9f34..e9eb2d800 100644 --- a/backend/modules/foundation/domain/user/service/mocks/user_service.go +++ b/backend/modules/foundation/domain/user/service/mocks/user_service.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/coze-dev/coze-loop/backend/modules/foundation/domain/user/service (interfaces: IUserService) +// Source: interface.go +// +// Generated by this command: +// +// mockgen -source=interface.go -destination=mocks/user_service.go -package=mocks -mock_names=IUserService=MockIUserService +// // Package mocks is a generated GoMock package. package mocks @@ -17,6 +22,7 @@ import ( type MockIUserService struct { ctrl *gomock.Controller recorder *MockIUserServiceMockRecorder + isgomock struct{} } // MockIUserServiceMockRecorder is the mock recorder for MockIUserService. @@ -37,54 +43,54 @@ func (m *MockIUserService) EXPECT() *MockIUserServiceMockRecorder { } // Create mocks base method. -func (m *MockIUserService) Create(arg0 context.Context, arg1 *service.CreateUserRequest) (*entity.User, error) { +func (m *MockIUserService) Create(ctx context.Context, req *service.CreateUserRequest) (*entity.User, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Create", arg0, arg1) + ret := m.ctrl.Call(m, "Create", ctx, req) ret0, _ := ret[0].(*entity.User) ret1, _ := ret[1].(error) return ret0, ret1 } // Create indicates an expected call of Create. -func (mr *MockIUserServiceMockRecorder) Create(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIUserServiceMockRecorder) Create(ctx, req any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockIUserService)(nil).Create), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockIUserService)(nil).Create), ctx, req) } // CreateSession mocks base method. -func (m *MockIUserService) CreateSession(arg0 context.Context, arg1 int64) (string, error) { +func (m *MockIUserService) CreateSession(ctx context.Context, userID int64) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateSession", arg0, arg1) + ret := m.ctrl.Call(m, "CreateSession", ctx, userID) ret0, _ := ret[0].(string) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateSession indicates an expected call of CreateSession. -func (mr *MockIUserServiceMockRecorder) CreateSession(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIUserServiceMockRecorder) CreateSession(ctx, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSession", reflect.TypeOf((*MockIUserService)(nil).CreateSession), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSession", reflect.TypeOf((*MockIUserService)(nil).CreateSession), ctx, userID) } // GetUserProfile mocks base method. -func (m *MockIUserService) GetUserProfile(arg0 context.Context, arg1 int64) (*entity.User, error) { +func (m *MockIUserService) GetUserProfile(ctx context.Context, userID int64) (*entity.User, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetUserProfile", arg0, arg1) + ret := m.ctrl.Call(m, "GetUserProfile", ctx, userID) ret0, _ := ret[0].(*entity.User) ret1, _ := ret[1].(error) return ret0, ret1 } // GetUserProfile indicates an expected call of GetUserProfile. -func (mr *MockIUserServiceMockRecorder) GetUserProfile(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIUserServiceMockRecorder) GetUserProfile(ctx, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserProfile", reflect.TypeOf((*MockIUserService)(nil).GetUserProfile), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserProfile", reflect.TypeOf((*MockIUserService)(nil).GetUserProfile), ctx, userID) } // GetUserSpaceList mocks base method. -func (m *MockIUserService) GetUserSpaceList(arg0 context.Context, arg1 *service.ListUserSpaceRequest) ([]*entity.Space, int32, error) { +func (m *MockIUserService) GetUserSpaceList(ctx context.Context, req *service.ListUserSpaceRequest) ([]*entity.Space, int32, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetUserSpaceList", arg0, arg1) + ret := m.ctrl.Call(m, "GetUserSpaceList", ctx, req) ret0, _ := ret[0].([]*entity.Space) ret1, _ := ret[1].(int32) ret2, _ := ret[2].(error) @@ -92,80 +98,80 @@ func (m *MockIUserService) GetUserSpaceList(arg0 context.Context, arg1 *service. } // GetUserSpaceList indicates an expected call of GetUserSpaceList. -func (mr *MockIUserServiceMockRecorder) GetUserSpaceList(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIUserServiceMockRecorder) GetUserSpaceList(ctx, req any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserSpaceList", reflect.TypeOf((*MockIUserService)(nil).GetUserSpaceList), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserSpaceList", reflect.TypeOf((*MockIUserService)(nil).GetUserSpaceList), ctx, req) } // Login mocks base method. -func (m *MockIUserService) Login(arg0 context.Context, arg1, arg2 string) (*entity.User, error) { +func (m *MockIUserService) Login(ctx context.Context, email, password string) (*entity.User, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Login", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "Login", ctx, email, password) ret0, _ := ret[0].(*entity.User) ret1, _ := ret[1].(error) return ret0, ret1 } // Login indicates an expected call of Login. -func (mr *MockIUserServiceMockRecorder) Login(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockIUserServiceMockRecorder) Login(ctx, email, password any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Login", reflect.TypeOf((*MockIUserService)(nil).Login), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Login", reflect.TypeOf((*MockIUserService)(nil).Login), ctx, email, password) } // Logout mocks base method. -func (m *MockIUserService) Logout(arg0 context.Context, arg1 int64) error { +func (m *MockIUserService) Logout(ctx context.Context, userID int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Logout", arg0, arg1) + ret := m.ctrl.Call(m, "Logout", ctx, userID) ret0, _ := ret[0].(error) return ret0 } // Logout indicates an expected call of Logout. -func (mr *MockIUserServiceMockRecorder) Logout(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIUserServiceMockRecorder) Logout(ctx, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logout", reflect.TypeOf((*MockIUserService)(nil).Logout), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logout", reflect.TypeOf((*MockIUserService)(nil).Logout), ctx, userID) } // MGetUserProfiles mocks base method. -func (m *MockIUserService) MGetUserProfiles(arg0 context.Context, arg1 []int64) ([]*entity.User, error) { +func (m *MockIUserService) MGetUserProfiles(ctx context.Context, userIDs []int64) ([]*entity.User, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MGetUserProfiles", arg0, arg1) + ret := m.ctrl.Call(m, "MGetUserProfiles", ctx, userIDs) ret0, _ := ret[0].([]*entity.User) ret1, _ := ret[1].(error) return ret0, ret1 } // MGetUserProfiles indicates an expected call of MGetUserProfiles. -func (mr *MockIUserServiceMockRecorder) MGetUserProfiles(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIUserServiceMockRecorder) MGetUserProfiles(ctx, userIDs any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetUserProfiles", reflect.TypeOf((*MockIUserService)(nil).MGetUserProfiles), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetUserProfiles", reflect.TypeOf((*MockIUserService)(nil).MGetUserProfiles), ctx, userIDs) } // ResetPassword mocks base method. -func (m *MockIUserService) ResetPassword(arg0 context.Context, arg1, arg2 string) error { +func (m *MockIUserService) ResetPassword(ctx context.Context, email, password string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ResetPassword", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "ResetPassword", ctx, email, password) ret0, _ := ret[0].(error) return ret0 } // ResetPassword indicates an expected call of ResetPassword. -func (mr *MockIUserServiceMockRecorder) ResetPassword(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockIUserServiceMockRecorder) ResetPassword(ctx, email, password any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetPassword", reflect.TypeOf((*MockIUserService)(nil).ResetPassword), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetPassword", reflect.TypeOf((*MockIUserService)(nil).ResetPassword), ctx, email, password) } // UpdateProfile mocks base method. -func (m *MockIUserService) UpdateProfile(arg0 context.Context, arg1 *service.UpdateProfileRequest) (*entity.User, error) { +func (m *MockIUserService) UpdateProfile(ctx context.Context, req *service.UpdateProfileRequest) (*entity.User, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateProfile", arg0, arg1) + ret := m.ctrl.Call(m, "UpdateProfile", ctx, req) ret0, _ := ret[0].(*entity.User) ret1, _ := ret[1].(error) return ret0, ret1 } // UpdateProfile indicates an expected call of UpdateProfile. -func (mr *MockIUserServiceMockRecorder) UpdateProfile(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockIUserServiceMockRecorder) UpdateProfile(ctx, req any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProfile", reflect.TypeOf((*MockIUserService)(nil).UpdateProfile), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProfile", reflect.TypeOf((*MockIUserService)(nil).UpdateProfile), ctx, req) } diff --git a/backend/modules/observability/domain/component/metrics/mocks/metrics.go b/backend/modules/observability/domain/component/metrics/mocks/metrics.go index 3bf3770b8..b6ed0a1ce 100644 --- a/backend/modules/observability/domain/component/metrics/mocks/metrics.go +++ b/backend/modules/observability/domain/component/metrics/mocks/metrics.go @@ -71,7 +71,7 @@ func (m *MockITraceMetrics) EmitTraceOapi(method string, workspaceId int64, plat } // EmitTraceOapi indicates an expected call of EmitTraceOapi. -func (mr *MockITraceMetricsMockRecorder) EmitTraceOapi(method, workspaceId, platformType, spanType, spanSize, errorCode, start, isError any) *gomock.Call { +func (mr *MockITraceMetricsMockRecorder) EmitTraceOapi(method, workspaceId, platformType, spanListType, spanSize, errorCode, start, isError any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EmitTraceOapi", reflect.TypeOf((*MockITraceMetrics)(nil).EmitTraceOapi), method, workspaceId, platformType, spanType, spanSize, errorCode, start, isError) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EmitTraceOapi", reflect.TypeOf((*MockITraceMetrics)(nil).EmitTraceOapi), method, workspaceId, platformType, spanListType, spanSize, errorCode, start, isError) } diff --git a/backend/modules/observability/domain/component/rpc/mocks/dataset_provider_mock.go b/backend/modules/observability/domain/component/rpc/mocks/dataset_provider_mock.go index efebc6c00..e59e59f4d 100644 --- a/backend/modules/observability/domain/component/rpc/mocks/dataset_provider_mock.go +++ b/backend/modules/observability/domain/component/rpc/mocks/dataset_provider_mock.go @@ -21,6 +21,7 @@ import ( type MockIDatasetProvider struct { ctrl *gomock.Controller recorder *MockIDatasetProviderMockRecorder + isgomock struct{} } // MockIDatasetProviderMockRecorder is the mock recorder for MockIDatasetProvider. @@ -41,9 +42,9 @@ func (m *MockIDatasetProvider) EXPECT() *MockIDatasetProviderMockRecorder { } // AddDatasetItems mocks base method. -func (m *MockIDatasetProvider) AddDatasetItems(arg0 context.Context, arg1 int64, arg2 entity.DatasetCategory, arg3 []*entity.DatasetItem) ([]*entity.DatasetItem, []entity.ItemErrorGroup, error) { +func (m *MockIDatasetProvider) AddDatasetItems(ctx context.Context, datasetID int64, category entity.DatasetCategory, items []*entity.DatasetItem) ([]*entity.DatasetItem, []entity.ItemErrorGroup, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddDatasetItems", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "AddDatasetItems", ctx, datasetID, category, items) ret0, _ := ret[0].([]*entity.DatasetItem) ret1, _ := ret[1].([]entity.ItemErrorGroup) ret2, _ := ret[2].(error) @@ -51,53 +52,53 @@ func (m *MockIDatasetProvider) AddDatasetItems(arg0 context.Context, arg1 int64, } // AddDatasetItems indicates an expected call of AddDatasetItems. -func (mr *MockIDatasetProviderMockRecorder) AddDatasetItems(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockIDatasetProviderMockRecorder) AddDatasetItems(ctx, datasetID, category, items any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddDatasetItems", reflect.TypeOf((*MockIDatasetProvider)(nil).AddDatasetItems), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddDatasetItems", reflect.TypeOf((*MockIDatasetProvider)(nil).AddDatasetItems), ctx, datasetID, category, items) } // ClearDatasetItems mocks base method. -func (m *MockIDatasetProvider) ClearDatasetItems(arg0 context.Context, arg1, arg2 int64, arg3 entity.DatasetCategory) error { +func (m *MockIDatasetProvider) ClearDatasetItems(ctx context.Context, workspaceID, datasetID int64, category entity.DatasetCategory) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ClearDatasetItems", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "ClearDatasetItems", ctx, workspaceID, datasetID, category) ret0, _ := ret[0].(error) return ret0 } // ClearDatasetItems indicates an expected call of ClearDatasetItems. -func (mr *MockIDatasetProviderMockRecorder) ClearDatasetItems(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockIDatasetProviderMockRecorder) ClearDatasetItems(ctx, workspaceID, datasetID, category any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearDatasetItems", reflect.TypeOf((*MockIDatasetProvider)(nil).ClearDatasetItems), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearDatasetItems", reflect.TypeOf((*MockIDatasetProvider)(nil).ClearDatasetItems), ctx, workspaceID, datasetID, category) } // CreateDataset mocks base method. -func (m *MockIDatasetProvider) CreateDataset(arg0 context.Context, arg1 *entity.Dataset) (int64, error) { +func (m *MockIDatasetProvider) CreateDataset(ctx context.Context, dataset *entity.Dataset) (int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateDataset", arg0, arg1) + ret := m.ctrl.Call(m, "CreateDataset", ctx, dataset) ret0, _ := ret[0].(int64) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateDataset indicates an expected call of CreateDataset. -func (mr *MockIDatasetProviderMockRecorder) CreateDataset(arg0, arg1 any) *gomock.Call { +func (mr *MockIDatasetProviderMockRecorder) CreateDataset(ctx, dataset any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateDataset", reflect.TypeOf((*MockIDatasetProvider)(nil).CreateDataset), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateDataset", reflect.TypeOf((*MockIDatasetProvider)(nil).CreateDataset), ctx, dataset) } // GetDataset mocks base method. -func (m *MockIDatasetProvider) GetDataset(arg0 context.Context, arg1, arg2 int64, arg3 entity.DatasetCategory) (*entity.Dataset, error) { +func (m *MockIDatasetProvider) GetDataset(ctx context.Context, workspaceID, datasetID int64, category entity.DatasetCategory) (*entity.Dataset, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetDataset", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "GetDataset", ctx, workspaceID, datasetID, category) ret0, _ := ret[0].(*entity.Dataset) ret1, _ := ret[1].(error) return ret0, ret1 } // GetDataset indicates an expected call of GetDataset. -func (mr *MockIDatasetProviderMockRecorder) GetDataset(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockIDatasetProviderMockRecorder) GetDataset(ctx, workspaceID, datasetID, category any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDataset", reflect.TypeOf((*MockIDatasetProvider)(nil).GetDataset), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDataset", reflect.TypeOf((*MockIDatasetProvider)(nil).GetDataset), ctx, workspaceID, datasetID, category) } // SearchDatasets mocks base method. @@ -116,23 +117,23 @@ func (mr *MockIDatasetProviderMockRecorder) SearchDatasets(arg0, arg1, arg2, arg } // UpdateDatasetSchema mocks base method. -func (m *MockIDatasetProvider) UpdateDatasetSchema(arg0 context.Context, arg1 *entity.Dataset) error { +func (m *MockIDatasetProvider) UpdateDatasetSchema(ctx context.Context, dataset *entity.Dataset) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateDatasetSchema", arg0, arg1) + ret := m.ctrl.Call(m, "UpdateDatasetSchema", ctx, dataset) ret0, _ := ret[0].(error) return ret0 } // UpdateDatasetSchema indicates an expected call of UpdateDatasetSchema. -func (mr *MockIDatasetProviderMockRecorder) UpdateDatasetSchema(arg0, arg1 any) *gomock.Call { +func (mr *MockIDatasetProviderMockRecorder) UpdateDatasetSchema(ctx, dataset any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateDatasetSchema", reflect.TypeOf((*MockIDatasetProvider)(nil).UpdateDatasetSchema), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateDatasetSchema", reflect.TypeOf((*MockIDatasetProvider)(nil).UpdateDatasetSchema), ctx, dataset) } // ValidateDatasetItems mocks base method. -func (m *MockIDatasetProvider) ValidateDatasetItems(arg0 context.Context, arg1 *entity.Dataset, arg2 []*entity.DatasetItem, arg3 *bool) ([]*entity.DatasetItem, []entity.ItemErrorGroup, error) { +func (m *MockIDatasetProvider) ValidateDatasetItems(ctx context.Context, dataset *entity.Dataset, items []*entity.DatasetItem, ignoreCurrentCount *bool) ([]*entity.DatasetItem, []entity.ItemErrorGroup, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ValidateDatasetItems", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "ValidateDatasetItems", ctx, dataset, items, ignoreCurrentCount) ret0, _ := ret[0].([]*entity.DatasetItem) ret1, _ := ret[1].([]entity.ItemErrorGroup) ret2, _ := ret[2].(error) @@ -140,7 +141,7 @@ func (m *MockIDatasetProvider) ValidateDatasetItems(arg0 context.Context, arg1 * } // ValidateDatasetItems indicates an expected call of ValidateDatasetItems. -func (mr *MockIDatasetProviderMockRecorder) ValidateDatasetItems(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockIDatasetProviderMockRecorder) ValidateDatasetItems(ctx, dataset, items, ignoreCurrentCount any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateDatasetItems", reflect.TypeOf((*MockIDatasetProvider)(nil).ValidateDatasetItems), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateDatasetItems", reflect.TypeOf((*MockIDatasetProvider)(nil).ValidateDatasetItems), ctx, dataset, items, ignoreCurrentCount) } diff --git a/backend/modules/observability/domain/trace/entity/collector/confmap/mocks/provider.go b/backend/modules/observability/domain/trace/entity/collector/confmap/mocks/provider.go new file mode 100644 index 000000000..d51b72206 --- /dev/null +++ b/backend/modules/observability/domain/trace/entity/collector/confmap/mocks/provider.go @@ -0,0 +1,85 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/collector/confmap (interfaces: Provider) +// +// Generated by this command: +// +// mockgen -destination=mocks/provider.go -package=mocks . Provider +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + confmap "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/collector/confmap" + gomock "go.uber.org/mock/gomock" +) + +// MockProvider is a mock of Provider interface. +type MockProvider struct { + ctrl *gomock.Controller + recorder *MockProviderMockRecorder + isgomock struct{} +} + +// MockProviderMockRecorder is the mock recorder for MockProvider. +type MockProviderMockRecorder struct { + mock *MockProvider +} + +// NewMockProvider creates a new mock instance. +func NewMockProvider(ctrl *gomock.Controller) *MockProvider { + mock := &MockProvider{ctrl: ctrl} + mock.recorder = &MockProviderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockProvider) EXPECT() *MockProviderMockRecorder { + return m.recorder +} + +// Retrieve mocks base method. +func (m *MockProvider) Retrieve(ctx context.Context, path string) (*confmap.Retrieved, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Retrieve", ctx, path) + ret0, _ := ret[0].(*confmap.Retrieved) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Retrieve indicates an expected call of Retrieve. +func (mr *MockProviderMockRecorder) Retrieve(ctx, path any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retrieve", reflect.TypeOf((*MockProvider)(nil).Retrieve), ctx, path) +} + +// Scheme mocks base method. +func (m *MockProvider) Scheme() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Scheme") + ret0, _ := ret[0].(string) + return ret0 +} + +// Scheme indicates an expected call of Scheme. +func (mr *MockProviderMockRecorder) Scheme() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scheme", reflect.TypeOf((*MockProvider)(nil).Scheme)) +} + +// Shutdown mocks base method. +func (m *MockProvider) Shutdown(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Shutdown", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Shutdown indicates an expected call of Shutdown. +func (mr *MockProviderMockRecorder) Shutdown(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Shutdown", reflect.TypeOf((*MockProvider)(nil).Shutdown), ctx) +} diff --git a/backend/modules/observability/domain/trace/entity/collector/mocks/conf_provider.go b/backend/modules/observability/domain/trace/entity/collector/mocks/conf_provider.go new file mode 100644 index 000000000..afcfa7256 --- /dev/null +++ b/backend/modules/observability/domain/trace/entity/collector/mocks/conf_provider.go @@ -0,0 +1,57 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/collector (interfaces: ConfigProvider) +// +// Generated by this command: +// +// mockgen -destination=mocks/conf_provider.go -package=mocks . ConfigProvider +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + collector "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/collector" + gomock "go.uber.org/mock/gomock" +) + +// MockConfigProvider is a mock of ConfigProvider interface. +type MockConfigProvider struct { + ctrl *gomock.Controller + recorder *MockConfigProviderMockRecorder + isgomock struct{} +} + +// MockConfigProviderMockRecorder is the mock recorder for MockConfigProvider. +type MockConfigProviderMockRecorder struct { + mock *MockConfigProvider +} + +// NewMockConfigProvider creates a new mock instance. +func NewMockConfigProvider(ctrl *gomock.Controller) *MockConfigProvider { + mock := &MockConfigProvider{ctrl: ctrl} + mock.recorder = &MockConfigProviderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConfigProvider) EXPECT() *MockConfigProviderMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockConfigProvider) Get(ctx context.Context, factories collector.Factories) (*collector.Config, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", ctx, factories) + ret0, _ := ret[0].(*collector.Config) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockConfigProviderMockRecorder) Get(ctx, factories any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockConfigProvider)(nil).Get), ctx, factories) +} diff --git a/backend/modules/observability/domain/trace/service/mocks/trace_export_service_mock.go b/backend/modules/observability/domain/trace/service/mocks/trace_export_service_mock.go index eb5b009e4..79734a876 100644 --- a/backend/modules/observability/domain/trace/service/mocks/trace_export_service_mock.go +++ b/backend/modules/observability/domain/trace/service/mocks/trace_export_service_mock.go @@ -21,6 +21,7 @@ import ( type MockITraceExportService struct { ctrl *gomock.Controller recorder *MockITraceExportServiceMockRecorder + isgomock struct{} } // MockITraceExportServiceMockRecorder is the mock recorder for MockITraceExportService. @@ -41,31 +42,31 @@ func (m *MockITraceExportService) EXPECT() *MockITraceExportServiceMockRecorder } // ExportTracesToDataset mocks base method. -func (m *MockITraceExportService) ExportTracesToDataset(arg0 context.Context, arg1 *service.ExportTracesToDatasetRequest) (*service.ExportTracesToDatasetResponse, error) { +func (m *MockITraceExportService) ExportTracesToDataset(ctx context.Context, req *service.ExportTracesToDatasetRequest) (*service.ExportTracesToDatasetResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ExportTracesToDataset", arg0, arg1) + ret := m.ctrl.Call(m, "ExportTracesToDataset", ctx, req) ret0, _ := ret[0].(*service.ExportTracesToDatasetResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // ExportTracesToDataset indicates an expected call of ExportTracesToDataset. -func (mr *MockITraceExportServiceMockRecorder) ExportTracesToDataset(arg0, arg1 any) *gomock.Call { +func (mr *MockITraceExportServiceMockRecorder) ExportTracesToDataset(ctx, req any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExportTracesToDataset", reflect.TypeOf((*MockITraceExportService)(nil).ExportTracesToDataset), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExportTracesToDataset", reflect.TypeOf((*MockITraceExportService)(nil).ExportTracesToDataset), ctx, req) } // PreviewExportTracesToDataset mocks base method. -func (m *MockITraceExportService) PreviewExportTracesToDataset(arg0 context.Context, arg1 *service.ExportTracesToDatasetRequest) (*service.PreviewExportTracesToDatasetResponse, error) { +func (m *MockITraceExportService) PreviewExportTracesToDataset(ctx context.Context, req *service.ExportTracesToDatasetRequest) (*service.PreviewExportTracesToDatasetResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PreviewExportTracesToDataset", arg0, arg1) + ret := m.ctrl.Call(m, "PreviewExportTracesToDataset", ctx, req) ret0, _ := ret[0].(*service.PreviewExportTracesToDatasetResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // PreviewExportTracesToDataset indicates an expected call of PreviewExportTracesToDataset. -func (mr *MockITraceExportServiceMockRecorder) PreviewExportTracesToDataset(arg0, arg1 any) *gomock.Call { +func (mr *MockITraceExportServiceMockRecorder) PreviewExportTracesToDataset(ctx, req any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PreviewExportTracesToDataset", reflect.TypeOf((*MockITraceExportService)(nil).PreviewExportTracesToDataset), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PreviewExportTracesToDataset", reflect.TypeOf((*MockITraceExportService)(nil).PreviewExportTracesToDataset), ctx, req) } diff --git a/backend/modules/observability/infra/rpc/dataset/dataset_test.go b/backend/modules/observability/infra/rpc/dataset/dataset_test.go index 19d4fc60c..1b37b70d7 100644 --- a/backend/modules/observability/infra/rpc/dataset/dataset_test.go +++ b/backend/modules/observability/infra/rpc/dataset/dataset_test.go @@ -8,17 +8,16 @@ import ( "fmt" "testing" + "github.com/bytedance/gg/gptr" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" - "github.com/bytedance/gg/gptr" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/data/dataset" dataset_domain "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/data/domain/dataset" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/rpc/dataset/mocks" ) -//go:generate mockgen -source=dataset.go -destination=mocks/mock_dataset.go //go:generate mockgen -package=mocks -destination=mocks/mock_datasetservice_client.go github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/data/datasetservice Client // Test helper functions diff --git a/backend/modules/observability/infra/rpc/dataset/mocks/mock_datasetservice_client.go b/backend/modules/observability/infra/rpc/dataset/mocks/mock_datasetservice_client.go index e36f65e95..59bd1819e 100644 --- a/backend/modules/observability/infra/rpc/dataset/mocks/mock_datasetservice_client.go +++ b/backend/modules/observability/infra/rpc/dataset/mocks/mock_datasetservice_client.go @@ -3,7 +3,7 @@ // // Generated by this command: // -// mockgen -package=mocks -destination=modules/observability/infra/rpc/dataset/mocks/mock_datasetservice_client.go github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/data/datasetservice Client +// mockgen -package=mocks -destination=mocks/mock_datasetservice_client.go github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/data/datasetservice Client // // Package mocks is a generated GoMock package. @@ -22,6 +22,7 @@ import ( type MockClient struct { ctrl *gomock.Controller recorder *MockClientMockRecorder + isgomock struct{} } // MockClientMockRecorder is the mock recorder for MockClient. @@ -42,10 +43,10 @@ func (m *MockClient) EXPECT() *MockClientMockRecorder { } // BatchCreateDatasetItems mocks base method. -func (m *MockClient) BatchCreateDatasetItems(arg0 context.Context, arg1 *dataset.BatchCreateDatasetItemsRequest, arg2 ...callopt.Option) (*dataset.BatchCreateDatasetItemsResponse, error) { +func (m *MockClient) BatchCreateDatasetItems(ctx context.Context, req *dataset.BatchCreateDatasetItemsRequest, callOptions ...callopt.Option) (*dataset.BatchCreateDatasetItemsResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchCreateDatasetItems", varargs...) @@ -55,17 +56,17 @@ func (m *MockClient) BatchCreateDatasetItems(arg0 context.Context, arg1 *dataset } // BatchCreateDatasetItems indicates an expected call of BatchCreateDatasetItems. -func (mr *MockClientMockRecorder) BatchCreateDatasetItems(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) BatchCreateDatasetItems(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchCreateDatasetItems", reflect.TypeOf((*MockClient)(nil).BatchCreateDatasetItems), varargs...) } // BatchDeleteDatasetItems mocks base method. -func (m *MockClient) BatchDeleteDatasetItems(arg0 context.Context, arg1 *dataset.BatchDeleteDatasetItemsRequest, arg2 ...callopt.Option) (*dataset.BatchDeleteDatasetItemsResponse, error) { +func (m *MockClient) BatchDeleteDatasetItems(ctx context.Context, req *dataset.BatchDeleteDatasetItemsRequest, callOptions ...callopt.Option) (*dataset.BatchDeleteDatasetItemsResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchDeleteDatasetItems", varargs...) @@ -75,17 +76,17 @@ func (m *MockClient) BatchDeleteDatasetItems(arg0 context.Context, arg1 *dataset } // BatchDeleteDatasetItems indicates an expected call of BatchDeleteDatasetItems. -func (mr *MockClientMockRecorder) BatchDeleteDatasetItems(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) BatchDeleteDatasetItems(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchDeleteDatasetItems", reflect.TypeOf((*MockClient)(nil).BatchDeleteDatasetItems), varargs...) } // BatchGetDatasetItems mocks base method. -func (m *MockClient) BatchGetDatasetItems(arg0 context.Context, arg1 *dataset.BatchGetDatasetItemsRequest, arg2 ...callopt.Option) (*dataset.BatchGetDatasetItemsResponse, error) { +func (m *MockClient) BatchGetDatasetItems(ctx context.Context, req *dataset.BatchGetDatasetItemsRequest, callOptions ...callopt.Option) (*dataset.BatchGetDatasetItemsResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchGetDatasetItems", varargs...) @@ -95,17 +96,17 @@ func (m *MockClient) BatchGetDatasetItems(arg0 context.Context, arg1 *dataset.Ba } // BatchGetDatasetItems indicates an expected call of BatchGetDatasetItems. -func (mr *MockClientMockRecorder) BatchGetDatasetItems(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) BatchGetDatasetItems(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetDatasetItems", reflect.TypeOf((*MockClient)(nil).BatchGetDatasetItems), varargs...) } // BatchGetDatasetItemsByVersion mocks base method. -func (m *MockClient) BatchGetDatasetItemsByVersion(arg0 context.Context, arg1 *dataset.BatchGetDatasetItemsByVersionRequest, arg2 ...callopt.Option) (*dataset.BatchGetDatasetItemsByVersionResponse, error) { +func (m *MockClient) BatchGetDatasetItemsByVersion(ctx context.Context, req *dataset.BatchGetDatasetItemsByVersionRequest, callOptions ...callopt.Option) (*dataset.BatchGetDatasetItemsByVersionResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchGetDatasetItemsByVersion", varargs...) @@ -115,17 +116,17 @@ func (m *MockClient) BatchGetDatasetItemsByVersion(arg0 context.Context, arg1 *d } // BatchGetDatasetItemsByVersion indicates an expected call of BatchGetDatasetItemsByVersion. -func (mr *MockClientMockRecorder) BatchGetDatasetItemsByVersion(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) BatchGetDatasetItemsByVersion(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetDatasetItemsByVersion", reflect.TypeOf((*MockClient)(nil).BatchGetDatasetItemsByVersion), varargs...) } // BatchGetDatasetVersions mocks base method. -func (m *MockClient) BatchGetDatasetVersions(arg0 context.Context, arg1 *dataset.BatchGetDatasetVersionsRequest, arg2 ...callopt.Option) (*dataset.BatchGetDatasetVersionsResponse, error) { +func (m *MockClient) BatchGetDatasetVersions(ctx context.Context, req *dataset.BatchGetDatasetVersionsRequest, callOptions ...callopt.Option) (*dataset.BatchGetDatasetVersionsResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchGetDatasetVersions", varargs...) @@ -135,17 +136,17 @@ func (m *MockClient) BatchGetDatasetVersions(arg0 context.Context, arg1 *dataset } // BatchGetDatasetVersions indicates an expected call of BatchGetDatasetVersions. -func (mr *MockClientMockRecorder) BatchGetDatasetVersions(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) BatchGetDatasetVersions(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetDatasetVersions", reflect.TypeOf((*MockClient)(nil).BatchGetDatasetVersions), varargs...) } // BatchGetDatasets mocks base method. -func (m *MockClient) BatchGetDatasets(arg0 context.Context, arg1 *dataset.BatchGetDatasetsRequest, arg2 ...callopt.Option) (*dataset.BatchGetDatasetsResponse, error) { +func (m *MockClient) BatchGetDatasets(ctx context.Context, req *dataset.BatchGetDatasetsRequest, callOptions ...callopt.Option) (*dataset.BatchGetDatasetsResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "BatchGetDatasets", varargs...) @@ -155,17 +156,17 @@ func (m *MockClient) BatchGetDatasets(arg0 context.Context, arg1 *dataset.BatchG } // BatchGetDatasets indicates an expected call of BatchGetDatasets. -func (mr *MockClientMockRecorder) BatchGetDatasets(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) BatchGetDatasets(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetDatasets", reflect.TypeOf((*MockClient)(nil).BatchGetDatasets), varargs...) } // ClearDatasetItem mocks base method. -func (m *MockClient) ClearDatasetItem(arg0 context.Context, arg1 *dataset.ClearDatasetItemRequest, arg2 ...callopt.Option) (*dataset.ClearDatasetItemResponse, error) { +func (m *MockClient) ClearDatasetItem(ctx context.Context, req *dataset.ClearDatasetItemRequest, callOptions ...callopt.Option) (*dataset.ClearDatasetItemResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "ClearDatasetItem", varargs...) @@ -175,17 +176,17 @@ func (m *MockClient) ClearDatasetItem(arg0 context.Context, arg1 *dataset.ClearD } // ClearDatasetItem indicates an expected call of ClearDatasetItem. -func (mr *MockClientMockRecorder) ClearDatasetItem(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) ClearDatasetItem(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearDatasetItem", reflect.TypeOf((*MockClient)(nil).ClearDatasetItem), varargs...) } // CreateDataset mocks base method. -func (m *MockClient) CreateDataset(arg0 context.Context, arg1 *dataset.CreateDatasetRequest, arg2 ...callopt.Option) (*dataset.CreateDatasetResponse, error) { +func (m *MockClient) CreateDataset(ctx context.Context, req *dataset.CreateDatasetRequest, callOptions ...callopt.Option) (*dataset.CreateDatasetResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "CreateDataset", varargs...) @@ -195,17 +196,17 @@ func (m *MockClient) CreateDataset(arg0 context.Context, arg1 *dataset.CreateDat } // CreateDataset indicates an expected call of CreateDataset. -func (mr *MockClientMockRecorder) CreateDataset(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) CreateDataset(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateDataset", reflect.TypeOf((*MockClient)(nil).CreateDataset), varargs...) } // CreateDatasetVersion mocks base method. -func (m *MockClient) CreateDatasetVersion(arg0 context.Context, arg1 *dataset.CreateDatasetVersionRequest, arg2 ...callopt.Option) (*dataset.CreateDatasetVersionResponse, error) { +func (m *MockClient) CreateDatasetVersion(ctx context.Context, req *dataset.CreateDatasetVersionRequest, callOptions ...callopt.Option) (*dataset.CreateDatasetVersionResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "CreateDatasetVersion", varargs...) @@ -215,17 +216,17 @@ func (m *MockClient) CreateDatasetVersion(arg0 context.Context, arg1 *dataset.Cr } // CreateDatasetVersion indicates an expected call of CreateDatasetVersion. -func (mr *MockClientMockRecorder) CreateDatasetVersion(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) CreateDatasetVersion(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateDatasetVersion", reflect.TypeOf((*MockClient)(nil).CreateDatasetVersion), varargs...) } // DeleteDataset mocks base method. -func (m *MockClient) DeleteDataset(arg0 context.Context, arg1 *dataset.DeleteDatasetRequest, arg2 ...callopt.Option) (*dataset.DeleteDatasetResponse, error) { +func (m *MockClient) DeleteDataset(ctx context.Context, req *dataset.DeleteDatasetRequest, callOptions ...callopt.Option) (*dataset.DeleteDatasetResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "DeleteDataset", varargs...) @@ -235,17 +236,17 @@ func (m *MockClient) DeleteDataset(arg0 context.Context, arg1 *dataset.DeleteDat } // DeleteDataset indicates an expected call of DeleteDataset. -func (mr *MockClientMockRecorder) DeleteDataset(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) DeleteDataset(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteDataset", reflect.TypeOf((*MockClient)(nil).DeleteDataset), varargs...) } // DeleteDatasetItem mocks base method. -func (m *MockClient) DeleteDatasetItem(arg0 context.Context, arg1 *dataset.DeleteDatasetItemRequest, arg2 ...callopt.Option) (*dataset.DeleteDatasetItemResponse, error) { +func (m *MockClient) DeleteDatasetItem(ctx context.Context, req *dataset.DeleteDatasetItemRequest, callOptions ...callopt.Option) (*dataset.DeleteDatasetItemResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "DeleteDatasetItem", varargs...) @@ -255,17 +256,17 @@ func (m *MockClient) DeleteDatasetItem(arg0 context.Context, arg1 *dataset.Delet } // DeleteDatasetItem indicates an expected call of DeleteDatasetItem. -func (mr *MockClientMockRecorder) DeleteDatasetItem(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) DeleteDatasetItem(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteDatasetItem", reflect.TypeOf((*MockClient)(nil).DeleteDatasetItem), varargs...) } // GetDataset mocks base method. -func (m *MockClient) GetDataset(arg0 context.Context, arg1 *dataset.GetDatasetRequest, arg2 ...callopt.Option) (*dataset.GetDatasetResponse, error) { +func (m *MockClient) GetDataset(ctx context.Context, req *dataset.GetDatasetRequest, callOptions ...callopt.Option) (*dataset.GetDatasetResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetDataset", varargs...) @@ -275,17 +276,17 @@ func (m *MockClient) GetDataset(arg0 context.Context, arg1 *dataset.GetDatasetRe } // GetDataset indicates an expected call of GetDataset. -func (mr *MockClientMockRecorder) GetDataset(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) GetDataset(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDataset", reflect.TypeOf((*MockClient)(nil).GetDataset), varargs...) } // GetDatasetIOJob mocks base method. -func (m *MockClient) GetDatasetIOJob(arg0 context.Context, arg1 *dataset.GetDatasetIOJobRequest, arg2 ...callopt.Option) (*dataset.GetDatasetIOJobResponse, error) { +func (m *MockClient) GetDatasetIOJob(ctx context.Context, req *dataset.GetDatasetIOJobRequest, callOptions ...callopt.Option) (*dataset.GetDatasetIOJobResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetDatasetIOJob", varargs...) @@ -295,17 +296,17 @@ func (m *MockClient) GetDatasetIOJob(arg0 context.Context, arg1 *dataset.GetData } // GetDatasetIOJob indicates an expected call of GetDatasetIOJob. -func (mr *MockClientMockRecorder) GetDatasetIOJob(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) GetDatasetIOJob(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDatasetIOJob", reflect.TypeOf((*MockClient)(nil).GetDatasetIOJob), varargs...) } // GetDatasetItem mocks base method. -func (m *MockClient) GetDatasetItem(arg0 context.Context, arg1 *dataset.GetDatasetItemRequest, arg2 ...callopt.Option) (*dataset.GetDatasetItemResponse, error) { +func (m *MockClient) GetDatasetItem(ctx context.Context, req *dataset.GetDatasetItemRequest, callOptions ...callopt.Option) (*dataset.GetDatasetItemResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetDatasetItem", varargs...) @@ -315,17 +316,17 @@ func (m *MockClient) GetDatasetItem(arg0 context.Context, arg1 *dataset.GetDatas } // GetDatasetItem indicates an expected call of GetDatasetItem. -func (mr *MockClientMockRecorder) GetDatasetItem(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) GetDatasetItem(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDatasetItem", reflect.TypeOf((*MockClient)(nil).GetDatasetItem), varargs...) } // GetDatasetSchema mocks base method. -func (m *MockClient) GetDatasetSchema(arg0 context.Context, arg1 *dataset.GetDatasetSchemaRequest, arg2 ...callopt.Option) (*dataset.GetDatasetSchemaResponse, error) { +func (m *MockClient) GetDatasetSchema(ctx context.Context, req *dataset.GetDatasetSchemaRequest, callOptions ...callopt.Option) (*dataset.GetDatasetSchemaResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetDatasetSchema", varargs...) @@ -335,17 +336,17 @@ func (m *MockClient) GetDatasetSchema(arg0 context.Context, arg1 *dataset.GetDat } // GetDatasetSchema indicates an expected call of GetDatasetSchema. -func (mr *MockClientMockRecorder) GetDatasetSchema(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) GetDatasetSchema(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDatasetSchema", reflect.TypeOf((*MockClient)(nil).GetDatasetSchema), varargs...) } // GetDatasetVersion mocks base method. -func (m *MockClient) GetDatasetVersion(arg0 context.Context, arg1 *dataset.GetDatasetVersionRequest, arg2 ...callopt.Option) (*dataset.GetDatasetVersionResponse, error) { +func (m *MockClient) GetDatasetVersion(ctx context.Context, req *dataset.GetDatasetVersionRequest, callOptions ...callopt.Option) (*dataset.GetDatasetVersionResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetDatasetVersion", varargs...) @@ -355,17 +356,17 @@ func (m *MockClient) GetDatasetVersion(arg0 context.Context, arg1 *dataset.GetDa } // GetDatasetVersion indicates an expected call of GetDatasetVersion. -func (mr *MockClientMockRecorder) GetDatasetVersion(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) GetDatasetVersion(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDatasetVersion", reflect.TypeOf((*MockClient)(nil).GetDatasetVersion), varargs...) } // ImportDataset mocks base method. -func (m *MockClient) ImportDataset(arg0 context.Context, arg1 *dataset.ImportDatasetRequest, arg2 ...callopt.Option) (*dataset.ImportDatasetResponse, error) { +func (m *MockClient) ImportDataset(ctx context.Context, req *dataset.ImportDatasetRequest, callOptions ...callopt.Option) (*dataset.ImportDatasetResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "ImportDataset", varargs...) @@ -375,17 +376,17 @@ func (m *MockClient) ImportDataset(arg0 context.Context, arg1 *dataset.ImportDat } // ImportDataset indicates an expected call of ImportDataset. -func (mr *MockClientMockRecorder) ImportDataset(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) ImportDataset(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ImportDataset", reflect.TypeOf((*MockClient)(nil).ImportDataset), varargs...) } // ListDatasetIOJobs mocks base method. -func (m *MockClient) ListDatasetIOJobs(arg0 context.Context, arg1 *dataset.ListDatasetIOJobsRequest, arg2 ...callopt.Option) (*dataset.ListDatasetIOJobsResponse, error) { +func (m *MockClient) ListDatasetIOJobs(ctx context.Context, req *dataset.ListDatasetIOJobsRequest, callOptions ...callopt.Option) (*dataset.ListDatasetIOJobsResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "ListDatasetIOJobs", varargs...) @@ -395,17 +396,17 @@ func (m *MockClient) ListDatasetIOJobs(arg0 context.Context, arg1 *dataset.ListD } // ListDatasetIOJobs indicates an expected call of ListDatasetIOJobs. -func (mr *MockClientMockRecorder) ListDatasetIOJobs(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) ListDatasetIOJobs(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListDatasetIOJobs", reflect.TypeOf((*MockClient)(nil).ListDatasetIOJobs), varargs...) } // ListDatasetItems mocks base method. -func (m *MockClient) ListDatasetItems(arg0 context.Context, arg1 *dataset.ListDatasetItemsRequest, arg2 ...callopt.Option) (*dataset.ListDatasetItemsResponse, error) { +func (m *MockClient) ListDatasetItems(ctx context.Context, req *dataset.ListDatasetItemsRequest, callOptions ...callopt.Option) (*dataset.ListDatasetItemsResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "ListDatasetItems", varargs...) @@ -415,17 +416,17 @@ func (m *MockClient) ListDatasetItems(arg0 context.Context, arg1 *dataset.ListDa } // ListDatasetItems indicates an expected call of ListDatasetItems. -func (mr *MockClientMockRecorder) ListDatasetItems(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) ListDatasetItems(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListDatasetItems", reflect.TypeOf((*MockClient)(nil).ListDatasetItems), varargs...) } // ListDatasetItemsByVersion mocks base method. -func (m *MockClient) ListDatasetItemsByVersion(arg0 context.Context, arg1 *dataset.ListDatasetItemsByVersionRequest, arg2 ...callopt.Option) (*dataset.ListDatasetItemsByVersionResponse, error) { +func (m *MockClient) ListDatasetItemsByVersion(ctx context.Context, req *dataset.ListDatasetItemsByVersionRequest, callOptions ...callopt.Option) (*dataset.ListDatasetItemsByVersionResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "ListDatasetItemsByVersion", varargs...) @@ -435,17 +436,17 @@ func (m *MockClient) ListDatasetItemsByVersion(arg0 context.Context, arg1 *datas } // ListDatasetItemsByVersion indicates an expected call of ListDatasetItemsByVersion. -func (mr *MockClientMockRecorder) ListDatasetItemsByVersion(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) ListDatasetItemsByVersion(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListDatasetItemsByVersion", reflect.TypeOf((*MockClient)(nil).ListDatasetItemsByVersion), varargs...) } // ListDatasetVersions mocks base method. -func (m *MockClient) ListDatasetVersions(arg0 context.Context, arg1 *dataset.ListDatasetVersionsRequest, arg2 ...callopt.Option) (*dataset.ListDatasetVersionsResponse, error) { +func (m *MockClient) ListDatasetVersions(ctx context.Context, req *dataset.ListDatasetVersionsRequest, callOptions ...callopt.Option) (*dataset.ListDatasetVersionsResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "ListDatasetVersions", varargs...) @@ -455,17 +456,17 @@ func (m *MockClient) ListDatasetVersions(arg0 context.Context, arg1 *dataset.Lis } // ListDatasetVersions indicates an expected call of ListDatasetVersions. -func (mr *MockClientMockRecorder) ListDatasetVersions(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) ListDatasetVersions(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListDatasetVersions", reflect.TypeOf((*MockClient)(nil).ListDatasetVersions), varargs...) } // ListDatasets mocks base method. -func (m *MockClient) ListDatasets(arg0 context.Context, arg1 *dataset.ListDatasetsRequest, arg2 ...callopt.Option) (*dataset.ListDatasetsResponse, error) { +func (m *MockClient) ListDatasets(ctx context.Context, req *dataset.ListDatasetsRequest, callOptions ...callopt.Option) (*dataset.ListDatasetsResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "ListDatasets", varargs...) @@ -475,17 +476,17 @@ func (m *MockClient) ListDatasets(arg0 context.Context, arg1 *dataset.ListDatase } // ListDatasets indicates an expected call of ListDatasets. -func (mr *MockClientMockRecorder) ListDatasets(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) ListDatasets(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListDatasets", reflect.TypeOf((*MockClient)(nil).ListDatasets), varargs...) } // UpdateDataset mocks base method. -func (m *MockClient) UpdateDataset(arg0 context.Context, arg1 *dataset.UpdateDatasetRequest, arg2 ...callopt.Option) (*dataset.UpdateDatasetResponse, error) { +func (m *MockClient) UpdateDataset(ctx context.Context, req *dataset.UpdateDatasetRequest, callOptions ...callopt.Option) (*dataset.UpdateDatasetResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "UpdateDataset", varargs...) @@ -495,17 +496,17 @@ func (m *MockClient) UpdateDataset(arg0 context.Context, arg1 *dataset.UpdateDat } // UpdateDataset indicates an expected call of UpdateDataset. -func (mr *MockClientMockRecorder) UpdateDataset(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) UpdateDataset(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateDataset", reflect.TypeOf((*MockClient)(nil).UpdateDataset), varargs...) } // UpdateDatasetItem mocks base method. -func (m *MockClient) UpdateDatasetItem(arg0 context.Context, arg1 *dataset.UpdateDatasetItemRequest, arg2 ...callopt.Option) (*dataset.UpdateDatasetItemResponse, error) { +func (m *MockClient) UpdateDatasetItem(ctx context.Context, req *dataset.UpdateDatasetItemRequest, callOptions ...callopt.Option) (*dataset.UpdateDatasetItemResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "UpdateDatasetItem", varargs...) @@ -515,17 +516,17 @@ func (m *MockClient) UpdateDatasetItem(arg0 context.Context, arg1 *dataset.Updat } // UpdateDatasetItem indicates an expected call of UpdateDatasetItem. -func (mr *MockClientMockRecorder) UpdateDatasetItem(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) UpdateDatasetItem(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateDatasetItem", reflect.TypeOf((*MockClient)(nil).UpdateDatasetItem), varargs...) } // UpdateDatasetSchema mocks base method. -func (m *MockClient) UpdateDatasetSchema(arg0 context.Context, arg1 *dataset.UpdateDatasetSchemaRequest, arg2 ...callopt.Option) (*dataset.UpdateDatasetSchemaResponse, error) { +func (m *MockClient) UpdateDatasetSchema(ctx context.Context, req *dataset.UpdateDatasetSchemaRequest, callOptions ...callopt.Option) (*dataset.UpdateDatasetSchemaResponse, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "UpdateDatasetSchema", varargs...) @@ -535,17 +536,17 @@ func (m *MockClient) UpdateDatasetSchema(arg0 context.Context, arg1 *dataset.Upd } // UpdateDatasetSchema indicates an expected call of UpdateDatasetSchema. -func (mr *MockClientMockRecorder) UpdateDatasetSchema(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) UpdateDatasetSchema(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateDatasetSchema", reflect.TypeOf((*MockClient)(nil).UpdateDatasetSchema), varargs...) } // ValidateDatasetItems mocks base method. -func (m *MockClient) ValidateDatasetItems(arg0 context.Context, arg1 *dataset.ValidateDatasetItemsReq, arg2 ...callopt.Option) (*dataset.ValidateDatasetItemsResp, error) { +func (m *MockClient) ValidateDatasetItems(ctx context.Context, req *dataset.ValidateDatasetItemsReq, callOptions ...callopt.Option) (*dataset.ValidateDatasetItemsResp, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, req} + for _, a := range callOptions { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "ValidateDatasetItems", varargs...) @@ -555,8 +556,8 @@ func (m *MockClient) ValidateDatasetItems(arg0 context.Context, arg1 *dataset.Va } // ValidateDatasetItems indicates an expected call of ValidateDatasetItems. -func (mr *MockClientMockRecorder) ValidateDatasetItems(arg0, arg1 any, arg2 ...any) *gomock.Call { +func (mr *MockClientMockRecorder) ValidateDatasetItems(ctx, req any, callOptions ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, req}, callOptions...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateDatasetItems", reflect.TypeOf((*MockClient)(nil).ValidateDatasetItems), varargs...) } diff --git a/backend/modules/observability/infra/rpc/evaluationset/evaluation_set_test.go b/backend/modules/observability/infra/rpc/evaluationset/evaluation_set_test.go index 611c07e9b..342030819 100644 --- a/backend/modules/observability/infra/rpc/evaluationset/evaluation_set_test.go +++ b/backend/modules/observability/infra/rpc/evaluationset/evaluation_set_test.go @@ -8,18 +8,17 @@ import ( "fmt" "testing" + "github.com/bytedance/gg/gptr" + "github.com/samber/lo" "github.com/stretchr/testify/assert" - "github.com/bytedance/gg/gptr" "github.com/coze-dev/coze-loop/backend/infra/middleware/session" dataset_domain "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/data/domain/dataset" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation/domain/common" eval_set_domain "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation/domain/eval_set" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity" - "github.com/samber/lo" ) -//go:generate mockgen -source=evaluation_set.go -destination=mocks/mock_evaluation_set.go //go:generate mockgen -package=mocks -destination=mocks/mock_evaluationsetservice_client.go github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation/evaluationsetservice Client // Test helper functions diff --git a/backend/modules/observability/infra/rpc/evaluationset/mocks/mock_evaluationsetservice_client.go b/backend/modules/observability/infra/rpc/evaluationset/mocks/mock_evaluationsetservice_client.go new file mode 100644 index 000000000..9dc9e7daa --- /dev/null +++ b/backend/modules/observability/infra/rpc/evaluationset/mocks/mock_evaluationsetservice_client.go @@ -0,0 +1,363 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation/evaluationsetservice (interfaces: Client) +// +// Generated by this command: +// +// mockgen -package=mocks -destination=mocks/mock_evaluationsetservice_client.go github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation/evaluationsetservice Client +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + callopt "github.com/cloudwego/kitex/client/callopt" + eval_set "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation/eval_set" + gomock "go.uber.org/mock/gomock" +) + +// MockClient is a mock of Client interface. +type MockClient struct { + ctrl *gomock.Controller + recorder *MockClientMockRecorder + isgomock struct{} +} + +// MockClientMockRecorder is the mock recorder for MockClient. +type MockClientMockRecorder struct { + mock *MockClient +} + +// NewMockClient creates a new mock instance. +func NewMockClient(ctrl *gomock.Controller) *MockClient { + mock := &MockClient{ctrl: ctrl} + mock.recorder = &MockClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClient) EXPECT() *MockClientMockRecorder { + return m.recorder +} + +// BatchCreateEvaluationSetItems mocks base method. +func (m *MockClient) BatchCreateEvaluationSetItems(ctx context.Context, req *eval_set.BatchCreateEvaluationSetItemsRequest, callOptions ...callopt.Option) (*eval_set.BatchCreateEvaluationSetItemsResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, req} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "BatchCreateEvaluationSetItems", varargs...) + ret0, _ := ret[0].(*eval_set.BatchCreateEvaluationSetItemsResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BatchCreateEvaluationSetItems indicates an expected call of BatchCreateEvaluationSetItems. +func (mr *MockClientMockRecorder) BatchCreateEvaluationSetItems(ctx, req any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, req}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchCreateEvaluationSetItems", reflect.TypeOf((*MockClient)(nil).BatchCreateEvaluationSetItems), varargs...) +} + +// BatchDeleteEvaluationSetItems mocks base method. +func (m *MockClient) BatchDeleteEvaluationSetItems(ctx context.Context, req *eval_set.BatchDeleteEvaluationSetItemsRequest, callOptions ...callopt.Option) (*eval_set.BatchDeleteEvaluationSetItemsResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, req} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "BatchDeleteEvaluationSetItems", varargs...) + ret0, _ := ret[0].(*eval_set.BatchDeleteEvaluationSetItemsResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BatchDeleteEvaluationSetItems indicates an expected call of BatchDeleteEvaluationSetItems. +func (mr *MockClientMockRecorder) BatchDeleteEvaluationSetItems(ctx, req any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, req}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchDeleteEvaluationSetItems", reflect.TypeOf((*MockClient)(nil).BatchDeleteEvaluationSetItems), varargs...) +} + +// BatchGetEvaluationSetItems mocks base method. +func (m *MockClient) BatchGetEvaluationSetItems(ctx context.Context, req *eval_set.BatchGetEvaluationSetItemsRequest, callOptions ...callopt.Option) (*eval_set.BatchGetEvaluationSetItemsResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, req} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "BatchGetEvaluationSetItems", varargs...) + ret0, _ := ret[0].(*eval_set.BatchGetEvaluationSetItemsResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BatchGetEvaluationSetItems indicates an expected call of BatchGetEvaluationSetItems. +func (mr *MockClientMockRecorder) BatchGetEvaluationSetItems(ctx, req any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, req}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetEvaluationSetItems", reflect.TypeOf((*MockClient)(nil).BatchGetEvaluationSetItems), varargs...) +} + +// BatchGetEvaluationSetVersions mocks base method. +func (m *MockClient) BatchGetEvaluationSetVersions(ctx context.Context, req *eval_set.BatchGetEvaluationSetVersionsRequest, callOptions ...callopt.Option) (*eval_set.BatchGetEvaluationSetVersionsResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, req} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "BatchGetEvaluationSetVersions", varargs...) + ret0, _ := ret[0].(*eval_set.BatchGetEvaluationSetVersionsResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BatchGetEvaluationSetVersions indicates an expected call of BatchGetEvaluationSetVersions. +func (mr *MockClientMockRecorder) BatchGetEvaluationSetVersions(ctx, req any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, req}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetEvaluationSetVersions", reflect.TypeOf((*MockClient)(nil).BatchGetEvaluationSetVersions), varargs...) +} + +// ClearEvaluationSetDraftItem mocks base method. +func (m *MockClient) ClearEvaluationSetDraftItem(ctx context.Context, req *eval_set.ClearEvaluationSetDraftItemRequest, callOptions ...callopt.Option) (*eval_set.ClearEvaluationSetDraftItemResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, req} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ClearEvaluationSetDraftItem", varargs...) + ret0, _ := ret[0].(*eval_set.ClearEvaluationSetDraftItemResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ClearEvaluationSetDraftItem indicates an expected call of ClearEvaluationSetDraftItem. +func (mr *MockClientMockRecorder) ClearEvaluationSetDraftItem(ctx, req any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, req}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearEvaluationSetDraftItem", reflect.TypeOf((*MockClient)(nil).ClearEvaluationSetDraftItem), varargs...) +} + +// CreateEvaluationSet mocks base method. +func (m *MockClient) CreateEvaluationSet(ctx context.Context, req *eval_set.CreateEvaluationSetRequest, callOptions ...callopt.Option) (*eval_set.CreateEvaluationSetResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, req} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "CreateEvaluationSet", varargs...) + ret0, _ := ret[0].(*eval_set.CreateEvaluationSetResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateEvaluationSet indicates an expected call of CreateEvaluationSet. +func (mr *MockClientMockRecorder) CreateEvaluationSet(ctx, req any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, req}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateEvaluationSet", reflect.TypeOf((*MockClient)(nil).CreateEvaluationSet), varargs...) +} + +// CreateEvaluationSetVersion mocks base method. +func (m *MockClient) CreateEvaluationSetVersion(ctx context.Context, req *eval_set.CreateEvaluationSetVersionRequest, callOptions ...callopt.Option) (*eval_set.CreateEvaluationSetVersionResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, req} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "CreateEvaluationSetVersion", varargs...) + ret0, _ := ret[0].(*eval_set.CreateEvaluationSetVersionResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateEvaluationSetVersion indicates an expected call of CreateEvaluationSetVersion. +func (mr *MockClientMockRecorder) CreateEvaluationSetVersion(ctx, req any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, req}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateEvaluationSetVersion", reflect.TypeOf((*MockClient)(nil).CreateEvaluationSetVersion), varargs...) +} + +// DeleteEvaluationSet mocks base method. +func (m *MockClient) DeleteEvaluationSet(ctx context.Context, req *eval_set.DeleteEvaluationSetRequest, callOptions ...callopt.Option) (*eval_set.DeleteEvaluationSetResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, req} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "DeleteEvaluationSet", varargs...) + ret0, _ := ret[0].(*eval_set.DeleteEvaluationSetResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteEvaluationSet indicates an expected call of DeleteEvaluationSet. +func (mr *MockClientMockRecorder) DeleteEvaluationSet(ctx, req any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, req}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteEvaluationSet", reflect.TypeOf((*MockClient)(nil).DeleteEvaluationSet), varargs...) +} + +// GetEvaluationSet mocks base method. +func (m *MockClient) GetEvaluationSet(ctx context.Context, req *eval_set.GetEvaluationSetRequest, callOptions ...callopt.Option) (*eval_set.GetEvaluationSetResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, req} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetEvaluationSet", varargs...) + ret0, _ := ret[0].(*eval_set.GetEvaluationSetResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetEvaluationSet indicates an expected call of GetEvaluationSet. +func (mr *MockClientMockRecorder) GetEvaluationSet(ctx, req any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, req}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEvaluationSet", reflect.TypeOf((*MockClient)(nil).GetEvaluationSet), varargs...) +} + +// GetEvaluationSetVersion mocks base method. +func (m *MockClient) GetEvaluationSetVersion(ctx context.Context, req *eval_set.GetEvaluationSetVersionRequest, callOptions ...callopt.Option) (*eval_set.GetEvaluationSetVersionResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, req} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetEvaluationSetVersion", varargs...) + ret0, _ := ret[0].(*eval_set.GetEvaluationSetVersionResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetEvaluationSetVersion indicates an expected call of GetEvaluationSetVersion. +func (mr *MockClientMockRecorder) GetEvaluationSetVersion(ctx, req any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, req}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEvaluationSetVersion", reflect.TypeOf((*MockClient)(nil).GetEvaluationSetVersion), varargs...) +} + +// ListEvaluationSetItems mocks base method. +func (m *MockClient) ListEvaluationSetItems(ctx context.Context, req *eval_set.ListEvaluationSetItemsRequest, callOptions ...callopt.Option) (*eval_set.ListEvaluationSetItemsResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, req} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ListEvaluationSetItems", varargs...) + ret0, _ := ret[0].(*eval_set.ListEvaluationSetItemsResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListEvaluationSetItems indicates an expected call of ListEvaluationSetItems. +func (mr *MockClientMockRecorder) ListEvaluationSetItems(ctx, req any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, req}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListEvaluationSetItems", reflect.TypeOf((*MockClient)(nil).ListEvaluationSetItems), varargs...) +} + +// ListEvaluationSetVersions mocks base method. +func (m *MockClient) ListEvaluationSetVersions(ctx context.Context, req *eval_set.ListEvaluationSetVersionsRequest, callOptions ...callopt.Option) (*eval_set.ListEvaluationSetVersionsResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, req} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ListEvaluationSetVersions", varargs...) + ret0, _ := ret[0].(*eval_set.ListEvaluationSetVersionsResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListEvaluationSetVersions indicates an expected call of ListEvaluationSetVersions. +func (mr *MockClientMockRecorder) ListEvaluationSetVersions(ctx, req any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, req}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListEvaluationSetVersions", reflect.TypeOf((*MockClient)(nil).ListEvaluationSetVersions), varargs...) +} + +// ListEvaluationSets mocks base method. +func (m *MockClient) ListEvaluationSets(ctx context.Context, req *eval_set.ListEvaluationSetsRequest, callOptions ...callopt.Option) (*eval_set.ListEvaluationSetsResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, req} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ListEvaluationSets", varargs...) + ret0, _ := ret[0].(*eval_set.ListEvaluationSetsResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListEvaluationSets indicates an expected call of ListEvaluationSets. +func (mr *MockClientMockRecorder) ListEvaluationSets(ctx, req any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, req}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListEvaluationSets", reflect.TypeOf((*MockClient)(nil).ListEvaluationSets), varargs...) +} + +// UpdateEvaluationSet mocks base method. +func (m *MockClient) UpdateEvaluationSet(ctx context.Context, req *eval_set.UpdateEvaluationSetRequest, callOptions ...callopt.Option) (*eval_set.UpdateEvaluationSetResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, req} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "UpdateEvaluationSet", varargs...) + ret0, _ := ret[0].(*eval_set.UpdateEvaluationSetResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateEvaluationSet indicates an expected call of UpdateEvaluationSet. +func (mr *MockClientMockRecorder) UpdateEvaluationSet(ctx, req any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, req}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEvaluationSet", reflect.TypeOf((*MockClient)(nil).UpdateEvaluationSet), varargs...) +} + +// UpdateEvaluationSetItem mocks base method. +func (m *MockClient) UpdateEvaluationSetItem(ctx context.Context, req *eval_set.UpdateEvaluationSetItemRequest, callOptions ...callopt.Option) (*eval_set.UpdateEvaluationSetItemResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, req} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "UpdateEvaluationSetItem", varargs...) + ret0, _ := ret[0].(*eval_set.UpdateEvaluationSetItemResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateEvaluationSetItem indicates an expected call of UpdateEvaluationSetItem. +func (mr *MockClientMockRecorder) UpdateEvaluationSetItem(ctx, req any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, req}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEvaluationSetItem", reflect.TypeOf((*MockClient)(nil).UpdateEvaluationSetItem), varargs...) +} + +// UpdateEvaluationSetSchema mocks base method. +func (m *MockClient) UpdateEvaluationSetSchema(ctx context.Context, req *eval_set.UpdateEvaluationSetSchemaRequest, callOptions ...callopt.Option) (*eval_set.UpdateEvaluationSetSchemaResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, req} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "UpdateEvaluationSetSchema", varargs...) + ret0, _ := ret[0].(*eval_set.UpdateEvaluationSetSchemaResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateEvaluationSetSchema indicates an expected call of UpdateEvaluationSetSchema. +func (mr *MockClientMockRecorder) UpdateEvaluationSetSchema(ctx, req any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, req}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEvaluationSetSchema", reflect.TypeOf((*MockClient)(nil).UpdateEvaluationSetSchema), varargs...) +} diff --git a/backend/modules/prompt/application/convertor/debug_context_test.go b/backend/modules/prompt/application/convertor/debug_context_test.go index caa68ef38..f3ccfe773 100644 --- a/backend/modules/prompt/application/convertor/debug_context_test.go +++ b/backend/modules/prompt/application/convertor/debug_context_test.go @@ -123,6 +123,7 @@ func mockDebugContextCases() []debugContextTestCase { PromptDetail: &prompt.PromptDetail{ PromptTemplate: &prompt.PromptTemplate{ TemplateType: ptr.Of(prompt.TemplateTypeNormal), + HasSnippet: ptr.Of(false), Messages: []*prompt.Message{ { Role: ptr.Of(prompt.RoleSystem), @@ -152,6 +153,7 @@ func mockDebugContextCases() []debugContextTestCase { PromptDetail: &entity.PromptDetail{ PromptTemplate: &entity.PromptTemplate{ TemplateType: entity.TemplateTypeNormal, + HasSnippets: false, Messages: []*entity.Message{ { Role: entity.RoleSystem, @@ -205,6 +207,7 @@ func mockDebugContextCases() []debugContextTestCase { PromptDetail: &prompt.PromptDetail{ PromptTemplate: &prompt.PromptTemplate{ TemplateType: ptr.Of(prompt.TemplateTypeNormal), + HasSnippet: ptr.Of(false), Messages: []*prompt.Message{ { Role: ptr.Of(prompt.RoleSystem), @@ -247,6 +250,7 @@ func mockDebugContextCases() []debugContextTestCase { PromptDetail: &entity.PromptDetail{ PromptTemplate: &entity.PromptTemplate{ TemplateType: entity.TemplateTypeNormal, + HasSnippets: false, Messages: []*entity.Message{ { Role: entity.RoleSystem, diff --git a/backend/modules/prompt/application/convertor/openapi_test.go b/backend/modules/prompt/application/convertor/openapi_test.go index 2b3234ea1..16e1e33b6 100755 --- a/backend/modules/prompt/application/convertor/openapi_test.go +++ b/backend/modules/prompt/application/convertor/openapi_test.go @@ -370,7 +370,7 @@ func mockOpenAPIPromptCases() []openAPIPromptTestCase { PromptKey: ptr.Of("test_prompt"), Version: ptr.Of("1.0.0"), PromptTemplate: &openapi.PromptTemplate{ - TemplateType: ptr.Of(prompt.TemplateType("")), + TemplateType: ptr.Of(""), VariableDefs: []*openapi.VariableDef{}, Metadata: map[string]string{"commit": "meta"}, }, @@ -442,7 +442,8 @@ func TestOpenAPIPromptTemplateDO2DTO(t *testing.T) { Metadata: map[string]string{"k": "v"}, }, want: &openapi.PromptTemplate{ - TemplateType: ptr.Of(prompt.TemplateType("")), + TemplateType: ptr.Of(""), + Messages: nil, VariableDefs: []*openapi.VariableDef{}, Metadata: map[string]string{"k": "v"}, }, diff --git a/backend/modules/prompt/application/convertor/prompt.go b/backend/modules/prompt/application/convertor/prompt.go index 12e49b60b..b70998524 100644 --- a/backend/modules/prompt/application/convertor/prompt.go +++ b/backend/modules/prompt/application/convertor/prompt.go @@ -25,6 +25,23 @@ func PromptDTO2DO(dto *prompt.Prompt) *entity.Prompt { } } +func BatchPromptDTO2DO(dtos []*prompt.Prompt) []*entity.Prompt { + if len(dtos) == 0 { + return nil + } + prompts := make([]*entity.Prompt, 0, len(dtos)) + for _, dto := range dtos { + if dto == nil { + continue + } + prompts = append(prompts, PromptDTO2DO(dto)) + } + if len(prompts) == 0 { + return nil + } + return prompts +} + func PromptDraftDTO2DO(dto *prompt.PromptDraft) *entity.PromptDraft { if dto == nil { return nil @@ -76,6 +93,7 @@ func PromptBasicDTO2DO(dto *prompt.PromptBasic) *entity.PromptBasic { return nil } return &entity.PromptBasic{ + PromptType: PromptTypeDTO2DO(dto.GetPromptType()), DisplayName: dto.GetDisplayName(), Description: dto.GetDescription(), LatestVersion: dto.GetLatestVersion(), @@ -109,6 +127,8 @@ func PromptTemplateDTO2DO(dto *prompt.PromptTemplate) *entity.PromptTemplate { TemplateType: TemplateTypeDTO2DO(dto.GetTemplateType()), Messages: BatchMessageDTO2DO(dto.Messages), VariableDefs: BatchVariableDefDTO2DO(dto.VariableDefs), + HasSnippets: dto.GetHasSnippet(), + Snippets: BatchPromptDTO2DO(dto.Snippets), Metadata: dto.Metadata, } } @@ -128,6 +148,17 @@ func TemplateTypeDTO2DO(dto prompt.TemplateType) entity.TemplateType { } } +func PromptTypeDTO2DO(dto prompt.PromptType) entity.PromptType { + switch dto { + case prompt.PromptTypeNormal: + return entity.PromptTypeNormal + case prompt.PromptTypeSnippet: + return entity.PromptTypeSnippet + default: + return entity.PromptTypeNormal + } +} + func BatchMessageDTO2DO(dtos []*prompt.Message) []*entity.Message { if dtos == nil { return nil @@ -782,9 +813,35 @@ func PromptBasicDO2DTO(do *entity.PromptBasic) *prompt.PromptBasic { } return ptr.Of(do.LatestCommittedAt.UnixMilli()) }(), + PromptType: ptr.Of(PromptTypeDO2DTO(do.PromptType)), } } +func PromptTypeDO2DTO(do entity.PromptType) prompt.PromptType { + switch do { + case entity.PromptTypeNormal: + return prompt.PromptTypeNormal + case entity.PromptTypeSnippet: + return prompt.PromptTypeSnippet + default: + return prompt.PromptTypeNormal + } +} + +func BatchPromptCommitDO2DTO(dos []*entity.PromptCommit) []*prompt.PromptCommit { + if len(dos) == 0 { + return nil + } + dtos := make([]*prompt.PromptCommit, 0, len(dos)) + for _, do := range dos { + if do == nil { + continue + } + dtos = append(dtos, PromptCommitDO2DTO(do)) + } + return dtos +} + func PromptCommitDO2DTO(do *entity.PromptCommit) *prompt.PromptCommit { if do == nil { return nil @@ -918,6 +975,8 @@ func PromptTemplateDO2DTO(do *entity.PromptTemplate) *prompt.PromptTemplate { TemplateType: ptr.Of(prompt.TemplateType(do.TemplateType)), Messages: BatchMessageDO2DTO(do.Messages), VariableDefs: BatchVariableDefDO2DTO(do.VariableDefs), + HasSnippet: ptr.Of(do.HasSnippets), + Snippets: BatchPromptDO2DTO(do.Snippets), Metadata: do.Metadata, } } diff --git a/backend/modules/prompt/application/convertor/prompt_test.go b/backend/modules/prompt/application/convertor/prompt_test.go index 81917072d..5d29593f3 100644 --- a/backend/modules/prompt/application/convertor/prompt_test.go +++ b/backend/modules/prompt/application/convertor/prompt_test.go @@ -63,6 +63,7 @@ func mockPromptCases() []promptTestCase { WorkspaceID: ptr.Of(int64(456)), PromptKey: ptr.Of("test_prompt"), PromptBasic: &prompt.PromptBasic{ + PromptType: ptr.Of(prompt.PromptTypeNormal), DisplayName: ptr.Of("Test Prompt"), Description: ptr.Of("Test PromptDescription"), LatestVersion: ptr.Of("1.0.0"), @@ -82,6 +83,7 @@ func mockPromptCases() []promptTestCase { Detail: &prompt.PromptDetail{ PromptTemplate: &prompt.PromptTemplate{ TemplateType: ptr.Of(prompt.TemplateTypeNormal), + HasSnippet: ptr.Of(false), Messages: []*prompt.Message{ { Role: ptr.Of(prompt.RoleSystem), @@ -143,6 +145,7 @@ func mockPromptCases() []promptTestCase { Detail: &prompt.PromptDetail{ PromptTemplate: &prompt.PromptTemplate{ TemplateType: ptr.Of(prompt.TemplateTypeNormal), + HasSnippet: ptr.Of(false), Messages: []*prompt.Message{ { Role: ptr.Of(prompt.RoleSystem), @@ -158,6 +161,7 @@ func mockPromptCases() []promptTestCase { SpaceID: 456, PromptKey: "test_prompt", PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, DisplayName: "Test Prompt", Description: "Test PromptDescription", LatestVersion: "1.0.0", @@ -256,6 +260,7 @@ func mockPromptCases() []promptTestCase { WorkspaceID: ptr.Of(int64(456)), PromptKey: ptr.Of("test_prompt"), PromptBasic: &prompt.PromptBasic{ + PromptType: ptr.Of(prompt.PromptTypeNormal), DisplayName: ptr.Of("Test Prompt"), Description: ptr.Of("Test PromptDescription"), LatestVersion: ptr.Of("1.0.0"), @@ -270,6 +275,7 @@ func mockPromptCases() []promptTestCase { SpaceID: 456, PromptKey: "test_prompt", PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, DisplayName: "Test Prompt", Description: "Test PromptDescription", LatestVersion: "1.0.0", @@ -352,6 +358,7 @@ func mockPromptCases() []promptTestCase { Detail: &prompt.PromptDetail{ PromptTemplate: &prompt.PromptTemplate{ TemplateType: ptr.Of(prompt.TemplateTypeNormal), + HasSnippet: ptr.Of(false), Metadata: map[string]string{"commit-meta": "value"}, }, }, @@ -360,6 +367,7 @@ func mockPromptCases() []promptTestCase { Detail: &prompt.PromptDetail{ PromptTemplate: &prompt.PromptTemplate{ TemplateType: ptr.Of(prompt.TemplateTypeNormal), + HasSnippet: ptr.Of(false), Metadata: map[string]string{"draft-meta": "value"}, }, }, @@ -370,6 +378,7 @@ func mockPromptCases() []promptTestCase { PromptDetail: &entity.PromptDetail{ PromptTemplate: &entity.PromptTemplate{ TemplateType: entity.TemplateTypeNormal, + HasSnippets: false, Metadata: map[string]string{"commit-meta": "value"}, }, }, @@ -378,12 +387,130 @@ func mockPromptCases() []promptTestCase { PromptDetail: &entity.PromptDetail{ PromptTemplate: &entity.PromptTemplate{ TemplateType: entity.TemplateTypeNormal, + HasSnippets: false, Metadata: map[string]string{"draft-meta": "value"}, }, }, }, }, }, + { + name: "snippet prompt with snippets", + dto: &prompt.Prompt{ + ID: ptr.Of(int64(789)), + WorkspaceID: ptr.Of(int64(321)), + PromptKey: ptr.Of("snippet_prompt"), + PromptBasic: &prompt.PromptBasic{ + PromptType: ptr.Of(prompt.PromptTypeSnippet), + DisplayName: ptr.Of("Snippet Prompt"), + Description: ptr.Of("Snippet description"), + LatestVersion: ptr.Of("2.0.0"), + CreatedBy: ptr.Of("snippet_creator"), + UpdatedBy: ptr.Of("snippet_updater"), + CreatedAt: ptr.Of(nowMilli), + UpdatedAt: ptr.Of(nowMilli), + }, + PromptCommit: &prompt.PromptCommit{ + CommitInfo: &prompt.CommitInfo{ + Version: ptr.Of("2.0.0"), + BaseVersion: ptr.Of("1.0.0"), + Description: ptr.Of("Snippet version"), + CommittedBy: ptr.Of("snippet_creator"), + CommittedAt: ptr.Of(nowMilli), + }, + Detail: &prompt.PromptDetail{ + PromptTemplate: &prompt.PromptTemplate{ + TemplateType: ptr.Of(prompt.TemplateTypeNormal), + HasSnippet: ptr.Of(true), + Messages: []*prompt.Message{ + { + Role: ptr.Of(prompt.RoleSystem), + Content: ptr.Of("Snippet content"), + }, + }, + }, + }, + }, + PromptDraft: &prompt.PromptDraft{ + DraftInfo: &prompt.DraftInfo{ + UserID: ptr.Of("snippet_creator"), + BaseVersion: ptr.Of("2.0.0"), + IsModified: ptr.Of(false), + CreatedAt: ptr.Of(nowMilli), + UpdatedAt: ptr.Of(nowMilli), + }, + Detail: &prompt.PromptDetail{ + PromptTemplate: &prompt.PromptTemplate{ + TemplateType: ptr.Of(prompt.TemplateTypeNormal), + HasSnippet: ptr.Of(true), + Messages: []*prompt.Message{ + { + Role: ptr.Of(prompt.RoleUser), + Content: ptr.Of("Draft snippet content"), + }, + }, + }, + }, + }, + }, + do: &entity.Prompt{ + ID: 789, + SpaceID: 321, + PromptKey: "snippet_prompt", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeSnippet, + DisplayName: "Snippet Prompt", + Description: "Snippet description", + LatestVersion: "2.0.0", + CreatedBy: "snippet_creator", + UpdatedBy: "snippet_updater", + CreatedAt: time.UnixMilli(nowMilli), + UpdatedAt: time.UnixMilli(nowMilli), + }, + PromptCommit: &entity.PromptCommit{ + CommitInfo: &entity.CommitInfo{ + Version: "2.0.0", + BaseVersion: "1.0.0", + Description: "Snippet version", + CommittedBy: "snippet_creator", + CommittedAt: time.UnixMilli(nowMilli), + }, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + TemplateType: entity.TemplateTypeNormal, + HasSnippets: true, + Messages: []*entity.Message{ + { + Role: entity.RoleSystem, + Content: ptr.Of("Snippet content"), + }, + }, + }, + }, + }, + PromptDraft: &entity.PromptDraft{ + DraftInfo: &entity.DraftInfo{ + UserID: "snippet_creator", + BaseVersion: "2.0.0", + IsModified: false, + CreatedAt: time.UnixMilli(nowMilli), + UpdatedAt: time.UnixMilli(nowMilli), + }, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + TemplateType: entity.TemplateTypeNormal, + HasSnippets: true, + Messages: []*entity.Message{ + { + Role: entity.RoleUser, + Content: ptr.Of("Draft snippet content"), + }, + }, + }, + }, + }, + }, + }, } } diff --git a/backend/modules/prompt/application/debug.go b/backend/modules/prompt/application/debug.go index af07f685f..8e65fd82a 100644 --- a/backend/modules/prompt/application/debug.go +++ b/backend/modules/prompt/application/debug.go @@ -226,6 +226,11 @@ func (p *PromptDebugApplicationImpl) doDebugStreaming(ctx context.Context, req * prompt := convertor.PromptDTO2DO(req.Prompt) // prompt hub span report p.reportDebugPromptHubSpan(ctx, prompt) + // expand snippets + err = p.promptService.ExpandSnippets(ctx, prompt) + if err != nil { + return nil, err + } // execute resultStream := make(chan *entity.Reply) errChan := make(chan error) diff --git a/backend/modules/prompt/application/debug_test.go b/backend/modules/prompt/application/debug_test.go index 7b79c147d..fb4eace52 100644 --- a/backend/modules/prompt/application/debug_test.go +++ b/backend/modules/prompt/application/debug_test.go @@ -64,6 +64,7 @@ func TestPromptDebugApplicationImpl_DebugStreaming(t *testing.T) { mockDebugLogRepo := repomocks.NewMockIDebugLogRepo(ctrl) mockDebugLogRepo.EXPECT().SaveDebugLog(gomock.Any(), gomock.Any()).Return(nil) mockPromptSvc := servicemocks.NewMockIPromptService(ctrl) + mockPromptSvc.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil) mockPromptSvc.EXPECT().MCompleteMultiModalFileURL(gomock.Any(), gomock.Any(), nil).Return(nil) mockPromptSvc.EXPECT().MConvertBase64DataURLToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptSvc.EXPECT().ExecuteStreaming(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, param service.ExecuteStreamingParam) (*entity.Reply, error) { @@ -126,6 +127,7 @@ func TestPromptDebugApplicationImpl_DebugStreaming(t *testing.T) { mockDebugLogRepo := repomocks.NewMockIDebugLogRepo(ctrl) mockDebugLogRepo.EXPECT().SaveDebugLog(gomock.Any(), gomock.Any()).Return(nil) mockPromptSvc := servicemocks.NewMockIPromptService(ctrl) + mockPromptSvc.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil) mockPromptSvc.EXPECT().MCompleteMultiModalFileURL(gomock.Any(), gomock.Any(), nil).Return(nil) mockPromptSvc.EXPECT().MConvertBase64DataURLToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptSvc.EXPECT().ExecuteStreaming(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, param service.ExecuteStreamingParam) (*entity.Reply, error) { @@ -182,6 +184,45 @@ func TestPromptDebugApplicationImpl_DebugStreaming(t *testing.T) { }, wantErr: nil, }, + { + name: "expand snippets error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockPromptSvc := servicemocks.NewMockIPromptService(ctrl) + mockPromptSvc.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(errorx.New("expand error")) + + mockBenefitSvc := benefitmocks.NewMockIBenefitService(ctrl) + mockBenefitSvc.EXPECT().CheckPromptBenefit(gomock.Any(), gomock.Any()).Return(&benefit.CheckPromptBenefitResult{}, nil) + + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + + return fields{ + promptService: mockPromptSvc, + benefitService: mockBenefitSvc, + auth: mockAuth, + } + }, + args: args{ + ctx: session.WithCtxUser(context.Background(), mockUser), + req: &debug.DebugStreamingRequest{ + Prompt: &prompt.Prompt{ + ID: ptr.Of(int64(123456)), + WorkspaceID: ptr.Of(int64(123456)), + PromptDraft: &prompt.PromptDraft{ + Detail: &prompt.PromptDetail{ + PromptTemplate: &prompt.PromptTemplate{ + TemplateType: ptr.Of(prompt.TemplateTypeNormal), + }, + ModelConfig: &prompt.ModelConfig{}, + }, + }, + }, + SingleStepDebug: ptr.Of(true), + }, + stream: localstream.NewInMemStream(context.Background(), make(chan *debug.DebugStreamingResponse), make(chan error)), + }, + wantErr: errorx.NewByCode(prompterr.CommonInternalErrorCode), + }, { name: "invalid param: prompt is nil", fieldsGetter: func(ctrl *gomock.Controller) fields { return fields{} }, @@ -224,6 +265,7 @@ func TestPromptDebugApplicationImpl_DebugStreaming(t *testing.T) { mockDebugLogRepo := repomocks.NewMockIDebugLogRepo(ctrl) mockDebugLogRepo.EXPECT().SaveDebugLog(gomock.Any(), gomock.Any()).Return(nil) mockPromptSvc := servicemocks.NewMockIPromptService(ctrl) + mockPromptSvc.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil) mockPromptSvc.EXPECT().MCompleteMultiModalFileURL(gomock.Any(), gomock.Any(), nil).Return(nil) mockPromptSvc.EXPECT().ExecuteStreaming(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, param service.ExecuteStreamingParam) (*entity.Reply, error) { param.ResultStream <- &entity.Reply{ @@ -291,6 +333,7 @@ func TestPromptDebugApplicationImpl_DebugStreaming(t *testing.T) { mockDebugLogRepo := repomocks.NewMockIDebugLogRepo(ctrl) mockDebugLogRepo.EXPECT().SaveDebugLog(gomock.Any(), gomock.Any()).Return(nil) mockPromptSvc := servicemocks.NewMockIPromptService(ctrl) + mockPromptSvc.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil) mockPromptSvc.EXPECT().MCompleteMultiModalFileURL(gomock.Any(), gomock.Any(), nil).Return(nil) mockPromptSvc.EXPECT().MConvertBase64DataURLToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptSvc.EXPECT().ExecuteStreaming(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, param service.ExecuteStreamingParam) (*entity.Reply, error) { diff --git a/backend/modules/prompt/application/execute.go b/backend/modules/prompt/application/execute.go index 10e3c981c..a3da775d3 100644 --- a/backend/modules/prompt/application/execute.go +++ b/backend/modules/prompt/application/execute.go @@ -66,6 +66,12 @@ func (p *PromptExecuteApplicationImpl) ExecuteInternal(ctx context.Context, req if err != nil { return r, err } + // expand snippets + err = p.promptService.ExpandSnippets(ctx, promptDO) + if err != nil { + return r, err + } + // override prompt params overridePromptParams(promptDO, req.OverridePromptParams) // execute reply, err = p.promptService.Execute(ctx, service.ExecuteParam{ diff --git a/backend/modules/prompt/application/execute_test.go b/backend/modules/prompt/application/execute_test.go index 0b64848c0..bbf5fc4d6 100755 --- a/backend/modules/prompt/application/execute_test.go +++ b/backend/modules/prompt/application/execute_test.go @@ -20,8 +20,7 @@ import ( "github.com/coze-dev/coze-loop/backend/modules/prompt/domain/service" servicemocks "github.com/coze-dev/coze-loop/backend/modules/prompt/domain/service/mocks" - // prompterr "github.com/coze-dev/coze-loop/backend/modules/prompt/pkg/errno" - // "github.com/coze-dev/coze-loop/backend/pkg/errorx" + "github.com/coze-dev/coze-loop/backend/pkg/errorx" "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" "github.com/coze-dev/coze-loop/backend/pkg/unittest" ) @@ -109,6 +108,7 @@ func TestPromptExecuteApplicationImpl_ExecuteInternal(t *testing.T) { mockManageRepo.EXPECT().GetPrompt(gomock.Any(), gomock.Any()).Return(createMockPrompt(), nil) mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil) mockPromptService.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(mockReply, nil) mockPromptService.EXPECT().MConvertBase64DataURLToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) @@ -147,6 +147,7 @@ func TestPromptExecuteApplicationImpl_ExecuteInternal(t *testing.T) { mockManageRepo.EXPECT().GetPrompt(gomock.Any(), gomock.Any()).Return(createMockPrompt(), nil) mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil) mockPromptService.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(mockReply, nil) mockPromptService.EXPECT().MConvertBase64DataURLToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("convert error")) @@ -198,6 +199,7 @@ func TestPromptExecuteApplicationImpl_ExecuteInternal(t *testing.T) { mockManageRepo.EXPECT().GetPrompt(gomock.Any(), gomock.Any()).Return(createMockPrompt(), nil) mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil) mockPromptService.EXPECT().Execute(gomock.Any(), gomock.Any()). Return(nil, errors.New("execution error")) @@ -217,6 +219,31 @@ func TestPromptExecuteApplicationImpl_ExecuteInternal(t *testing.T) { wantR: execute.NewExecuteInternalResponse(), wantErr: errors.New("execution error"), }, + { + name: "expand snippets error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockManageRepo := repomocks.NewMockIManageRepo(ctrl) + mockManageRepo.EXPECT().GetPrompt(gomock.Any(), gomock.Any()).Return(createMockPrompt(), nil) + + mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(errorx.New("expand error")) + + return fields{ + promptService: mockPromptService, + manageRepo: mockManageRepo, + } + }, + args: args{ + ctx: context.Background(), + req: &execute.ExecuteInternalRequest{ + PromptID: ptr.Of(int64(123)), + WorkspaceID: ptr.Of(int64(123456)), + Version: ptr.Of("1.0.0"), + }, + }, + wantR: execute.NewExecuteInternalResponse(), + wantErr: errorx.New("expand error"), + }, { name: "success with override params", fieldsGetter: func(ctrl *gomock.Controller) fields { @@ -224,6 +251,7 @@ func TestPromptExecuteApplicationImpl_ExecuteInternal(t *testing.T) { mockManageRepo.EXPECT().GetPrompt(gomock.Any(), gomock.Any()).Return(createMockPrompt(), nil) mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil) mockPromptService.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(mockReply, nil) mockPromptService.EXPECT().MConvertBase64DataURLToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) @@ -268,6 +296,7 @@ func TestPromptExecuteApplicationImpl_ExecuteInternal(t *testing.T) { mockManageRepo.EXPECT().GetPrompt(gomock.Any(), gomock.Any()).Return(createMockPrompt(), nil) mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil) mockPromptService.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(mockReply, nil) mockPromptService.EXPECT().MConvertBase64DataURLToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) diff --git a/backend/modules/prompt/application/manage.go b/backend/modules/prompt/application/manage.go index 519a831fb..02faa3b8a 100644 --- a/backend/modules/prompt/application/manage.go +++ b/backend/modules/prompt/application/manage.go @@ -58,6 +58,56 @@ type PromptManageApplicationImpl struct { configProvider conf.IConfigProvider } +func (app *PromptManageApplicationImpl) ListParentPrompt(ctx context.Context, request *manage.ListParentPromptRequest) (r *manage.ListParentPromptResponse, err error) { + r = manage.NewListParentPromptResponse() + + // 用户验证 + userID, ok := session.UserIDInCtx(ctx) + if !ok || lo.IsEmpty(userID) { + return r, errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("User not found")) + } + + // 权限检查 + err = app.authRPCProvider.CheckSpacePermission(ctx, request.GetWorkspaceID(), consts.ActionLoopPromptRead) + if err != nil { + return r, err + } + + // 参数验证 + if request.GetPromptID() <= 0 { + return r, errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("Prompt ID is required")) + } + + // 调用repository层查询父prompt + result, err := app.manageRepo.ListParentPrompt(ctx, repo.ListParentPromptParam{ + SubPromptID: request.GetPromptID(), + SubPromptVersions: request.GetCommitVersions(), + }) + if err != nil { + return r, err + } + + // 转换结果 + parentPrompts := make(map[string][]*prompt.PromptCommitVersions) + for version, promptCommitVersions := range result { + promptVersionDTOs := make([]*prompt.PromptCommitVersions, 0, len(promptCommitVersions)) + for _, promptCommitVersion := range promptCommitVersions { + promptVersionDTO := &prompt.PromptCommitVersions{ + ID: ptr.Of(promptCommitVersion.PromptID), + WorkspaceID: ptr.Of(promptCommitVersion.SpaceID), + PromptKey: ptr.Of(promptCommitVersion.PromptKey), + PromptBasic: convertor.PromptBasicDO2DTO(promptCommitVersion.PromptBasic), + CommitVersions: promptCommitVersion.CommitVersions, + } + promptVersionDTOs = append(promptVersionDTOs, promptVersionDTO) + } + parentPrompts[version] = promptVersionDTOs + } + + r.ParentPrompts = parentPrompts + return r, nil +} + func (app *PromptManageApplicationImpl) CreatePrompt(ctx context.Context, request *manage.CreatePromptRequest) (r *manage.CreatePromptResponse, err error) { r = manage.NewCreatePromptResponse() @@ -73,11 +123,15 @@ func (app *PromptManageApplicationImpl) CreatePrompt(ctx context.Context, reques return r, err } + if request.PromptType == nil { + request.PromptType = ptr.Of(prompt.PromptTypeNormal) + } // create prompt promptDTO := &prompt.Prompt{ WorkspaceID: request.WorkspaceID, PromptKey: request.PromptKey, PromptBasic: &prompt.PromptBasic{ + PromptType: request.PromptType, DisplayName: request.PromptName, Description: request.PromptDescription, CreatedBy: ptr.Of(userID), @@ -104,8 +158,9 @@ func (app *PromptManageApplicationImpl) CreatePrompt(ctx context.Context, reques return r, err } - // create prompt - promptID, err := app.manageRepo.CreatePrompt(ctx, promptDO) + // create prompt using domain service with snippet validation + var promptID int64 + promptID, err = app.promptService.CreatePrompt(ctx, promptDO) if err != nil { return r, err } @@ -159,7 +214,7 @@ func (app *PromptManageApplicationImpl) ClonePrompt(ctx context.Context, request PromptDetail: clonedPromptDO.PromptCommit.PromptDetail, } clonedPromptDO.PromptCommit = nil - clonedPromptID, err := app.manageRepo.CreatePrompt(ctx, clonedPromptDO) + clonedPromptID, err := app.promptService.CreatePrompt(ctx, clonedPromptDO) if err != nil { return r, err } @@ -184,6 +239,9 @@ func (app *PromptManageApplicationImpl) DeletePrompt(ctx context.Context, reques if err != nil { return r, err } + if promptDO.PromptBasic.PromptType == entity.PromptTypeSnippet { + return r, errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("Snippet prompt can not be deleted")) + } // 权限 err = app.authRPCProvider.MCheckPromptPermission(ctx, promptDO.SpaceID, []int64{request.GetPromptID()}, consts.ActionLoopPromptEdit) @@ -219,16 +277,17 @@ func (app *PromptManageApplicationImpl) GetPrompt(ctx context.Context, request * } // prompt - getPromptParam := repo.GetPromptParam{ + getPromptParam := service.GetPromptParam{ PromptID: request.GetPromptID(), WithCommit: !lo.IsEmpty(commitVersion), CommitVersion: commitVersion, - WithDraft: request.GetWithDraft(), - UserID: userID, + WithDraft: request.GetWithDraft(), + UserID: userID, + ExpandSnippet: request.GetExpandSnippet(), } - promptDO, err := app.manageRepo.GetPrompt(ctx, getPromptParam) + promptDO, err := app.promptService.GetPrompt(ctx, getPromptParam) if err != nil { return r, err } @@ -255,6 +314,38 @@ func (app *PromptManageApplicationImpl) GetPrompt(ctx context.Context, request * } r.DefaultConfig = defaultConfig } + + // [prompt片段]返回被引用总次数 + if promptDO.PromptBasic != nil && promptDO.PromptBasic.PromptType == entity.PromptTypeSnippet { + var commitVersionParams []string + if request.GetWithCommit() && lo.IsNotEmpty(request.GetCommitVersion()) { + commitVersionParams = append(commitVersionParams, commitVersion) + } else { + commitVersions, err := app.manageRepo.MGetVersionsByPromptID(ctx, request.GetPromptID()) + if err != nil { + return r, err + } + commitVersionParams = append(commitVersionParams, commitVersions...) + } + if len(commitVersionParams) > 0 { + parentPromptCommitVersions, err := app.manageRepo.ListParentPrompt(ctx, repo.ListParentPromptParam{ + SubPromptID: request.GetPromptID(), + SubPromptVersions: commitVersionParams, + }) + if err != nil { + return r, err + } + if len(parentPromptCommitVersions) > 0 { + var total int32 + for _, parents := range parentPromptCommitVersions { + for _, parent := range parents { + total += int32(len(parent.CommitVersions)) + } + } + r.TotalParentReferences = ptr.Of(total) + } + } + } return r, err } @@ -300,14 +391,26 @@ func (app *PromptManageApplicationImpl) ListPrompt(ctx context.Context, request return r, err } - // list prompt + // Default filtering behavior: if no filter_prompt_types specified, only show normal prompts + filterPromptTypes := request.GetFilterPromptTypes() + if len(filterPromptTypes) == 0 { + filterPromptTypes = []prompt.PromptType{prompt.PromptTypeNormal} + } + + // Convert prompt.PromptType to entity.PromptType + var entityFilterPromptTypes []entity.PromptType + for _, pt := range filterPromptTypes { + entityFilterPromptTypes = append(entityFilterPromptTypes, convertor.PromptTypeDTO2DO(pt)) + } + listPromptParam := repo.ListPromptParam{ SpaceID: request.GetWorkspaceID(), - KeyWord: request.GetKeyWord(), - CreatedBys: request.GetCreatedBys(), - UserID: userID, - CommittedOnly: request.GetCommittedOnly(), + KeyWord: request.GetKeyWord(), + CreatedBys: request.GetCreatedBys(), + UserID: userID, + CommittedOnly: request.GetCommittedOnly(), + FilterPromptTypes: entityFilterPromptTypes, PageNum: int(request.GetPageNum()), PageSize: int(request.GetPageSize()), @@ -329,6 +432,7 @@ func (app *PromptManageApplicationImpl) ListPrompt(ctx context.Context, request continue } userIDSet[promptDTO.PromptBasic.GetCreatedBy()] = struct{}{} + userIDSet[promptDTO.PromptBasic.GetUpdatedBy()] = struct{}{} } userDOs, err := app.userRPCProvider.MGetUserInfo(ctx, maps.Keys(userIDSet)) if err != nil { @@ -429,7 +533,7 @@ func (app *PromptManageApplicationImpl) SaveDraft(ctx context.Context, request * } // save draft - draftInfoDO, err := app.manageRepo.SaveDraft(ctx, savingPromptDO) + draftInfoDO, err := app.promptService.SaveDraft(ctx, savingPromptDO) if err != nil { return r, err } @@ -547,6 +651,17 @@ func (app *PromptManageApplicationImpl) ListCommit(ctx context.Context, request r.HasMore = ptr.Of(true) } r.PromptCommitInfos = convertor.BatchCommitInfoDO2DTO(listCommitResult.CommitInfoDOs) + if request.GetWithCommitDetail() { + commitDTOs := convertor.BatchPromptCommitDO2DTO(listCommitResult.CommitDOs) + promptCommitDetailMap := make(map[string]*prompt.PromptDetail) + for _, commitDTO := range commitDTOs { + if commitDTO == nil || commitDTO.CommitInfo == nil || lo.IsEmpty(commitDTO.CommitInfo.Version) { + continue + } + promptCommitDetailMap[commitDTO.GetCommitInfo().GetVersion()] = commitDTO.Detail + } + r.PromptCommitDetailMapping = promptCommitDetailMap + } userIDSet := make(map[string]struct{}) for _, commitInfoDTO := range r.PromptCommitInfos { if commitInfoDTO == nil || lo.IsEmpty(commitInfoDTO.GetCommittedBy()) { @@ -560,7 +675,6 @@ func (app *PromptManageApplicationImpl) ListCommit(ctx context.Context, request } r.Users = convertor.BatchUserInfoDO2DTO(userDOs) - // 填充commit版本标签映射 if len(r.PromptCommitInfos) > 0 { var commitVersions []string for _, commitInfo := range r.PromptCommitInfos { @@ -569,6 +683,7 @@ func (app *PromptManageApplicationImpl) ListCommit(ctx context.Context, request } } + // 填充commit版本标签映射 if len(commitVersions) > 0 { // 查询这些版本的标签映射,使用labelService commitLabelMapping, err := app.promptService.BatchGetCommitLabels(ctx, request.GetPromptID(), commitVersions) @@ -590,6 +705,30 @@ func (app *PromptManageApplicationImpl) ListCommit(ctx context.Context, request r.CommitVersionLabelMapping = commitVersionLabelMapping } + // 填充被引用次数映射 + if len(commitVersions) > 0 && promptDO.PromptBasic != nil && promptDO.PromptBasic.PromptType == entity.PromptTypeSnippet { + // 查询这些版本的被引用次数,使用labelService + parentPromptCommitVersions, err := app.manageRepo.ListParentPrompt(ctx, repo.ListParentPromptParam{ + SubPromptID: request.GetPromptID(), + SubPromptVersions: commitVersions, + }) + if err != nil { + return r, err + } + + // 构建版本到被引用次数的映射 + commitVersionReferencesMapping := make(map[string]int32) + for version, parents := range parentPromptCommitVersions { + for _, parent := range parents { + if parent == nil { + continue + } + commitVersionReferencesMapping[version] += int32(len(parent.CommitVersions)) + } + } + + r.ParentReferencesMapping = commitVersionReferencesMapping + } } return r, nil @@ -634,7 +773,7 @@ func (app *PromptManageApplicationImpl) RevertDraftFromCommit(ctx context.Contex }, PromptDetail: promptDO.PromptCommit.PromptDetail, } - _, err = app.manageRepo.SaveDraft(ctx, promptDO) + _, err = app.promptService.SaveDraft(ctx, promptDO) return r, err } diff --git a/backend/modules/prompt/application/manage_test.go b/backend/modules/prompt/application/manage_test.go index f06697c5e..2c60e9fd9 100644 --- a/backend/modules/prompt/application/manage_test.go +++ b/backend/modules/prompt/application/manage_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/Masterminds/semver/v3" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" @@ -121,14 +122,20 @@ func TestPromptManageApplicationImpl_ClonePrompt(t *testing.T) { }, }, nil) - mockRepo.EXPECT().CreatePrompt(gomock.Any(), gomock.Any()).Return(int64(0), errorx.New("create prompt error")) + // 注意:在promptService.CreatePrompt内部会调用manageRepo.CreatePrompt + // 当manageRepo.CreatePrompt返回错误时,promptService.CreatePrompt也会返回错误 + mockRepo.EXPECT().CreatePrompt(gomock.Any(), gomock.Any()).Return(int64(0), errorx.New("create prompt error")).MinTimes(0).MaxTimes(1) mockAuth := mocks.NewMockIAuthProvider(ctrl) mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().CreatePrompt(gomock.Any(), gomock.Any()).Return(int64(0), errorx.New("create prompt error")) + return fields{ manageRepo: mockRepo, + promptService: mockPromptService, authRPCProvider: mockAuth, } }, @@ -175,7 +182,8 @@ func TestPromptManageApplicationImpl_ClonePrompt(t *testing.T) { mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - mockRepo.EXPECT().CreatePrompt(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, prompt *entity.Prompt) (int64, error) { + mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().CreatePrompt(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, prompt *entity.Prompt) (int64, error) { assert.Equal(t, "test_key", prompt.PromptKey) assert.Equal(t, "test_key", prompt.PromptBasic.DisplayName) assert.Equal(t, "test description", prompt.PromptBasic.Description) @@ -189,6 +197,7 @@ func TestPromptManageApplicationImpl_ClonePrompt(t *testing.T) { return fields{ manageRepo: mockRepo, + promptService: mockPromptService, authRPCProvider: mockAuth, } }, @@ -233,24 +242,248 @@ func TestPromptManageApplicationImpl_ClonePrompt(t *testing.T) { } } -func TestPromptManageApplicationImpl_GetPrompt(t *testing.T) { +func TestPromptManageApplicationImpl_DeletePrompt(t *testing.T) { type fields struct { manageRepo repo.IManageRepo - promptService service.IPromptService authRPCProvider rpc.IAuthProvider - userRPCProvider rpc.IUserProvider - configProvider conf.IConfigProvider } type args struct { ctx context.Context - request *manage.GetPromptRequest + request *manage.DeletePromptRequest } - now := time.Now() tests := []struct { name string fieldsGetter func(ctrl *gomock.Controller) fields args args - want *manage.GetPromptResponse + want *manage.DeletePromptResponse + wantErr error + }{ + { + name: "user not found", + fieldsGetter: func(ctrl *gomock.Controller) fields { return fields{} }, + args: args{ + ctx: context.Background(), + request: &manage.DeletePromptRequest{ + PromptID: ptr.Of(int64(1)), + }, + }, + want: manage.NewDeletePromptResponse(), + wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("User not found")), + }, + { + name: "get prompt error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + repoMock := repomocks.NewMockIManageRepo(ctrl) + repoMock.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{PromptID: 1}).Return(nil, errorx.New("get error")) + return fields{manageRepo: repoMock} + }, + args: args{ + ctx: session.WithCtxUser(context.Background(), &session.User{ID: "user"}), + request: &manage.DeletePromptRequest{ + PromptID: ptr.Of(int64(1)), + }, + }, + want: manage.NewDeletePromptResponse(), + wantErr: errorx.New("get error"), + }, + { + name: "snippet prompt not allowed", + fieldsGetter: func(ctrl *gomock.Controller) fields { + repoMock := repomocks.NewMockIManageRepo(ctrl) + repoMock.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{PromptID: 2}).Return(&entity.Prompt{ + ID: 2, + SpaceID: 10, + PromptBasic: &entity.PromptBasic{PromptType: entity.PromptTypeSnippet}, + }, nil) + return fields{manageRepo: repoMock} + }, + args: args{ + ctx: session.WithCtxUser(context.Background(), &session.User{ID: "user"}), + request: &manage.DeletePromptRequest{ + PromptID: ptr.Of(int64(2)), + }, + }, + want: manage.NewDeletePromptResponse(), + wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("Snippet prompt can not be deleted")), + }, + { + name: "success", + fieldsGetter: func(ctrl *gomock.Controller) fields { + repoMock := repomocks.NewMockIManageRepo(ctrl) + repoMock.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{PromptID: 3}).Return(&entity.Prompt{ + ID: 3, + SpaceID: 20, + PromptBasic: &entity.PromptBasic{PromptType: entity.PromptTypeNormal}, + }, nil) + repoMock.EXPECT().DeletePrompt(gomock.Any(), int64(3)).Return(nil) + auth := mocks.NewMockIAuthProvider(ctrl) + auth.EXPECT().MCheckPromptPermission(gomock.Any(), int64(20), []int64{int64(3)}, consts.ActionLoopPromptEdit).Return(nil) + return fields{manageRepo: repoMock, authRPCProvider: auth} + }, + args: args{ + ctx: session.WithCtxUser(context.Background(), &session.User{ID: "user"}), + request: &manage.DeletePromptRequest{ + PromptID: ptr.Of(int64(3)), + }, + }, + want: manage.NewDeletePromptResponse(), + wantErr: nil, + }, + } + + for _, tt := range tests { + caseData := tt + t.Run(caseData.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + tFields := caseData.fieldsGetter(ctrl) + app := &PromptManageApplicationImpl{ + manageRepo: tFields.manageRepo, + authRPCProvider: tFields.authRPCProvider, + } + + resp, err := app.DeletePrompt(caseData.args.ctx, caseData.args.request) + unittest.AssertErrorEqual(t, caseData.wantErr, err) + if err == nil { + assert.Equal(t, caseData.want, resp) + } + }) + } +} + +func TestPromptManageApplicationImpl_BatchGetPrompt(t *testing.T) { + type fields struct { + manageRepo repo.IManageRepo + } + type args struct { + ctx context.Context + request *manage.BatchGetPromptRequest + } + tests := []struct { + name string + fieldsGetter func(ctrl *gomock.Controller) fields + args args + wantLen int + wantErr error + }{ + { + name: "repo error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + repoMock := repomocks.NewMockIManageRepo(ctrl) + repoMock.EXPECT().MGetPrompt(gomock.Any(), gomock.Any()).Return(nil, errorx.New("mget error")) + return fields{manageRepo: repoMock} + }, + args: args{ + ctx: context.Background(), + request: &manage.BatchGetPromptRequest{ + Queries: []*manage.PromptQuery{ + { + PromptID: ptr.Of(int64(1)), + WithCommit: ptr.Of(true), + CommitVersion: ptr.Of("v1"), + }, + }, + }, + }, + wantLen: 0, + wantErr: errorx.New("mget error"), + }, + { + name: "success", + fieldsGetter: func(ctrl *gomock.Controller) fields { + repoMock := repomocks.NewMockIManageRepo(ctrl) + repoMock.EXPECT().MGetPrompt(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, params []repo.GetPromptParam, _ ...repo.GetPromptOptionFunc) (map[repo.GetPromptParam]*entity.Prompt, error) { + assert.Len(t, params, 1) + return map[repo.GetPromptParam]*entity.Prompt{ + params[0]: { + ID: params[0].PromptID, + SpaceID: 100, + PromptKey: "key", + PromptBasic: &entity.PromptBasic{ + DisplayName: "name", + }, + }, + }, nil + }) + return fields{manageRepo: repoMock} + }, + args: args{ + ctx: context.Background(), + request: &manage.BatchGetPromptRequest{ + Queries: []*manage.PromptQuery{ + { + PromptID: ptr.Of(int64(5)), + WithCommit: ptr.Of(false), + }, + }, + }, + }, + wantLen: 1, + wantErr: nil, + }, + } + + for _, tt := range tests { + caseData := tt + t.Run(caseData.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + tFields := caseData.fieldsGetter(ctrl) + app := &PromptManageApplicationImpl{manageRepo: tFields.manageRepo} + + resp, err := app.BatchGetPrompt(caseData.args.ctx, caseData.args.request) + unittest.AssertErrorEqual(t, caseData.wantErr, err) + if err == nil { + assert.Len(t, resp.Results, caseData.wantLen) + } + }) + } +} + +func TestNewPromptManageApplication(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + manageRepo := repomocks.NewMockIManageRepo(ctrl) + labelRepo := repomocks.NewMockILabelRepo(ctrl) + promptService := servicemocks.NewMockIPromptService(ctrl) + auth := mocks.NewMockIAuthProvider(ctrl) + user := mocks.NewMockIUserProvider(ctrl) + audit := mocks.NewMockIAuditProvider(ctrl) + config := confmocks.NewMockIConfigProvider(ctrl) + + app := NewPromptManageApplication(manageRepo, labelRepo, promptService, auth, user, audit, config) + impl, ok := app.(*PromptManageApplicationImpl) + assert.True(t, ok) + assert.Equal(t, manageRepo, impl.manageRepo) + assert.Equal(t, labelRepo, impl.labelRepo) + assert.Equal(t, promptService, impl.promptService) + assert.Equal(t, auth, impl.authRPCProvider) + assert.Equal(t, user, impl.userRPCProvider) + assert.Equal(t, audit, impl.auditRPCProvider) + assert.Equal(t, config, impl.configProvider) +} + +func TestPromptManageApplicationImpl_CreatePrompt(t *testing.T) { + type fields struct { + promptService service.IPromptService + authRPCProvider rpc.IAuthProvider + auditRPCProvider rpc.IAuditProvider + } + type args struct { + ctx context.Context + request *manage.CreatePromptRequest + } + tests := []struct { + name string + fieldsGetter func(ctrl *gomock.Controller) fields + args args + want *manage.CreatePromptResponse wantErr error }{ { @@ -260,50 +493,173 @@ func TestPromptManageApplicationImpl_GetPrompt(t *testing.T) { }, args: args{ ctx: context.Background(), - request: &manage.GetPromptRequest{ - PromptID: ptr.Of(int64(1)), + request: &manage.CreatePromptRequest{ + WorkspaceID: ptr.Of(int64(100)), + PromptKey: ptr.Of("prompt_key"), + PromptName: ptr.Of("prompt_name"), }, }, - want: manage.NewGetPromptResponse(), + want: manage.NewCreatePromptResponse(), wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("User not found")), }, { - name: "get latest version error", + name: "permission denied", fieldsGetter: func(ctrl *gomock.Controller) fields { - mockRepo := repomocks.NewMockIManageRepo(ctrl) - mockRepo.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{ - PromptID: 1, - }).Return(nil, errorx.New("get prompt error")) - + auth := mocks.NewMockIAuthProvider(ctrl) + auth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceCreateLoopPrompt).Return(errorx.New("permission denied")) + return fields{authRPCProvider: auth} + }, + args: args{ + ctx: session.WithCtxUser(context.Background(), &session.User{ID: "user"}), + request: &manage.CreatePromptRequest{ + WorkspaceID: ptr.Of(int64(100)), + PromptKey: ptr.Of("prompt_key"), + PromptName: ptr.Of("prompt_name"), + }, + }, + want: manage.NewCreatePromptResponse(), + wantErr: errorx.New("permission denied"), + }, + { + name: "audit failed", + fieldsGetter: func(ctrl *gomock.Controller) fields { + auth := mocks.NewMockIAuthProvider(ctrl) + auth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceCreateLoopPrompt).Return(nil) + audit := mocks.NewMockIAuditProvider(ctrl) + audit.EXPECT().AuditPrompt(gomock.Any(), gomock.Any()).Return(errorx.New("audit failed")) + return fields{authRPCProvider: auth, auditRPCProvider: audit} + }, + args: args{ + ctx: session.WithCtxUser(context.Background(), &session.User{ID: "user"}), + request: &manage.CreatePromptRequest{ + WorkspaceID: ptr.Of(int64(100)), + PromptKey: ptr.Of("prompt_key"), + PromptName: ptr.Of("prompt_name"), + PromptDescription: ptr.Of("desc"), + DraftDetail: &prompt.PromptDetail{}, + }, + }, + want: manage.NewCreatePromptResponse(), + wantErr: errorx.New("audit failed"), + }, + { + name: "success", + fieldsGetter: func(ctrl *gomock.Controller) fields { + auth := mocks.NewMockIAuthProvider(ctrl) + auth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceCreateLoopPrompt).Return(nil) + audit := mocks.NewMockIAuditProvider(ctrl) + audit.EXPECT().AuditPrompt(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, promptDO *entity.Prompt) error { + assert.Equal(t, int64(100), promptDO.SpaceID) + assert.Equal(t, "user", promptDO.PromptBasic.CreatedBy) + return nil + }) + promptSvc := servicemocks.NewMockIPromptService(ctrl) + promptSvc.EXPECT().CreatePrompt(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, promptDO *entity.Prompt) (int64, error) { + assert.Equal(t, entity.PromptTypeNormal, promptDO.PromptBasic.PromptType) + assert.NotNil(t, promptDO.PromptDraft) + return 999, nil + }) return fields{ - manageRepo: mockRepo, + promptService: promptSvc, + authRPCProvider: auth, + auditRPCProvider: audit, } }, args: args{ - ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), + ctx: session.WithCtxUser(context.Background(), &session.User{ID: "user"}), + request: &manage.CreatePromptRequest{ + WorkspaceID: ptr.Of(int64(100)), + PromptKey: ptr.Of("prompt_key"), + PromptName: ptr.Of("prompt_name"), + PromptDescription: ptr.Of("desc"), + DraftDetail: &prompt.PromptDetail{ + PromptTemplate: &prompt.PromptTemplate{}, + }, + }, + }, + want: &manage.CreatePromptResponse{PromptID: ptr.Of(int64(999))}, + wantErr: nil, + }, + } + + for _, tt := range tests { + caseData := tt + t.Run(caseData.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + tFields := caseData.fieldsGetter(ctrl) + app := &PromptManageApplicationImpl{ + promptService: tFields.promptService, + authRPCProvider: tFields.authRPCProvider, + auditRPCProvider: tFields.auditRPCProvider, + } + + resp, err := app.CreatePrompt(caseData.args.ctx, caseData.args.request) + unittest.AssertErrorEqual(t, caseData.wantErr, err) + if err == nil { + assert.Equal(t, caseData.want, resp) + } + }) + } +} + +func TestPromptManageApplicationImpl_GetPrompt(t *testing.T) { + type fields struct { + manageRepo repo.IManageRepo + promptService service.IPromptService + authRPCProvider rpc.IAuthProvider + userRPCProvider rpc.IUserProvider + configProvider conf.IConfigProvider + } + type args struct { + ctx context.Context + request *manage.GetPromptRequest + } + + baseTime := time.Date(2025, 1, 2, 3, 4, 5, 0, time.UTC) + draftTime := baseTime.Add(time.Minute) + snippetTime := baseTime.Add(2 * time.Minute) + + tests := []struct { + name string + fieldsGetter func(ctrl *gomock.Controller) fields + args args + want *manage.GetPromptResponse + wantErr error + }{ + { + name: "user not found", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{} + }, + args: args{ + ctx: context.Background(), request: &manage.GetPromptRequest{ - PromptID: ptr.Of(int64(1)), - WithCommit: ptr.Of(true), - CommitVersion: nil, + PromptID: ptr.Of(int64(1)), }, }, want: manage.NewGetPromptResponse(), - wantErr: errorx.New("get prompt error"), + wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("User not found")), }, { - name: "get prompt error", + name: "prompt service error when commit version provided", fieldsGetter: func(ctrl *gomock.Controller) fields { mockRepo := repomocks.NewMockIManageRepo(ctrl) - mockRepo.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{ + mockRepo.EXPECT().GetPrompt(gomock.Any(), gomock.Any()).Times(0) + mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().GetPrompt(gomock.Any(), service.GetPromptParam{ PromptID: 1, WithCommit: true, CommitVersion: "1.0.0", WithDraft: false, UserID: "123", - }).Return(nil, errorx.New("get prompt error")) - + ExpandSnippet: false, + }).Return(nil, errorx.New("prompt service error")) return fields{ - manageRepo: mockRepo, + manageRepo: mockRepo, + promptService: mockPromptService, } }, args: args{ @@ -315,59 +671,62 @@ func TestPromptManageApplicationImpl_GetPrompt(t *testing.T) { }, }, want: manage.NewGetPromptResponse(), - wantErr: errorx.New("get prompt error"), + wantErr: errorx.New("prompt service error"), }, { name: "get prompt with commit success", fieldsGetter: func(ctrl *gomock.Controller) fields { mockRepo := repomocks.NewMockIManageRepo(ctrl) - mockRepo.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{ - PromptID: 1, - WithCommit: true, - CommitVersion: "1.0.0", - WithDraft: false, - UserID: "123", - }).Return(&entity.Prompt{ + mockRepo.EXPECT().GetPrompt(gomock.Any(), gomock.Any()).Times(0) + promptDO := &entity.Prompt{ ID: 1, SpaceID: 100, - PromptKey: "test_key", + PromptKey: "commit_key", PromptBasic: &entity.PromptBasic{ - DisplayName: "test_name", - Description: "test_description", - LatestVersion: "1.0.0", - CreatedBy: "test_creator", - UpdatedBy: "test_updater", - CreatedAt: now, - UpdatedAt: now, - LatestCommittedAt: nil, + PromptType: entity.PromptTypeNormal, + DisplayName: "commit name", + Description: "commit description", + LatestVersion: "1.0.0", + CreatedBy: "creator", + UpdatedBy: "updater", + CreatedAt: baseTime, + UpdatedAt: baseTime, }, PromptCommit: &entity.PromptCommit{ PromptDetail: &entity.PromptDetail{ PromptTemplate: &entity.PromptTemplate{ TemplateType: entity.TemplateTypeNormal, - Messages: []*entity.Message{ - { - Role: entity.RoleUser, - Content: ptr.Of("test content"), - }, - }, + Messages: []*entity.Message{{ + Role: entity.RoleUser, + Content: ptr.Of("commit content"), + }}, + HasSnippets: false, }, }, CommitInfo: &entity.CommitInfo{ Version: "1.0.0", BaseVersion: "0.9.0", - Description: "test commit", - CommittedBy: "test_user", - CommittedAt: now, + Description: "commit description", + CommittedBy: "committer", + CommittedAt: baseTime, }, }, - }, nil) - + } + mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().GetPrompt(gomock.Any(), service.GetPromptParam{ + PromptID: 1, + WithCommit: true, + CommitVersion: "1.0.0", + WithDraft: false, + UserID: "123", + ExpandSnippet: false, + }).Return(promptDO, nil) mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), int64(100), []int64{int64(1)}, consts.ActionLoopPromptRead).Return(nil) return fields{ manageRepo: mockRepo, + promptService: mockPromptService, authRPCProvider: mockAuth, } }, @@ -379,1230 +738,894 @@ func TestPromptManageApplicationImpl_GetPrompt(t *testing.T) { CommitVersion: ptr.Of("1.0.0"), }, }, - want: &manage.GetPromptResponse{ - Prompt: &prompt.Prompt{ + want: func() *manage.GetPromptResponse { + resp := manage.NewGetPromptResponse() + resp.Prompt = &prompt.Prompt{ ID: ptr.Of(int64(1)), WorkspaceID: ptr.Of(int64(100)), - PromptKey: ptr.Of("test_key"), + PromptKey: ptr.Of("commit_key"), PromptBasic: &prompt.PromptBasic{ - DisplayName: ptr.Of("test_name"), - Description: ptr.Of("test_description"), + DisplayName: ptr.Of("commit name"), + Description: ptr.Of("commit description"), LatestVersion: ptr.Of("1.0.0"), - CreatedBy: ptr.Of("test_creator"), - UpdatedBy: ptr.Of("test_updater"), - CreatedAt: ptr.Of(now.UnixMilli()), - UpdatedAt: ptr.Of(now.UnixMilli()), + CreatedBy: ptr.Of("creator"), + UpdatedBy: ptr.Of("updater"), + CreatedAt: ptr.Of(baseTime.UnixMilli()), + UpdatedAt: ptr.Of(baseTime.UnixMilli()), + PromptType: ptr.Of(prompt.PromptTypeNormal), }, PromptCommit: &prompt.PromptCommit{ Detail: &prompt.PromptDetail{ PromptTemplate: &prompt.PromptTemplate{ TemplateType: ptr.Of(prompt.TemplateTypeNormal), - Messages: []*prompt.Message{ - { - Role: ptr.Of(prompt.RoleUser), - Content: ptr.Of("test content"), - }, - }, + HasSnippet: ptr.Of(false), + Messages: []*prompt.Message{{ + Role: ptr.Of(prompt.RoleUser), + Content: ptr.Of("commit content"), + }}, }, }, CommitInfo: &prompt.CommitInfo{ Version: ptr.Of("1.0.0"), BaseVersion: ptr.Of("0.9.0"), - Description: ptr.Of("test commit"), - CommittedBy: ptr.Of("test_user"), - CommittedAt: ptr.Of(now.UnixMilli()), + Description: ptr.Of("commit description"), + CommittedBy: ptr.Of("committer"), + CommittedAt: ptr.Of(baseTime.UnixMilli()), }, }, - }, - }, + } + return resp + }(), wantErr: nil, }, { - name: "get prompt with draft success", + name: "get prompt with draft and default config success", fieldsGetter: func(ctrl *gomock.Controller) fields { mockRepo := repomocks.NewMockIManageRepo(ctrl) - mockRepo.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{ - PromptID: 1, - WithDraft: true, - UserID: "123", - }).Return(&entity.Prompt{ - ID: 1, - SpaceID: 100, - PromptKey: "test_key", + mockRepo.EXPECT().GetPrompt(gomock.Any(), gomock.Any()).Times(0) + promptDO := &entity.Prompt{ + ID: 2, + SpaceID: 200, + PromptKey: "draft_key", PromptBasic: &entity.PromptBasic{ - DisplayName: "test_name", - Description: "test_description", - LatestVersion: "1.0.0", - CreatedBy: "test_creator", - UpdatedBy: "test_updater", - CreatedAt: now, - UpdatedAt: now, - LatestCommittedAt: nil, + PromptType: entity.PromptTypeNormal, + DisplayName: "draft name", + Description: "draft description", + LatestVersion: "2.0.0", + CreatedBy: "creator", + UpdatedBy: "updater", + CreatedAt: draftTime, + UpdatedAt: draftTime, }, PromptDraft: &entity.PromptDraft{ PromptDetail: &entity.PromptDetail{ PromptTemplate: &entity.PromptTemplate{ TemplateType: entity.TemplateTypeNormal, - Messages: []*entity.Message{ - { - Role: entity.RoleUser, - Content: ptr.Of("test content"), - }, - }, + Messages: []*entity.Message{{ + Role: entity.RoleSystem, + Content: ptr.Of("draft content"), + }}, + HasSnippets: false, }, }, DraftInfo: &entity.DraftInfo{ UserID: "123", - BaseVersion: "1.0.0", + BaseVersion: "2.0.0", IsModified: true, - CreatedAt: now, - UpdatedAt: now, + CreatedAt: draftTime, + UpdatedAt: draftTime, }, }, - }, nil) - + } + mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().GetPrompt(gomock.Any(), service.GetPromptParam{ + PromptID: 2, + WithCommit: false, + CommitVersion: "", + WithDraft: true, + UserID: "123", + ExpandSnippet: false, + }).Return(promptDO, nil) mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), int64(200), []int64{int64(2)}, consts.ActionLoopPromptRead).Return(nil) + mockConfig := confmocks.NewMockIConfigProvider(ctrl) + mockConfig.EXPECT().GetPromptDefaultConfig(gomock.Any()).Return(&prompt.PromptDetail{ + PromptTemplate: &prompt.PromptTemplate{ + TemplateType: ptr.Of(prompt.TemplateTypeNormal), + HasSnippet: ptr.Of(false), + Messages: []*prompt.Message{{ + Role: ptr.Of(prompt.RoleSystem), + Content: ptr.Of("default config"), + }}, + }, + }, nil) return fields{ manageRepo: mockRepo, + promptService: mockPromptService, authRPCProvider: mockAuth, + configProvider: mockConfig, } }, args: args{ ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), request: &manage.GetPromptRequest{ - PromptID: ptr.Of(int64(1)), - WithDraft: ptr.Of(true), + PromptID: ptr.Of(int64(2)), + WithDraft: ptr.Of(true), + WithDefaultConfig: ptr.Of(true), }, }, - want: &manage.GetPromptResponse{ - Prompt: &prompt.Prompt{ - ID: ptr.Of(int64(1)), - WorkspaceID: ptr.Of(int64(100)), - PromptKey: ptr.Of("test_key"), + want: func() *manage.GetPromptResponse { + resp := manage.NewGetPromptResponse() + resp.Prompt = &prompt.Prompt{ + ID: ptr.Of(int64(2)), + WorkspaceID: ptr.Of(int64(200)), + PromptKey: ptr.Of("draft_key"), PromptBasic: &prompt.PromptBasic{ - DisplayName: ptr.Of("test_name"), - Description: ptr.Of("test_description"), - LatestVersion: ptr.Of("1.0.0"), - CreatedBy: ptr.Of("test_creator"), - UpdatedBy: ptr.Of("test_updater"), - CreatedAt: ptr.Of(now.UnixMilli()), - UpdatedAt: ptr.Of(now.UnixMilli()), - LatestCommittedAt: nil, + DisplayName: ptr.Of("draft name"), + Description: ptr.Of("draft description"), + LatestVersion: ptr.Of("2.0.0"), + CreatedBy: ptr.Of("creator"), + UpdatedBy: ptr.Of("updater"), + CreatedAt: ptr.Of(draftTime.UnixMilli()), + UpdatedAt: ptr.Of(draftTime.UnixMilli()), + PromptType: ptr.Of(prompt.PromptTypeNormal), }, PromptDraft: &prompt.PromptDraft{ Detail: &prompt.PromptDetail{ PromptTemplate: &prompt.PromptTemplate{ TemplateType: ptr.Of(prompt.TemplateTypeNormal), - Messages: []*prompt.Message{ - { - Role: ptr.Of(prompt.RoleUser), - Content: ptr.Of("test content"), - }, - }, + HasSnippet: ptr.Of(false), + Messages: []*prompt.Message{{ + Role: ptr.Of(prompt.RoleSystem), + Content: ptr.Of("draft content"), + }}, }, }, DraftInfo: &prompt.DraftInfo{ UserID: ptr.Of("123"), - BaseVersion: ptr.Of("1.0.0"), + BaseVersion: ptr.Of("2.0.0"), IsModified: ptr.Of(true), - CreatedAt: ptr.Of(now.UnixMilli()), - UpdatedAt: ptr.Of(now.UnixMilli()), + CreatedAt: ptr.Of(draftTime.UnixMilli()), + UpdatedAt: ptr.Of(draftTime.UnixMilli()), }, }, - }, - }, + } + resp.DefaultConfig = &prompt.PromptDetail{ + PromptTemplate: &prompt.PromptTemplate{ + TemplateType: ptr.Of(prompt.TemplateTypeNormal), + HasSnippet: ptr.Of(false), + Messages: []*prompt.Message{{ + Role: ptr.Of(prompt.RoleSystem), + Content: ptr.Of("default config"), + }}, + }, + } + return resp + }(), wantErr: nil, }, { - name: "get prompt with latest version success", + name: "workspace mismatch returns resource not found", fieldsGetter: func(ctrl *gomock.Controller) fields { mockRepo := repomocks.NewMockIManageRepo(ctrl) - mockRepo.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{ - PromptID: 1, - }).Return(&entity.Prompt{ - ID: 1, - SpaceID: 100, - PromptKey: "test_key", + mockRepo.EXPECT().GetPrompt(gomock.Any(), gomock.Any()).Times(0) + promptDO := &entity.Prompt{ + ID: 3, + SpaceID: 300, + PromptKey: "workspace_key", PromptBasic: &entity.PromptBasic{ - DisplayName: "test_name", - Description: "test_description", - LatestVersion: "1.0.0", - CreatedBy: "test_creator", - UpdatedBy: "test_updater", - CreatedAt: now, - UpdatedAt: now, - LatestCommittedAt: nil, + PromptType: entity.PromptTypeNormal, + DisplayName: "workspace name", + Description: "workspace description", + LatestVersion: "3.0.0", + CreatedBy: "creator", + UpdatedBy: "updater", + CreatedAt: baseTime, + UpdatedAt: baseTime, }, - }, nil) - - mockRepo.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{ - PromptID: 1, + } + mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().GetPrompt(gomock.Any(), service.GetPromptParam{ + PromptID: 3, WithCommit: true, - CommitVersion: "1.0.0", + CommitVersion: "3.0.0", WithDraft: false, UserID: "123", - }).Return(&entity.Prompt{ - ID: 1, - SpaceID: 100, - PromptKey: "test_key", - PromptBasic: &entity.PromptBasic{ - DisplayName: "test_name", - Description: "test_description", - CreatedBy: "test_creator", - UpdatedBy: "test_updater", - CreatedAt: now, - UpdatedAt: now, - LatestVersion: "1.0.0", - }, - PromptCommit: &entity.PromptCommit{ - PromptDetail: &entity.PromptDetail{ - PromptTemplate: &entity.PromptTemplate{ - TemplateType: entity.TemplateTypeNormal, - Messages: []*entity.Message{ - { - Role: entity.RoleUser, - Content: ptr.Of("test content"), - }, - }, - }, - }, - CommitInfo: &entity.CommitInfo{ - Version: "1.0.0", - BaseVersion: "0.9.0", - Description: "test commit", - CommittedBy: "test_user", - CommittedAt: now, - }, - }, - }, nil) - + ExpandSnippet: false, + }).Return(promptDO, nil) mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), int64(300), []int64{int64(3)}, consts.ActionLoopPromptRead).Return(nil) return fields{ manageRepo: mockRepo, + promptService: mockPromptService, authRPCProvider: mockAuth, } }, args: args{ ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), request: &manage.GetPromptRequest{ - PromptID: ptr.Of(int64(1)), + PromptID: ptr.Of(int64(3)), + WorkspaceID: ptr.Of(int64(999)), WithCommit: ptr.Of(true), - CommitVersion: nil, - }, - }, - want: &manage.GetPromptResponse{ - Prompt: &prompt.Prompt{ - ID: ptr.Of(int64(1)), - WorkspaceID: ptr.Of(int64(100)), - PromptKey: ptr.Of("test_key"), - PromptBasic: &prompt.PromptBasic{ - DisplayName: ptr.Of("test_name"), - Description: ptr.Of("test_description"), - LatestVersion: ptr.Of("1.0.0"), - CreatedBy: ptr.Of("test_creator"), - UpdatedBy: ptr.Of("test_updater"), - CreatedAt: ptr.Of(now.UnixMilli()), - UpdatedAt: ptr.Of(now.UnixMilli()), - LatestCommittedAt: nil, - }, - PromptCommit: &prompt.PromptCommit{ - Detail: &prompt.PromptDetail{ - PromptTemplate: &prompt.PromptTemplate{ - TemplateType: ptr.Of(prompt.TemplateTypeNormal), - Messages: []*prompt.Message{ - { - Role: ptr.Of(prompt.RoleUser), - Content: ptr.Of("test content"), - }, - }, - }, - }, - CommitInfo: &prompt.CommitInfo{ - Version: ptr.Of("1.0.0"), - BaseVersion: ptr.Of("0.9.0"), - Description: ptr.Of("test commit"), - CommittedBy: ptr.Of("test_user"), - CommittedAt: ptr.Of(now.UnixMilli()), - }, - }, + CommitVersion: ptr.Of("3.0.0"), }, }, - wantErr: nil, + want: manage.NewGetPromptResponse(), + wantErr: errorx.NewByCode(prompterr.ResourceNotFoundCode, errorx.WithExtraMsg("WorkspaceID not match")), }, { - name: "get prompt with default config success", + name: "snippet prompt parent references success", fieldsGetter: func(ctrl *gomock.Controller) fields { - mockRepo := repomocks.NewMockIManageRepo(ctrl) - mockRepo.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{ - PromptID: 1, - WithDraft: true, - UserID: "123", - }).Return(&entity.Prompt{ - ID: 1, - SpaceID: 100, - PromptKey: "test_key", + promptDO := &entity.Prompt{ + ID: 4, + SpaceID: 400, + PromptKey: "snippet_key", PromptBasic: &entity.PromptBasic{ - DisplayName: "test_name", - Description: "test_description", - LatestVersion: "1.0.0", - CreatedBy: "test_creator", - UpdatedBy: "test_updater", - CreatedAt: now, - UpdatedAt: now, - LatestCommittedAt: nil, - }, - PromptDraft: &entity.PromptDraft{ - PromptDetail: &entity.PromptDetail{ - PromptTemplate: &entity.PromptTemplate{ - TemplateType: entity.TemplateTypeNormal, - Messages: []*entity.Message{ - { - Role: entity.RoleUser, - Content: ptr.Of("test content"), - }, - }, - }, - }, - DraftInfo: &entity.DraftInfo{ - UserID: "123", - BaseVersion: "1.0.0", - IsModified: true, - CreatedAt: now, - UpdatedAt: now, - }, + PromptType: entity.PromptTypeSnippet, + DisplayName: "snippet name", + Description: "snippet description", + LatestVersion: "4.0.0", + CreatedBy: "creator", + UpdatedBy: "updater", + CreatedAt: snippetTime, + UpdatedAt: snippetTime, }, - }, nil) - + } + mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().GetPrompt(gomock.Any(), service.GetPromptParam{ + PromptID: 4, + WithCommit: false, + CommitVersion: "", + WithDraft: false, + UserID: "123", + ExpandSnippet: false, + }).Return(promptDO, nil) mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), int64(400), []int64{int64(4)}, consts.ActionLoopPromptRead).Return(nil) - mockConfig := &prompt.PromptDetail{ - PromptTemplate: &prompt.PromptTemplate{ - TemplateType: ptr.Of(prompt.TemplateTypeNormal), - Messages: []*prompt.Message{ - { - Role: ptr.Of(prompt.RoleSystem), - Content: ptr.Of("Default system message"), - }, + mockRepo := repomocks.NewMockIManageRepo(ctrl) + gomock.InOrder( + mockRepo.EXPECT().MGetVersionsByPromptID(gomock.Any(), int64(4)).Return([]string{"4.0.0", "4.1.0"}, nil), + mockRepo.EXPECT().ListParentPrompt(gomock.Any(), repo.ListParentPromptParam{ + SubPromptID: 4, + SubPromptVersions: []string{"4.0.0", "4.1.0"}, + }).Return(map[string][]*repo.PromptCommitVersions{ + "4.0.0": {{CommitVersions: []string{"10.0.0", "10.1.0"}}}, + "4.1.0": { + {CommitVersions: []string{"11.0.0"}}, + {CommitVersions: []string{"12.0.0", "12.1.0"}}, }, - }, - } - mockConfigProvider := confmocks.NewMockIConfigProvider(ctrl) - mockConfigProvider.EXPECT().GetPromptDefaultConfig(gomock.Any()).Return(mockConfig, nil) + }, nil), + ) return fields{ manageRepo: mockRepo, + promptService: mockPromptService, authRPCProvider: mockAuth, - configProvider: mockConfigProvider, } }, args: args{ ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), request: &manage.GetPromptRequest{ - PromptID: ptr.Of(int64(1)), - WithDraft: ptr.Of(true), - WithDefaultConfig: ptr.Of(true), + PromptID: ptr.Of(int64(4)), }, }, - want: &manage.GetPromptResponse{ - Prompt: &prompt.Prompt{ - ID: ptr.Of(int64(1)), - WorkspaceID: ptr.Of(int64(100)), - PromptKey: ptr.Of("test_key"), + want: func() *manage.GetPromptResponse { + resp := manage.NewGetPromptResponse() + resp.Prompt = &prompt.Prompt{ + ID: ptr.Of(int64(4)), + WorkspaceID: ptr.Of(int64(400)), + PromptKey: ptr.Of("snippet_key"), PromptBasic: &prompt.PromptBasic{ - DisplayName: ptr.Of("test_name"), - Description: ptr.Of("test_description"), - LatestVersion: ptr.Of("1.0.0"), - CreatedBy: ptr.Of("test_creator"), - UpdatedBy: ptr.Of("test_updater"), - CreatedAt: ptr.Of(now.UnixMilli()), - UpdatedAt: ptr.Of(now.UnixMilli()), - LatestCommittedAt: nil, - }, - PromptDraft: &prompt.PromptDraft{ - Detail: &prompt.PromptDetail{ - PromptTemplate: &prompt.PromptTemplate{ - TemplateType: ptr.Of(prompt.TemplateTypeNormal), - Messages: []*prompt.Message{ - { - Role: ptr.Of(prompt.RoleUser), - Content: ptr.Of("test content"), - }, - }, - }, - }, - DraftInfo: &prompt.DraftInfo{ - UserID: ptr.Of("123"), - BaseVersion: ptr.Of("1.0.0"), - IsModified: ptr.Of(true), - CreatedAt: ptr.Of(now.UnixMilli()), - UpdatedAt: ptr.Of(now.UnixMilli()), - }, - }, - }, - DefaultConfig: &prompt.PromptDetail{ - PromptTemplate: &prompt.PromptTemplate{ - TemplateType: ptr.Of(prompt.TemplateTypeNormal), - Messages: []*prompt.Message{ - { - Role: ptr.Of(prompt.RoleSystem), - Content: ptr.Of("Default system message"), - }, - }, + DisplayName: ptr.Of("snippet name"), + Description: ptr.Of("snippet description"), + LatestVersion: ptr.Of("4.0.0"), + CreatedBy: ptr.Of("creator"), + UpdatedBy: ptr.Of("updater"), + CreatedAt: ptr.Of(snippetTime.UnixMilli()), + UpdatedAt: ptr.Of(snippetTime.UnixMilli()), + PromptType: ptr.Of(prompt.PromptTypeSnippet), }, + } + resp.TotalParentReferences = ptr.Of(int32(5)) + return resp + }(), + wantErr: nil, + }, + } + + for _, tt := range tests { + caseData := tt + t.Run(caseData.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ff := caseData.fieldsGetter(ctrl) + app := &PromptManageApplicationImpl{ + manageRepo: ff.manageRepo, + promptService: ff.promptService, + authRPCProvider: ff.authRPCProvider, + userRPCProvider: ff.userRPCProvider, + configProvider: ff.configProvider, + } + + got, err := app.GetPrompt(caseData.args.ctx, caseData.args.request) + unittest.AssertErrorEqual(t, caseData.wantErr, err) + if err == nil { + assert.Equal(t, caseData.want, got) + } + }) + } +} + +func TestPromptManageApplicationImpl_ListPrompt(t *testing.T) { + type fields struct { + manageRepo repo.IManageRepo + promptService service.IPromptService + authRPCProvider rpc.IAuthProvider + userRPCProvider rpc.IUserProvider + auditRPCProvider rpc.IAuditProvider + configProvider conf.IConfigProvider + } + type args struct { + ctx context.Context + request *manage.ListPromptRequest + } + now := time.Now() + tests := []struct { + name string + fieldsGetter func(ctrl *gomock.Controller) fields + args args + want *manage.ListPromptResponse + wantErr error + }{ + { + name: "user not found", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{} + }, + args: args{ + ctx: context.Background(), + request: &manage.ListPromptRequest{ + WorkspaceID: ptr.Of(int64(100)), + PageNum: ptr.Of(int32(1)), + PageSize: ptr.Of(int32(10)), }, }, - wantErr: nil, + want: manage.NewListPromptResponse(), + wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("User not found")), }, { - name: "get prompt with default config false", + name: "permission check error", fieldsGetter: func(ctrl *gomock.Controller) fields { - mockRepo := repomocks.NewMockIManageRepo(ctrl) - mockRepo.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{ - PromptID: 1, - WithDraft: true, - UserID: "123", - }).Return(&entity.Prompt{ - ID: 1, - SpaceID: 100, - PromptKey: "test_key", - PromptBasic: &entity.PromptBasic{ - DisplayName: "test_name", - Description: "test_description", - LatestVersion: "1.0.0", - CreatedBy: "test_creator", - UpdatedBy: "test_updater", - CreatedAt: now, - UpdatedAt: now, - LatestCommittedAt: nil, - }, - PromptDraft: &entity.PromptDraft{ - PromptDetail: &entity.PromptDetail{ - PromptTemplate: &entity.PromptTemplate{ - TemplateType: entity.TemplateTypeNormal, - Messages: []*entity.Message{ - { - Role: entity.RoleUser, - Content: ptr.Of("test content"), - }, - }, - }, - }, - DraftInfo: &entity.DraftInfo{ - UserID: "123", - BaseVersion: "1.0.0", - IsModified: true, - CreatedAt: now, - UpdatedAt: now, - }, - }, - }, nil) - mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceListLoopPrompt).Return(errorx.New("permission denied")) return fields{ - manageRepo: mockRepo, authRPCProvider: mockAuth, } }, args: args{ ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.GetPromptRequest{ - PromptID: ptr.Of(int64(1)), - WithDraft: ptr.Of(true), - WithDefaultConfig: ptr.Of(false), + request: &manage.ListPromptRequest{ + WorkspaceID: ptr.Of(int64(100)), + PageNum: ptr.Of(int32(1)), + PageSize: ptr.Of(int32(10)), }, }, - want: &manage.GetPromptResponse{ - Prompt: &prompt.Prompt{ - ID: ptr.Of(int64(1)), - WorkspaceID: ptr.Of(int64(100)), - PromptKey: ptr.Of("test_key"), - PromptBasic: &prompt.PromptBasic{ - DisplayName: ptr.Of("test_name"), - Description: ptr.Of("test_description"), - LatestVersion: ptr.Of("1.0.0"), - CreatedBy: ptr.Of("test_creator"), - UpdatedBy: ptr.Of("test_updater"), - CreatedAt: ptr.Of(now.UnixMilli()), - UpdatedAt: ptr.Of(now.UnixMilli()), - LatestCommittedAt: nil, - }, - PromptDraft: &prompt.PromptDraft{ - Detail: &prompt.PromptDetail{ - PromptTemplate: &prompt.PromptTemplate{ - TemplateType: ptr.Of(prompt.TemplateTypeNormal), - Messages: []*prompt.Message{ - { - Role: ptr.Of(prompt.RoleUser), - Content: ptr.Of("test content"), - }, - }, - }, - }, - DraftInfo: &prompt.DraftInfo{ - UserID: ptr.Of("123"), - BaseVersion: ptr.Of("1.0.0"), - IsModified: ptr.Of(true), - CreatedAt: ptr.Of(now.UnixMilli()), - UpdatedAt: ptr.Of(now.UnixMilli()), - }, - }, - }, - }, - wantErr: nil, + want: manage.NewListPromptResponse(), + wantErr: errorx.New("permission denied"), }, { - name: "config provider error", + name: "list prompt with committed only true", fieldsGetter: func(ctrl *gomock.Controller) fields { mockRepo := repomocks.NewMockIManageRepo(ctrl) - mockRepo.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{ - PromptID: 1, - WithDraft: true, - UserID: "123", - }).Return(&entity.Prompt{ - ID: 1, - SpaceID: 100, - PromptKey: "test_key", - PromptBasic: &entity.PromptBasic{ - DisplayName: "test_name", - Description: "test_description", - LatestVersion: "1.0.0", - CreatedBy: "test_creator", - UpdatedBy: "test_updater", - CreatedAt: now, - UpdatedAt: now, - LatestCommittedAt: nil, - }, - PromptDraft: &entity.PromptDraft{ - PromptDetail: &entity.PromptDetail{ - PromptTemplate: &entity.PromptTemplate{ - TemplateType: entity.TemplateTypeNormal, - Messages: []*entity.Message{ - { - Role: entity.RoleUser, - Content: ptr.Of("test content"), - }, - }, + mockRepo.EXPECT().ListPrompt(gomock.Any(), repo.ListPromptParam{ + SpaceID: 100, + UserID: "123", + CommittedOnly: true, + FilterPromptTypes: []entity.PromptType{prompt.PromptTypeNormal}, + PageNum: 1, + PageSize: 10, + OrderBy: mysql.ListPromptBasicOrderByID, + Asc: false, + }).Return(&repo.ListPromptResult{ + Total: 1, + PromptDOs: []*entity.Prompt{ + { + ID: 1, + SpaceID: 100, + PromptKey: "test_key", + PromptBasic: &entity.PromptBasic{ + DisplayName: "test_name", + Description: "test_description", + LatestVersion: "1.0.0", + CreatedBy: "test_creator", + UpdatedBy: "test_updater", + CreatedAt: now, + UpdatedAt: now, + LatestCommittedAt: &now, + PromptType: entity.PromptTypeNormal, }, }, - DraftInfo: &entity.DraftInfo{ - UserID: "123", - BaseVersion: "1.0.0", - IsModified: true, - CreatedAt: now, - UpdatedAt: now, - }, }, }, nil) mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceListLoopPrompt).Return(nil) - mockConfigProvider := confmocks.NewMockIConfigProvider(ctrl) - mockConfigProvider.EXPECT().GetPromptDefaultConfig(gomock.Any()).Return(nil, errorx.New("config provider error")) + mockUser := mocks.NewMockIUserProvider(ctrl) + mockUser.EXPECT().MGetUserInfo(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, ids []string) ([]*rpc.UserInfo, error) { + assert.ElementsMatch(t, []string{"test_creator", "test_updater"}, ids) + return []*rpc.UserInfo{ + {UserID: "test_creator", UserName: "Test Creator"}, + {UserID: "test_updater", UserName: "Test Updater"}, + }, nil + }) return fields{ manageRepo: mockRepo, authRPCProvider: mockAuth, - configProvider: mockConfigProvider, + userRPCProvider: mockUser, } }, args: args{ ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.GetPromptRequest{ - PromptID: ptr.Of(int64(1)), - WithDraft: ptr.Of(true), - WithDefaultConfig: ptr.Of(true), - }, - }, - want: manage.NewGetPromptResponse(), - wantErr: errorx.New("config provider error"), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - ttFields := tt.fieldsGetter(ctrl) - - d := &PromptManageApplicationImpl{ - manageRepo: ttFields.manageRepo, - promptService: ttFields.promptService, - authRPCProvider: ttFields.authRPCProvider, - userRPCProvider: ttFields.userRPCProvider, - configProvider: ttFields.configProvider, - } - - got, err := d.GetPrompt(tt.args.ctx, tt.args.request) - unittest.AssertErrorEqual(t, tt.wantErr, err) - if err == nil { - assert.Equal(t, tt.want, got) - } - }) - } -} - -func TestPromptManageApplicationImpl_RevertDraftFromCommit(t *testing.T) { - type fields struct { - manageRepo repo.IManageRepo - promptService service.IPromptService - authRPCProvider rpc.IAuthProvider - userRPCProvider rpc.IUserProvider - auditRPCProvider rpc.IAuditProvider - configProvider conf.IConfigProvider - } - type args struct { - ctx context.Context - request *manage.RevertDraftFromCommitRequest - } - now := time.Now() - tests := []struct { - name string - fieldsGetter func(ctrl *gomock.Controller) fields - args args - want *manage.RevertDraftFromCommitResponse - wantErr error - }{ - { - name: "user not found", - fieldsGetter: func(ctrl *gomock.Controller) fields { - return fields{} - }, - args: args{ - ctx: context.Background(), - request: &manage.RevertDraftFromCommitRequest{ - PromptID: ptr.Of(int64(1)), - CommitVersionRevertingFrom: ptr.Of("1.0.0"), + request: &manage.ListPromptRequest{ + WorkspaceID: ptr.Of(int64(100)), + CommittedOnly: ptr.Of(true), + PageNum: ptr.Of(int32(1)), + PageSize: ptr.Of(int32(10)), }, }, - want: manage.NewRevertDraftFromCommitResponse(), - wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("User not found")), - }, - { - name: "get prompt error", - fieldsGetter: func(ctrl *gomock.Controller) fields { - mockManageRepo := repomocks.NewMockIManageRepo(ctrl) - mockManageRepo.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{ - PromptID: 1, - WithCommit: true, - CommitVersion: "1.0.0", - }).Return(nil, errorx.New("get prompt error")) - - return fields{ - manageRepo: mockManageRepo, - } - }, - args: args{ - ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.RevertDraftFromCommitRequest{ - PromptID: ptr.Of(int64(1)), - CommitVersionRevertingFrom: ptr.Of("1.0.0"), + want: &manage.ListPromptResponse{ + Total: ptr.Of(int32(1)), + Prompts: []*prompt.Prompt{ + { + ID: ptr.Of(int64(1)), + WorkspaceID: ptr.Of(int64(100)), + PromptKey: ptr.Of("test_key"), + PromptBasic: &prompt.PromptBasic{ + DisplayName: ptr.Of("test_name"), + Description: ptr.Of("test_description"), + LatestVersion: ptr.Of("1.0.0"), + CreatedBy: ptr.Of("test_creator"), + UpdatedBy: ptr.Of("test_updater"), + CreatedAt: ptr.Of(now.UnixMilli()), + UpdatedAt: ptr.Of(now.UnixMilli()), + LatestCommittedAt: ptr.Of(now.UnixMilli()), + PromptType: ptr.Of(prompt.PromptTypeNormal), + }, + }, }, - }, - want: manage.NewRevertDraftFromCommitResponse(), - wantErr: errorx.New("get prompt error"), - }, - { - name: "prompt or commit not found", - fieldsGetter: func(ctrl *gomock.Controller) fields { - mockManageRepo := repomocks.NewMockIManageRepo(ctrl) - mockManageRepo.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{ - PromptID: 1, - WithCommit: true, - CommitVersion: "1.0.0", - }).Return(&entity.Prompt{ - ID: 1, - SpaceID: 100, - PromptKey: "test_key", - PromptBasic: &entity.PromptBasic{ - DisplayName: "test_name", - Description: "test_description", - LatestVersion: "1.0.0", - CreatedBy: "test_creator", - UpdatedBy: "test_updater", - CreatedAt: now, - UpdatedAt: now, - LatestCommittedAt: nil, + Users: []*user.UserInfoDetail{ + { + UserID: ptr.Of("test_creator"), + Name: ptr.Of("Test Creator"), + NickName: ptr.Of(""), + AvatarURL: ptr.Of(""), + Email: ptr.Of(""), + Mobile: ptr.Of(""), + }, + { + UserID: ptr.Of("test_updater"), + Name: ptr.Of("Test Updater"), + NickName: ptr.Of(""), + AvatarURL: ptr.Of(""), + Email: ptr.Of(""), + Mobile: ptr.Of(""), }, - }, nil) - - return fields{ - manageRepo: mockManageRepo, - } - }, - args: args{ - ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.RevertDraftFromCommitRequest{ - PromptID: ptr.Of(int64(1)), - CommitVersionRevertingFrom: ptr.Of("1.0.0"), }, }, - want: manage.NewRevertDraftFromCommitResponse(), - wantErr: errorx.New("Prompt or commit not found, prompt id = 1, commit version = 1.0.0"), + wantErr: nil, }, { - name: "save draft error", + name: "list prompt with committed only false", fieldsGetter: func(ctrl *gomock.Controller) fields { - mockManageRepo := repomocks.NewMockIManageRepo(ctrl) - mockManageRepo.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{ - PromptID: 1, - WithCommit: true, - CommitVersion: "1.0.0", - }).Return(&entity.Prompt{ - ID: 1, - SpaceID: 100, - PromptKey: "test_key", - PromptBasic: &entity.PromptBasic{ - DisplayName: "test_name", - Description: "test_description", - LatestVersion: "1.0.0", - CreatedBy: "test_creator", - UpdatedBy: "test_updater", - CreatedAt: now, - UpdatedAt: now, - LatestCommittedAt: nil, - }, - PromptCommit: &entity.PromptCommit{ - PromptDetail: &entity.PromptDetail{ - PromptTemplate: &entity.PromptTemplate{ - TemplateType: entity.TemplateTypeNormal, - Messages: []*entity.Message{ - { - Role: entity.RoleUser, - Content: ptr.Of("test content"), - }, - }, + mockRepo := repomocks.NewMockIManageRepo(ctrl) + mockRepo.EXPECT().ListPrompt(gomock.Any(), repo.ListPromptParam{ + SpaceID: 100, + UserID: "123", + CommittedOnly: false, + FilterPromptTypes: []entity.PromptType{entity.PromptTypeNormal}, + PageNum: 1, + PageSize: 10, + OrderBy: mysql.ListPromptBasicOrderByID, + Asc: false, + }).Return(&repo.ListPromptResult{ + Total: 2, + PromptDOs: []*entity.Prompt{ + { + ID: 1, + SpaceID: 100, + PromptKey: "test_key_1", + PromptBasic: &entity.PromptBasic{ + DisplayName: "test_name_1", + Description: "test_description_1", + LatestVersion: "1.0.0", + CreatedBy: "test_creator", + UpdatedBy: "test_updater", + CreatedAt: now, + UpdatedAt: now, + LatestCommittedAt: &now, + PromptType: entity.PromptTypeNormal, }, }, - CommitInfo: &entity.CommitInfo{ - Version: "1.0.0", - BaseVersion: "0.9.0", - Description: "test commit", - CommittedBy: "test_user", - CommittedAt: now, + { + ID: 2, + SpaceID: 100, + PromptKey: "test_key_2", + PromptBasic: &entity.PromptBasic{ + DisplayName: "test_name_2", + Description: "test_description_2", + LatestVersion: "", + CreatedBy: "test_creator", + UpdatedBy: "test_updater", + CreatedAt: now, + UpdatedAt: now, + LatestCommittedAt: nil, + }, }, }, }, nil) mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceListLoopPrompt).Return(nil) - mockManageRepo.EXPECT().SaveDraft(gomock.Any(), gomock.Any()).Return(nil, errorx.New("save draft error")) + mockUser := mocks.NewMockIUserProvider(ctrl) + mockUser.EXPECT().MGetUserInfo(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, ids []string) ([]*rpc.UserInfo, error) { + assert.ElementsMatch(t, []string{"test_creator", "test_updater"}, ids) + return []*rpc.UserInfo{ + {UserID: "test_creator", UserName: "Test Creator"}, + {UserID: "test_updater", UserName: "Test Updater"}, + }, nil + }) return fields{ - manageRepo: mockManageRepo, + manageRepo: mockRepo, authRPCProvider: mockAuth, + userRPCProvider: mockUser, } }, args: args{ ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.RevertDraftFromCommitRequest{ - PromptID: ptr.Of(int64(1)), - CommitVersionRevertingFrom: ptr.Of("1.0.0"), - }, - }, - want: manage.NewRevertDraftFromCommitResponse(), - wantErr: errorx.New("save draft error"), - }, + request: &manage.ListPromptRequest{ + WorkspaceID: ptr.Of(int64(100)), + CommittedOnly: ptr.Of(false), + PageNum: ptr.Of(int32(1)), + PageSize: ptr.Of(int32(10)), + }, + }, + want: &manage.ListPromptResponse{ + Total: ptr.Of(int32(2)), + Prompts: []*prompt.Prompt{ + { + ID: ptr.Of(int64(1)), + WorkspaceID: ptr.Of(int64(100)), + PromptKey: ptr.Of("test_key_1"), + PromptBasic: &prompt.PromptBasic{ + DisplayName: ptr.Of("test_name_1"), + Description: ptr.Of("test_description_1"), + LatestVersion: ptr.Of("1.0.0"), + CreatedBy: ptr.Of("test_creator"), + UpdatedBy: ptr.Of("test_updater"), + CreatedAt: ptr.Of(now.UnixMilli()), + UpdatedAt: ptr.Of(now.UnixMilli()), + LatestCommittedAt: ptr.Of(now.UnixMilli()), + PromptType: ptr.Of(prompt.PromptTypeNormal), + }, + }, + { + ID: ptr.Of(int64(2)), + WorkspaceID: ptr.Of(int64(100)), + PromptKey: ptr.Of("test_key_2"), + PromptBasic: &prompt.PromptBasic{ + DisplayName: ptr.Of("test_name_2"), + Description: ptr.Of("test_description_2"), + LatestVersion: ptr.Of(""), + CreatedBy: ptr.Of("test_creator"), + UpdatedBy: ptr.Of("test_updater"), + CreatedAt: ptr.Of(now.UnixMilli()), + UpdatedAt: ptr.Of(now.UnixMilli()), + PromptType: ptr.Of(prompt.PromptTypeNormal), + }, + }, + }, + Users: []*user.UserInfoDetail{ + { + UserID: ptr.Of("test_creator"), + Name: ptr.Of("Test Creator"), + NickName: ptr.Of(""), + AvatarURL: ptr.Of(""), + Email: ptr.Of(""), + Mobile: ptr.Of(""), + }, + { + UserID: ptr.Of("test_updater"), + Name: ptr.Of("Test Updater"), + NickName: ptr.Of(""), + AvatarURL: ptr.Of(""), + Email: ptr.Of(""), + Mobile: ptr.Of(""), + }, + }, + }, + wantErr: nil, + }, { - name: "success", + name: "list prompt with user draft association", fieldsGetter: func(ctrl *gomock.Controller) fields { - mockManageRepo := repomocks.NewMockIManageRepo(ctrl) - mockManageRepo.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{ - PromptID: 1, - WithCommit: true, - CommitVersion: "1.0.0", - }).Return(&entity.Prompt{ - ID: 1, - SpaceID: 100, - PromptKey: "test_key", - PromptBasic: &entity.PromptBasic{ - DisplayName: "test_name", - Description: "test_description", - LatestVersion: "1.0.0", - CreatedBy: "test_creator", - UpdatedBy: "test_updater", - CreatedAt: now, - UpdatedAt: now, - LatestCommittedAt: nil, - }, - PromptCommit: &entity.PromptCommit{ - PromptDetail: &entity.PromptDetail{ - PromptTemplate: &entity.PromptTemplate{ - TemplateType: entity.TemplateTypeNormal, - Messages: []*entity.Message{ - { - Role: entity.RoleUser, - Content: ptr.Of("test content"), + mockRepo := repomocks.NewMockIManageRepo(ctrl) + mockRepo.EXPECT().ListPrompt(gomock.Any(), repo.ListPromptParam{ + SpaceID: 100, + UserID: "123", + KeyWord: "draft", + FilterPromptTypes: []entity.PromptType{entity.PromptTypeNormal}, + PageNum: 1, + PageSize: 10, + OrderBy: mysql.ListPromptBasicOrderByID, + Asc: false, + }).Return(&repo.ListPromptResult{ + Total: 1, + PromptDOs: []*entity.Prompt{ + { + ID: 1, + SpaceID: 100, + PromptKey: "test_key", + PromptBasic: &entity.PromptBasic{ + DisplayName: "test_name", + Description: "test_description", + LatestVersion: "1.0.0", + CreatedBy: "test_creator", + UpdatedBy: "test_updater", + CreatedAt: now, + UpdatedAt: now, + LatestCommittedAt: &now, + PromptType: entity.PromptTypeNormal, + }, + PromptDraft: &entity.PromptDraft{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + TemplateType: entity.TemplateTypeNormal, + Messages: []*entity.Message{ + { + Role: entity.RoleUser, + Content: ptr.Of("draft content"), + }, + }, + HasSnippets: false, }, }, + DraftInfo: &entity.DraftInfo{ + UserID: "123", + BaseVersion: "1.0.0", + IsModified: true, + CreatedAt: now, + UpdatedAt: now, + }, }, }, - CommitInfo: &entity.CommitInfo{ - Version: "1.0.0", - BaseVersion: "0.9.0", - Description: "test commit", - CommittedBy: "test_user", - CommittedAt: now, - }, }, }, nil) mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceListLoopPrompt).Return(nil) - mockManageRepo.EXPECT().SaveDraft(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, promptDO *entity.Prompt) (*entity.DraftInfo, error) { - assert.Equal(t, int64(1), promptDO.ID) - assert.Equal(t, "123", promptDO.PromptDraft.DraftInfo.UserID) - assert.Equal(t, "1.0.0", promptDO.PromptDraft.DraftInfo.BaseVersion) - assert.Equal(t, entity.TemplateTypeNormal, promptDO.PromptDraft.PromptDetail.PromptTemplate.TemplateType) - assert.Equal(t, 1, len(promptDO.PromptDraft.PromptDetail.PromptTemplate.Messages)) - assert.Equal(t, entity.RoleUser, promptDO.PromptDraft.PromptDetail.PromptTemplate.Messages[0].Role) - assert.Equal(t, "test content", *promptDO.PromptDraft.PromptDetail.PromptTemplate.Messages[0].Content) - return &entity.DraftInfo{}, nil + mockUser := mocks.NewMockIUserProvider(ctrl) + mockUser.EXPECT().MGetUserInfo(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, ids []string) ([]*rpc.UserInfo, error) { + assert.ElementsMatch(t, []string{"test_creator", "test_updater"}, ids) + return []*rpc.UserInfo{ + {UserID: "test_creator", UserName: "Test Creator"}, + {UserID: "test_updater", UserName: "Test Updater"}, + }, nil }) return fields{ - manageRepo: mockManageRepo, + manageRepo: mockRepo, authRPCProvider: mockAuth, + userRPCProvider: mockUser, } }, args: args{ ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.RevertDraftFromCommitRequest{ - PromptID: ptr.Of(int64(1)), - CommitVersionRevertingFrom: ptr.Of("1.0.0"), + request: &manage.ListPromptRequest{ + WorkspaceID: ptr.Of(int64(100)), + KeyWord: ptr.Of("draft"), + PageNum: ptr.Of(int32(1)), + PageSize: ptr.Of(int32(10)), }, }, - want: manage.NewRevertDraftFromCommitResponse(), - wantErr: nil, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - ttFields := tt.fieldsGetter(ctrl) - - app := &PromptManageApplicationImpl{ - manageRepo: ttFields.manageRepo, - promptService: ttFields.promptService, - authRPCProvider: ttFields.authRPCProvider, - userRPCProvider: ttFields.userRPCProvider, - auditRPCProvider: ttFields.auditRPCProvider, - configProvider: ttFields.configProvider, - } - - got, err := app.RevertDraftFromCommit(tt.args.ctx, tt.args.request) - unittest.AssertErrorEqual(t, tt.wantErr, err) - if err == nil { - assert.Equal(t, tt.want, got) - } - }) - } -} - -func TestPromptManageApplicationImpl_ListCommit(t *testing.T) { - type fields struct { - manageRepo repo.IManageRepo - promptService service.IPromptService - authRPCProvider rpc.IAuthProvider - userRPCProvider rpc.IUserProvider - auditRPCProvider rpc.IAuditProvider - configProvider conf.IConfigProvider - } - type args struct { - ctx context.Context - request *manage.ListCommitRequest - } - now := time.Now() - tests := []struct { - name string - fieldsGetter func(ctrl *gomock.Controller) fields - args args - want *manage.ListCommitResponse - wantErr error - }{ - { - name: "user not found", - fieldsGetter: func(ctrl *gomock.Controller) fields { - return fields{} - }, - args: args{ - ctx: context.Background(), - request: &manage.ListCommitRequest{ - PromptID: ptr.Of(int64(1)), - PageSize: ptr.Of(int32(10)), - PageToken: nil, - Asc: ptr.Of(false), + want: &manage.ListPromptResponse{ + Total: ptr.Of(int32(1)), + Prompts: []*prompt.Prompt{ + { + ID: ptr.Of(int64(1)), + WorkspaceID: ptr.Of(int64(100)), + PromptKey: ptr.Of("test_key"), + PromptBasic: &prompt.PromptBasic{ + DisplayName: ptr.Of("test_name"), + Description: ptr.Of("test_description"), + LatestVersion: ptr.Of("1.0.0"), + CreatedBy: ptr.Of("test_creator"), + UpdatedBy: ptr.Of("test_updater"), + CreatedAt: ptr.Of(now.UnixMilli()), + UpdatedAt: ptr.Of(now.UnixMilli()), + LatestCommittedAt: ptr.Of(now.UnixMilli()), + PromptType: ptr.Of(prompt.PromptTypeNormal), + }, + PromptDraft: &prompt.PromptDraft{ + Detail: &prompt.PromptDetail{ + PromptTemplate: &prompt.PromptTemplate{ + TemplateType: ptr.Of(prompt.TemplateTypeNormal), + Messages: []*prompt.Message{ + { + Role: ptr.Of(prompt.RoleUser), + Content: ptr.Of("draft content"), + }, + }, + HasSnippet: ptr.Of(false), + }, + }, + DraftInfo: &prompt.DraftInfo{ + UserID: ptr.Of("123"), + BaseVersion: ptr.Of("1.0.0"), + IsModified: ptr.Of(true), + CreatedAt: ptr.Of(now.UnixMilli()), + UpdatedAt: ptr.Of(now.UnixMilli()), + }, + }, + }, + }, + Users: []*user.UserInfoDetail{ + { + UserID: ptr.Of("test_creator"), + Name: ptr.Of("Test Creator"), + NickName: ptr.Of(""), + AvatarURL: ptr.Of(""), + Email: ptr.Of(""), + Mobile: ptr.Of(""), + }, + { + UserID: ptr.Of("test_updater"), + Name: ptr.Of("Test Updater"), + NickName: ptr.Of(""), + AvatarURL: ptr.Of(""), + Email: ptr.Of(""), + Mobile: ptr.Of(""), + }, }, }, - want: manage.NewListCommitResponse(), - wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("User not found")), + wantErr: nil, }, { - name: "invalid page token", + name: "list prompt repo error", fieldsGetter: func(ctrl *gomock.Controller) fields { - mockManageRepo := repomocks.NewMockIManageRepo(ctrl) - mockManageRepo.EXPECT().GetPrompt(gomock.Any(), gomock.Any()).Return(&entity.Prompt{ID: 1}, nil) + mockRepo := repomocks.NewMockIManageRepo(ctrl) + mockRepo.EXPECT().ListPrompt(gomock.Any(), repo.ListPromptParam{ + SpaceID: 100, + UserID: "123", + FilterPromptTypes: []entity.PromptType{entity.PromptTypeNormal}, + PageNum: 1, + PageSize: 10, + OrderBy: mysql.ListPromptBasicOrderByID, + Asc: false, + }).Return(nil, errorx.New("list prompt error")) mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceListLoopPrompt).Return(nil) + return fields{ - manageRepo: mockManageRepo, + manageRepo: mockRepo, authRPCProvider: mockAuth, } }, args: args{ ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.ListCommitRequest{ - PromptID: ptr.Of(int64(1)), - PageSize: ptr.Of(int32(10)), - PageToken: ptr.Of("invalid"), - Asc: ptr.Of(false), + request: &manage.ListPromptRequest{ + WorkspaceID: ptr.Of(int64(100)), + PageNum: ptr.Of(int32(1)), + PageSize: ptr.Of(int32(10)), }, }, - want: manage.NewListCommitResponse(), - wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("Page token is invalid, page token = invalid")), - }, - { - name: "list commit error", - fieldsGetter: func(ctrl *gomock.Controller) fields { - mockManageRepo := repomocks.NewMockIManageRepo(ctrl) - mockManageRepo.EXPECT().GetPrompt(gomock.Any(), gomock.Any()).Return(&entity.Prompt{ID: 1}, nil) - mockManageRepo.EXPECT().ListCommitInfo(gomock.Any(), repo.ListCommitInfoParam{ - PromptID: 1, - PageSize: 10, - PageToken: nil, - Asc: false, - }).Return(nil, errorx.New("list commit error")) - - mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - - return fields{ - manageRepo: mockManageRepo, - authRPCProvider: mockAuth, - } - }, - args: args{ - ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.ListCommitRequest{ - PromptID: ptr.Of(int64(1)), - PageSize: ptr.Of(int32(10)), - PageToken: nil, - Asc: ptr.Of(false), - }, - }, - want: manage.NewListCommitResponse(), - wantErr: errorx.New("list commit error"), - }, - { - name: "empty result", - fieldsGetter: func(ctrl *gomock.Controller) fields { - mockManageRepo := repomocks.NewMockIManageRepo(ctrl) - mockManageRepo.EXPECT().GetPrompt(gomock.Any(), gomock.Any()).Return(&entity.Prompt{ID: 1}, nil) - mockManageRepo.EXPECT().ListCommitInfo(gomock.Any(), repo.ListCommitInfoParam{ - PromptID: 1, - PageSize: 10, - PageToken: nil, - Asc: false, - }).Return(nil, nil) - - mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - - return fields{ - manageRepo: mockManageRepo, - authRPCProvider: mockAuth, - } - }, - args: args{ - ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.ListCommitRequest{ - PromptID: ptr.Of(int64(1)), - PageSize: ptr.Of(int32(10)), - PageToken: nil, - Asc: ptr.Of(false), - }, - }, - want: manage.NewListCommitResponse(), - wantErr: nil, + want: manage.NewListPromptResponse(), + wantErr: errorx.New("list prompt error"), }, { - name: "single page result", + name: "list prompt with snippet type filter", fieldsGetter: func(ctrl *gomock.Controller) fields { - mockManageRepo := repomocks.NewMockIManageRepo(ctrl) - mockManageRepo.EXPECT().GetPrompt(gomock.Any(), gomock.Any()).Return(&entity.Prompt{ID: 1}, nil) - mockManageRepo.EXPECT().ListCommitInfo(gomock.Any(), repo.ListCommitInfoParam{ - PromptID: 1, - PageSize: 10, - PageToken: nil, - Asc: false, - }).Return(&repo.ListCommitResult{ - CommitInfoDOs: []*entity.CommitInfo{ - { - Version: "1.0.0", - BaseVersion: "0.9.0", - Description: "test commit 1", - CommittedBy: "test_user", - CommittedAt: now, - }, + mockRepo := repomocks.NewMockIManageRepo(ctrl) + mockRepo.EXPECT().ListPrompt(gomock.Any(), repo.ListPromptParam{ + SpaceID: 100, + UserID: "123", + FilterPromptTypes: []entity.PromptType{entity.PromptTypeSnippet}, + PageNum: 1, + PageSize: 10, + OrderBy: mysql.ListPromptBasicOrderByID, + Asc: false, + }).Return(&repo.ListPromptResult{ + Total: 1, + PromptDOs: []*entity.Prompt{ { - Version: "1.1.0", - BaseVersion: "1.0.0", - Description: "test commit 2", - CommittedBy: "test_user", - CommittedAt: now, + ID: 1, + SpaceID: 100, + PromptKey: "snippet_key", + PromptBasic: &entity.PromptBasic{ + DisplayName: "snippet_name", + Description: "snippet_description", + LatestVersion: "1.0.0", + CreatedBy: "test_creator", + UpdatedBy: "test_updater", + CreatedAt: now, + UpdatedAt: now, + LatestCommittedAt: &now, + PromptType: entity.PromptTypeSnippet, + }, }, }, }, nil) mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceListLoopPrompt).Return(nil) mockUser := mocks.NewMockIUserProvider(ctrl) - mockUser.EXPECT().MGetUserInfo(gomock.Any(), []string{"test_user"}).Return([]*rpc.UserInfo{ - { - UserID: "test_user", - UserName: "Test User", - }, - }, nil) - - mockPromptService := servicemocks.NewMockIPromptService(ctrl) - mockPromptService.EXPECT().BatchGetCommitLabels(gomock.Any(), int64(1), []string{"1.0.0", "1.1.0"}).Return(map[string][]string{}, nil) + mockUser.EXPECT().MGetUserInfo(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, ids []string) ([]*rpc.UserInfo, error) { + assert.ElementsMatch(t, []string{"test_creator", "test_updater"}, ids) + return []*rpc.UserInfo{ + {UserID: "test_creator", UserName: "Test Creator"}, + {UserID: "test_updater", UserName: "Test Updater"}, + }, nil + }) return fields{ - manageRepo: mockManageRepo, - promptService: mockPromptService, + manageRepo: mockRepo, authRPCProvider: mockAuth, userRPCProvider: mockUser, } }, args: args{ ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.ListCommitRequest{ - PromptID: ptr.Of(int64(1)), - PageSize: ptr.Of(int32(10)), - PageToken: nil, - Asc: ptr.Of(false), + request: &manage.ListPromptRequest{ + WorkspaceID: ptr.Of(int64(100)), + FilterPromptTypes: []prompt.PromptType{prompt.PromptTypeSnippet}, + PageNum: ptr.Of(int32(1)), + PageSize: ptr.Of(int32(10)), }, }, - want: &manage.ListCommitResponse{ - PromptCommitInfos: []*prompt.CommitInfo{ - { - Version: ptr.Of("1.0.0"), - BaseVersion: ptr.Of("0.9.0"), - Description: ptr.Of("test commit 1"), - CommittedBy: ptr.Of("test_user"), - CommittedAt: ptr.Of(now.UnixMilli()), - }, - { - Version: ptr.Of("1.1.0"), - BaseVersion: ptr.Of("1.0.0"), - Description: ptr.Of("test commit 2"), - CommittedBy: ptr.Of("test_user"), - CommittedAt: ptr.Of(now.UnixMilli()), - }, - }, - CommitVersionLabelMapping: map[string][]*prompt.Label{}, - Users: []*user.UserInfoDetail{ + want: &manage.ListPromptResponse{ + Total: ptr.Of(int32(1)), + Prompts: []*prompt.Prompt{ { - UserID: ptr.Of("test_user"), - Name: ptr.Of("Test User"), - NickName: ptr.Of(""), - AvatarURL: ptr.Of(""), - Email: ptr.Of(""), - Mobile: ptr.Of(""), - }, - }, - }, - wantErr: nil, - }, - { - name: "multiple pages result", - fieldsGetter: func(ctrl *gomock.Controller) fields { - mockManageRepo := repomocks.NewMockIManageRepo(ctrl) - mockManageRepo.EXPECT().GetPrompt(gomock.Any(), gomock.Any()).Return(&entity.Prompt{ID: 1}, nil) - mockManageRepo.EXPECT().ListCommitInfo(gomock.Any(), repo.ListCommitInfoParam{ - PromptID: 1, - PageSize: 2, - PageToken: nil, - Asc: false, - }).Return(&repo.ListCommitResult{ - CommitInfoDOs: []*entity.CommitInfo{ - { - Version: "1.0.0", - BaseVersion: "0.9.0", - Description: "test commit 1", - CommittedBy: "test_user", - CommittedAt: now, - }, - { - Version: "1.1.0", - BaseVersion: "1.0.0", - Description: "test commit 2", - CommittedBy: "test_user", - CommittedAt: now, + ID: ptr.Of(int64(1)), + WorkspaceID: ptr.Of(int64(100)), + PromptKey: ptr.Of("snippet_key"), + PromptBasic: &prompt.PromptBasic{ + DisplayName: ptr.Of("snippet_name"), + Description: ptr.Of("snippet_description"), + LatestVersion: ptr.Of("1.0.0"), + CreatedBy: ptr.Of("test_creator"), + UpdatedBy: ptr.Of("test_updater"), + CreatedAt: ptr.Of(now.UnixMilli()), + UpdatedAt: ptr.Of(now.UnixMilli()), + LatestCommittedAt: ptr.Of(now.UnixMilli()), + PromptType: ptr.Of(prompt.PromptTypeSnippet), }, }, - NextPageToken: 3, - }, nil) - - mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - - mockUser := mocks.NewMockIUserProvider(ctrl) - mockUser.EXPECT().MGetUserInfo(gomock.Any(), []string{"test_user"}).Return([]*rpc.UserInfo{ - { - UserID: "test_user", - UserName: "Test User", - }, - }, nil) - - mockPromptService := servicemocks.NewMockIPromptService(ctrl) - mockPromptService.EXPECT().BatchGetCommitLabels(gomock.Any(), int64(1), []string{"1.0.0", "1.1.0"}).Return(map[string][]string{}, nil) - - return fields{ - manageRepo: mockManageRepo, - promptService: mockPromptService, - authRPCProvider: mockAuth, - userRPCProvider: mockUser, - } - }, - args: args{ - ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.ListCommitRequest{ - PromptID: ptr.Of(int64(1)), - PageSize: ptr.Of(int32(2)), - PageToken: nil, - Asc: ptr.Of(false), - }, - }, - want: &manage.ListCommitResponse{ - PromptCommitInfos: []*prompt.CommitInfo{ - { - Version: ptr.Of("1.0.0"), - BaseVersion: ptr.Of("0.9.0"), - Description: ptr.Of("test commit 1"), - CommittedBy: ptr.Of("test_user"), - CommittedAt: ptr.Of(now.UnixMilli()), - }, - { - Version: ptr.Of("1.1.0"), - BaseVersion: ptr.Of("1.0.0"), - Description: ptr.Of("test commit 2"), - CommittedBy: ptr.Of("test_user"), - CommittedAt: ptr.Of(now.UnixMilli()), - }, }, - CommitVersionLabelMapping: map[string][]*prompt.Label{}, - HasMore: ptr.Of(true), - NextPageToken: ptr.Of("3"), Users: []*user.UserInfoDetail{ { - UserID: ptr.Of("test_user"), - Name: ptr.Of("Test User"), + UserID: ptr.Of("test_creator"), + Name: ptr.Of("Test Creator"), NickName: ptr.Of(""), AvatarURL: ptr.Of(""), Email: ptr.Of(""), Mobile: ptr.Of(""), }, - }, - }, - wantErr: nil, - }, - { - name: "with page token and asc", - fieldsGetter: func(ctrl *gomock.Controller) fields { - mockManageRepo := repomocks.NewMockIManageRepo(ctrl) - mockManageRepo.EXPECT().GetPrompt(gomock.Any(), gomock.Any()).Return(&entity.Prompt{ID: 1}, nil) - mockManageRepo.EXPECT().ListCommitInfo(gomock.Any(), repo.ListCommitInfoParam{ - PromptID: 1, - PageSize: 10, - PageToken: ptr.Of(int64(2)), - Asc: true, - }).Return(&repo.ListCommitResult{ - CommitInfoDOs: []*entity.CommitInfo{ - { - Version: "1.2.0", - BaseVersion: "1.1.0", - Description: "test commit 3", - CommittedBy: "test_user", - CommittedAt: now, - }, - { - Version: "1.3.0", - BaseVersion: "1.2.0", - Description: "test commit 4", - CommittedBy: "test_user", - CommittedAt: now, - }, - }, - }, nil) - - mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - - mockUser := mocks.NewMockIUserProvider(ctrl) - mockUser.EXPECT().MGetUserInfo(gomock.Any(), []string{"test_user"}).Return([]*rpc.UserInfo{ - { - UserID: "test_user", - UserName: "Test User", - }, - }, nil) - - mockPromptService := servicemocks.NewMockIPromptService(ctrl) - mockPromptService.EXPECT().BatchGetCommitLabels(gomock.Any(), int64(1), []string{"1.2.0", "1.3.0"}).Return(map[string][]string{}, nil) - - return fields{ - manageRepo: mockManageRepo, - promptService: mockPromptService, - authRPCProvider: mockAuth, - userRPCProvider: mockUser, - } - }, - args: args{ - ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.ListCommitRequest{ - PromptID: ptr.Of(int64(1)), - PageSize: ptr.Of(int32(10)), - PageToken: ptr.Of("2"), - Asc: ptr.Of(true), - }, - }, - want: &manage.ListCommitResponse{ - PromptCommitInfos: []*prompt.CommitInfo{ - { - Version: ptr.Of("1.2.0"), - BaseVersion: ptr.Of("1.1.0"), - Description: ptr.Of("test commit 3"), - CommittedBy: ptr.Of("test_user"), - CommittedAt: ptr.Of(now.UnixMilli()), - }, - { - Version: ptr.Of("1.3.0"), - BaseVersion: ptr.Of("1.2.0"), - Description: ptr.Of("test commit 4"), - CommittedBy: ptr.Of("test_user"), - CommittedAt: ptr.Of(now.UnixMilli()), - }, - }, - CommitVersionLabelMapping: map[string][]*prompt.Label{}, - Users: []*user.UserInfoDetail{ { - UserID: ptr.Of("test_user"), - Name: ptr.Of("Test User"), + UserID: ptr.Of("test_updater"), + Name: ptr.Of("Test Updater"), NickName: ptr.Of(""), AvatarURL: ptr.Of(""), Email: ptr.Of(""), @@ -1630,7 +1653,7 @@ func TestPromptManageApplicationImpl_ListCommit(t *testing.T) { configProvider: ttFields.configProvider, } - got, err := app.ListCommit(tt.args.ctx, tt.args.request) + got, err := app.ListPrompt(tt.args.ctx, tt.args.request) unittest.AssertErrorEqual(t, tt.wantErr, err) if err == nil { assert.Equal(t, tt.want, got) @@ -1639,9 +1662,12 @@ func TestPromptManageApplicationImpl_ListCommit(t *testing.T) { } } -func TestPromptManageApplicationImpl_CommitDraft(t *testing.T) { +func TestPromptManageApplicationImpl_CreateLabel(t *testing.T) { + t.Parallel() + type fields struct { manageRepo repo.IManageRepo + labelRepo repo.ILabelRepo promptService service.IPromptService authRPCProvider rpc.IAuthProvider userRPCProvider rpc.IUserProvider @@ -1650,170 +1676,80 @@ func TestPromptManageApplicationImpl_CommitDraft(t *testing.T) { } type args struct { ctx context.Context - request *manage.CommitDraftRequest + request *manage.CreateLabelRequest } tests := []struct { name string fieldsGetter func(ctrl *gomock.Controller) fields args args - want *manage.CommitDraftResponse + want *manage.CreateLabelResponse wantErr error }{ { - name: "user not found", - fieldsGetter: func(ctrl *gomock.Controller) fields { - return fields{} - }, - args: args{ - ctx: context.Background(), - request: &manage.CommitDraftRequest{ - PromptID: ptr.Of(int64(1)), - CommitVersion: ptr.Of("1.0.0"), - }, - }, - want: manage.NewCommitDraftResponse(), - wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("User not found")), - }, - { - name: "invalid version format", + name: "成功创建标签", fieldsGetter: func(ctrl *gomock.Controller) fields { - return fields{} + mockAuth := mocks.NewMockIAuthProvider(ctrl) + mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceCreateLoopPrompt).Return(nil) + + mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().CreateLabel(gomock.Any(), gomock.Any()).Return(nil) + + return fields{ + authRPCProvider: mockAuth, + promptService: mockPromptService, + } }, args: args{ ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.CommitDraftRequest{ - PromptID: ptr.Of(int64(1)), - CommitVersion: ptr.Of("invalid-version"), + request: &manage.CreateLabelRequest{ + WorkspaceID: ptr.Of(int64(100)), + Label: &prompt.Label{ + Key: ptr.Of("test-label"), + }, }, }, - want: manage.NewCommitDraftResponse(), - wantErr: errorx.New("Invalid Semantic Version"), + want: manage.NewCreateLabelResponse(), + wantErr: nil, }, { - name: "get prompt error", + name: "用户未找到", fieldsGetter: func(ctrl *gomock.Controller) fields { - mockRepo := repomocks.NewMockIManageRepo(ctrl) - mockRepo.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{ - PromptID: 1, - }).Return(nil, errorx.New("get prompt error")) - - return fields{ - manageRepo: mockRepo, - } + return fields{} }, args: args{ - ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.CommitDraftRequest{ - PromptID: ptr.Of(int64(1)), - CommitVersion: ptr.Of("1.0.0"), + ctx: context.Background(), + request: &manage.CreateLabelRequest{ + WorkspaceID: ptr.Of(int64(100)), + Label: &prompt.Label{ + Key: ptr.Of("test-label"), + }, }, }, - want: manage.NewCommitDraftResponse(), - wantErr: errorx.New("get prompt error"), + want: manage.NewCreateLabelResponse(), + wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("User not found")), }, { - name: "permission check error", + name: "权限检查失败", fieldsGetter: func(ctrl *gomock.Controller) fields { - mockRepo := repomocks.NewMockIManageRepo(ctrl) - mockRepo.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{ - PromptID: 1, - }).Return(&entity.Prompt{ - ID: 1, - SpaceID: 100, - }, nil) - mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), int64(100), []int64{1}, consts.ActionLoopPromptEdit).Return(errorx.New("permission denied")) + mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceCreateLoopPrompt).Return(errorx.New("permission denied")) return fields{ - manageRepo: mockRepo, authRPCProvider: mockAuth, } }, args: args{ ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.CommitDraftRequest{ - PromptID: ptr.Of(int64(1)), - CommitVersion: ptr.Of("1.0.0"), + request: &manage.CreateLabelRequest{ + WorkspaceID: ptr.Of(int64(100)), + Label: &prompt.Label{ + Key: ptr.Of("test-label"), + }, }, }, - want: manage.NewCommitDraftResponse(), + want: manage.NewCreateLabelResponse(), wantErr: errorx.New("permission denied"), }, - { - name: "commit draft error", - fieldsGetter: func(ctrl *gomock.Controller) fields { - mockRepo := repomocks.NewMockIManageRepo(ctrl) - mockRepo.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{ - PromptID: 1, - }).Return(&entity.Prompt{ - ID: 1, - SpaceID: 100, - }, nil) - - mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), int64(100), []int64{1}, consts.ActionLoopPromptEdit).Return(nil) - - mockRepo.EXPECT().CommitDraft(gomock.Any(), repo.CommitDraftParam{ - PromptID: 1, - UserID: "123", - CommitVersion: "1.0.0", - CommitDescription: "test commit", - }).Return(errorx.New("commit draft error")) - - return fields{ - manageRepo: mockRepo, - authRPCProvider: mockAuth, - } - }, - args: args{ - ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.CommitDraftRequest{ - PromptID: ptr.Of(int64(1)), - CommitVersion: ptr.Of("1.0.0"), - CommitDescription: ptr.Of("test commit"), - }, - }, - want: manage.NewCommitDraftResponse(), - wantErr: errorx.New("commit draft error"), - }, - { - name: "success", - fieldsGetter: func(ctrl *gomock.Controller) fields { - mockRepo := repomocks.NewMockIManageRepo(ctrl) - mockRepo.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{ - PromptID: 1, - }).Return(&entity.Prompt{ - ID: 1, - SpaceID: 100, - }, nil) - - mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), int64(100), []int64{1}, consts.ActionLoopPromptEdit).Return(nil) - - mockRepo.EXPECT().CommitDraft(gomock.Any(), repo.CommitDraftParam{ - PromptID: 1, - UserID: "123", - CommitVersion: "1.0.0", - CommitDescription: "test commit", - }).Return(nil) - - return fields{ - manageRepo: mockRepo, - authRPCProvider: mockAuth, - } - }, - args: args{ - ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.CommitDraftRequest{ - PromptID: ptr.Of(int64(1)), - CommitVersion: ptr.Of("1.0.0"), - CommitDescription: ptr.Of("test commit"), - }, - }, - want: manage.NewCommitDraftResponse(), - wantErr: nil, - }, } for _, tt := range tests { @@ -1825,6 +1761,7 @@ func TestPromptManageApplicationImpl_CommitDraft(t *testing.T) { app := &PromptManageApplicationImpl{ manageRepo: ttFields.manageRepo, + labelRepo: ttFields.labelRepo, promptService: ttFields.promptService, authRPCProvider: ttFields.authRPCProvider, userRPCProvider: ttFields.userRPCProvider, @@ -1832,7 +1769,7 @@ func TestPromptManageApplicationImpl_CommitDraft(t *testing.T) { configProvider: ttFields.configProvider, } - got, err := app.CommitDraft(tt.args.ctx, tt.args.request) + got, err := app.CreateLabel(tt.args.ctx, tt.args.request) unittest.AssertErrorEqual(t, tt.wantErr, err) if err == nil { assert.Equal(t, tt.want, got) @@ -1841,9 +1778,12 @@ func TestPromptManageApplicationImpl_CommitDraft(t *testing.T) { } } -func TestPromptManageApplicationImpl_ListPrompt(t *testing.T) { +func TestPromptManageApplicationImpl_ListLabel(t *testing.T) { + t.Parallel() + type fields struct { manageRepo repo.IManageRepo + labelRepo repo.ILabelRepo promptService service.IPromptService authRPCProvider rpc.IAuthProvider userRPCProvider rpc.IUserProvider @@ -1852,34 +1792,55 @@ func TestPromptManageApplicationImpl_ListPrompt(t *testing.T) { } type args struct { ctx context.Context - request *manage.ListPromptRequest + request *manage.ListLabelRequest } - now := time.Now() tests := []struct { name string fieldsGetter func(ctrl *gomock.Controller) fields args args - want *manage.ListPromptResponse + want *manage.ListLabelResponse wantErr error }{ { - name: "user not found", + name: "成功列出标签", fieldsGetter: func(ctrl *gomock.Controller) fields { - return fields{} + mockAuth := mocks.NewMockIAuthProvider(ctrl) + mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceListLoopPrompt).Return(nil) + + mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ListLabel(gomock.Any(), gomock.Any()).Return([]*entity.PromptLabel{ + { + ID: 1, + SpaceID: 100, + LabelKey: "test-label", + }, + }, ptr.Of(int64(2)), nil) + + return fields{ + authRPCProvider: mockAuth, + promptService: mockPromptService, + } }, args: args{ ctx: context.Background(), - request: &manage.ListPromptRequest{ + request: &manage.ListLabelRequest{ WorkspaceID: ptr.Of(int64(100)), - PageNum: ptr.Of(int32(1)), PageSize: ptr.Of(int32(10)), }, }, - want: manage.NewListPromptResponse(), - wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("User not found")), + want: &manage.ListLabelResponse{ + Labels: []*prompt.Label{ + { + Key: ptr.Of("test-label"), + }, + }, + NextPageToken: ptr.Of("2"), + HasMore: ptr.Of(true), + }, + wantErr: nil, }, { - name: "permission check error", + name: "权限检查失败", fieldsGetter: func(ctrl *gomock.Controller) fields { mockAuth := mocks.NewMockIAuthProvider(ctrl) mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceListLoopPrompt).Return(errorx.New("permission denied")) @@ -1889,389 +1850,266 @@ func TestPromptManageApplicationImpl_ListPrompt(t *testing.T) { } }, args: args{ - ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.ListPromptRequest{ + ctx: context.Background(), + request: &manage.ListLabelRequest{ WorkspaceID: ptr.Of(int64(100)), - PageNum: ptr.Of(int32(1)), PageSize: ptr.Of(int32(10)), }, }, - want: manage.NewListPromptResponse(), + want: manage.NewListLabelResponse(), wantErr: errorx.New("permission denied"), }, { - name: "list prompt with committed only true", + name: "需要版本映射但未提供PromptID", fieldsGetter: func(ctrl *gomock.Controller) fields { - mockRepo := repomocks.NewMockIManageRepo(ctrl) - mockRepo.EXPECT().ListPrompt(gomock.Any(), repo.ListPromptParam{ - SpaceID: 100, - UserID: "123", - CommittedOnly: true, - PageNum: 1, - PageSize: 10, - OrderBy: mysql.ListPromptBasicOrderByID, - Asc: false, - }).Return(&repo.ListPromptResult{ - Total: 1, - PromptDOs: []*entity.Prompt{ - { - ID: 1, - SpaceID: 100, - PromptKey: "test_key", - PromptBasic: &entity.PromptBasic{ - DisplayName: "test_name", - Description: "test_description", - LatestVersion: "1.0.0", - CreatedBy: "test_creator", - UpdatedBy: "test_updater", - CreatedAt: now, - UpdatedAt: now, - LatestCommittedAt: &now, - }, - }, - }, - }, nil) + mockAuth := mocks.NewMockIAuthProvider(ctrl) + mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceListLoopPrompt).Return(nil) + + return fields{ + authRPCProvider: mockAuth, + } + }, + args: args{ + ctx: context.Background(), + request: &manage.ListLabelRequest{ + WorkspaceID: ptr.Of(int64(100)), + PageSize: ptr.Of(int32(10)), + WithPromptVersionMapping: ptr.Of(true), + }, + }, + want: manage.NewListLabelResponse(), + wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("PromptID must be provided when WithPromptVersionMapping is true")), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ttFields := tt.fieldsGetter(ctrl) + + app := &PromptManageApplicationImpl{ + manageRepo: ttFields.manageRepo, + labelRepo: ttFields.labelRepo, + promptService: ttFields.promptService, + authRPCProvider: ttFields.authRPCProvider, + userRPCProvider: ttFields.userRPCProvider, + auditRPCProvider: ttFields.auditRPCProvider, + configProvider: ttFields.configProvider, + } + + got, err := app.ListLabel(tt.args.ctx, tt.args.request) + unittest.AssertErrorEqual(t, tt.wantErr, err) + if err == nil { + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestPromptManageApplicationImpl_BatchGetLabel(t *testing.T) { + t.Parallel() + type fields struct { + manageRepo repo.IManageRepo + labelRepo repo.ILabelRepo + promptService service.IPromptService + authRPCProvider rpc.IAuthProvider + userRPCProvider rpc.IUserProvider + auditRPCProvider rpc.IAuditProvider + configProvider conf.IConfigProvider + } + type args struct { + ctx context.Context + request *manage.BatchGetLabelRequest + } + tests := []struct { + name string + fieldsGetter func(ctrl *gomock.Controller) fields + args args + want *manage.BatchGetLabelResponse + wantErr error + }{ + { + name: "成功批量获取标签", + fieldsGetter: func(ctrl *gomock.Controller) fields { mockAuth := mocks.NewMockIAuthProvider(ctrl) mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceListLoopPrompt).Return(nil) - mockUser := mocks.NewMockIUserProvider(ctrl) - mockUser.EXPECT().MGetUserInfo(gomock.Any(), []string{"test_creator"}).Return([]*rpc.UserInfo{ + mockLabelRepo := repomocks.NewMockILabelRepo(ctrl) + mockLabelRepo.EXPECT().BatchGetLabel(gomock.Any(), int64(100), []string{"label1", "label2"}).Return([]*entity.PromptLabel{ + { + ID: 1, + SpaceID: 100, + LabelKey: "label1", + }, { - UserID: "test_creator", - UserName: "Test Creator", + ID: 2, + SpaceID: 100, + LabelKey: "label2", }, }, nil) return fields{ - manageRepo: mockRepo, authRPCProvider: mockAuth, - userRPCProvider: mockUser, + labelRepo: mockLabelRepo, } }, args: args{ - ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.ListPromptRequest{ - WorkspaceID: ptr.Of(int64(100)), - CommittedOnly: ptr.Of(true), - PageNum: ptr.Of(int32(1)), - PageSize: ptr.Of(int32(10)), + ctx: context.Background(), + request: &manage.BatchGetLabelRequest{ + WorkspaceID: ptr.Of(int64(100)), + LabelKeys: []string{"label1", "label2"}, }, }, - want: &manage.ListPromptResponse{ - Total: ptr.Of(int32(1)), - Prompts: []*prompt.Prompt{ + want: &manage.BatchGetLabelResponse{ + Labels: []*prompt.Label{ { - ID: ptr.Of(int64(1)), - WorkspaceID: ptr.Of(int64(100)), - PromptKey: ptr.Of("test_key"), - PromptBasic: &prompt.PromptBasic{ - DisplayName: ptr.Of("test_name"), - Description: ptr.Of("test_description"), - LatestVersion: ptr.Of("1.0.0"), - CreatedBy: ptr.Of("test_creator"), - UpdatedBy: ptr.Of("test_updater"), - CreatedAt: ptr.Of(now.UnixMilli()), - UpdatedAt: ptr.Of(now.UnixMilli()), - LatestCommittedAt: ptr.Of(now.UnixMilli()), - }, + Key: ptr.Of("label1"), }, - }, - Users: []*user.UserInfoDetail{ { - UserID: ptr.Of("test_creator"), - Name: ptr.Of("Test Creator"), - NickName: ptr.Of(""), - AvatarURL: ptr.Of(""), - Email: ptr.Of(""), - Mobile: ptr.Of(""), + Key: ptr.Of("label2"), }, }, }, wantErr: nil, }, { - name: "list prompt with committed only false", + name: "权限检查失败", fieldsGetter: func(ctrl *gomock.Controller) fields { - mockRepo := repomocks.NewMockIManageRepo(ctrl) - mockRepo.EXPECT().ListPrompt(gomock.Any(), repo.ListPromptParam{ - SpaceID: 100, - UserID: "123", - CommittedOnly: false, - PageNum: 1, - PageSize: 10, - OrderBy: mysql.ListPromptBasicOrderByID, - Asc: false, - }).Return(&repo.ListPromptResult{ - Total: 2, - PromptDOs: []*entity.Prompt{ - { - ID: 1, - SpaceID: 100, - PromptKey: "test_key_1", - PromptBasic: &entity.PromptBasic{ - DisplayName: "test_name_1", - Description: "test_description_1", - LatestVersion: "1.0.0", - CreatedBy: "test_creator", - UpdatedBy: "test_updater", - CreatedAt: now, - UpdatedAt: now, - LatestCommittedAt: &now, - }, - }, - { - ID: 2, - SpaceID: 100, - PromptKey: "test_key_2", - PromptBasic: &entity.PromptBasic{ - DisplayName: "test_name_2", - Description: "test_description_2", - LatestVersion: "", - CreatedBy: "test_creator", - UpdatedBy: "test_updater", - CreatedAt: now, - UpdatedAt: now, - LatestCommittedAt: nil, - }, - }, - }, - }, nil) - mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceListLoopPrompt).Return(nil) - - mockUser := mocks.NewMockIUserProvider(ctrl) - mockUser.EXPECT().MGetUserInfo(gomock.Any(), []string{"test_creator"}).Return([]*rpc.UserInfo{ - { - UserID: "test_creator", - UserName: "Test Creator", - }, - }, nil) + mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceListLoopPrompt).Return(errorx.New("permission denied")) return fields{ - manageRepo: mockRepo, authRPCProvider: mockAuth, - userRPCProvider: mockUser, } }, args: args{ - ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.ListPromptRequest{ - WorkspaceID: ptr.Of(int64(100)), - CommittedOnly: ptr.Of(false), - PageNum: ptr.Of(int32(1)), - PageSize: ptr.Of(int32(10)), - }, - }, - want: &manage.ListPromptResponse{ - Total: ptr.Of(int32(2)), - Prompts: []*prompt.Prompt{ - { - ID: ptr.Of(int64(1)), - WorkspaceID: ptr.Of(int64(100)), - PromptKey: ptr.Of("test_key_1"), - PromptBasic: &prompt.PromptBasic{ - DisplayName: ptr.Of("test_name_1"), - Description: ptr.Of("test_description_1"), - LatestVersion: ptr.Of("1.0.0"), - CreatedBy: ptr.Of("test_creator"), - UpdatedBy: ptr.Of("test_updater"), - CreatedAt: ptr.Of(now.UnixMilli()), - UpdatedAt: ptr.Of(now.UnixMilli()), - LatestCommittedAt: ptr.Of(now.UnixMilli()), - }, - }, - { - ID: ptr.Of(int64(2)), - WorkspaceID: ptr.Of(int64(100)), - PromptKey: ptr.Of("test_key_2"), - PromptBasic: &prompt.PromptBasic{ - DisplayName: ptr.Of("test_name_2"), - Description: ptr.Of("test_description_2"), - LatestVersion: ptr.Of(""), - CreatedBy: ptr.Of("test_creator"), - UpdatedBy: ptr.Of("test_updater"), - CreatedAt: ptr.Of(now.UnixMilli()), - UpdatedAt: ptr.Of(now.UnixMilli()), - }, - }, - }, - Users: []*user.UserInfoDetail{ - { - UserID: ptr.Of("test_creator"), - Name: ptr.Of("Test Creator"), - NickName: ptr.Of(""), - AvatarURL: ptr.Of(""), - Email: ptr.Of(""), - Mobile: ptr.Of(""), - }, + ctx: context.Background(), + request: &manage.BatchGetLabelRequest{ + WorkspaceID: ptr.Of(int64(100)), + LabelKeys: []string{"label1", "label2"}, }, }, - wantErr: nil, + want: manage.NewBatchGetLabelResponse(), + wantErr: errorx.New("permission denied"), }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ttFields := tt.fieldsGetter(ctrl) + + app := &PromptManageApplicationImpl{ + manageRepo: ttFields.manageRepo, + labelRepo: ttFields.labelRepo, + promptService: ttFields.promptService, + authRPCProvider: ttFields.authRPCProvider, + userRPCProvider: ttFields.userRPCProvider, + auditRPCProvider: ttFields.auditRPCProvider, + configProvider: ttFields.configProvider, + } + + got, err := app.BatchGetLabel(tt.args.ctx, tt.args.request) + unittest.AssertErrorEqual(t, tt.wantErr, err) + if err == nil { + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestPromptManageApplicationImpl_UpdateCommitLabels(t *testing.T) { + t.Parallel() + + type fields struct { + manageRepo repo.IManageRepo + labelRepo repo.ILabelRepo + promptService service.IPromptService + authRPCProvider rpc.IAuthProvider + userRPCProvider rpc.IUserProvider + auditRPCProvider rpc.IAuditProvider + configProvider conf.IConfigProvider + } + type args struct { + ctx context.Context + request *manage.UpdateCommitLabelsRequest + } + tests := []struct { + name string + fieldsGetter func(ctrl *gomock.Controller) fields + args args + want *manage.UpdateCommitLabelsResponse + wantErr error + }{ { - name: "list prompt with user draft association", + name: "成功更新提交标签", fieldsGetter: func(ctrl *gomock.Controller) fields { - mockRepo := repomocks.NewMockIManageRepo(ctrl) - mockRepo.EXPECT().ListPrompt(gomock.Any(), repo.ListPromptParam{ - SpaceID: 100, - UserID: "123", - KeyWord: "draft", - PageNum: 1, - PageSize: 10, - OrderBy: mysql.ListPromptBasicOrderByID, - Asc: false, - }).Return(&repo.ListPromptResult{ - Total: 1, - PromptDOs: []*entity.Prompt{ - { - ID: 1, - SpaceID: 100, - PromptKey: "test_key", - PromptBasic: &entity.PromptBasic{ - DisplayName: "test_name", - Description: "test_description", - LatestVersion: "1.0.0", - CreatedBy: "test_creator", - UpdatedBy: "test_updater", - CreatedAt: now, - UpdatedAt: now, - LatestCommittedAt: &now, - }, - PromptDraft: &entity.PromptDraft{ - PromptDetail: &entity.PromptDetail{ - PromptTemplate: &entity.PromptTemplate{ - TemplateType: entity.TemplateTypeNormal, - Messages: []*entity.Message{ - { - Role: entity.RoleUser, - Content: ptr.Of("draft content"), - }, - }, - }, - }, - DraftInfo: &entity.DraftInfo{ - UserID: "123", - BaseVersion: "1.0.0", - IsModified: true, - CreatedAt: now, - UpdatedAt: now, - }, - }, - }, - }, - }, nil) - mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceListLoopPrompt).Return(nil) + mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), int64(100), []int64{1}, consts.ActionLoopPromptEdit).Return(nil) - mockUser := mocks.NewMockIUserProvider(ctrl) - mockUser.EXPECT().MGetUserInfo(gomock.Any(), []string{"test_creator"}).Return([]*rpc.UserInfo{ - { - UserID: "test_creator", - UserName: "Test Creator", - }, - }, nil) + mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().UpdateCommitLabels(gomock.Any(), gomock.Any()).Return(nil) return fields{ - manageRepo: mockRepo, authRPCProvider: mockAuth, - userRPCProvider: mockUser, + promptService: mockPromptService, } }, args: args{ ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.ListPromptRequest{ - WorkspaceID: ptr.Of(int64(100)), - KeyWord: ptr.Of("draft"), - PageNum: ptr.Of(int32(1)), - PageSize: ptr.Of(int32(10)), - }, - }, - want: &manage.ListPromptResponse{ - Total: ptr.Of(int32(1)), - Prompts: []*prompt.Prompt{ - { - ID: ptr.Of(int64(1)), - WorkspaceID: ptr.Of(int64(100)), - PromptKey: ptr.Of("test_key"), - PromptBasic: &prompt.PromptBasic{ - DisplayName: ptr.Of("test_name"), - Description: ptr.Of("test_description"), - LatestVersion: ptr.Of("1.0.0"), - CreatedBy: ptr.Of("test_creator"), - UpdatedBy: ptr.Of("test_updater"), - CreatedAt: ptr.Of(now.UnixMilli()), - UpdatedAt: ptr.Of(now.UnixMilli()), - LatestCommittedAt: ptr.Of(now.UnixMilli()), - }, - PromptDraft: &prompt.PromptDraft{ - Detail: &prompt.PromptDetail{ - PromptTemplate: &prompt.PromptTemplate{ - TemplateType: ptr.Of(prompt.TemplateTypeNormal), - Messages: []*prompt.Message{ - { - Role: ptr.Of(prompt.RoleUser), - Content: ptr.Of("draft content"), - }, - }, - }, - }, - DraftInfo: &prompt.DraftInfo{ - UserID: ptr.Of("123"), - BaseVersion: ptr.Of("1.0.0"), - IsModified: ptr.Of(true), - CreatedAt: ptr.Of(now.UnixMilli()), - UpdatedAt: ptr.Of(now.UnixMilli()), - }, - }, - }, - }, - Users: []*user.UserInfoDetail{ - { - UserID: ptr.Of("test_creator"), - Name: ptr.Of("Test Creator"), - NickName: ptr.Of(""), - AvatarURL: ptr.Of(""), - Email: ptr.Of(""), - Mobile: ptr.Of(""), - }, + request: &manage.UpdateCommitLabelsRequest{ + WorkspaceID: ptr.Of(int64(100)), + PromptID: ptr.Of(int64(1)), + CommitVersion: ptr.Of("1.0.0"), + LabelKeys: []string{"label1", "label2"}, }, }, + want: manage.NewUpdateCommitLabelsResponse(), wantErr: nil, }, { - name: "list prompt repo error", + name: "用户未找到", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{} + }, + args: args{ + ctx: context.Background(), + request: &manage.UpdateCommitLabelsRequest{ + WorkspaceID: ptr.Of(int64(100)), + PromptID: ptr.Of(int64(1)), + CommitVersion: ptr.Of("1.0.0"), + LabelKeys: []string{"label1", "label2"}, + }, + }, + want: manage.NewUpdateCommitLabelsResponse(), + wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("User not found")), + }, + { + name: "权限检查失败", fieldsGetter: func(ctrl *gomock.Controller) fields { - mockRepo := repomocks.NewMockIManageRepo(ctrl) - mockRepo.EXPECT().ListPrompt(gomock.Any(), repo.ListPromptParam{ - SpaceID: 100, - UserID: "123", - PageNum: 1, - PageSize: 10, - OrderBy: mysql.ListPromptBasicOrderByID, - Asc: false, - }).Return(nil, errorx.New("list prompt error")) - mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceListLoopPrompt).Return(nil) + mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), int64(100), []int64{1}, consts.ActionLoopPromptEdit).Return(errorx.New("permission denied")) return fields{ - manageRepo: mockRepo, authRPCProvider: mockAuth, } }, args: args{ ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.ListPromptRequest{ - WorkspaceID: ptr.Of(int64(100)), - PageNum: ptr.Of(int32(1)), - PageSize: ptr.Of(int32(10)), + request: &manage.UpdateCommitLabelsRequest{ + WorkspaceID: ptr.Of(int64(100)), + PromptID: ptr.Of(int64(1)), + CommitVersion: ptr.Of("1.0.0"), + LabelKeys: []string{"label1", "label2"}, }, }, - want: manage.NewListPromptResponse(), - wantErr: errorx.New("list prompt error"), + want: manage.NewUpdateCommitLabelsResponse(), + wantErr: errorx.New("permission denied"), }, } @@ -2284,6 +2122,7 @@ func TestPromptManageApplicationImpl_ListPrompt(t *testing.T) { app := &PromptManageApplicationImpl{ manageRepo: ttFields.manageRepo, + labelRepo: ttFields.labelRepo, promptService: ttFields.promptService, authRPCProvider: ttFields.authRPCProvider, userRPCProvider: ttFields.userRPCProvider, @@ -2291,7 +2130,7 @@ func TestPromptManageApplicationImpl_ListPrompt(t *testing.T) { configProvider: ttFields.configProvider, } - got, err := app.ListPrompt(tt.args.ctx, tt.args.request) + got, err := app.UpdateCommitLabels(tt.args.ctx, tt.args.request) unittest.AssertErrorEqual(t, tt.wantErr, err) if err == nil { assert.Equal(t, tt.want, got) @@ -2300,101 +2139,199 @@ func TestPromptManageApplicationImpl_ListPrompt(t *testing.T) { } } -func TestPromptManageApplicationImpl_CreateLabel(t *testing.T) { - t.Parallel() - +func TestPromptManageApplicationImpl_ListParentPrompt(t *testing.T) { type fields struct { manageRepo repo.IManageRepo - labelRepo repo.ILabelRepo promptService service.IPromptService authRPCProvider rpc.IAuthProvider userRPCProvider rpc.IUserProvider auditRPCProvider rpc.IAuditProvider configProvider conf.IConfigProvider + labelRepo repo.ILabelRepo } type args struct { ctx context.Context - request *manage.CreateLabelRequest + request *manage.ListParentPromptRequest } tests := []struct { name string fieldsGetter func(ctrl *gomock.Controller) fields args args - want *manage.CreateLabelResponse + want *manage.ListParentPromptResponse wantErr error }{ { - name: "成功创建标签", + name: "user not found", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{} + }, + args: args{ + ctx: context.Background(), + request: &manage.ListParentPromptRequest{ + WorkspaceID: ptr.Of(int64(1)), + PromptID: ptr.Of(int64(1)), + }, + }, + want: manage.NewListParentPromptResponse(), + wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("User not found")), + }, + { + name: "permission denied", fieldsGetter: func(ctrl *gomock.Controller) fields { mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceCreateLoopPrompt).Return(nil) - - mockPromptService := servicemocks.NewMockIPromptService(ctrl) - mockPromptService.EXPECT().CreateLabel(gomock.Any(), gomock.Any()).Return(nil) + mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(1), consts.ActionLoopPromptRead). + Return(errorx.NewByCode(prompterr.CommonNoPermissionCode)) return fields{ authRPCProvider: mockAuth, - promptService: mockPromptService, } }, args: args{ ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.CreateLabelRequest{ - WorkspaceID: ptr.Of(int64(100)), - Label: &prompt.Label{ - Key: ptr.Of("test-label"), - }, + request: &manage.ListParentPromptRequest{ + WorkspaceID: ptr.Of(int64(1)), + PromptID: ptr.Of(int64(1)), }, }, - want: manage.NewCreateLabelResponse(), - wantErr: nil, + want: manage.NewListParentPromptResponse(), + wantErr: errorx.NewByCode(prompterr.CommonNoPermissionCode), }, { - name: "用户未找到", + name: "invalid prompt ID", fieldsGetter: func(ctrl *gomock.Controller) fields { - return fields{} + mockAuth := mocks.NewMockIAuthProvider(ctrl) + mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(1), consts.ActionLoopPromptRead). + Return(nil) + + return fields{ + authRPCProvider: mockAuth, + manageRepo: nil, + promptService: nil, + userRPCProvider: nil, + auditRPCProvider: nil, + configProvider: nil, + labelRepo: nil, + } }, args: args{ - ctx: context.Background(), - request: &manage.CreateLabelRequest{ - WorkspaceID: ptr.Of(int64(100)), - Label: &prompt.Label{ - Key: ptr.Of("test-label"), - }, + ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), + request: &manage.ListParentPromptRequest{ + WorkspaceID: ptr.Of(int64(1)), + PromptID: ptr.Of(int64(0)), }, }, - want: manage.NewCreateLabelResponse(), - wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("User not found")), + want: manage.NewListParentPromptResponse(), + wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("Prompt ID is required")), }, { - name: "权限检查失败", + name: "successful list parent prompts", fieldsGetter: func(ctrl *gomock.Controller) fields { mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceCreateLoopPrompt).Return(errorx.New("permission denied")) + mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(1), consts.ActionLoopPromptRead). + Return(nil) + + mockRepo := repomocks.NewMockIManageRepo(ctrl) + mockRepo.EXPECT().ListParentPrompt(gomock.Any(), repo.ListParentPromptParam{ + SubPromptID: 1, + SubPromptVersions: []string{"v1.0.0"}, + }).Return(map[string][]*repo.PromptCommitVersions{ + "v1.0.0": { + { + PromptID: 2, + PromptKey: "parent_prompt", + SpaceID: 1, + PromptBasic: &entity.PromptBasic{ + DisplayName: "parent name", + Description: "parent description", + LatestVersion: "2.0.0", + PromptType: entity.PromptTypeSnippet, + }, + CommitVersions: []string{"v2.0.0"}, + }, + }, + }, nil) return fields{ - authRPCProvider: mockAuth, + manageRepo: mockRepo, + authRPCProvider: mockAuth, + promptService: nil, + userRPCProvider: nil, + auditRPCProvider: nil, + configProvider: nil, + labelRepo: nil, } }, args: args{ ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.CreateLabelRequest{ - WorkspaceID: ptr.Of(int64(100)), - Label: &prompt.Label{ - Key: ptr.Of("test-label"), + request: &manage.ListParentPromptRequest{ + WorkspaceID: ptr.Of(int64(1)), + PromptID: ptr.Of(int64(1)), + CommitVersions: []string{"v1.0.0"}, + }, + }, + want: &manage.ListParentPromptResponse{ + ParentPrompts: map[string][]*prompt.PromptCommitVersions{ + "v1.0.0": { + { + ID: ptr.Of(int64(2)), + WorkspaceID: ptr.Of(int64(1)), + PromptKey: ptr.Of("parent_prompt"), + PromptBasic: &prompt.PromptBasic{ + DisplayName: ptr.Of("parent name"), + Description: ptr.Of("parent description"), + LatestVersion: ptr.Of("2.0.0"), + PromptType: ptr.Of(prompt.PromptTypeSnippet), + CreatedBy: ptr.Of(""), + UpdatedBy: ptr.Of(""), + CreatedAt: ptr.Of(time.Time{}.UnixMilli()), + UpdatedAt: ptr.Of(time.Time{}.UnixMilli()), + }, + CommitVersions: []string{"v2.0.0"}, + }, }, }, }, - want: manage.NewCreateLabelResponse(), - wantErr: errorx.New("permission denied"), + wantErr: nil, + }, + { + name: "repository error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockAuth := mocks.NewMockIAuthProvider(ctrl) + mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(1), consts.ActionLoopPromptRead). + Return(nil) + + mockRepo := repomocks.NewMockIManageRepo(ctrl) + mockRepo.EXPECT().ListParentPrompt(gomock.Any(), repo.ListParentPromptParam{ + SubPromptID: 1, + }).Return(nil, errorx.New("database error")) + + return fields{ + manageRepo: mockRepo, + authRPCProvider: mockAuth, + promptService: nil, + userRPCProvider: nil, + auditRPCProvider: nil, + configProvider: nil, + labelRepo: nil, + } + }, + args: args{ + ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), + request: &manage.ListParentPromptRequest{ + WorkspaceID: ptr.Of(int64(1)), + PromptID: ptr.Of(int64(1)), + }, + }, + want: manage.NewListParentPromptResponse(), + wantErr: errorx.New("database error"), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - t.Parallel() ctrl := gomock.NewController(t) defer ctrl.Finish() + ttFields := tt.fieldsGetter(ctrl) app := &PromptManageApplicationImpl{ @@ -2407,371 +2344,626 @@ func TestPromptManageApplicationImpl_CreateLabel(t *testing.T) { configProvider: ttFields.configProvider, } - got, err := app.CreateLabel(tt.args.ctx, tt.args.request) + got, err := app.ListParentPrompt(tt.args.ctx, tt.args.request) unittest.AssertErrorEqual(t, tt.wantErr, err) if err == nil { - assert.Equal(t, tt.want, got) + assert.Equal(t, tt.want.ParentPrompts, got.ParentPrompts) } }) } } -func TestPromptManageApplicationImpl_ListLabel(t *testing.T) { +func TestPromptManageApplicationImpl_UpdatePrompt(t *testing.T) { + type fields struct { + manageRepo repo.IManageRepo + authProvider rpc.IAuthProvider + auditProvider rpc.IAuditProvider + } + type args struct { + ctx context.Context + request *manage.UpdatePromptRequest + } + tests := []struct { + name string + fieldsGetter func(ctrl *gomock.Controller) fields + args args + wantErr error + }{ + { + name: "user not found", + fieldsGetter: func(ctrl *gomock.Controller) fields { return fields{} }, + args: args{ + ctx: context.Background(), + request: &manage.UpdatePromptRequest{PromptID: ptr.Of(int64(1))}, + }, + wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("User not found")), + }, + { + name: "success", + fieldsGetter: func(ctrl *gomock.Controller) fields { + repoMock := repomocks.NewMockIManageRepo(ctrl) + repoMock.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{PromptID: 2}).Return(&entity.Prompt{ + ID: 2, + SpaceID: 20, + PromptBasic: &entity.PromptBasic{PromptType: entity.PromptTypeNormal}, + }, nil) + repoMock.EXPECT().UpdatePrompt(gomock.Any(), repo.UpdatePromptParam{ + PromptID: 2, + UpdatedBy: "user", + PromptName: "name", + PromptDescription: "desc", + }).Return(nil) + auth := mocks.NewMockIAuthProvider(ctrl) + auth.EXPECT().MCheckPromptPermission(gomock.Any(), int64(20), []int64{int64(2)}, consts.ActionLoopPromptEdit).Return(nil) + audit := mocks.NewMockIAuditProvider(ctrl) + audit.EXPECT().AuditPrompt(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, promptDO *entity.Prompt) error { + assert.Equal(t, int64(2), promptDO.ID) + assert.Equal(t, "name", promptDO.PromptBasic.DisplayName) + return nil + }) + return fields{manageRepo: repoMock, authProvider: auth, auditProvider: audit} + }, + args: args{ + ctx: session.WithCtxUser(context.Background(), &session.User{ID: "user"}), + request: &manage.UpdatePromptRequest{ + PromptID: ptr.Of(int64(2)), + PromptName: ptr.Of("name"), + PromptDescription: ptr.Of("desc"), + }, + }, + wantErr: nil, + }, + } + + for _, tt := range tests { + caseData := tt + t.Run(caseData.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + tFields := caseData.fieldsGetter(ctrl) + app := &PromptManageApplicationImpl{ + manageRepo: tFields.manageRepo, + authRPCProvider: tFields.authProvider, + auditRPCProvider: tFields.auditProvider, + } + + resp, err := app.UpdatePrompt(caseData.args.ctx, caseData.args.request) + unittest.AssertErrorEqual(t, caseData.wantErr, err) + if err == nil { + assert.NotNil(t, resp) + } + }) + } +} + +func TestPromptManageApplicationImpl_GetPrompt_AutoCommitVersion(t *testing.T) { t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + repoMock := repomocks.NewMockIManageRepo(ctrl) + repoMock.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{PromptID: 10}).Return(&entity.Prompt{ + ID: 10, + PromptBasic: &entity.PromptBasic{LatestVersion: "v2"}, + }, nil) + promptSvc := servicemocks.NewMockIPromptService(ctrl) + promptSvc.EXPECT().GetPrompt(gomock.Any(), service.GetPromptParam{ + PromptID: 10, + WithCommit: true, + CommitVersion: "v2", + WithDraft: false, + UserID: "user", + ExpandSnippet: false, + }).Return(&entity.Prompt{ID: 10, SpaceID: 200}, nil) + auth := mocks.NewMockIAuthProvider(ctrl) + auth.EXPECT().MCheckPromptPermission(gomock.Any(), int64(200), []int64{int64(10)}, consts.ActionLoopPromptRead).Return(nil) + + app := &PromptManageApplicationImpl{ + manageRepo: repoMock, + promptService: promptSvc, + authRPCProvider: auth, + } + + resp, err := app.GetPrompt(session.WithCtxUser(context.Background(), &session.User{ID: "user"}), &manage.GetPromptRequest{ + PromptID: ptr.Of(int64(10)), + WithCommit: ptr.Of(true), + }) + unittest.AssertErrorEqual(t, nil, err) + assert.NotNil(t, resp) + assert.Equal(t, int64(10), resp.GetPrompt().GetID()) +} +func TestPromptManageApplicationImpl_SaveDraft(t *testing.T) { type fields struct { - manageRepo repo.IManageRepo - labelRepo repo.ILabelRepo - promptService service.IPromptService - authRPCProvider rpc.IAuthProvider - userRPCProvider rpc.IUserProvider - auditRPCProvider rpc.IAuditProvider - configProvider conf.IConfigProvider + manageRepo repo.IManageRepo + authProvider rpc.IAuthProvider + auditProvider rpc.IAuditProvider + promptService service.IPromptService + } + type args struct { + ctx context.Context + request *manage.SaveDraftRequest + } + tests := []struct { + name string + fieldsGetter func(ctrl *gomock.Controller) fields + args args + wantErr error + }{ + { + name: "invalid draft", + fieldsGetter: func(ctrl *gomock.Controller) fields { return fields{} }, + args: args{ + ctx: session.WithCtxUser(context.Background(), &session.User{ID: "user"}), + request: &manage.SaveDraftRequest{PromptDraft: &prompt.PromptDraft{}}, + }, + wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("Draft is not specified")), + }, + { + name: "get prompt error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + repoMock := repomocks.NewMockIManageRepo(ctrl) + repoMock.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{PromptID: 3}).Return(nil, errorx.New("repo error")) + return fields{manageRepo: repoMock} + }, + args: args{ + ctx: session.WithCtxUser(context.Background(), &session.User{ID: "user"}), + request: &manage.SaveDraftRequest{ + PromptID: ptr.Of(int64(3)), + PromptDraft: &prompt.PromptDraft{ + DraftInfo: &prompt.DraftInfo{}, + Detail: &prompt.PromptDetail{}, + }, + }, + }, + wantErr: errorx.New("repo error"), + }, + { + name: "success", + fieldsGetter: func(ctrl *gomock.Controller) fields { + repoMock := repomocks.NewMockIManageRepo(ctrl) + repoMock.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{PromptID: 3}).Return(&entity.Prompt{ID: 3, SpaceID: 30}, nil) + auth := mocks.NewMockIAuthProvider(ctrl) + auth.EXPECT().MCheckPromptPermission(gomock.Any(), int64(30), []int64{int64(3)}, consts.ActionLoopPromptEdit).Return(nil) + audit := mocks.NewMockIAuditProvider(ctrl) + audit.EXPECT().AuditPrompt(gomock.Any(), gomock.Any()).Return(nil) + promptSvc := servicemocks.NewMockIPromptService(ctrl) + promptSvc.EXPECT().SaveDraft(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, promptDO *entity.Prompt) (*entity.DraftInfo, error) { + assert.Equal(t, "user", promptDO.PromptDraft.DraftInfo.UserID) + return &entity.DraftInfo{UserID: "user", IsModified: true}, nil + }) + return fields{manageRepo: repoMock, authProvider: auth, auditProvider: audit, promptService: promptSvc} + }, + args: args{ + ctx: session.WithCtxUser(context.Background(), &session.User{ID: "user"}), + request: &manage.SaveDraftRequest{ + PromptID: ptr.Of(int64(3)), + PromptDraft: &prompt.PromptDraft{ + DraftInfo: &prompt.DraftInfo{}, + Detail: &prompt.PromptDetail{}, + }, + }, + }, + wantErr: nil, + }, + } + + for _, tt := range tests { + caseData := tt + t.Run(caseData.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + tFields := caseData.fieldsGetter(ctrl) + app := &PromptManageApplicationImpl{ + manageRepo: tFields.manageRepo, + authRPCProvider: tFields.authProvider, + auditRPCProvider: tFields.auditProvider, + promptService: tFields.promptService, + } + + resp, err := app.SaveDraft(caseData.args.ctx, caseData.args.request) + unittest.AssertErrorEqual(t, caseData.wantErr, err) + if err == nil { + assert.NotNil(t, resp.DraftInfo) + } + }) + } +} + +func TestPromptManageApplicationImpl_CommitDraft(t *testing.T) { + invalidVersionErr := func() error { + _, err := semver.StrictNewVersion("invalid") + return err + }() + type fields struct { + manageRepo repo.IManageRepo + authProvider rpc.IAuthProvider + promptService service.IPromptService + } + type args struct { + ctx context.Context + request *manage.CommitDraftRequest + } + tests := []struct { + name string + fieldsGetter func(ctrl *gomock.Controller) fields + args args + wantErr error + }{ + { + name: "invalid semver", + fieldsGetter: func(ctrl *gomock.Controller) fields { return fields{} }, + args: args{ + ctx: session.WithCtxUser(context.Background(), &session.User{ID: "user"}), + request: &manage.CommitDraftRequest{PromptID: ptr.Of(int64(4)), CommitVersion: ptr.Of("invalid")}, + }, + wantErr: invalidVersionErr, + }, + { + name: "success", + fieldsGetter: func(ctrl *gomock.Controller) fields { + repoMock := repomocks.NewMockIManageRepo(ctrl) + repoMock.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{PromptID: 4}).Return(&entity.Prompt{ID: 4, SpaceID: 40}, nil) + repoMock.EXPECT().CommitDraft(gomock.Any(), repo.CommitDraftParam{ + PromptID: 4, + UserID: "user", + CommitVersion: "1.0.0", + CommitDescription: "desc", + LabelKeys: []string{"label"}, + }).Return(nil) + auth := mocks.NewMockIAuthProvider(ctrl) + auth.EXPECT().MCheckPromptPermission(gomock.Any(), int64(40), []int64{int64(4)}, consts.ActionLoopPromptEdit).Return(nil) + promptSvc := servicemocks.NewMockIPromptService(ctrl) + promptSvc.EXPECT().ValidateLabelsExist(gomock.Any(), int64(40), []string{"label"}).Return(nil) + return fields{manageRepo: repoMock, authProvider: auth, promptService: promptSvc} + }, + args: args{ + ctx: session.WithCtxUser(context.Background(), &session.User{ID: "user"}), + request: &manage.CommitDraftRequest{ + PromptID: ptr.Of(int64(4)), + CommitVersion: ptr.Of("1.0.0"), + CommitDescription: ptr.Of("desc"), + LabelKeys: []string{"label"}, + }, + }, + wantErr: nil, + }, + } + + for _, tt := range tests { + caseData := tt + t.Run(caseData.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + tFields := caseData.fieldsGetter(ctrl) + app := &PromptManageApplicationImpl{ + manageRepo: tFields.manageRepo, + authRPCProvider: tFields.authProvider, + promptService: tFields.promptService, + } + + resp, err := app.CommitDraft(caseData.args.ctx, caseData.args.request) + unittest.AssertErrorEqual(t, caseData.wantErr, err) + if err == nil { + assert.NotNil(t, resp) + } + }) + } +} + +func TestPromptManageApplicationImpl_ListCommit(t *testing.T) { + type fields struct { + manageRepo repo.IManageRepo + authProvider rpc.IAuthProvider + promptService service.IPromptService + userProvider rpc.IUserProvider } type args struct { ctx context.Context - request *manage.ListLabelRequest + request *manage.ListCommitRequest } tests := []struct { name string fieldsGetter func(ctrl *gomock.Controller) fields args args - want *manage.ListLabelResponse wantErr error }{ { - name: "成功列出标签", + name: "invalid page token", fieldsGetter: func(ctrl *gomock.Controller) fields { - mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceListLoopPrompt).Return(nil) - - mockPromptService := servicemocks.NewMockIPromptService(ctrl) - mockPromptService.EXPECT().ListLabel(gomock.Any(), gomock.Any()).Return([]*entity.PromptLabel{ - { - ID: 1, - SpaceID: 100, - LabelKey: "test-label", - }, - }, ptr.Of(int64(2)), nil) - - return fields{ - authRPCProvider: mockAuth, - promptService: mockPromptService, - } + repoMock := repomocks.NewMockIManageRepo(ctrl) + repoMock.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{PromptID: 5}).Return(&entity.Prompt{ID: 5, SpaceID: 50}, nil) + auth := mocks.NewMockIAuthProvider(ctrl) + auth.EXPECT().MCheckPromptPermission(gomock.Any(), int64(50), []int64{int64(5)}, consts.ActionLoopPromptRead).Return(nil) + return fields{manageRepo: repoMock, authProvider: auth} }, args: args{ - ctx: context.Background(), - request: &manage.ListLabelRequest{ - WorkspaceID: ptr.Of(int64(100)), - PageSize: ptr.Of(int32(10)), - }, - }, - want: &manage.ListLabelResponse{ - Labels: []*prompt.Label{ - { - Key: ptr.Of("test-label"), - }, + ctx: session.WithCtxUser(context.Background(), &session.User{ID: "user"}), + request: &manage.ListCommitRequest{ + PromptID: ptr.Of(int64(5)), + PageToken: ptr.Of("bad"), }, - NextPageToken: ptr.Of("2"), - HasMore: ptr.Of(true), }, - wantErr: nil, + wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("Page token is invalid, page token = bad")), }, { - name: "权限检查失败", + name: "success", fieldsGetter: func(ctrl *gomock.Controller) fields { - mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceListLoopPrompt).Return(errorx.New("permission denied")) - - return fields{ - authRPCProvider: mockAuth, - } + repoMock := repomocks.NewMockIManageRepo(ctrl) + repoMock.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{PromptID: 5}).Return(&entity.Prompt{ID: 5, SpaceID: 50, PromptBasic: &entity.PromptBasic{PromptType: entity.PromptTypeNormal}}, nil) + repoMock.EXPECT().ListCommitInfo(gomock.Any(), repo.ListCommitInfoParam{PromptID: 5, PageSize: 10, Asc: true}).Return(&repo.ListCommitResult{ + CommitInfoDOs: []*entity.CommitInfo{{Version: "1.0.0", CommittedBy: "userA"}}, + CommitDOs: []*entity.PromptCommit{{CommitInfo: &entity.CommitInfo{Version: "1.0.0"}}}, + NextPageToken: 77, + }, nil) + auth := mocks.NewMockIAuthProvider(ctrl) + auth.EXPECT().MCheckPromptPermission(gomock.Any(), int64(50), []int64{int64(5)}, consts.ActionLoopPromptRead).Return(nil) + promptSvc := servicemocks.NewMockIPromptService(ctrl) + promptSvc.EXPECT().BatchGetCommitLabels(gomock.Any(), int64(5), []string{"1.0.0"}).Return(map[string][]string{"1.0.0": {"label"}}, nil) + userProvider := mocks.NewMockIUserProvider(ctrl) + userProvider.EXPECT().MGetUserInfo(gomock.Any(), []string{"userA"}).Return([]*rpc.UserInfo{{UserID: "userA", UserName: "User A"}}, nil) + return fields{manageRepo: repoMock, authProvider: auth, promptService: promptSvc, userProvider: userProvider} }, args: args{ - ctx: context.Background(), - request: &manage.ListLabelRequest{ - WorkspaceID: ptr.Of(int64(100)), - PageSize: ptr.Of(int32(10)), + ctx: session.WithCtxUser(context.Background(), &session.User{ID: "user"}), + request: &manage.ListCommitRequest{ + PromptID: ptr.Of(int64(5)), + PageSize: ptr.Of(int32(10)), + Asc: ptr.Of(true), + WithCommitDetail: ptr.Of(true), }, }, - want: manage.NewListLabelResponse(), - wantErr: errorx.New("permission denied"), + wantErr: nil, }, { - name: "需要版本映射但未提供PromptID", + name: "snippet success", fieldsGetter: func(ctrl *gomock.Controller) fields { - mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceListLoopPrompt).Return(nil) - - return fields{ - authRPCProvider: mockAuth, - } + repoMock := repomocks.NewMockIManageRepo(ctrl) + repoMock.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{PromptID: 6}).Return(&entity.Prompt{ID: 6, SpaceID: 60, PromptBasic: &entity.PromptBasic{PromptType: entity.PromptTypeSnippet}}, nil) + repoMock.EXPECT().ListCommitInfo(gomock.Any(), repo.ListCommitInfoParam{PromptID: 6, PageSize: 5, Asc: false}).Return(&repo.ListCommitResult{ + CommitInfoDOs: []*entity.CommitInfo{{Version: "2.0.0", CommittedBy: "userB"}}, + CommitDOs: []*entity.PromptCommit{{CommitInfo: &entity.CommitInfo{Version: "2.0.0"}}}, + }, nil) + repoMock.EXPECT().ListParentPrompt(gomock.Any(), repo.ListParentPromptParam{SubPromptID: 6, SubPromptVersions: []string{"2.0.0"}}).Return(map[string][]*repo.PromptCommitVersions{ + "2.0.0": {{CommitVersions: []string{"3.0.0", "3.1.0"}}}, + }, nil) + auth := mocks.NewMockIAuthProvider(ctrl) + auth.EXPECT().MCheckPromptPermission(gomock.Any(), int64(60), []int64{int64(6)}, consts.ActionLoopPromptRead).Return(nil) + promptSvc := servicemocks.NewMockIPromptService(ctrl) + promptSvc.EXPECT().BatchGetCommitLabels(gomock.Any(), int64(6), []string{"2.0.0"}).Return(map[string][]string{"2.0.0": {"labelB"}}, nil) + userProvider := mocks.NewMockIUserProvider(ctrl) + userProvider.EXPECT().MGetUserInfo(gomock.Any(), []string{"userB"}).Return([]*rpc.UserInfo{{UserID: "userB"}}, nil) + return fields{manageRepo: repoMock, authProvider: auth, promptService: promptSvc, userProvider: userProvider} }, args: args{ - ctx: context.Background(), - request: &manage.ListLabelRequest{ - WorkspaceID: ptr.Of(int64(100)), - PageSize: ptr.Of(int32(10)), - WithPromptVersionMapping: ptr.Of(true), + ctx: session.WithCtxUser(context.Background(), &session.User{ID: "user"}), + request: &manage.ListCommitRequest{ + PromptID: ptr.Of(int64(6)), + PageSize: ptr.Of(int32(5)), }, }, - want: manage.NewListLabelResponse(), - wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("PromptID must be provided when WithPromptVersionMapping is true")), + wantErr: nil, }, } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { + caseData := tt + t.Run(caseData.name, func(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) defer ctrl.Finish() - ttFields := tt.fieldsGetter(ctrl) + tFields := caseData.fieldsGetter(ctrl) app := &PromptManageApplicationImpl{ - manageRepo: ttFields.manageRepo, - labelRepo: ttFields.labelRepo, - promptService: ttFields.promptService, - authRPCProvider: ttFields.authRPCProvider, - userRPCProvider: ttFields.userRPCProvider, - auditRPCProvider: ttFields.auditRPCProvider, - configProvider: ttFields.configProvider, + manageRepo: tFields.manageRepo, + authRPCProvider: tFields.authProvider, + promptService: tFields.promptService, + userRPCProvider: tFields.userProvider, } - got, err := app.ListLabel(tt.args.ctx, tt.args.request) - unittest.AssertErrorEqual(t, tt.wantErr, err) + resp, err := app.ListCommit(caseData.args.ctx, caseData.args.request) + unittest.AssertErrorEqual(t, caseData.wantErr, err) if err == nil { - assert.Equal(t, tt.want, got) + assert.NotNil(t, resp) + assert.Len(t, resp.PromptCommitInfos, 1) + switch caseData.name { + case "success": + assert.Equal(t, ptr.Of("77"), resp.NextPageToken) + assert.Equal(t, ptr.Of(true), resp.HasMore) + case "snippet success": + assert.Nil(t, resp.NextPageToken) + } } }) } } -func TestPromptManageApplicationImpl_BatchGetLabel(t *testing.T) { - t.Parallel() - +func TestPromptManageApplicationImpl_RevertDraftFromCommit(t *testing.T) { type fields struct { - manageRepo repo.IManageRepo - labelRepo repo.ILabelRepo - promptService service.IPromptService - authRPCProvider rpc.IAuthProvider - userRPCProvider rpc.IUserProvider - auditRPCProvider rpc.IAuditProvider - configProvider conf.IConfigProvider + manageRepo repo.IManageRepo + authProvider rpc.IAuthProvider + promptService service.IPromptService } type args struct { ctx context.Context - request *manage.BatchGetLabelRequest + request *manage.RevertDraftFromCommitRequest } tests := []struct { name string fieldsGetter func(ctrl *gomock.Controller) fields args args - want *manage.BatchGetLabelResponse wantErr error }{ { - name: "成功批量获取标签", + name: "commit missing", fieldsGetter: func(ctrl *gomock.Controller) fields { - mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceListLoopPrompt).Return(nil) - - mockLabelRepo := repomocks.NewMockILabelRepo(ctrl) - mockLabelRepo.EXPECT().BatchGetLabel(gomock.Any(), int64(100), []string{"label1", "label2"}).Return([]*entity.PromptLabel{ - { - ID: 1, - SpaceID: 100, - LabelKey: "label1", - }, - { - ID: 2, - SpaceID: 100, - LabelKey: "label2", - }, - }, nil) - - return fields{ - authRPCProvider: mockAuth, - labelRepo: mockLabelRepo, - } + repoMock := repomocks.NewMockIManageRepo(ctrl) + repoMock.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{PromptID: 6, WithCommit: true, CommitVersion: "1.0.0"}).Return(&entity.Prompt{ID: 6, PromptCommit: nil}, nil) + return fields{manageRepo: repoMock} }, args: args{ - ctx: context.Background(), - request: &manage.BatchGetLabelRequest{ - WorkspaceID: ptr.Of(int64(100)), - LabelKeys: []string{"label1", "label2"}, - }, - }, - want: &manage.BatchGetLabelResponse{ - Labels: []*prompt.Label{ - { - Key: ptr.Of("label1"), - }, - { - Key: ptr.Of("label2"), - }, - }, + ctx: session.WithCtxUser(context.Background(), &session.User{ID: "user"}), + request: &manage.RevertDraftFromCommitRequest{PromptID: ptr.Of(int64(6)), CommitVersionRevertingFrom: ptr.Of("1.0.0")}, }, - wantErr: nil, + wantErr: errorx.New("Prompt or commit not found, prompt id = 6, commit version = 1.0.0"), }, { - name: "权限检查失败", + name: "success", fieldsGetter: func(ctrl *gomock.Controller) fields { - mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().CheckSpacePermission(gomock.Any(), int64(100), consts.ActionWorkspaceListLoopPrompt).Return(errorx.New("permission denied")) - - return fields{ - authRPCProvider: mockAuth, - } + repoMock := repomocks.NewMockIManageRepo(ctrl) + repoMock.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{PromptID: 6, WithCommit: true, CommitVersion: "1.0.0"}).Return(&entity.Prompt{ + ID: 6, + SpaceID: 60, + PromptCommit: &entity.PromptCommit{ + CommitInfo: &entity.CommitInfo{Version: "1.0.0"}, + PromptDetail: &entity.PromptDetail{PromptTemplate: &entity.PromptTemplate{}}, + }, + }, nil) + auth := mocks.NewMockIAuthProvider(ctrl) + auth.EXPECT().MCheckPromptPermission(gomock.Any(), int64(60), []int64{int64(6)}, consts.ActionLoopPromptEdit).Return(nil) + promptSvc := servicemocks.NewMockIPromptService(ctrl) + promptSvc.EXPECT().SaveDraft(gomock.Any(), gomock.Any()).Return(&entity.DraftInfo{}, nil) + return fields{manageRepo: repoMock, authProvider: auth, promptService: promptSvc} }, args: args{ - ctx: context.Background(), - request: &manage.BatchGetLabelRequest{ - WorkspaceID: ptr.Of(int64(100)), - LabelKeys: []string{"label1", "label2"}, - }, + ctx: session.WithCtxUser(context.Background(), &session.User{ID: "user"}), + request: &manage.RevertDraftFromCommitRequest{PromptID: ptr.Of(int64(6)), CommitVersionRevertingFrom: ptr.Of("1.0.0")}, }, - want: manage.NewBatchGetLabelResponse(), - wantErr: errorx.New("permission denied"), + wantErr: nil, }, } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { + caseData := tt + t.Run(caseData.name, func(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) defer ctrl.Finish() - ttFields := tt.fieldsGetter(ctrl) + tFields := caseData.fieldsGetter(ctrl) app := &PromptManageApplicationImpl{ - manageRepo: ttFields.manageRepo, - labelRepo: ttFields.labelRepo, - promptService: ttFields.promptService, - authRPCProvider: ttFields.authRPCProvider, - userRPCProvider: ttFields.userRPCProvider, - auditRPCProvider: ttFields.auditRPCProvider, - configProvider: ttFields.configProvider, + manageRepo: tFields.manageRepo, + authRPCProvider: tFields.authProvider, + promptService: tFields.promptService, } - got, err := app.BatchGetLabel(tt.args.ctx, tt.args.request) - unittest.AssertErrorEqual(t, tt.wantErr, err) + resp, err := app.RevertDraftFromCommit(caseData.args.ctx, caseData.args.request) + unittest.AssertErrorEqual(t, caseData.wantErr, err) if err == nil { - assert.Equal(t, tt.want, got) + assert.NotNil(t, resp) } }) } } -func TestPromptManageApplicationImpl_UpdateCommitLabels(t *testing.T) { - t.Parallel() +func TestPromptManageApplicationImpl_listPromptOrderBy(t *testing.T) { + app := &PromptManageApplicationImpl{} + tests := []struct { + name string + arg *manage.ListPromptOrderBy + exp int + }{ + {"nil", nil, mysql.ListPromptBasicOrderByID}, + {"created", ptr.Of(manage.ListPromptOrderByCreatedAt), mysql.ListPromptBasicOrderByCreatedAt}, + {"committed", ptr.Of(manage.ListPromptOrderByCommitedAt), mysql.ListPromptBasicOrderByLatestCommittedAt}, + {"default", ptr.Of(manage.ListPromptOrderBy("unknown")), mysql.ListPromptBasicOrderByID}, + } + for _, tt := range tests { + assert.Equal(t, tt.exp, app.listPromptOrderBy(tt.arg)) + } +} +func TestPromptManageApplicationImpl_ListLabelAdditional(t *testing.T) { type fields struct { - manageRepo repo.IManageRepo - labelRepo repo.ILabelRepo - promptService service.IPromptService - authRPCProvider rpc.IAuthProvider - userRPCProvider rpc.IUserProvider - auditRPCProvider rpc.IAuditProvider - configProvider conf.IConfigProvider + authProvider rpc.IAuthProvider + promptService service.IPromptService } type args struct { ctx context.Context - request *manage.UpdateCommitLabelsRequest + request *manage.ListLabelRequest } tests := []struct { name string fieldsGetter func(ctrl *gomock.Controller) fields args args - want *manage.UpdateCommitLabelsResponse wantErr error }{ { - name: "成功更新提交标签", + name: "mapping without prompt id", fieldsGetter: func(ctrl *gomock.Controller) fields { - mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), int64(100), []int64{1}, consts.ActionLoopPromptEdit).Return(nil) - - mockPromptService := servicemocks.NewMockIPromptService(ctrl) - mockPromptService.EXPECT().UpdateCommitLabels(gomock.Any(), gomock.Any()).Return(nil) - - return fields{ - authRPCProvider: mockAuth, - promptService: mockPromptService, - } + auth := mocks.NewMockIAuthProvider(ctrl) + auth.EXPECT().CheckSpacePermission(gomock.Any(), int64(70), consts.ActionWorkspaceListLoopPrompt).Return(nil) + return fields{authProvider: auth} }, args: args{ - ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.UpdateCommitLabelsRequest{ - WorkspaceID: ptr.Of(int64(100)), - PromptID: ptr.Of(int64(1)), - CommitVersion: ptr.Of("1.0.0"), - LabelKeys: []string{"label1", "label2"}, - }, + ctx: session.WithCtxUser(context.Background(), &session.User{ID: "user"}), + request: &manage.ListLabelRequest{WorkspaceID: ptr.Of(int64(70)), WithPromptVersionMapping: ptr.Of(true)}, }, - want: manage.NewUpdateCommitLabelsResponse(), - wantErr: nil, + wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("PromptID must be provided when WithPromptVersionMapping is true")), }, { - name: "用户未找到", + name: "invalid page token", fieldsGetter: func(ctrl *gomock.Controller) fields { - return fields{} + auth := mocks.NewMockIAuthProvider(ctrl) + auth.EXPECT().CheckSpacePermission(gomock.Any(), int64(70), consts.ActionWorkspaceListLoopPrompt).Return(nil) + return fields{authProvider: auth} }, args: args{ - ctx: context.Background(), - request: &manage.UpdateCommitLabelsRequest{ - WorkspaceID: ptr.Of(int64(100)), - PromptID: ptr.Of(int64(1)), - CommitVersion: ptr.Of("1.0.0"), - LabelKeys: []string{"label1", "label2"}, - }, + ctx: session.WithCtxUser(context.Background(), &session.User{ID: "user"}), + request: &manage.ListLabelRequest{WorkspaceID: ptr.Of(int64(70)), PageToken: ptr.Of("bad")}, }, - want: manage.NewUpdateCommitLabelsResponse(), - wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("User not found")), + wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("Invalid page token")), }, { - name: "权限检查失败", + name: "success", fieldsGetter: func(ctrl *gomock.Controller) fields { - mockAuth := mocks.NewMockIAuthProvider(ctrl) - mockAuth.EXPECT().MCheckPromptPermission(gomock.Any(), int64(100), []int64{1}, consts.ActionLoopPromptEdit).Return(errorx.New("permission denied")) - - return fields{ - authRPCProvider: mockAuth, - } + auth := mocks.NewMockIAuthProvider(ctrl) + auth.EXPECT().CheckSpacePermission(gomock.Any(), int64(70), consts.ActionWorkspaceListLoopPrompt).Return(nil) + promptSvc := servicemocks.NewMockIPromptService(ctrl) + next := int64(88) + promptSvc.EXPECT().ListLabel(gomock.Any(), service.ListLabelParam{SpaceID: 70, LabelKeyLike: "key", PageSize: 10}).Return([]*entity.PromptLabel{{LabelKey: "label"}}, &next, nil) + promptSvc.EXPECT().BatchGetLabelMappingPromptVersion(gomock.Any(), []service.PromptLabelQuery{{PromptID: 99, LabelKey: "label"}}).Return(map[service.PromptLabelQuery]string{{PromptID: 99, LabelKey: "label"}: "1.0.0"}, nil) + return fields{authProvider: auth, promptService: promptSvc} }, args: args{ - ctx: session.WithCtxUser(context.Background(), &session.User{ID: "123"}), - request: &manage.UpdateCommitLabelsRequest{ - WorkspaceID: ptr.Of(int64(100)), - PromptID: ptr.Of(int64(1)), - CommitVersion: ptr.Of("1.0.0"), - LabelKeys: []string{"label1", "label2"}, + ctx: session.WithCtxUser(context.Background(), &session.User{ID: "user"}), + request: &manage.ListLabelRequest{ + WorkspaceID: ptr.Of(int64(70)), + PromptID: ptr.Of(int64(99)), + LabelKeyLike: ptr.Of("key"), + PageSize: ptr.Of(int32(10)), + WithPromptVersionMapping: ptr.Of(true), }, }, - want: manage.NewUpdateCommitLabelsResponse(), - wantErr: errorx.New("permission denied"), + wantErr: nil, }, } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { + caseData := tt + t.Run(caseData.name, func(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) defer ctrl.Finish() - ttFields := tt.fieldsGetter(ctrl) + tFields := caseData.fieldsGetter(ctrl) app := &PromptManageApplicationImpl{ - manageRepo: ttFields.manageRepo, - labelRepo: ttFields.labelRepo, - promptService: ttFields.promptService, - authRPCProvider: ttFields.authRPCProvider, - userRPCProvider: ttFields.userRPCProvider, - auditRPCProvider: ttFields.auditRPCProvider, - configProvider: ttFields.configProvider, + authRPCProvider: tFields.authProvider, + promptService: tFields.promptService, } - got, err := app.UpdateCommitLabels(tt.args.ctx, tt.args.request) - unittest.AssertErrorEqual(t, tt.wantErr, err) + resp, err := app.ListLabel(caseData.args.ctx, caseData.args.request) + unittest.AssertErrorEqual(t, caseData.wantErr, err) if err == nil { - assert.Equal(t, tt.want, got) + assert.Len(t, resp.Labels, 1) + assert.Equal(t, ptr.Of("88"), resp.NextPageToken) + assert.Equal(t, "1.0.0", resp.PromptVersionMapping["label"]) } }) } diff --git a/backend/modules/prompt/application/openapi.go b/backend/modules/prompt/application/openapi.go index e6305bbeb..d71acc5b4 100644 --- a/backend/modules/prompt/application/openapi.go +++ b/backend/modules/prompt/application/openapi.go @@ -232,9 +232,13 @@ func (p *PromptOpenAPIApplicationImpl) fetchPromptResults(ctx context.Context, r return nil, err } - // 构建版本映射 + // 展开片段内容(若有),构建版本映射 promptMap := make(map[service.PromptKeyVersionPair]*entity.Prompt) for _, prompt := range maps.Values(prompts) { + err = p.promptService.ExpandSnippets(ctx, prompt) + if err != nil { + return nil, err + } promptMap[service.PromptKeyVersionPair{ PromptKey: prompt.PromptKey, Version: prompt.GetVersion(), @@ -387,6 +391,11 @@ func (p *PromptOpenAPIApplicationImpl) doExecute(ctx context.Context, req *opena if err != nil { return promptDO, nil, err } + // expand snippets + err = p.promptService.ExpandSnippets(ctx, promptDO) + if err != nil { + return promptDO, nil, err + } // 执行权限检查 if err = p.auth.MCheckPromptPermissionForOpenAPI(ctx, req.GetWorkspaceID(), []int64{promptDO.ID}, consts.ActionLoopPromptExecute); err != nil { @@ -483,6 +492,11 @@ func (p *PromptOpenAPIApplicationImpl) doExecuteStreaming(ctx context.Context, r if err != nil { return promptDO, nil, err } + // expand snippets + err = p.promptService.ExpandSnippets(ctx, promptDO) + if err != nil { + return promptDO, nil, err + } // 执行权限检查 if err = p.auth.MCheckPromptPermissionForOpenAPI(ctx, req.GetWorkspaceID(), []int64{promptDO.ID}, consts.ActionLoopPromptExecute); err != nil { diff --git a/backend/modules/prompt/application/openapi_test.go b/backend/modules/prompt/application/openapi_test.go index 3fd599ec0..46064ef09 100644 --- a/backend/modules/prompt/application/openapi_test.go +++ b/backend/modules/prompt/application/openapi_test.go @@ -65,6 +65,7 @@ func TestPromptOpenAPIApplicationImpl_BatchGetPromptByPromptKey(t *testing.T) { name: "success: specific version", fieldsGetter: func(ctrl *gomock.Controller) fields { mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[string]int64{ "test_prompt1": 123, "test_prompt2": 456, @@ -257,10 +258,100 @@ func TestPromptOpenAPIApplicationImpl_BatchGetPromptByPromptKey(t *testing.T) { }, wantErr: nil, }, + { + name: "expand snippets error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(errorx.New("expand error")) + mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[string]int64{ + "test_prompt1": 123, + }, nil) + mockPromptService.EXPECT().MParseCommitVersion(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[service.PromptQueryParam]string{ + {PromptID: 123, PromptKey: "test_prompt1", Version: "1.0.0"}: "1.0.0", + }, nil) + + mockManageRepo := repomocks.NewMockIManageRepo(ctrl) + startTime := time.Now() + mockManageRepo.EXPECT().MGetPrompt(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[repo.GetPromptParam]*entity.Prompt{ + { + PromptID: 123, + WithCommit: true, + CommitVersion: "1.0.0", + }: { + ID: 123, + SpaceID: 123456, + PromptKey: "test_prompt1", + PromptBasic: &entity.PromptBasic{ + DisplayName: "Test Prompt 1", + Description: "Test PromptDescription 1", + LatestVersion: "1.0.0", + CreatedBy: "test_user", + UpdatedBy: "test_user", + CreatedAt: startTime, + UpdatedAt: startTime, + }, + PromptCommit: &entity.PromptCommit{ + CommitInfo: &entity.CommitInfo{ + Version: "1.0.0", + BaseVersion: "0.9.0", + Description: "Initial version", + CommittedBy: "test_user", + CommittedAt: startTime, + }, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + TemplateType: entity.TemplateTypeNormal, + Messages: []*entity.Message{ + { + Role: entity.RoleSystem, + Content: ptr.Of("You are a helpful assistant."), + }, + }, + }, + }, + }, + }, + }, nil) + + mockConfig := confmocks.NewMockIConfigProvider(ctrl) + mockConfig.EXPECT().GetPromptHubMaxQPSBySpace(gomock.Any(), gomock.Any()).Return(10, nil) + + mockRateLimiter := limitermocks.NewMockIRateLimiter(ctrl) + mockRateLimiter.EXPECT().AllowN(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&limiter.Result{ + Allowed: true, + }, nil) + + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockAuth.EXPECT().MCheckPromptPermissionForOpenAPI(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + + return fields{ + promptService: mockPromptService, + promptManageRepo: mockManageRepo, + config: mockConfig, + auth: mockAuth, + rateLimiter: mockRateLimiter, + } + }, + args: args{ + ctx: context.Background(), + req: &openapi.BatchGetPromptByPromptKeyRequest{ + WorkspaceID: ptr.Of(int64(123456)), + Queries: []*openapi.PromptQuery{ + { + PromptKey: ptr.Of("test_prompt1"), + Version: ptr.Of("1.0.0"), + }, + }, + }, + }, + wantR: nil, + wantErr: errorx.New("expand error"), + }, { name: "success: latest commit version", fieldsGetter: func(ctrl *gomock.Controller) fields { mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[string]int64{ "test_prompt1": 123, }, nil) @@ -515,6 +606,7 @@ func TestPromptOpenAPIApplicationImpl_BatchGetPromptByPromptKey(t *testing.T) { name: "mget prompt ids error", fieldsGetter: func(ctrl *gomock.Controller) fields { mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, errors.New("database error")) @@ -551,6 +643,7 @@ func TestPromptOpenAPIApplicationImpl_BatchGetPromptByPromptKey(t *testing.T) { name: "permission check failed", fieldsGetter: func(ctrl *gomock.Controller) fields { mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[string]int64{ "test_prompt1": 123, }, nil) @@ -593,6 +686,7 @@ func TestPromptOpenAPIApplicationImpl_BatchGetPromptByPromptKey(t *testing.T) { name: "parse commit version error", fieldsGetter: func(ctrl *gomock.Controller) fields { mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[string]int64{ "test_prompt1": 123, }, nil) @@ -636,6 +730,7 @@ func TestPromptOpenAPIApplicationImpl_BatchGetPromptByPromptKey(t *testing.T) { name: "mget prompt error", fieldsGetter: func(ctrl *gomock.Controller) fields { mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[string]int64{ "test_prompt1": 123, }, nil) @@ -684,6 +779,7 @@ func TestPromptOpenAPIApplicationImpl_BatchGetPromptByPromptKey(t *testing.T) { name: "prompt version not exist", fieldsGetter: func(ctrl *gomock.Controller) fields { mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[string]int64{ "test_prompt1": 123, }, nil) @@ -772,6 +868,7 @@ func TestPromptOpenAPIApplicationImpl_BatchGetPromptByPromptKey(t *testing.T) { name: "enhanced error info with prompt_key", fieldsGetter: func(ctrl *gomock.Controller) fields { mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[string]int64{ "test_prompt1": 123, }, nil) @@ -823,6 +920,7 @@ func TestPromptOpenAPIApplicationImpl_BatchGetPromptByPromptKey(t *testing.T) { name: "success: query with label", fieldsGetter: func(ctrl *gomock.Controller) fields { mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[string]int64{ "test_prompt1": 123, }, nil) @@ -948,6 +1046,7 @@ func TestPromptOpenAPIApplicationImpl_BatchGetPromptByPromptKey(t *testing.T) { name: "success: mixed version and label queries", fieldsGetter: func(ctrl *gomock.Controller) fields { mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[string]int64{ "test_prompt1": 123, "test_prompt2": 456, @@ -1145,6 +1244,7 @@ func TestPromptOpenAPIApplicationImpl_BatchGetPromptByPromptKey(t *testing.T) { name: "error: label not found", fieldsGetter: func(ctrl *gomock.Controller) fields { mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[string]int64{ "test_prompt1": 123, }, nil) @@ -1188,6 +1288,7 @@ func TestPromptOpenAPIApplicationImpl_BatchGetPromptByPromptKey(t *testing.T) { name: "error: prompt key not found in result construction", fieldsGetter: func(ctrl *gomock.Controller) fields { mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[string]int64{ "test_prompt1": 123, // test_prompt2 不存在,但在查询构建阶段会被跳过,在结果构建阶段会报错 @@ -1285,6 +1386,7 @@ func TestPromptOpenAPIApplicationImpl_BatchGetPromptByPromptKey(t *testing.T) { name: "error: prompt version not exist in result construction", fieldsGetter: func(ctrl *gomock.Controller) fields { mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[string]int64{ "test_prompt1": 123, }, nil) @@ -1704,6 +1806,7 @@ func TestPromptOpenAPIApplicationImpl_getPromptByPromptKey(t *testing.T) { name: "success: get prompt by key and version", fieldsGetter: func(ctrl *gomock.Controller) fields { mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), int64(123456), []string{"test_prompt"}).Return(map[string]int64{ "test_prompt": 123, }, nil) @@ -1824,6 +1927,7 @@ func TestPromptOpenAPIApplicationImpl_getPromptByPromptKey(t *testing.T) { name: "error: get prompt IDs failed", fieldsGetter: func(ctrl *gomock.Controller) fields { mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), int64(123456), []string{"test_prompt"}).Return(nil, errors.New("database error")) return fields{ @@ -1845,6 +1949,7 @@ func TestPromptOpenAPIApplicationImpl_getPromptByPromptKey(t *testing.T) { name: "error: parse commit version failed", fieldsGetter: func(ctrl *gomock.Controller) fields { mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), int64(123456), []string{"test_prompt"}).Return(map[string]int64{ "test_prompt": 123, }, nil) @@ -1871,6 +1976,7 @@ func TestPromptOpenAPIApplicationImpl_getPromptByPromptKey(t *testing.T) { name: "error: get prompt failed", fieldsGetter: func(ctrl *gomock.Controller) fields { mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), int64(123456), []string{"test_prompt"}).Return(map[string]int64{ "test_prompt": 123, }, nil) @@ -1903,6 +2009,7 @@ func TestPromptOpenAPIApplicationImpl_getPromptByPromptKey(t *testing.T) { name: "error: prompt version not exist with enhanced error info", fieldsGetter: func(ctrl *gomock.Controller) fields { mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), int64(123456), []string{"test_prompt"}).Return(map[string]int64{ "test_prompt": 123, }, nil) @@ -1938,6 +2045,7 @@ func TestPromptOpenAPIApplicationImpl_getPromptByPromptKey(t *testing.T) { name: "success: get prompt by label", fieldsGetter: func(ctrl *gomock.Controller) fields { mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), int64(123456), []string{"test_prompt"}).Return(map[string]int64{ "test_prompt": 123, }, nil) @@ -2292,6 +2400,7 @@ func TestPromptOpenAPIApplicationImpl_doExecute(t *testing.T) { }, nil) mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), int64(123456), []string{"test_prompt"}).Return(map[string]int64{ "test_prompt": 123, }, nil) @@ -2456,6 +2565,7 @@ func TestPromptOpenAPIApplicationImpl_doExecute(t *testing.T) { }, }, } + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(expectedReply, nil) mockPromptService.EXPECT().MConvertBase64DataURLToFileURL(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("convert error")) @@ -2527,6 +2637,7 @@ func TestPromptOpenAPIApplicationImpl_doExecute(t *testing.T) { }, nil) mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), int64(123456), []string{"test_prompt"}).Return(nil, errors.New("database error")) return fields{ @@ -2561,6 +2672,7 @@ func TestPromptOpenAPIApplicationImpl_doExecute(t *testing.T) { }, nil) mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), int64(123456), []string{"test_prompt"}).Return(map[string]int64{ "test_prompt": 123, }, nil) @@ -2619,6 +2731,7 @@ func TestPromptOpenAPIApplicationImpl_doExecute(t *testing.T) { }, nil) mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), int64(123456), []string{"test_prompt"}).Return(map[string]int64{ "test_prompt": 123, }, nil) @@ -2741,6 +2854,7 @@ func TestPromptOpenAPIApplicationImpl_Execute(t *testing.T) { }, nil) mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), int64(123456), []string{"test_prompt"}).Return(map[string]int64{ "test_prompt": 123, }, nil) @@ -2867,6 +2981,7 @@ func TestPromptOpenAPIApplicationImpl_Execute(t *testing.T) { }, nil) mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), int64(123456), []string{"test_prompt"}).Return(map[string]int64{ "test_prompt": 123, }, nil) @@ -3010,6 +3125,7 @@ func TestPromptOpenAPIApplicationImpl_Execute(t *testing.T) { }, nil) mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), int64(123456), []string{"test_prompt"}).Return(map[string]int64{ "test_prompt": 123, }, nil) @@ -3283,6 +3399,7 @@ func TestPromptOpenAPIApplicationImpl_ExecuteStreaming(t *testing.T) { }, }, } + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil) mockPromptService.EXPECT().ExecuteStreaming(gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, param service.ExecuteStreamingParam) (*entity.Reply, error) { // 模拟发送多个流式响应 - 使用同步方式避免竞争条件 @@ -3391,6 +3508,7 @@ func TestPromptOpenAPIApplicationImpl_ExecuteStreaming(t *testing.T) { mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) mockAuth.EXPECT().MCheckPromptPermissionForOpenAPI(gomock.Any(), int64(123456), []int64{123}, consts.ActionLoopPromptExecute).Return(nil) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil) mockPromptService.EXPECT().ExecuteStreaming(gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, param service.ExecuteStreamingParam) (*entity.Reply, error) { param.ResultStream <- &entity.Reply{ @@ -3654,6 +3772,7 @@ func TestPromptOpenAPIApplicationImpl_ExecuteStreaming(t *testing.T) { }, nil) mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), int64(123456), []string{"test_prompt"}).Return(map[string]int64{ "test_prompt": 123, }, nil) @@ -3760,6 +3879,7 @@ func TestPromptOpenAPIApplicationImpl_ExecuteStreaming(t *testing.T) { }, nil) mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), int64(123456), []string{"test_prompt"}).Return(nil, errors.New("database error")) mockCollector := collectormocks.NewMockICollectorProvider(ctrl) @@ -3811,6 +3931,7 @@ func TestPromptOpenAPIApplicationImpl_ExecuteStreaming(t *testing.T) { }, nil) mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), int64(123456), []string{"test_prompt"}).Return(map[string]int64{ "test_prompt": 123, }, nil) @@ -3923,6 +4044,7 @@ func TestPromptOpenAPIApplicationImpl_ExecuteStreaming(t *testing.T) { }, nil) mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), int64(123456), []string{"test_prompt"}).Return(map[string]int64{ "test_prompt": 123, }, nil) @@ -4064,6 +4186,7 @@ func TestPromptOpenAPIApplicationImpl_ExecuteStreaming(t *testing.T) { }, nil) mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), int64(123456), []string{"test_prompt"}).Return(map[string]int64{ "test_prompt": 123, }, nil) @@ -4205,6 +4328,7 @@ func TestPromptOpenAPIApplicationImpl_ExecuteStreaming(t *testing.T) { }, nil) mockPromptService := servicemocks.NewMockIPromptService(ctrl) + mockPromptService.EXPECT().ExpandSnippets(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockPromptService.EXPECT().MGetPromptIDs(gomock.Any(), int64(123456), []string{"test_prompt"}).Return(map[string]int64{ "test_prompt": 123, }, nil) diff --git a/backend/modules/prompt/application/wire.go b/backend/modules/prompt/application/wire.go index 37359831b..bc3584c24 100644 --- a/backend/modules/prompt/application/wire.go +++ b/backend/modules/prompt/application/wire.go @@ -45,6 +45,7 @@ var ( mysql.NewPromptBasicDAO, mysql.NewPromptCommitDAO, mysql.NewPromptUserDraftDAO, + mysql.NewPromptRelationDAO, mysql.NewLabelDAO, mysql.NewCommitLabelMappingDAO, mysql.NewDebugLogDAO, @@ -59,6 +60,7 @@ var ( rpc.NewUserRPCProvider, rpc.NewAuditRPCProvider, collector.NewEventCollectorProvider, + service.NewCozeLoopSnippetParser, ) manageSet = wire.NewSet( NewPromptManageApplication, diff --git a/backend/modules/prompt/application/wire_gen.go b/backend/modules/prompt/application/wire_gen.go index b506425eb..acd15bcee 100644 --- a/backend/modules/prompt/application/wire_gen.go +++ b/backend/modules/prompt/application/wire_gen.go @@ -40,9 +40,10 @@ func InitPromptManageApplication(idgen2 idgen.IIDGenerator, db2 db.Provider, red iPromptCommitDAO := mysql.NewPromptCommitDAO(db2, redisCli) iPromptUserDraftDAO := mysql.NewPromptUserDraftDAO(db2, redisCli) iCommitLabelMappingDAO := mysql.NewCommitLabelMappingDAO(db2, redisCli) + iPromptRelationDAO := mysql.NewPromptRelationDAO(db2, redisCli) redisIPromptBasicDAO := redis2.NewPromptBasicDAO() iPromptDAO := redis2.NewPromptDAO() - iManageRepo := repo.NewManageRepo(db2, idgen2, meter, iPromptBasicDAO, iPromptCommitDAO, iPromptUserDraftDAO, iCommitLabelMappingDAO, redisIPromptBasicDAO, iPromptDAO) + iManageRepo := repo.NewManageRepo(db2, idgen2, meter, iPromptBasicDAO, iPromptCommitDAO, iPromptUserDraftDAO, iCommitLabelMappingDAO, iPromptRelationDAO, redisIPromptBasicDAO, iPromptDAO) iLabelDAO := mysql.NewLabelDAO(db2, redisCli) iConfigProvider, err := conf2.NewPromptConfigProvider(configFactory) if err != nil { @@ -57,7 +58,8 @@ func InitPromptManageApplication(idgen2 idgen.IIDGenerator, db2 db.Provider, red iDebugContextRepo := repo.NewDebugContextRepo(idgen2, iDebugContextDAO) illmProvider := rpc.NewLLMRPCProvider(llmClient) iFileProvider := rpc.NewFileRPCProvider(fileClient) - iPromptService := service.NewPromptService(iPromptFormatter, idgen2, iDebugLogRepo, iDebugContextRepo, iManageRepo, iLabelRepo, iConfigProvider, illmProvider, iFileProvider) + snippetParser := service.NewCozeLoopSnippetParser() + iPromptService := service.NewPromptService(iPromptFormatter, idgen2, iDebugLogRepo, iDebugContextRepo, iManageRepo, iLabelRepo, iConfigProvider, illmProvider, iFileProvider, snippetParser) iAuthProvider := rpc.NewAuthRPCProvider(authClient) iUserProvider := rpc.NewUserRPCProvider(userClient) iAuditProvider := rpc.NewAuditRPCProvider(auditClient) @@ -75,9 +77,10 @@ func InitPromptDebugApplication(idgen2 idgen.IIDGenerator, db2 db.Provider, redi iPromptCommitDAO := mysql.NewPromptCommitDAO(db2, redisCli) iPromptUserDraftDAO := mysql.NewPromptUserDraftDAO(db2, redisCli) iCommitLabelMappingDAO := mysql.NewCommitLabelMappingDAO(db2, redisCli) + iPromptRelationDAO := mysql.NewPromptRelationDAO(db2, redisCli) redisIPromptBasicDAO := redis2.NewPromptBasicDAO() iPromptDAO := redis2.NewPromptDAO() - iManageRepo := repo.NewManageRepo(db2, idgen2, meter, iPromptBasicDAO, iPromptCommitDAO, iPromptUserDraftDAO, iCommitLabelMappingDAO, redisIPromptBasicDAO, iPromptDAO) + iManageRepo := repo.NewManageRepo(db2, idgen2, meter, iPromptBasicDAO, iPromptCommitDAO, iPromptUserDraftDAO, iCommitLabelMappingDAO, iPromptRelationDAO, redisIPromptBasicDAO, iPromptDAO) iLabelDAO := mysql.NewLabelDAO(db2, redisCli) iConfigProvider, err := conf2.NewPromptConfigProvider(configFactory) if err != nil { @@ -87,7 +90,8 @@ func InitPromptDebugApplication(idgen2 idgen.IIDGenerator, db2 db.Provider, redi iLabelRepo := repo.NewLabelRepo(db2, idgen2, meter, iLabelDAO, iCommitLabelMappingDAO, iPromptBasicDAO, iPromptLabelVersionDAO) illmProvider := rpc.NewLLMRPCProvider(llmClient) iFileProvider := rpc.NewFileRPCProvider(fileClient) - iPromptService := service.NewPromptService(iPromptFormatter, idgen2, iDebugLogRepo, iDebugContextRepo, iManageRepo, iLabelRepo, iConfigProvider, illmProvider, iFileProvider) + snippetParser := service.NewCozeLoopSnippetParser() + iPromptService := service.NewPromptService(iPromptFormatter, idgen2, iDebugLogRepo, iDebugContextRepo, iManageRepo, iLabelRepo, iConfigProvider, illmProvider, iFileProvider, snippetParser) iAuthProvider := rpc.NewAuthRPCProvider(authClient) promptDebugService := NewPromptDebugApplication(iDebugLogRepo, iDebugContextRepo, iPromptService, benefitSvc, iAuthProvider, iFileProvider) return promptDebugService, nil @@ -103,9 +107,10 @@ func InitPromptExecuteApplication(idgen2 idgen.IIDGenerator, db2 db.Provider, re iPromptCommitDAO := mysql.NewPromptCommitDAO(db2, redisCli) iPromptUserDraftDAO := mysql.NewPromptUserDraftDAO(db2, redisCli) iCommitLabelMappingDAO := mysql.NewCommitLabelMappingDAO(db2, redisCli) + iPromptRelationDAO := mysql.NewPromptRelationDAO(db2, redisCli) redisIPromptBasicDAO := redis2.NewPromptBasicDAO() iPromptDAO := redis2.NewPromptDAO() - iManageRepo := repo.NewManageRepo(db2, idgen2, meter, iPromptBasicDAO, iPromptCommitDAO, iPromptUserDraftDAO, iCommitLabelMappingDAO, redisIPromptBasicDAO, iPromptDAO) + iManageRepo := repo.NewManageRepo(db2, idgen2, meter, iPromptBasicDAO, iPromptCommitDAO, iPromptUserDraftDAO, iCommitLabelMappingDAO, iPromptRelationDAO, redisIPromptBasicDAO, iPromptDAO) iLabelDAO := mysql.NewLabelDAO(db2, redisCli) iConfigProvider, err := conf2.NewPromptConfigProvider(configFactory) if err != nil { @@ -115,7 +120,8 @@ func InitPromptExecuteApplication(idgen2 idgen.IIDGenerator, db2 db.Provider, re iLabelRepo := repo.NewLabelRepo(db2, idgen2, meter, iLabelDAO, iCommitLabelMappingDAO, iPromptBasicDAO, iPromptLabelVersionDAO) illmProvider := rpc.NewLLMRPCProvider(llmClient) iFileProvider := rpc.NewFileRPCProvider(fileClient) - iPromptService := service.NewPromptService(iPromptFormatter, idgen2, iDebugLogRepo, iDebugContextRepo, iManageRepo, iLabelRepo, iConfigProvider, illmProvider, iFileProvider) + snippetParser := service.NewCozeLoopSnippetParser() + iPromptService := service.NewPromptService(iPromptFormatter, idgen2, iDebugLogRepo, iDebugContextRepo, iManageRepo, iLabelRepo, iConfigProvider, illmProvider, iFileProvider, snippetParser) promptExecuteService := NewPromptExecuteApplication(iPromptService, iManageRepo) return promptExecuteService, nil } @@ -130,9 +136,10 @@ func InitPromptOpenAPIApplication(idgen2 idgen.IIDGenerator, db2 db.Provider, re iPromptCommitDAO := mysql.NewPromptCommitDAO(db2, redisCli) iPromptUserDraftDAO := mysql.NewPromptUserDraftDAO(db2, redisCli) iCommitLabelMappingDAO := mysql.NewCommitLabelMappingDAO(db2, redisCli) + iPromptRelationDAO := mysql.NewPromptRelationDAO(db2, redisCli) redisIPromptBasicDAO := redis2.NewPromptBasicDAO() iPromptDAO := redis2.NewPromptDAO() - iManageRepo := repo.NewManageRepo(db2, idgen2, meter, iPromptBasicDAO, iPromptCommitDAO, iPromptUserDraftDAO, iCommitLabelMappingDAO, redisIPromptBasicDAO, iPromptDAO) + iManageRepo := repo.NewManageRepo(db2, idgen2, meter, iPromptBasicDAO, iPromptCommitDAO, iPromptUserDraftDAO, iCommitLabelMappingDAO, iPromptRelationDAO, redisIPromptBasicDAO, iPromptDAO) iLabelDAO := mysql.NewLabelDAO(db2, redisCli) iConfigProvider, err := conf2.NewPromptConfigProvider(configFactory) if err != nil { @@ -142,7 +149,8 @@ func InitPromptOpenAPIApplication(idgen2 idgen.IIDGenerator, db2 db.Provider, re iLabelRepo := repo.NewLabelRepo(db2, idgen2, meter, iLabelDAO, iCommitLabelMappingDAO, iPromptBasicDAO, iPromptLabelVersionDAO) illmProvider := rpc.NewLLMRPCProvider(llmClient) iFileProvider := rpc.NewFileRPCProvider(fileClient) - iPromptService := service.NewPromptService(iPromptFormatter, idgen2, iDebugLogRepo, iDebugContextRepo, iManageRepo, iLabelRepo, iConfigProvider, illmProvider, iFileProvider) + snippetParser := service.NewCozeLoopSnippetParser() + iPromptService := service.NewPromptService(iPromptFormatter, idgen2, iDebugLogRepo, iDebugContextRepo, iManageRepo, iLabelRepo, iConfigProvider, illmProvider, iFileProvider, snippetParser) iAuthProvider := rpc.NewAuthRPCProvider(authClient) iCollectorProvider := collector.NewEventCollectorProvider() promptOpenAPIService, err := NewPromptOpenAPIApplication(iPromptService, iManageRepo, iConfigProvider, iAuthProvider, limiterFactory, iCollectorProvider) @@ -155,7 +163,7 @@ func InitPromptOpenAPIApplication(idgen2 idgen.IIDGenerator, db2 db.Provider, re // wire.go: var ( - promptDomainSet = wire.NewSet(service.NewPromptFormatter, service.NewPromptService, repo.NewManageRepo, repo.NewLabelRepo, repo.NewDebugLogRepo, repo.NewDebugContextRepo, mysql.NewPromptBasicDAO, mysql.NewPromptCommitDAO, mysql.NewPromptUserDraftDAO, mysql.NewLabelDAO, mysql.NewCommitLabelMappingDAO, mysql.NewDebugLogDAO, mysql.NewDebugContextDAO, redis2.NewPromptBasicDAO, redis2.NewPromptDAO, redis2.NewPromptLabelVersionDAO, conf2.NewPromptConfigProvider, rpc.NewLLMRPCProvider, rpc.NewAuthRPCProvider, rpc.NewFileRPCProvider, rpc.NewUserRPCProvider, rpc.NewAuditRPCProvider, collector.NewEventCollectorProvider) + promptDomainSet = wire.NewSet(service.NewPromptFormatter, service.NewPromptService, repo.NewManageRepo, repo.NewLabelRepo, repo.NewDebugLogRepo, repo.NewDebugContextRepo, mysql.NewPromptBasicDAO, mysql.NewPromptCommitDAO, mysql.NewPromptUserDraftDAO, mysql.NewPromptRelationDAO, mysql.NewLabelDAO, mysql.NewCommitLabelMappingDAO, mysql.NewDebugLogDAO, mysql.NewDebugContextDAO, redis2.NewPromptBasicDAO, redis2.NewPromptDAO, redis2.NewPromptLabelVersionDAO, conf2.NewPromptConfigProvider, rpc.NewLLMRPCProvider, rpc.NewAuthRPCProvider, rpc.NewFileRPCProvider, rpc.NewUserRPCProvider, rpc.NewAuditRPCProvider, collector.NewEventCollectorProvider, service.NewCozeLoopSnippetParser) manageSet = wire.NewSet( NewPromptManageApplication, promptDomainSet, diff --git a/backend/modules/prompt/domain/entity/prompt_basic.go b/backend/modules/prompt/domain/entity/prompt_basic.go index 3648c2342..6c2bf277a 100644 --- a/backend/modules/prompt/domain/entity/prompt_basic.go +++ b/backend/modules/prompt/domain/entity/prompt_basic.go @@ -6,6 +6,7 @@ package entity import "time" type PromptBasic struct { + PromptType PromptType `json:"prompt_type"` DisplayName string `json:"display_name"` Description string `json:"description"` LatestVersion string `json:"latest_version"` @@ -15,3 +16,10 @@ type PromptBasic struct { UpdatedAt time.Time `json:"updated_at"` LatestCommittedAt *time.Time `json:"latest_committed_at"` } + +type PromptType string + +const ( + PromptTypeNormal PromptType = "normal" + PromptTypeSnippet PromptType = "snippet" +) diff --git a/backend/modules/prompt/domain/entity/prompt_detail.go b/backend/modules/prompt/domain/entity/prompt_detail.go index c5558dc19..f7c76a696 100644 --- a/backend/modules/prompt/domain/entity/prompt_detail.go +++ b/backend/modules/prompt/domain/entity/prompt_detail.go @@ -34,10 +34,13 @@ type PromptDetail struct { } type PromptTemplate struct { - TemplateType TemplateType `json:"template_type"` - Messages []*Message `json:"messages,omitempty"` - VariableDefs []*VariableDef `json:"variable_defs,omitempty"` - Metadata map[string]string `json:"metadata,omitempty"` + TemplateType TemplateType `json:"template_type"` + Messages []*Message `json:"messages,omitempty"` + VariableDefs []*VariableDef `json:"variable_defs,omitempty"` + + HasSnippets bool `json:"has_snippets"` + Snippets []*Prompt `json:"snippets,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` } type TemplateType string diff --git a/backend/modules/prompt/domain/repo/manage.go b/backend/modules/prompt/domain/repo/manage.go index ea151ed05..4ecf0e7a4 100644 --- a/backend/modules/prompt/domain/repo/manage.go +++ b/backend/modules/prompt/domain/repo/manage.go @@ -17,10 +17,12 @@ type IManageRepo interface { MGetPrompt(ctx context.Context, queries []GetPromptParam, opts ...GetPromptOptionFunc) (promptDOMap map[GetPromptParam]*entity.Prompt, err error) MGetPromptBasicByPromptKey(ctx context.Context, spaceID int64, promptKeys []string, opts ...GetPromptBasicOptionFunc) (promptDOs []*entity.Prompt, err error) ListPrompt(ctx context.Context, param ListPromptParam) (result *ListPromptResult, err error) + ListParentPrompt(ctx context.Context, param ListParentPromptParam) (result map[string][]*PromptCommitVersions, err error) UpdatePrompt(ctx context.Context, param UpdatePromptParam) (err error) SaveDraft(ctx context.Context, promptDO *entity.Prompt) (draftInfo *entity.DraftInfo, err error) CommitDraft(ctx context.Context, param CommitDraftParam) (err error) ListCommitInfo(ctx context.Context, param ListCommitInfoParam) (result *ListCommitResult, err error) + MGetVersionsByPromptID(ctx context.Context, promptID int64) (versions []string, err error) } type GetPromptParam struct { @@ -36,10 +38,12 @@ type GetPromptParam struct { type ListPromptParam struct { SpaceID int64 - KeyWord string - CreatedBys []string - UserID string - CommittedOnly bool + KeyWord string + CreatedBys []string + UserID string + CommittedOnly bool + FilterPromptTypes []entity.PromptType + PromptIDs []int64 PageNum int PageSize int @@ -80,9 +84,29 @@ type ListCommitInfoParam struct { type ListCommitResult struct { CommitInfoDOs []*entity.CommitInfo + CommitDOs []*entity.PromptCommit NextPageToken int64 } +type ListParentPromptParam struct { + SubPromptID int64 + SubPromptVersions []string +} + +type ListSubPromptParam struct { + PromptID int64 + PromptVersions []string + PromptDraftUserID string +} + +type PromptCommitVersions struct { + PromptID int64 + SpaceID int64 + PromptKey string + PromptBasic *entity.PromptBasic + CommitVersions []string +} + type CacheOption struct { CacheEnable bool } diff --git a/backend/modules/prompt/domain/repo/mocks/manage_repo.go b/backend/modules/prompt/domain/repo/mocks/manage_repo.go index 5291b890a..09a65463d 100644 --- a/backend/modules/prompt/domain/repo/mocks/manage_repo.go +++ b/backend/modules/prompt/domain/repo/mocks/manage_repo.go @@ -115,6 +115,21 @@ func (mr *MockIManageRepoMockRecorder) ListCommitInfo(ctx, param any) *gomock.Ca return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListCommitInfo", reflect.TypeOf((*MockIManageRepo)(nil).ListCommitInfo), ctx, param) } +// ListParentPrompt mocks base method. +func (m *MockIManageRepo) ListParentPrompt(ctx context.Context, param repo.ListParentPromptParam) (map[string][]*repo.PromptCommitVersions, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListParentPrompt", ctx, param) + ret0, _ := ret[0].(map[string][]*repo.PromptCommitVersions) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListParentPrompt indicates an expected call of ListParentPrompt. +func (mr *MockIManageRepoMockRecorder) ListParentPrompt(ctx, param any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListParentPrompt", reflect.TypeOf((*MockIManageRepo)(nil).ListParentPrompt), ctx, param) +} + // ListPrompt mocks base method. func (m *MockIManageRepo) ListPrompt(ctx context.Context, param repo.ListPromptParam) (*repo.ListPromptResult, error) { m.ctrl.T.Helper() @@ -170,6 +185,21 @@ func (mr *MockIManageRepoMockRecorder) MGetPromptBasicByPromptKey(ctx, spaceID, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetPromptBasicByPromptKey", reflect.TypeOf((*MockIManageRepo)(nil).MGetPromptBasicByPromptKey), varargs...) } +// MGetVersionsByPromptID mocks base method. +func (m *MockIManageRepo) MGetVersionsByPromptID(ctx context.Context, promptID int64) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MGetVersionsByPromptID", ctx, promptID) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MGetVersionsByPromptID indicates an expected call of MGetVersionsByPromptID. +func (mr *MockIManageRepoMockRecorder) MGetVersionsByPromptID(ctx, promptID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetVersionsByPromptID", reflect.TypeOf((*MockIManageRepo)(nil).MGetVersionsByPromptID), ctx, promptID) +} + // SaveDraft mocks base method. func (m *MockIManageRepo) SaveDraft(ctx context.Context, promptDO *entity.Prompt) (*entity.DraftInfo, error) { m.ctrl.T.Helper() diff --git a/backend/modules/prompt/domain/service/manage.go b/backend/modules/prompt/domain/service/manage.go index b14f3db12..eea9400c7 100644 --- a/backend/modules/prompt/domain/service/manage.go +++ b/backend/modules/prompt/domain/service/manage.go @@ -9,10 +9,13 @@ import ( "fmt" "strings" + "github.com/bytedance/gg/gmap" + "github.com/coze-dev/coze-loop/backend/modules/prompt/domain/entity" "github.com/coze-dev/coze-loop/backend/modules/prompt/domain/repo" prompterr "github.com/coze-dev/coze-loop/backend/modules/prompt/pkg/errno" "github.com/coze-dev/coze-loop/backend/pkg/errorx" + "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" "github.com/coze-dev/coze-loop/backend/pkg/logs" ) @@ -302,3 +305,293 @@ func (p *PromptServiceImpl) MParseCommitVersion(ctx context.Context, spaceID int return promptKeyCommitVersionMap, nil } + +// GetPrompt retrieves a prompt by its ID +func (p *PromptServiceImpl) GetPrompt(ctx context.Context, param GetPromptParam) (*entity.Prompt, error) { + promptDO, err := p.manageRepo.GetPrompt(ctx, repo.GetPromptParam{ + PromptID: param.PromptID, + WithCommit: param.WithCommit, + CommitVersion: param.CommitVersion, + WithDraft: param.WithDraft, + UserID: param.UserID, + }) + if err != nil { + return nil, err + } + + err = p.parseAndValidateSnippets(ctx, promptDO) + if err != nil { + return nil, err + } + + if param.ExpandSnippet { + // expand snippets + err = p.ExpandSnippets(ctx, promptDO) + if err != nil { + return nil, err + } + } + return promptDO, nil +} + +// CreatePrompt creates a prompt with optional snippet validation +func (p *PromptServiceImpl) CreatePrompt(ctx context.Context, promptDO *entity.Prompt) (promptID int64, err error) { + if promptDO == nil { + return 0, errorx.New("promptDO is empty") + } + + // Validate basic prompt information + if promptDO.SpaceID <= 0 { + return 0, errorx.New("spaceID is invalid: %d", promptDO.SpaceID) + } + if promptDO.PromptKey == "" { + return 0, errorx.New("promptKey is empty") + } + if promptDO.PromptBasic == nil { + return 0, errorx.New("promptBasic is empty") + } + + err = p.parseAndValidateSnippets(ctx, promptDO) + if err != nil { + return 0, err + } + + // Create the prompt through repository + promptID, err = p.manageRepo.CreatePrompt(ctx, promptDO) + if err != nil { + return 0, err + } + + return promptID, nil +} + +func (p *PromptServiceImpl) parseAndValidateSnippets(ctx context.Context, promptDO *entity.Prompt) error { + // Check if prompt has snippets based on the flag + hasSnippets := false + if promptDetail := promptDO.GetPromptDetail(); promptDetail != nil && promptDetail.PromptTemplate != nil { + hasSnippets = promptDetail.PromptTemplate.HasSnippets + } + if !hasSnippets { + return nil + } + + // Only parse and validate snippets if hasSnippets is true + if promptDetail := promptDO.GetPromptDetail(); promptDetail != nil && promptDetail.PromptTemplate != nil { + var allContent string + for _, message := range promptDetail.PromptTemplate.Messages { + if ptr.From(message.Content) != "" { + allContent += ptr.From(message.Content) + } + for _, part := range message.Parts { + if ptr.From(part.Text) != "" { + allContent += ptr.From(part.Text) + } + } + } + + var snippetRefs []*SnippetReference + var err error + if allContent != "" { + snippetRefs, err = p.snippetParser.ParseReferences(allContent) + if err != nil { + return errorx.WrapByCode(err, prompterr.CommonInvalidParamCode, + errorx.WithExtraMsg("failed to parse snippet references")) + } + } + + // Validate that snippets were actually found + if len(snippetRefs) == 0 { + return errorx.NewByCode(prompterr.CommonInvalidParamCode, + errorx.WithExtraMsg("has_snippets is true but no snippet references found in content")) + } + + // Validate snippet references exist and are of correct type with valid versions + queriesMap := make(map[repo.GetPromptParam]bool) + for _, ref := range snippetRefs { + queriesMap[repo.GetPromptParam{ + PromptID: ref.PromptID, + WithCommit: true, + CommitVersion: ref.CommitVersion, + }] = true + } + + snippetPrompts, err := p.manageRepo.MGetPrompt(ctx, gmap.Keys(queriesMap), repo.WithPromptCacheEnable()) + if err != nil { + return errorx.WrapByCode(err, prompterr.CommonInvalidParamCode, + errorx.WithExtraMsg("failed to get snippet prompts")) + } + // fill snippets + promptDetail.PromptTemplate.Snippets = gmap.Values(snippetPrompts) + + // Validate each snippet reference using map access + for _, ref := range snippetRefs { + key := repo.GetPromptParam{ + PromptID: ref.PromptID, + WithCommit: true, + CommitVersion: ref.CommitVersion, + } + + prompt, exists := snippetPrompts[key] + if !exists || prompt == nil { + return errorx.NewByCode(prompterr.ResourceNotFoundCode, + errorx.WithExtraMsg(fmt.Sprintf("snippet prompt %d with version %s not found", ref.PromptID, ref.CommitVersion))) + } + + // Check if prompt is a snippet type + if prompt.PromptBasic == nil || prompt.PromptBasic.PromptType != entity.PromptTypeSnippet { + return errorx.NewByCode(prompterr.CommonInvalidParamCode, + errorx.WithExtraMsg(fmt.Sprintf("prompt %d is not a snippet type", ref.PromptID))) + } + } + } + return nil +} + +// SaveDraft saves a draft with snippet validation and relationship management +func (p *PromptServiceImpl) SaveDraft(ctx context.Context, promptDO *entity.Prompt) (*entity.DraftInfo, error) { + if promptDO == nil || promptDO.PromptDraft == nil { + return nil, errorx.New("promptDO or promptDO.PromptDraft is empty") + } + + // Parse and validate snippets in the draft content + err := p.parseAndValidateSnippets(ctx, promptDO) + if err != nil { + return nil, err + } + + // Save the draft through repository (which will handle snippet relationships) + draftInfo, err := p.manageRepo.SaveDraft(ctx, promptDO) + if err != nil { + return nil, err + } + + return draftInfo, nil +} + +// ExpandSnippets expands all snippet references in the prompt's messages +func (p *PromptServiceImpl) ExpandSnippets(ctx context.Context, promptDO *entity.Prompt) error { + maxDepth := 2 + return p.doExpandSnippets(ctx, promptDO, maxDepth) +} + +func (p *PromptServiceImpl) doExpandSnippets(ctx context.Context, promptDO *entity.Prompt, maxDepth int) error { + if promptDO == nil { + return errorx.New("promptDO is empty") + } + + // Get the prompt detail + promptDetail := promptDO.GetPromptDetail() + if promptDetail == nil || promptDetail.PromptTemplate == nil { + return nil // No template to expand + } + + // Check if prompt has snippets + if !promptDetail.PromptTemplate.HasSnippets { + return nil // No snippets to expand + } + + // Validate max depth to prevent infinite recursion + if maxDepth <= 0 { + return errorx.New("max recursion depth reached") + } + // First, parse and validate snippets to populate the Snippets field + // This will call MGetPrompt once to get all snippet data + err := p.parseAndValidateSnippets(ctx, promptDO) + if err != nil { + return err + } + + for _, snippet := range promptDetail.PromptTemplate.Snippets { + if snippet.GetPromptDetail().PromptTemplate == nil || !snippet.GetPromptDetail().PromptTemplate.HasSnippets { + continue + } + err = p.doExpandSnippets(ctx, snippet, maxDepth-1) + if err != nil { + return err + } + } + + // Build map for quick lookup of snippet content + snippetContentMap := make(map[string]string) + for _, snippet := range promptDetail.PromptTemplate.Snippets { + if snippet == nil || snippet.PromptBasic == nil { + continue + } + + snippetDetail := snippet.GetPromptDetail() + if snippetDetail == nil || snippetDetail.PromptTemplate == nil { + continue + } + + // Build lookup key: "promptID_version" + key := fmt.Sprintf("%d_%s", snippet.ID, snippet.GetVersion()) + // Get snippet content from the first message + if len(snippetDetail.PromptTemplate.Messages) > 0 { + snippetContent := ptr.From(snippetDetail.PromptTemplate.Messages[0].Content) + snippetContentMap[key] = snippetContent + } + } + + // Expand all snippet references in messages using the pre-built content map + for _, message := range promptDetail.PromptTemplate.Messages { + if message == nil { + continue + } + + // Expand content if it exists + if message.Content != nil && *message.Content != "" { + expandedContent, err := p.expandWithSnippetMap(ctx, *message.Content, snippetContentMap) + if err != nil { + return err + } + message.Content = &expandedContent + } + + // Expand text in parts + for _, part := range message.Parts { + if part == nil || part.Text == nil || *part.Text == "" { + continue + } + expandedText, err := p.expandWithSnippetMap(ctx, *part.Text, snippetContentMap) + if err != nil { + return err + } + part.Text = &expandedText + } + } + + return nil +} + +// expandWithSnippetMap expands snippet references using a pre-built content map and returns expanded content +func (p *PromptServiceImpl) expandWithSnippetMap(ctx context.Context, content string, snippetContentMap map[string]string) (string, error) { + // Parse snippet references from content + snippetRefs, err := p.snippetParser.ParseReferences(content) + if err != nil { + return "", err + } + + // If no references found, return original content and empty variable definitions + if len(snippetRefs) == 0 { + return content, nil + } + + // Replace each reference with expanded content from the map + expandedContent := content + + for _, ref := range snippetRefs { + // Build lookup key: "promptID_version" + key := fmt.Sprintf("%d_%s", ref.PromptID, ref.CommitVersion) + expandedSnippetContent, exists := snippetContentMap[key] + if !exists { + return "", errorx.NewByCode(prompterr.ResourceNotFoundCode, + errorx.WithExtraMsg(fmt.Sprintf("snippet content for prompt %d with version %s not found in cache", ref.PromptID, ref.CommitVersion))) + } + + // Replace the reference with expanded content + refString := p.snippetParser.SerializeReference(ref) + expandedContent = strings.ReplaceAll(expandedContent, refString, expandedSnippetContent) + } + + return expandedContent, nil +} diff --git a/backend/modules/prompt/domain/service/manage_test.go b/backend/modules/prompt/domain/service/manage_test.go index 8ae0145c1..f11fa5b0a 100755 --- a/backend/modules/prompt/domain/service/manage_test.go +++ b/backend/modules/prompt/domain/service/manage_test.go @@ -6,6 +6,8 @@ package service import ( "context" "encoding/base64" + "errors" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -25,6 +27,1016 @@ import ( "github.com/coze-dev/coze-loop/backend/pkg/unittest" ) +type fakeSnippetParser struct { + parseFunc func(string) ([]*SnippetReference, error) + serializeFunc func(*SnippetReference) string +} + +func (f fakeSnippetParser) ParseReferences(content string) ([]*SnippetReference, error) { + if f.parseFunc != nil { + return f.parseFunc(content) + } + return nil, nil +} + +func (f fakeSnippetParser) SerializeReference(ref *SnippetReference) string { + if f.serializeFunc != nil { + return f.serializeFunc(ref) + } + return fmt.Sprintf("id=%d&version=%s", ref.PromptID, ref.CommitVersion) +} + +func TestPromptServiceImpl_SaveDraft(t *testing.T) { + t.Parallel() + type fields struct { + idgen idgen.IIDGenerator + debugLogRepo repo.IDebugLogRepo + debugContextRepo repo.IDebugContextRepo + manageRepo repo.IManageRepo + labelRepo repo.ILabelRepo + configProvider conf.IConfigProvider + llm rpc.ILLMProvider + file rpc.IFileProvider + snippetParser SnippetParser + } + type args struct { + ctx context.Context + promptDO *entity.Prompt + } + tests := []struct { + name string + fieldsGetter func(ctrl *gomock.Controller) fields + args args + want *entity.DraftInfo + wantErr error + assertFunc func(t *testing.T, prompt *entity.Prompt) + }{ + { + name: "正常保存草稿 - 无片段", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockManageRepo := repomocks.NewMockIManageRepo(ctrl) + mockManageRepo.EXPECT(). + SaveDraft(gomock.Any(), gomock.Any()). + Return(&entity.DraftInfo{ + UserID: "user123", + IsModified: true, + }, nil) + return fields{ + manageRepo: mockManageRepo, + } + }, + args: args{ + ctx: context.Background(), + promptDO: &entity.Prompt{ + ID: 123, + SpaceID: 456, + PromptDraft: &entity.PromptDraft{ + DraftInfo: &entity.DraftInfo{ + UserID: "user123", + }, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: false, + }, + }, + }, + }, + }, + want: &entity.DraftInfo{ + UserID: "user123", + IsModified: true, + }, + }, + { + name: "参数错误 - promptDO为空", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{} + }, + args: args{ + ctx: context.Background(), + promptDO: nil, + }, + wantErr: errorx.New("promptDO or promptDO.PromptDraft is empty"), + }, + { + name: "参数错误 - PromptDraft为空", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{} + }, + args: args{ + ctx: context.Background(), + promptDO: &entity.Prompt{ + ID: 123, + PromptDraft: nil, + }, + }, + wantErr: errorx.New("promptDO or promptDO.PromptDraft is empty"), + }, + { + name: "保存失败 - repository错误", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockManageRepo := repomocks.NewMockIManageRepo(ctrl) + mockManageRepo.EXPECT(). + SaveDraft(gomock.Any(), gomock.Any()). + Return(nil, errorx.New("repository error")) + return fields{ + manageRepo: mockManageRepo, + } + }, + args: args{ + ctx: context.Background(), + promptDO: &entity.Prompt{ + ID: 123, + SpaceID: 456, + PromptDraft: &entity.PromptDraft{ + DraftInfo: &entity.DraftInfo{ + UserID: "user123", + }, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: false, + }, + }, + }, + }, + }, + wantErr: errorx.New("repository error"), + }, + { + name: "片段解析失败", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{ + snippetParser: fakeSnippetParser{ + parseFunc: func(string) ([]*SnippetReference, error) { + return nil, errors.New("parse error") + }, + }, + } + }, + args: args{ + ctx: context.Background(), + promptDO: &entity.Prompt{ + SpaceID: 1, + PromptKey: "key", + PromptDraft: &entity.PromptDraft{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: true, + Messages: []*entity.Message{ + {Content: ptr.Of("content")}, + }, + }, + }, + }, + }, + }, + wantErr: errorx.WrapByCode(errors.New("parse error"), prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("failed to parse snippet references")), + }, + { + name: "片段引用不存在", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockManageRepo := repomocks.NewMockIManageRepo(ctrl) + mockManageRepo.EXPECT(). + MGetPrompt(gomock.Any(), gomock.Any(), gomock.Any()). + Return(map[repo.GetPromptParam]*entity.Prompt{}, nil) + return fields{ + manageRepo: mockManageRepo, + snippetParser: NewCozeLoopSnippetParser(), + } + }, + args: args{ + ctx: context.Background(), + promptDO: &entity.Prompt{ + SpaceID: 1, + PromptKey: "key", + PromptDraft: &entity.PromptDraft{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: true, + Messages: []*entity.Message{ + {Content: ptr.Of("id=2&version=v1")}, + }, + }, + }, + }, + }, + }, + wantErr: errorx.NewByCode(prompterr.ResourceNotFoundCode, errorx.WithExtraMsg("snippet prompt 2 with version v1 not found")), + }, + { + name: "片段校验成功并保存", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockManageRepo := repomocks.NewMockIManageRepo(ctrl) + mockManageRepo.EXPECT(). + MGetPrompt(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, params []repo.GetPromptParam, _ ...repo.GetPromptOptionFunc) (map[repo.GetPromptParam]*entity.Prompt, error) { + assert.Len(t, params, 1) + query := params[0] + snippetPrompt := &entity.Prompt{ + ID: query.PromptID, + SpaceID: 1, + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeSnippet, + }, + PromptCommit: &entity.PromptCommit{ + CommitInfo: &entity.CommitInfo{Version: query.CommitVersion}, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + Messages: []*entity.Message{{Content: ptr.Of("snippet content")}}, + }, + }, + }, + } + return map[repo.GetPromptParam]*entity.Prompt{query: snippetPrompt}, nil + }) + mockManageRepo.EXPECT(). + SaveDraft(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, prompt *entity.Prompt) (*entity.DraftInfo, error) { + if detail := prompt.GetPromptDetail(); detail != nil && detail.PromptTemplate != nil { + assert.NotEmpty(t, detail.PromptTemplate.Snippets) + } + return &entity.DraftInfo{UserID: "user123", IsModified: true}, nil + }) + return fields{ + manageRepo: mockManageRepo, + snippetParser: NewCozeLoopSnippetParser(), + } + }, + args: args{ + ctx: context.Background(), + promptDO: &entity.Prompt{ + SpaceID: 1, + PromptKey: "key", + PromptDraft: &entity.PromptDraft{ + DraftInfo: &entity.DraftInfo{UserID: "user123"}, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: true, + Messages: []*entity.Message{ + {Content: ptr.Of("id=2&version=v1")}, + }, + }, + }, + }, + }, + }, + want: &entity.DraftInfo{UserID: "user123", IsModified: true}, + assertFunc: func(t *testing.T, prompt *entity.Prompt) { + as := assert.New(t) + detail := prompt.GetPromptDetail() + as.NotNil(detail) + if detail == nil || detail.PromptTemplate == nil { + return + } + as.NotEmpty(detail.PromptTemplate.Snippets) + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + tFields := tt.fieldsGetter(ctrl) + + p := &PromptServiceImpl{ + idgen: tFields.idgen, + debugLogRepo: tFields.debugLogRepo, + debugContextRepo: tFields.debugContextRepo, + manageRepo: tFields.manageRepo, + labelRepo: tFields.labelRepo, + configProvider: tFields.configProvider, + llm: tFields.llm, + file: tFields.file, + snippetParser: tFields.snippetParser, + } + + got, err := p.SaveDraft(tt.args.ctx, tt.args.promptDO) + unittest.AssertErrorEqual(t, tt.wantErr, err) + if tt.wantErr != nil { + assert.Nil(t, got) + return + } + assert.NotNil(t, got) + if tt.want != nil { + assert.Equal(t, tt.want.UserID, got.UserID) + assert.Equal(t, tt.want.IsModified, got.IsModified) + } + if tt.assertFunc != nil { + tt.assertFunc(t, tt.args.promptDO) + } + }) + } +} + +func TestPromptServiceImpl_CreatePrompt(t *testing.T) { + t.Parallel() + type fields struct { + manageRepo repo.IManageRepo + snippetParser SnippetParser + } + type args struct { + ctx context.Context + promptDO *entity.Prompt + } + tests := []struct { + name string + fieldsGetter func(ctrl *gomock.Controller) fields + args args + wantID int64 + wantErr error + }{ + { + name: "prompt do is nil", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{} + }, + args: args{ + ctx: context.Background(), + promptDO: nil, + }, + wantErr: errorx.New("promptDO is empty"), + }, + { + name: "prompt key empty", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{} + }, + args: args{ + ctx: context.Background(), + promptDO: &entity.Prompt{ + SpaceID: 1, + PromptKey: "", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, + }, + }, + wantErr: errorx.New("promptKey is empty"), + }, + { + name: "prompt basic nil", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{} + }, + args: args{ + ctx: context.Background(), + promptDO: &entity.Prompt{ + SpaceID: 1, + PromptKey: "key", + }, + }, + wantErr: errorx.New("promptBasic is empty"), + }, + { + name: "space id invalid", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{} + }, + args: args{ + ctx: context.Background(), + promptDO: &entity.Prompt{ + SpaceID: 0, + PromptKey: "key", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, + }, + }, + wantErr: errorx.New("spaceID is invalid: %d", 0), + }, + { + name: "has snippets parse error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{ + snippetParser: fakeSnippetParser{ + parseFunc: func(string) ([]*SnippetReference, error) { + return nil, errors.New("parse error") + }, + }, + } + }, + args: args{ + ctx: context.Background(), + promptDO: &entity.Prompt{ + SpaceID: 1, + PromptKey: "key", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, + PromptDraft: &entity.PromptDraft{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: true, + Messages: []*entity.Message{ + {Content: ptr.Of("test content")}, + }, + }, + }, + }, + }, + }, + wantErr: errorx.WrapByCode(errors.New("parse error"), prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("failed to parse snippet references")), + }, + { + name: "has snippets but no references", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{ + snippetParser: fakeSnippetParser{ + parseFunc: func(string) ([]*SnippetReference, error) { + return []*SnippetReference{}, nil + }, + }, + } + }, + args: args{ + ctx: context.Background(), + promptDO: &entity.Prompt{ + SpaceID: 1, + PromptKey: "key", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, + PromptDraft: &entity.PromptDraft{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: true, + Messages: []*entity.Message{ + {Content: ptr.Of("id=2&version=v1")}, + }, + }, + }, + }, + }, + }, + wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("has_snippets is true but no snippet references found in content")), + }, + { + name: "snippet prompt type invalid", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockRepo := repomocks.NewMockIManageRepo(ctrl) + mockRepo.EXPECT().MGetPrompt(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, params []repo.GetPromptParam, _ ...repo.GetPromptOptionFunc) (map[repo.GetPromptParam]*entity.Prompt, error) { + assert.Len(t, params, 1) + query := params[0] + snippetPrompt := &entity.Prompt{ + ID: query.PromptID, + SpaceID: 1, + PromptKey: "snippet", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, + PromptCommit: &entity.PromptCommit{ + CommitInfo: &entity.CommitInfo{Version: query.CommitVersion}, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + Messages: []*entity.Message{ + {Content: ptr.Of("snippet content")}, + }, + }, + }, + }, + } + return map[repo.GetPromptParam]*entity.Prompt{query: snippetPrompt}, nil + }) + return fields{ + manageRepo: mockRepo, + snippetParser: NewCozeLoopSnippetParser(), + } + }, + args: func() args { + promptDO := &entity.Prompt{ + SpaceID: 1, + PromptKey: "key", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, + PromptDraft: &entity.PromptDraft{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: true, + Messages: []*entity.Message{ + {Content: ptr.Of("id=2&version=v1")}, + }, + }, + }, + }, + } + return args{ + ctx: context.Background(), + promptDO: promptDO, + } + }(), + wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg(fmt.Sprintf("prompt %d is not a snippet type", 2))), + }, + { + name: "create prompt repo error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockRepo := repomocks.NewMockIManageRepo(ctrl) + mockRepo.EXPECT().CreatePrompt(gomock.Any(), gomock.Any()).Return(int64(0), errorx.New("create failed")) + return fields{ + manageRepo: mockRepo, + } + }, + args: args{ + ctx: context.Background(), + promptDO: &entity.Prompt{ + SpaceID: 1, + PromptKey: "key", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, + }, + }, + wantErr: errorx.New("create failed"), + }, + { + name: "snippet repo error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockRepo := repomocks.NewMockIManageRepo(ctrl) + mockRepo.EXPECT().MGetPrompt(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errorx.New("mget error")) + return fields{ + manageRepo: mockRepo, + snippetParser: fakeSnippetParser{ + parseFunc: func(string) ([]*SnippetReference, error) { + return []*SnippetReference{{PromptID: 2, CommitVersion: "v1"}}, nil + }, + }, + } + }, + args: func() args { + promptDO := &entity.Prompt{ + SpaceID: 1, + PromptKey: "key", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, + PromptDraft: &entity.PromptDraft{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: true, + Messages: []*entity.Message{ + {Content: ptr.Of("id=2&version=v1")}, + }, + }, + }, + }, + } + return args{ + ctx: context.Background(), + promptDO: promptDO, + } + }(), + wantErr: errorx.WrapByCode(errorx.New("mget error"), prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("failed to get snippet prompts")), + }, + { + name: "success without snippets", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockRepo := repomocks.NewMockIManageRepo(ctrl) + mockRepo.EXPECT().CreatePrompt(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, prompt *entity.Prompt) (int64, error) { + return 11, nil + }) + return fields{ + manageRepo: mockRepo, + } + }, + args: args{ + ctx: context.Background(), + promptDO: &entity.Prompt{ + SpaceID: 1, + PromptKey: "key", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, + }, + }, + wantID: 11, + }, + { + name: "success with snippets", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockRepo := repomocks.NewMockIManageRepo(ctrl) + mockRepo.EXPECT().MGetPrompt(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, params []repo.GetPromptParam, _ ...repo.GetPromptOptionFunc) (map[repo.GetPromptParam]*entity.Prompt, error) { + assert.Len(t, params, 1) + query := params[0] + snippetPrompt := &entity.Prompt{ + ID: query.PromptID, + SpaceID: 1, + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeSnippet, + }, + PromptCommit: &entity.PromptCommit{ + CommitInfo: &entity.CommitInfo{Version: query.CommitVersion}, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + Messages: []*entity.Message{ + {Content: ptr.Of("snippet content")}, + }, + VariableDefs: []*entity.VariableDef{{Key: "snippet_var"}}, + }, + }, + }, + } + return map[repo.GetPromptParam]*entity.Prompt{query: snippetPrompt}, nil + }) + mockRepo.EXPECT().CreatePrompt(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, prompt *entity.Prompt) (int64, error) { + if prompt.PromptDraft != nil && prompt.PromptDraft.PromptDetail != nil && prompt.PromptDraft.PromptDetail.PromptTemplate != nil { + assert.NotEmpty(t, prompt.PromptDraft.PromptDetail.PromptTemplate.Snippets) + } + return 101, nil + }) + return fields{ + manageRepo: mockRepo, + snippetParser: NewCozeLoopSnippetParser(), + } + }, + args: func() args { + promptDO := &entity.Prompt{ + SpaceID: 1, + PromptKey: "key", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, + PromptDraft: &entity.PromptDraft{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: true, + Messages: []*entity.Message{ + {Content: ptr.Of("id=2&version=v1")}, + }, + VariableDefs: []*entity.VariableDef{{Key: "base_var"}}, + }, + }, + }, + } + return args{ + ctx: context.Background(), + promptDO: promptDO, + } + }(), + wantID: 101, + }, + } + + for _, tt := range tests { + ttt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + tFields := ttt.fieldsGetter(ctrl) + service := &PromptServiceImpl{ + manageRepo: tFields.manageRepo, + snippetParser: tFields.snippetParser, + } + + got, err := service.CreatePrompt(ttt.args.ctx, ttt.args.promptDO) + unittest.AssertErrorEqual(t, tt.wantErr, err) + if tt.wantErr == nil { + assert.Equal(t, tt.wantID, got) + } + }) + } +} + +func TestPromptServiceImpl_ExpandSnippets(t *testing.T) { + t.Parallel() + type fields struct { + manageRepo repo.IManageRepo + snippetParser SnippetParser + } + type args struct { + ctx context.Context + promptDO *entity.Prompt + } + tests := []struct { + name string + fieldsGetter func(ctrl *gomock.Controller) fields + args args + assertFunc func(t *testing.T, prompt *entity.Prompt) + wantErr error + }{ + { + name: "prompt do is nil", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{} + }, + args: args{ + ctx: context.Background(), + promptDO: nil, + }, + wantErr: errorx.New("promptDO is empty"), + }, + { + name: "prompt detail missing", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{} + }, + args: args{ + ctx: context.Background(), + promptDO: &entity.Prompt{ + SpaceID: 1, + PromptKey: "key", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, + }, + }, + }, + { + name: "no snippets flag", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{} + }, + args: args{ + ctx: context.Background(), + promptDO: &entity.Prompt{ + SpaceID: 1, + PromptKey: "key", + PromptBasic: &entity.PromptBasic{PromptType: entity.PromptTypeNormal}, + PromptDraft: &entity.PromptDraft{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: false, + }, + }, + }, + }, + }, + }, + { + name: "snippet not found", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockRepo := repomocks.NewMockIManageRepo(ctrl) + mockRepo.EXPECT().MGetPrompt(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[repo.GetPromptParam]*entity.Prompt{}, nil) + return fields{ + manageRepo: mockRepo, + snippetParser: NewCozeLoopSnippetParser(), + } + }, + args: func() args { + promptDO := &entity.Prompt{ + SpaceID: 1, + PromptKey: "key", + PromptBasic: &entity.PromptBasic{PromptType: entity.PromptTypeNormal}, + PromptDraft: &entity.PromptDraft{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: true, + Messages: []*entity.Message{{Content: ptr.Of("id=2&version=v1")}}, + }, + }, + }, + } + return args{ctx: context.Background(), promptDO: promptDO} + }(), + wantErr: errorx.NewByCode(prompterr.ResourceNotFoundCode, errorx.WithExtraMsg("snippet prompt 2 with version v1 not found")), + }, + { + name: "exceed max depth", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockRepo := repomocks.NewMockIManageRepo(ctrl) + mockRepo.EXPECT().MGetPrompt(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, params []repo.GetPromptParam, _ ...repo.GetPromptOptionFunc) (map[repo.GetPromptParam]*entity.Prompt, error) { + assert.Len(t, params, 1) + query := params[0] + switch query.PromptID { + case 2: + snippetPrompt := &entity.Prompt{ + ID: query.PromptID, + SpaceID: 1, + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeSnippet, + }, + PromptCommit: &entity.PromptCommit{ + CommitInfo: &entity.CommitInfo{Version: query.CommitVersion}, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: true, + Messages: []*entity.Message{{Content: ptr.Of("id=3&version=v2")}}, + }, + }, + }, + } + return map[repo.GetPromptParam]*entity.Prompt{query: snippetPrompt}, nil + case 3: + nestedPrompt := &entity.Prompt{ + ID: query.PromptID, + SpaceID: 1, + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeSnippet, + }, + PromptCommit: &entity.PromptCommit{ + CommitInfo: &entity.CommitInfo{Version: query.CommitVersion}, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: true, + }, + }, + }, + } + return map[repo.GetPromptParam]*entity.Prompt{query: nestedPrompt}, nil + default: + return map[repo.GetPromptParam]*entity.Prompt{}, nil + } + }).AnyTimes() + return fields{ + manageRepo: mockRepo, + snippetParser: NewCozeLoopSnippetParser(), + } + }, + args: func() args { + prompt := &entity.Prompt{ + ID: 10, + SpaceID: 1, + PromptKey: "main", + PromptBasic: &entity.PromptBasic{PromptType: entity.PromptTypeNormal}, + PromptDraft: &entity.PromptDraft{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: true, + Messages: []*entity.Message{{Content: ptr.Of("id=2&version=v1")}}, + }, + }, + }, + } + return args{ctx: context.Background(), promptDO: prompt} + }(), + wantErr: errorx.New("max recursion depth reached"), + }, + { + name: "expand snippets success", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockRepo := repomocks.NewMockIManageRepo(ctrl) + mockRepo.EXPECT().MGetPrompt(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, params []repo.GetPromptParam, _ ...repo.GetPromptOptionFunc) (map[repo.GetPromptParam]*entity.Prompt, error) { + assert.Len(t, params, 1) + query := params[0] + snippetPrompt := &entity.Prompt{ + ID: query.PromptID, + SpaceID: 1, + PromptBasic: &entity.PromptBasic{PromptType: entity.PromptTypeSnippet}, + PromptCommit: &entity.PromptCommit{ + CommitInfo: &entity.CommitInfo{Version: query.CommitVersion}, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + Messages: []*entity.Message{{Content: ptr.Of("snippet body")}}, + VariableDefs: []*entity.VariableDef{{Key: "snippet_var"}}, + }, + }, + }, + } + return map[repo.GetPromptParam]*entity.Prompt{query: snippetPrompt}, nil + }) + return fields{ + manageRepo: mockRepo, + snippetParser: NewCozeLoopSnippetParser(), + } + }, + args: func() args { + prompt := &entity.Prompt{ + ID: 10, + SpaceID: 1, + PromptKey: "main", + PromptBasic: &entity.PromptBasic{PromptType: entity.PromptTypeNormal}, + PromptDraft: &entity.PromptDraft{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: true, + Messages: []*entity.Message{{Content: ptr.Of("hello id=2&version=v1")}}, + VariableDefs: []*entity.VariableDef{{Key: "main_var"}}, + }, + }, + }, + } + return args{ctx: context.Background(), promptDO: prompt} + }(), + assertFunc: func(t *testing.T, prompt *entity.Prompt) { + detail := prompt.GetPromptDetail() + as := assert.New(t) + as.NotNil(detail) + if detail == nil { + return + } + as.NotNil(detail.PromptTemplate) + if detail.PromptTemplate == nil { + return + } + as.NotEmpty(detail.PromptTemplate.Snippets) + as.Equal("hello snippet body", ptr.From(detail.PromptTemplate.Messages[0].Content)) + as.Len(detail.PromptTemplate.VariableDefs, 1) + if len(detail.PromptTemplate.VariableDefs) > 0 { + as.Equal("main_var", detail.PromptTemplate.VariableDefs[0].Key) + } + }, + }, + } + + for _, tt := range tests { + ttt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + tFields := ttt.fieldsGetter(ctrl) + service := &PromptServiceImpl{ + manageRepo: tFields.manageRepo, + snippetParser: tFields.snippetParser, + } + + err := service.ExpandSnippets(ttt.args.ctx, ttt.args.promptDO) + unittest.AssertErrorEqual(t, tt.wantErr, err) + if tt.wantErr == nil && ttt.assertFunc != nil { + tt.assertFunc(t, tt.args.promptDO) + } + }) + } +} + +func TestPromptServiceImpl_expandWithSnippetMap(t *testing.T) { + t.Parallel() + type fields struct { + snippetParser SnippetParser + } + type args struct { + content string + snippetContentMap map[string]string + } + tests := []struct { + name string + fields fields + args args + wantContent string + wantErr error + }{ + { + name: "parse error", + fields: fields{ + snippetParser: fakeSnippetParser{ + parseFunc: func(string) ([]*SnippetReference, error) { + return nil, errors.New("parse fail") + }, + }, + }, + args: args{ + content: "test", + snippetContentMap: map[string]string{}, + }, + wantErr: errors.New("parse fail"), + }, + { + name: "snippet content missing", + fields: fields{ + snippetParser: fakeSnippetParser{ + parseFunc: func(string) ([]*SnippetReference, error) { + return []*SnippetReference{{PromptID: 2, CommitVersion: "v1"}}, nil + }, + }, + }, + args: args{ + content: "id=2&version=v1", + snippetContentMap: map[string]string{}, + }, + wantErr: errorx.NewByCode(prompterr.ResourceNotFoundCode, errorx.WithExtraMsg("snippet content for prompt 2 with version v1 not found in cache")), + }, + { + name: "success expands duplicated snippets", + fields: fields{ + snippetParser: NewCozeLoopSnippetParser(), + }, + args: args{ + content: "hello id=2&version=v1 and again id=2&version=v1", + snippetContentMap: map[string]string{ + "2_v1": "snippet", + }, + }, + wantContent: "hello snippet and again snippet", + }, + } + + for _, tt := range tests { + ttt := tt + t.Run(ttt.name, func(t *testing.T) { + t.Parallel() + svc := &PromptServiceImpl{ + snippetParser: ttt.fields.snippetParser, + } + if svc.snippetParser == nil { + svc.snippetParser = NewCozeLoopSnippetParser() + } + + gotContent, err := svc.expandWithSnippetMap(context.Background(), ttt.args.content, ttt.args.snippetContentMap) + unittest.AssertErrorEqual(t, ttt.wantErr, err) + if ttt.wantErr != nil { + return + } + assert.Equal(t, ttt.wantContent, gotContent) + }) + } +} + func TestPromptServiceImpl_MCompleteMultiModalFileURL(t *testing.T) { type fields struct { idgen idgen.IIDGenerator @@ -1637,3 +2649,212 @@ func TestPromptServiceImpl_MConvertBase64ToFileURL(t *testing.T) { }) } } + +func TestPromptServiceImpl_GetPrompt(t *testing.T) { + type fields struct { + manageRepo repo.IManageRepo + snippetParser SnippetParser + } + type args struct { + ctx context.Context + param GetPromptParam + } + tests := []struct { + name string + fieldsGetter func(ctrl *gomock.Controller) fields + args args + wantErr error + }{ + { + name: "repo error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + repoMock := repomocks.NewMockIManageRepo(ctrl) + repoMock.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{PromptID: 1, WithCommit: true, CommitVersion: "1.0.0", WithDraft: false, UserID: ""}).Return(nil, errorx.New("repo error")) + return fields{manageRepo: repoMock} + }, + args: args{ + ctx: context.Background(), + param: GetPromptParam{PromptID: 1, WithCommit: true, CommitVersion: "1.0.0"}, + }, + wantErr: errorx.New("repo error"), + }, + { + name: "parse snippet error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + repoMock := repomocks.NewMockIManageRepo(ctrl) + repoMock.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{PromptID: 2, WithCommit: true, CommitVersion: "1.0.0", WithDraft: false, UserID: ""}).Return(&entity.Prompt{ + ID: 2, + PromptCommit: &entity.PromptCommit{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: true, + Messages: []*entity.Message{{Content: ptr.Of("content")}}, + }, + }, + }, + }, nil) + parser := fakeSnippetParser{ + parseFunc: func(string) ([]*SnippetReference, error) { return nil, errors.New("parse fail") }, + } + return fields{manageRepo: repoMock, snippetParser: parser} + }, + args: args{ + ctx: context.Background(), + param: GetPromptParam{PromptID: 2, WithCommit: true, CommitVersion: "1.0.0"}, + }, + wantErr: errorx.WrapByCode(errors.New("parse fail"), prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("failed to parse snippet references")), + }, + { + name: "success expand snippet", + fieldsGetter: func(ctrl *gomock.Controller) fields { + repoMock := repomocks.NewMockIManageRepo(ctrl) + promptDO := &entity.Prompt{ + ID: 3, + PromptDraft: &entity.PromptDraft{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: true, + Messages: []*entity.Message{{Content: ptr.Of("hello id=4&version=v1")}}, + }, + }, + }, + } + repoMock.EXPECT().GetPrompt(gomock.Any(), repo.GetPromptParam{PromptID: 3, WithCommit: false, CommitVersion: "", WithDraft: false, UserID: "user"}).Return(promptDO, nil) + repoMock.EXPECT().MGetPrompt(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[repo.GetPromptParam]*entity.Prompt{ + {PromptID: 4, WithCommit: true, CommitVersion: "v1"}: { + ID: 4, + PromptBasic: &entity.PromptBasic{PromptType: entity.PromptTypeSnippet}, + PromptCommit: &entity.PromptCommit{ + CommitInfo: &entity.CommitInfo{Version: "v1"}, + PromptDetail: &entity.PromptDetail{PromptTemplate: &entity.PromptTemplate{Messages: []*entity.Message{{Content: ptr.Of("snippet content")}}}}, + }, + }, + }, nil).AnyTimes() + return fields{manageRepo: repoMock, snippetParser: NewCozeLoopSnippetParser()} + }, + args: args{ + ctx: context.Background(), + param: GetPromptParam{PromptID: 3, UserID: "user", ExpandSnippet: true}, + }, + wantErr: nil, + }, + } + + for _, tt := range tests { + caseData := tt + t.Run(caseData.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + tFields := caseData.fieldsGetter(ctrl) + service := &PromptServiceImpl{ + manageRepo: tFields.manageRepo, + snippetParser: tFields.snippetParser, + } + + got, err := service.GetPrompt(caseData.args.ctx, caseData.args.param) + unittest.AssertErrorEqual(t, caseData.wantErr, err) + if err == nil { + assert.NotNil(t, got) + if caseData.name == "success expand snippet" { + assert.Contains(t, ptr.From(got.PromptDraft.PromptDetail.PromptTemplate.Messages[0].Content), "snippet content") + } + } + }) + } +} + +func TestPromptServiceImpl_doExpandSnippets_MaxDepth(t *testing.T) { + t.Parallel() + service := &PromptServiceImpl{snippetParser: NewCozeLoopSnippetParser()} + err := service.doExpandSnippets(context.Background(), &entity.Prompt{ + PromptDraft: &entity.PromptDraft{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{HasSnippets: true}, + }, + }, + }, 0) + unittest.AssertErrorEqual(t, errorx.New("max recursion depth reached"), err) +} + +func TestPromptServiceImpl_doExpandSnippets_NoTemplate(t *testing.T) { + t.Parallel() + service := &PromptServiceImpl{snippetParser: NewCozeLoopSnippetParser()} + err := service.doExpandSnippets(context.Background(), &entity.Prompt{}, 2) + unittest.AssertErrorEqual(t, nil, err) +} + +func TestPromptServiceImpl_doExpandSnippets_Nested(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + repoMock := repomocks.NewMockIManageRepo(ctrl) + repoMock.EXPECT().MGetPrompt(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, params []repo.GetPromptParam, _ ...repo.GetPromptOptionFunc) (map[repo.GetPromptParam]*entity.Prompt, error) { + assert.Len(t, params, 1) + switch params[0].PromptID { + case 4: + return map[repo.GetPromptParam]*entity.Prompt{ + params[0]: { + ID: 4, + PromptBasic: &entity.PromptBasic{PromptType: entity.PromptTypeSnippet}, + PromptCommit: &entity.PromptCommit{ + CommitInfo: &entity.CommitInfo{Version: params[0].CommitVersion}, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: true, + Messages: []*entity.Message{{Content: ptr.Of("id=5&version=v2")}}, + }, + }, + }, + }, + {PromptID: 0}: { + ID: 0, + PromptCommit: &entity.PromptCommit{PromptDetail: &entity.PromptDetail{}}, + }, + {PromptID: 99}: { + ID: 99, + PromptCommit: &entity.PromptCommit{PromptDetail: &entity.PromptDetail{}}, + }, + }, nil + case 5: + return map[repo.GetPromptParam]*entity.Prompt{ + params[0]: { + ID: 5, + PromptBasic: &entity.PromptBasic{PromptType: entity.PromptTypeSnippet}, + PromptCommit: &entity.PromptCommit{ + CommitInfo: &entity.CommitInfo{Version: params[0].CommitVersion}, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + Messages: []*entity.Message{{Content: ptr.Of("deep snippet")}}, + }, + }, + }, + }, + {PromptID: 101}: { + ID: 101, + PromptCommit: &entity.PromptCommit{PromptDetail: &entity.PromptDetail{}}, + }, + }, nil + default: + return map[repo.GetPromptParam]*entity.Prompt{}, nil + } + }).AnyTimes() + + service := &PromptServiceImpl{manageRepo: repoMock, snippetParser: NewCozeLoopSnippetParser()} + prompt := &entity.Prompt{ + PromptDraft: &entity.PromptDraft{ + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: true, + Messages: []*entity.Message{{Content: ptr.Of("outer id=4&version=v1")}}, + }, + }, + }, + } + + err := service.doExpandSnippets(context.Background(), prompt, 2) + unittest.AssertErrorEqual(t, nil, err) + assert.Equal(t, "outer deep snippet", ptr.From(prompt.PromptDraft.PromptDetail.PromptTemplate.Messages[0].Content)) +} diff --git a/backend/modules/prompt/domain/service/mocks/prompt_service.go b/backend/modules/prompt/domain/service/mocks/prompt_service.go index d1582e4be..6648d5d3b 100644 --- a/backend/modules/prompt/domain/service/mocks/prompt_service.go +++ b/backend/modules/prompt/domain/service/mocks/prompt_service.go @@ -86,6 +86,21 @@ func (mr *MockIPromptServiceMockRecorder) CreateLabel(ctx, labelDO any) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateLabel", reflect.TypeOf((*MockIPromptService)(nil).CreateLabel), ctx, labelDO) } +// CreatePrompt mocks base method. +func (m *MockIPromptService) CreatePrompt(ctx context.Context, promptDO *entity.Prompt) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreatePrompt", ctx, promptDO) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreatePrompt indicates an expected call of CreatePrompt. +func (mr *MockIPromptServiceMockRecorder) CreatePrompt(ctx, promptDO any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreatePrompt", reflect.TypeOf((*MockIPromptService)(nil).CreatePrompt), ctx, promptDO) +} + // Execute mocks base method. func (m *MockIPromptService) Execute(ctx context.Context, param service.ExecuteParam) (*entity.Reply, error) { m.ctrl.T.Helper() @@ -116,6 +131,20 @@ func (mr *MockIPromptServiceMockRecorder) ExecuteStreaming(ctx, param any) *gomo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecuteStreaming", reflect.TypeOf((*MockIPromptService)(nil).ExecuteStreaming), ctx, param) } +// ExpandSnippets mocks base method. +func (m *MockIPromptService) ExpandSnippets(ctx context.Context, promptDO *entity.Prompt) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ExpandSnippets", ctx, promptDO) + ret0, _ := ret[0].(error) + return ret0 +} + +// ExpandSnippets indicates an expected call of ExpandSnippets. +func (mr *MockIPromptServiceMockRecorder) ExpandSnippets(ctx, promptDO any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExpandSnippets", reflect.TypeOf((*MockIPromptService)(nil).ExpandSnippets), ctx, promptDO) +} + // FormatPrompt mocks base method. func (m *MockIPromptService) FormatPrompt(ctx context.Context, prompt *entity.Prompt, messages []*entity.Message, variableVals []*entity.VariableVal) ([]*entity.Message, error) { m.ctrl.T.Helper() @@ -131,6 +160,21 @@ func (mr *MockIPromptServiceMockRecorder) FormatPrompt(ctx, prompt, messages, va return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FormatPrompt", reflect.TypeOf((*MockIPromptService)(nil).FormatPrompt), ctx, prompt, messages, variableVals) } +// GetPrompt mocks base method. +func (m *MockIPromptService) GetPrompt(ctx context.Context, param service.GetPromptParam) (*entity.Prompt, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPrompt", ctx, param) + ret0, _ := ret[0].(*entity.Prompt) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPrompt indicates an expected call of GetPrompt. +func (mr *MockIPromptServiceMockRecorder) GetPrompt(ctx, param any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrompt", reflect.TypeOf((*MockIPromptService)(nil).GetPrompt), ctx, param) +} + // ListLabel mocks base method. func (m *MockIPromptService) ListLabel(ctx context.Context, param service.ListLabelParam) ([]*entity.PromptLabel, *int64, error) { m.ctrl.T.Helper() @@ -219,6 +263,21 @@ func (mr *MockIPromptServiceMockRecorder) MParseCommitVersion(ctx, spaceID, para return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MParseCommitVersion", reflect.TypeOf((*MockIPromptService)(nil).MParseCommitVersion), ctx, spaceID, params) } +// SaveDraft mocks base method. +func (m *MockIPromptService) SaveDraft(ctx context.Context, promptDO *entity.Prompt) (*entity.DraftInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveDraft", ctx, promptDO) + ret0, _ := ret[0].(*entity.DraftInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SaveDraft indicates an expected call of SaveDraft. +func (mr *MockIPromptServiceMockRecorder) SaveDraft(ctx, promptDO any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveDraft", reflect.TypeOf((*MockIPromptService)(nil).SaveDraft), ctx, promptDO) +} + // UpdateCommitLabels mocks base method. func (m *MockIPromptService) UpdateCommitLabels(ctx context.Context, param service.UpdateCommitLabelsParam) error { m.ctrl.T.Helper() diff --git a/backend/modules/prompt/domain/service/service.go b/backend/modules/prompt/domain/service/service.go index f234bf005..a8026206a 100644 --- a/backend/modules/prompt/domain/service/service.go +++ b/backend/modules/prompt/domain/service/service.go @@ -26,6 +26,14 @@ type IPromptService interface { // MParseCommitVersion 统一解析提交版本,支持version和label两种方式 MParseCommitVersion(ctx context.Context, spaceID int64, params []PromptQueryParam) (promptKeyCommitVersionMap map[PromptQueryParam]string, err error) + // Prompt管理相关方法 + CreatePrompt(ctx context.Context, promptDO *entity.Prompt) (promptID int64, err error) + SaveDraft(ctx context.Context, promptDO *entity.Prompt) (*entity.DraftInfo, error) + GetPrompt(ctx context.Context, param GetPromptParam) (*entity.Prompt, error) + + // Snippet扩展相关方法 + ExpandSnippets(ctx context.Context, promptDO *entity.Prompt) error + // Label管理相关方法 CreateLabel(ctx context.Context, labelDO *entity.PromptLabel) error ListLabel(ctx context.Context, param ListLabelParam) ([]*entity.PromptLabel, *int64, error) @@ -76,6 +84,18 @@ type PromptServiceImpl struct { configProvider conf.IConfigProvider llm rpc.ILLMProvider file rpc.IFileProvider + snippetParser SnippetParser +} + +type GetPromptParam struct { + PromptID int64 + + WithCommit bool + CommitVersion string + + WithDraft bool + UserID string + ExpandSnippet bool } func NewPromptService( @@ -88,6 +108,7 @@ func NewPromptService( configProvider conf.IConfigProvider, llm rpc.ILLMProvider, file rpc.IFileProvider, + snippetParser SnippetParser, ) IPromptService { return &PromptServiceImpl{ formatter: formatter, @@ -99,5 +120,6 @@ func NewPromptService( configProvider: configProvider, llm: llm, file: file, + snippetParser: snippetParser, } } diff --git a/backend/modules/prompt/domain/service/service_test.go b/backend/modules/prompt/domain/service/service_test.go index 586a7f268..c95bdf23c 100644 --- a/backend/modules/prompt/domain/service/service_test.go +++ b/backend/modules/prompt/domain/service/service_test.go @@ -42,6 +42,7 @@ func TestNewPromptService(t *testing.T) { mockConfigProvider, mockLLM, mockFile, + NewCozeLoopSnippetParser(), ) // Verify @@ -88,6 +89,7 @@ func TestNewPromptService(t *testing.T) { mockConfigProvider, mockLLM, mockFile, + NewCozeLoopSnippetParser(), ) impl := service.(*PromptServiceImpl) diff --git a/backend/modules/prompt/domain/service/snippet_parser.go b/backend/modules/prompt/domain/service/snippet_parser.go new file mode 100644 index 000000000..ed9add91a --- /dev/null +++ b/backend/modules/prompt/domain/service/snippet_parser.go @@ -0,0 +1,78 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package service + +import ( + "fmt" + "regexp" + "strconv" + "strings" +) + +// SnippetReference represents a parsed snippet reference +type SnippetReference struct { + PromptID int64 + CommitVersion string +} + +// SnippetParser defines the interface for parsing snippet references +// Supports extending to other formats in the future +type SnippetParser interface { + // ParseReferences Parses content and returns snippet references + ParseReferences(content string) ([]*SnippetReference, error) + // SerializeReference Serializes a snippet reference back to string + SerializeReference(ref *SnippetReference) string +} + +// CozeLoopSnippetParser implements only cozeloop format parsing +type CozeLoopSnippetParser struct { + referencePattern *regexp.Regexp +} + +// NewCozeLoopSnippetParser creates a new parser for cozeloop format +func NewCozeLoopSnippetParser() SnippetParser { + // Pattern matches: id=123&version=v1 + pattern := regexp.MustCompile(`id=(\d+)&version=([^&]*)?`) + return &CozeLoopSnippetParser{ + referencePattern: pattern, + } +} + +// ParseReferences parses cozeloop snippet references from content +func (p *CozeLoopSnippetParser) ParseReferences(content string) ([]*SnippetReference, error) { + if content == "" { + return nil, nil + } + + matches := p.referencePattern.FindAllStringSubmatch(content, -1) + var refs []*SnippetReference + + for _, match := range matches { + if len(match) < 2 { + continue + } + + promptID, err := strconv.ParseInt(match[1], 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid prompt ID in reference: %s", match[1]) + } + + version := "" + if len(match) > 2 { + version = strings.TrimSpace(match[2]) + } + + refs = append(refs, &SnippetReference{ + PromptID: promptID, + CommitVersion: version, + }) + } + + return refs, nil +} + +// SerializeReference serializes a snippet reference back to cozeloop format +func (p *CozeLoopSnippetParser) SerializeReference(ref *SnippetReference) string { + return fmt.Sprintf("id=%d&version=%s", ref.PromptID, ref.CommitVersion) +} diff --git a/backend/modules/prompt/domain/service/snippet_parser_test.go b/backend/modules/prompt/domain/service/snippet_parser_test.go new file mode 100755 index 000000000..b1e3c6ec5 --- /dev/null +++ b/backend/modules/prompt/domain/service/snippet_parser_test.go @@ -0,0 +1,108 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package service + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCozeLoopSnippetParser_ParseReferences(t *testing.T) { + t.Parallel() + tests := []struct { + name string + content string + want []*SnippetReference + wantErr bool + }{ + { + name: "empty content", + content: "", + want: nil, + }, + { + name: "single reference with version", + content: "prefix id=123&version=v1 suffix", + want: []*SnippetReference{{ + PromptID: 123, + CommitVersion: "v1", + }}, + }, + { + name: "multiple references", + content: "id=1&version=v1 text id=2&version=v2", + want: []*SnippetReference{ + {PromptID: 1, CommitVersion: "v1"}, + {PromptID: 2, CommitVersion: "v2"}, + }, + }, + { + name: "reference without version", + content: "id=5&version=", + want: []*SnippetReference{{ + PromptID: 5, + CommitVersion: "", + }}, + }, + { + name: "non matching pattern", + content: "id=abc&version=v1", + want: nil, + }, + } + + for _, tt := range tests { + ttt := tt + t.Run(ttt.name, func(t *testing.T) { + t.Parallel() + parser := NewCozeLoopSnippetParser() + + refs, err := parser.ParseReferences(ttt.content) + if ttt.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, ttt.want, refs) + }) + } +} + +func TestCozeLoopSnippetParser_SerializeReference(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ref *SnippetReference + want string + }{ + { + name: "with version", + ref: &SnippetReference{ + PromptID: 10, + CommitVersion: "v1", + }, + want: "id=10&version=v1", + }, + { + name: "empty version", + ref: &SnippetReference{ + PromptID: 20, + CommitVersion: "", + }, + want: "id=20&version=", + }, + } + + for _, tt := range tests { + ttt := tt + t.Run(ttt.name, func(t *testing.T) { + t.Parallel() + parser := NewCozeLoopSnippetParser() + got := parser.SerializeReference(ttt.ref) + assert.Equal(t, ttt.want, got) + }) + } +} diff --git a/backend/modules/prompt/infra/repo/manage.go b/backend/modules/prompt/infra/repo/manage.go index 92629779b..e75a3ddeb 100644 --- a/backend/modules/prompt/infra/repo/manage.go +++ b/backend/modules/prompt/infra/repo/manage.go @@ -18,6 +18,7 @@ import ( "github.com/coze-dev/coze-loop/backend/infra/metrics" "github.com/coze-dev/coze-loop/backend/modules/prompt/domain/entity" "github.com/coze-dev/coze-loop/backend/modules/prompt/domain/repo" + "github.com/coze-dev/coze-loop/backend/modules/prompt/domain/service" metricsinfra "github.com/coze-dev/coze-loop/backend/modules/prompt/infra/metrics" "github.com/coze-dev/coze-loop/backend/modules/prompt/infra/repo/mysql" "github.com/coze-dev/coze-loop/backend/modules/prompt/infra/repo/mysql/convertor" @@ -38,6 +39,7 @@ type ManageRepoImpl struct { promptCommitDAO mysql.IPromptCommitDAO promptDraftDAO mysql.IPromptUserDraftDAO commitLabelMappingDAO mysql.ICommitLabelMappingDAO + promptRelationDAO mysql.IPromptRelationDAO promptBasicCacheDAO redis.IPromptBasicDAO promptCacheDAO redis.IPromptDAO @@ -53,6 +55,7 @@ func NewManageRepo( promptCommitDao mysql.IPromptCommitDAO, promptDraftDao mysql.IPromptUserDraftDAO, commitLabelMappingDAO mysql.ICommitLabelMappingDAO, + promptRelationDAO mysql.IPromptRelationDAO, promptBasicCacheDAO redis.IPromptBasicDAO, promptCacheDAO redis.IPromptDAO, ) repo.IManageRepo { @@ -63,6 +66,7 @@ func NewManageRepo( promptCommitDAO: promptCommitDao, promptDraftDAO: promptDraftDao, commitLabelMappingDAO: commitLabelMappingDAO, + promptRelationDAO: promptRelationDAO, promptBasicCacheDAO: promptBasicCacheDAO, promptCacheDAO: promptCacheDAO, promptCacheMetrics: metricsinfra.NewPromptCacheMetrics(meter), @@ -106,6 +110,46 @@ func (d *ManageRepoImpl) CreatePrompt(ctx context.Context, promptDO *entity.Prom } } + // Handle snippet relations if prompt contains snippets + if promptDO.PromptDraft != nil && promptDO.PromptDraft.PromptDetail != nil && + promptDO.PromptDraft.PromptDetail.PromptTemplate != nil && + promptDO.PromptDraft.PromptDetail.PromptTemplate.HasSnippets && + len(promptDO.PromptDraft.PromptDetail.PromptTemplate.Snippets) > 0 { + + snippets := promptDO.PromptDraft.PromptDetail.PromptTemplate.Snippets + relations := make([]*model.PromptRelation, 0, len(snippets)) + relationIDs, err := d.idgen.GenMultiIDs(ctx, len(snippets)) + if err != nil { + return err + } + for i, snippet := range snippets { + if snippet == nil { + continue + } + var snippetVersion string + if snippet.PromptCommit != nil && snippet.PromptCommit.CommitInfo != nil { + snippetVersion = snippet.PromptCommit.CommitInfo.Version + } + + relation := &model.PromptRelation{ + ID: relationIDs[i], + SpaceID: promptDO.SpaceID, + MainPromptID: promptID, + MainPromptVersion: "", // Empty for draft + MainDraftUserID: promptDO.PromptDraft.DraftInfo.UserID, + SubPromptID: snippet.ID, + SubPromptVersion: snippetVersion, + } + relations = append(relations, relation) + } + + if len(relations) > 0 { + if err := d.promptRelationDAO.BatchCreate(ctx, relations, opt); err != nil { + return err + } + } + } + return nil }) } @@ -402,12 +446,20 @@ func (d *ManageRepoImpl) ListPrompt(ctx context.Context, param repo.ListPromptPa return nil, errorx.New("param(SpaceID or PageNum or PageSize) is invalid, param = %s", json.Jsonify(param)) } + // Convert PromptType slice to string slice + var promptTypes []string + for _, pt := range param.FilterPromptTypes { + promptTypes = append(promptTypes, string(pt)) + } + listBasicParam := mysql.ListPromptBasicParam{ SpaceID: param.SpaceID, KeyWord: param.KeyWord, CreatedBys: param.CreatedBys, CommittedOnly: param.CommittedOnly, + PromptTypes: promptTypes, + PromptIDs: param.PromptIDs, Offset: (param.PageNum - 1) * param.PageSize, Limit: param.PageSize, @@ -528,6 +580,12 @@ func (d *ManageRepoImpl) SaveDraft(ctx context.Context, promptDO *entity.Prompt) draftInfo = convertor.DraftPO2DO(createdDraftPO).DraftInfo } + // 使用统一的方法管理 snippet relations(无需区分创建/更新场景) + err = d.manageDraftSnippetRelations(ctx, promptDO, userID, basicPO.SpaceID, opt) + if err != nil { + return err + } + return nil } @@ -561,6 +619,13 @@ func (d *ManageRepoImpl) SaveDraft(ctx context.Context, promptDO *entity.Prompt) draftInfo = convertor.DraftPO2DO(updatedDraftPO).DraftInfo } + // Handle snippet relationships incrementally for update scenario + // 使用统一的方法管理 snippet relations(无需区分创建/更新场景) + err = d.manageDraftSnippetRelations(ctx, promptDO, userID, basicPO.SpaceID, opt) + if err != nil { + return err + } + return nil }) if err != nil { @@ -570,6 +635,123 @@ func (d *ManageRepoImpl) SaveDraft(ctx context.Context, promptDO *entity.Prompt) return draftInfo, nil } +// manageDraftSnippetRelations统一管理snippet关系,统一处理创建和更新场景 +// 只要has_snippets为true,就查询已有的relation,和当前草稿嵌入的片段对比, +// 判断哪些需要增加,哪些需要删除,哪些保持不动 +func (d *ManageRepoImpl) manageDraftSnippetRelations(ctx context.Context, promptDO *entity.Prompt, userID string, spaceID int64, opt db.Option) error { + if promptDO == nil || promptDO.PromptDraft == nil || promptDO.PromptDraft.PromptDetail == nil { + return nil + } + + promptDetail := promptDO.PromptDraft.PromptDetail + if promptDetail.PromptTemplate == nil { + return nil + } + + hasSnippets := promptDetail.PromptTemplate.HasSnippets + + // 如果没有片段,删除所有现有关系 + if !hasSnippets { + return d.promptRelationDAO.DeleteByMainPrompt(ctx, promptDO.ID, "", userID, opt) + } + + // 统一处理:查询已有的relation,和当前草稿嵌入的片段对比 + // 判断哪些需要增加,哪些需要删除,哪些保持不动 + + // 获取当前草稿中嵌入的片段引用(包含版本信息) + currentSnippetRefs := make(map[service.SnippetReference]bool) + if promptDetail.PromptTemplate.Snippets != nil { + for _, snippet := range promptDetail.PromptTemplate.Snippets { + if snippet != nil && snippet.ID > 0 { + // 从snippet中获取版本信息 + var version string + if snippet.PromptCommit != nil && snippet.PromptCommit.CommitInfo != nil { + version = snippet.PromptCommit.CommitInfo.Version + } + currentSnippetRefs[service.SnippetReference{ + PromptID: snippet.ID, + CommitVersion: version, + }] = true + } + } + } + + // 查询已有的relation + existingRelations, err := d.promptRelationDAO.List(ctx, mysql.ListPromptRelationParam{ + MainPromptID: &promptDO.ID, + MainDraftUserID: &userID, + }, opt) + if err != nil { + return err + } + + // 构建现有关系的复合key映射 + existingRelationMap := make(map[service.SnippetReference]*model.PromptRelation) + for _, relation := range existingRelations { + key := service.SnippetReference{ + PromptID: relation.SubPromptID, + CommitVersion: relation.SubPromptVersion, + } + existingRelationMap[key] = relation + } + + // 确定需要删除和添加的关系 + var relationsToDelete []int64 + var relationsToAdd []service.SnippetReference + + // 找出需要删除的关系(存在于DB但不在当前片段中) + for key, existingRelation := range existingRelationMap { + if !currentSnippetRefs[key] { + relationsToDelete = append(relationsToDelete, existingRelation.ID) + } + } + + // 找出需要添加的关系(存在于当前片段但不在DB中) + for ref := range currentSnippetRefs { + if _, exists := existingRelationMap[ref]; !exists { + relationsToAdd = append(relationsToAdd, ref) + } + } + + // 删除不再需要的关系 + if len(relationsToDelete) > 0 { + err = d.promptRelationDAO.BatchDeleteByIDs(ctx, relationsToDelete, opt) + if err != nil { + return err + } + } + + // 添加新的关系 + if len(relationsToAdd) > 0 { + ids, err := d.idgen.GenMultiIDs(ctx, len(relationsToAdd)) + if err != nil { + return err + } + var newRelationPOs []*model.PromptRelation + for i, ref := range relationsToAdd { + relationPO := &model.PromptRelation{ + ID: ids[i], + SpaceID: spaceID, + MainPromptID: promptDO.ID, + MainPromptVersion: "", // Empty for draft + MainDraftUserID: userID, + SubPromptID: ref.PromptID, + SubPromptVersion: ref.CommitVersion, + } + newRelationPOs = append(newRelationPOs, relationPO) + } + + if len(newRelationPOs) > 0 { + err = d.promptRelationDAO.BatchCreate(ctx, newRelationPOs, opt) + if err != nil { + return err + } + } + } + + return nil +} + func (d *ManageRepoImpl) CommitDraft(ctx context.Context, param repo.CommitDraftParam) (err error) { if param.PromptID <= 0 || lo.IsEmpty(param.UserID) || lo.IsEmpty(param.CommitVersion) { return errorx.New("param(PromptID or UserID or CommitVersion) is invalid, param = %s", json.Jsonify(param)) @@ -632,11 +814,65 @@ func (d *ManageRepoImpl) CommitDraft(ctx context.Context, param repo.CommitDraft err = d.promptBasicDAO.Update(ctx, basicPO.ID, map[string]interface{}{ q.PromptBasic.LatestCommitTime.ColumnName().String(): timeNow, q.PromptBasic.LatestVersion.ColumnName().String(): param.CommitVersion, + q.PromptBasic.UpdatedBy.ColumnName().String(): param.UserID, }, opt) if err != nil { return err } + // 只有在草稿包含snippet时才处理relation拷贝 + if draftDO.PromptDetail != nil && draftDO.PromptDetail.PromptTemplate != nil && + draftDO.PromptDetail.PromptTemplate.HasSnippets { + + // 拷贝草稿的relation到提交版本 + // 1. 查询草稿的所有relation + draftRelations, err := d.promptRelationDAO.List(ctx, mysql.ListPromptRelationParam{ + MainPromptID: ¶m.PromptID, + MainDraftUserID: ¶m.UserID, + }, opt) + if err != nil { + return err + } + + // 2. 如果有草稿relation,拷贝到提交版本 + if len(draftRelations) > 0 { + relationIDs, err := d.idgen.GenMultiIDs(ctx, len(draftRelations)) + if err != nil { + return err + } + + var commitRelations []*model.PromptRelation + for i, draftRelation := range draftRelations { + commitRelation := &model.PromptRelation{ + ID: relationIDs[i], + SpaceID: draftRelation.SpaceID, + MainPromptID: draftRelation.MainPromptID, + MainPromptVersion: param.CommitVersion, // 使用提交版本号 + MainDraftUserID: "", // 提交版本没有草稿用户ID + SubPromptID: draftRelation.SubPromptID, + SubPromptVersion: draftRelation.SubPromptVersion, + } + commitRelations = append(commitRelations, commitRelation) + } + + // 批量创建提交版本的relation + err = d.promptRelationDAO.BatchCreate(ctx, commitRelations, opt) + if err != nil { + return err + } + + // 3. 删除草稿的relation + draftRelationIDs := make([]int64, 0, len(draftRelations)) + for _, relation := range draftRelations { + draftRelationIDs = append(draftRelationIDs, relation.ID) + } + err = d.promptRelationDAO.BatchDeleteByIDs(ctx, draftRelationIDs, opt) + if err != nil { + return err + } + } + } + // 提交版本绑定label // 根据prompt_id和label_keys查询现有的标签映射 labelExistMappings, err := d.commitLabelMappingDAO.ListByPromptIDAndLabelKeys(ctx, param.PromptID, param.LabelKeys, opt) @@ -736,9 +972,109 @@ func (d *ManageRepoImpl) ListCommitInfo(ctx context.Context, param repo.ListComm commitInfoDOs := convertor.BatchGetCommitInfoDOFromCommitDO(commitDOs) if len(commitPOs) <= param.PageSize { result.CommitInfoDOs = commitInfoDOs + result.CommitDOs = commitDOs return result, nil } result.NextPageToken = commitPOs[param.PageSize].ID result.CommitInfoDOs = commitInfoDOs[:len(commitPOs)-1] + result.CommitDOs = commitDOs[:len(commitPOs)-1] + return result, nil +} + +func (d *ManageRepoImpl) MGetVersionsByPromptID(ctx context.Context, promptID int64) ([]string, error) { + if promptID <= 0 { + return nil, errorx.New("promptID is invalid, promptID = %d", promptID) + } + + versions, err := d.promptCommitDAO.MGetVersionsByPromptID(ctx, promptID) + if err != nil { + return nil, err + } + return versions, nil +} + +func (d *ManageRepoImpl) ListParentPrompt(ctx context.Context, param repo.ListParentPromptParam) (result map[string][]*repo.PromptCommitVersions, err error) { + if param.SubPromptID <= 0 { + return nil, errorx.New("param(SubPromptID) is invalid, param = %s", json.Jsonify(param)) + } + + // Query prompt relations by sub-prompt ID + listRelationParam := mysql.ListPromptRelationParam{ + SubPromptID: ¶m.SubPromptID, + SubPromptVersions: param.SubPromptVersions, + } + + relations, err := d.promptRelationDAO.List(ctx, listRelationParam) + if err != nil { + return nil, err + } + + if len(relations) == 0 { + return nil, nil + } + + // Group relations by sub-prompt version + relationsBySubVersion := make(map[string][]*model.PromptRelation) + for _, relation := range relations { + // filer draft + if relation.MainPromptVersion == "" { + continue + } + subVersion := relation.SubPromptVersion + relationsBySubVersion[subVersion] = append(relationsBySubVersion[subVersion], relation) + } + + // Collect all main prompt IDs to batch query + getMainPromptPram := make([]repo.GetPromptParam, 0) + mainPromptMap := make(map[int64]bool) + for _, relations := range relationsBySubVersion { + for _, relation := range relations { + if !mainPromptMap[relation.MainPromptID] { + mainPromptMap[relation.MainPromptID] = true + getMainPromptPram = append(getMainPromptPram, repo.GetPromptParam{ + PromptID: relation.MainPromptID, + }) + } + } + } + + // Query all main prompt basic info + mainPromptBasics, err := d.MGetPrompt(ctx, getMainPromptPram) + if err != nil { + return nil, err + } + + if len(mainPromptBasics) <= 0 { + return nil, nil + } + + // Build result map + result = make(map[string][]*repo.PromptCommitVersions) + // Organize results by sub-prompt version + for subVersion, relations := range relationsBySubVersion { + promptCommitVersions := make([]*repo.PromptCommitVersions, 0, len(mainPromptBasics)) + + for _, prompt := range mainPromptBasics { + promptCommitVersion := &repo.PromptCommitVersions{ + PromptID: prompt.ID, + SpaceID: prompt.SpaceID, + PromptKey: prompt.PromptKey, + PromptBasic: prompt.PromptBasic, + } + for _, relation := range relations { + if prompt.ID == relation.MainPromptID { + promptCommitVersion.CommitVersions = append(promptCommitVersion.CommitVersions, relation.MainPromptVersion) + } + } + if len(promptCommitVersion.CommitVersions) > 0 { + promptCommitVersions = append(promptCommitVersions, promptCommitVersion) + } + } + + if len(promptCommitVersions) > 0 { + result[subVersion] = promptCommitVersions + } + } + return result, nil } diff --git a/backend/modules/prompt/infra/repo/manage_test.go b/backend/modules/prompt/infra/repo/manage_test.go index 8ce158568..2a01403be 100644 --- a/backend/modules/prompt/infra/repo/manage_test.go +++ b/backend/modules/prompt/infra/repo/manage_test.go @@ -109,9 +109,10 @@ func TestManageRepoImpl_MGetPrompt(t *testing.T) { mockBasicDAO := daomocks.NewMockIPromptBasicDAO(ctrl) mockBasicDAO.EXPECT().MGet(gomock.Any(), []int64{123}, gomock.Any()).Return(map[int64]*model.PromptBasic{ 123: { - ID: 123, - SpaceID: 123456, - PromptKey: "test_key_1", + ID: 123, + SpaceID: 123456, + PromptKey: "test_key_1", + PromptType: "normal", }, }, nil) mockDraftDAO := daomocks.NewMockIPromptUserDraftDAO(ctrl) @@ -144,10 +145,12 @@ func TestManageRepoImpl_MGetPrompt(t *testing.T) { WithDraft: true, UserID: "111222", }: { - ID: 123, - SpaceID: 123456, - PromptKey: "test_key_1", - PromptBasic: &entity.PromptBasic{}, + ID: 123, + SpaceID: 123456, + PromptKey: "test_key_1", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, PromptDraft: &entity.PromptDraft{ PromptDetail: &entity.PromptDetail{ PromptTemplate: &entity.PromptTemplate{}, @@ -168,10 +171,12 @@ func TestManageRepoImpl_MGetPrompt(t *testing.T) { WithCommit: true, CommitVersion: "1.0.0", }: { - ID: 123, - SpaceID: 123456, - PromptKey: "test_key_1", - PromptBasic: &entity.PromptBasic{}, + ID: 123, + SpaceID: 123456, + PromptKey: "test_key_1", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, PromptCommit: &entity.PromptCommit{ PromptDetail: &entity.PromptDetail{ PromptTemplate: &entity.PromptTemplate{}, @@ -247,10 +252,12 @@ func TestManageRepoImpl_MGetPrompt(t *testing.T) { WithCommit: true, CommitVersion: "1.0.0", }: { - ID: 123, - SpaceID: 123456, - PromptKey: "test_key_1", - PromptBasic: &entity.PromptBasic{}, + ID: 123, + SpaceID: 123456, + PromptKey: "test_key_1", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, PromptCommit: &entity.PromptCommit{ PromptDetail: &entity.PromptDetail{ PromptTemplate: &entity.PromptTemplate{}, @@ -263,10 +270,12 @@ func TestManageRepoImpl_MGetPrompt(t *testing.T) { WithCommit: true, CommitVersion: "1.0.0", }: { - ID: 456, - SpaceID: 123456, - PromptKey: "test_key_2", - PromptBasic: &entity.PromptBasic{}, + ID: 456, + SpaceID: 123456, + PromptKey: "test_key_2", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, PromptCommit: &entity.PromptCommit{ PromptDetail: &entity.PromptDetail{ PromptTemplate: &entity.PromptTemplate{}, @@ -305,10 +314,12 @@ func TestManageRepoImpl_MGetPrompt(t *testing.T) { { PromptID: 123, }: { - ID: 123, - SpaceID: 123456, - PromptKey: "test_key", - PromptBasic: &entity.PromptBasic{}, + ID: 123, + SpaceID: 123456, + PromptKey: "test_key", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, }, }, wantErr: nil, @@ -370,10 +381,12 @@ func TestManageRepoImpl_MGetPrompt(t *testing.T) { WithCommit: true, CommitVersion: "1.0.0", }: { - ID: 123, - SpaceID: 123456, - PromptKey: "test_key", - PromptBasic: &entity.PromptBasic{}, + ID: 123, + SpaceID: 123456, + PromptKey: "test_key", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, PromptCommit: &entity.PromptCommit{ PromptDetail: &entity.PromptDetail{ PromptTemplate: &entity.PromptTemplate{ @@ -541,16 +554,20 @@ func TestManageRepoImpl_MGetPromptBasicByPromptKey(t *testing.T) { }, want: []*entity.Prompt{ { - ID: 123, - SpaceID: 123, - PromptKey: "test_key_1", - PromptBasic: &entity.PromptBasic{}, + ID: 123, + SpaceID: 123, + PromptKey: "test_key_1", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, }, { - ID: 456, - SpaceID: 123, - PromptKey: "test_key_2", - PromptBasic: &entity.PromptBasic{}, + ID: 456, + SpaceID: 123, + PromptKey: "test_key_2", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, }, }, wantErr: nil, @@ -595,16 +612,20 @@ func TestManageRepoImpl_MGetPromptBasicByPromptKey(t *testing.T) { }, want: []*entity.Prompt{ { - ID: 123, - SpaceID: 123, - PromptKey: "test_key_1", - PromptBasic: &entity.PromptBasic{}, + ID: 123, + SpaceID: 123, + PromptKey: "test_key_1", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, }, { - ID: 456, - SpaceID: 123, - PromptKey: "test_key_2", - PromptBasic: &entity.PromptBasic{}, + ID: 456, + SpaceID: 123, + PromptKey: "test_key_2", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, }, }, wantErr: nil, @@ -643,10 +664,12 @@ func TestManageRepoImpl_MGetPromptBasicByPromptKey(t *testing.T) { mockCacheDAO := redismocks.NewMockIPromptBasicDAO(ctrl) mockCacheDAO.EXPECT().MGetByPromptKey(gomock.Any(), int64(123), []string{"test_key_1", "test_key_2"}).Return(map[string]*entity.Prompt{ "test_key_1": { - ID: 123, - SpaceID: 123, - PromptKey: "test_key_1", - PromptBasic: &entity.PromptBasic{}, + ID: 123, + SpaceID: 123, + PromptKey: "test_key_1", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, }, }, nil) @@ -679,16 +702,20 @@ func TestManageRepoImpl_MGetPromptBasicByPromptKey(t *testing.T) { }, want: []*entity.Prompt{ { - ID: 123, - SpaceID: 123, - PromptKey: "test_key_1", - PromptBasic: &entity.PromptBasic{}, + ID: 123, + SpaceID: 123, + PromptKey: "test_key_1", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, }, { - ID: 456, - SpaceID: 123, - PromptKey: "test_key_2", - PromptBasic: &entity.PromptBasic{}, + ID: 456, + SpaceID: 123, + PromptKey: "test_key_2", + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, }, }, wantErr: nil, @@ -722,10 +749,11 @@ func TestManageRepoImpl_MGetPromptBasicByPromptKey(t *testing.T) { func TestManageRepoImpl_GetPrompt(t *testing.T) { type fields struct { - db db.Provider - promptBasicDAO mysql.IPromptBasicDAO - promptCommitDAO mysql.IPromptCommitDAO - promptDraftDAO mysql.IPromptUserDraftDAO + db db.Provider + promptBasicDAO mysql.IPromptBasicDAO + promptCommitDAO mysql.IPromptCommitDAO + promptDraftDAO mysql.IPromptUserDraftDAO + promptRelationDAO mysql.IPromptRelationDAO } type args struct { ctx context.Context @@ -882,6 +910,7 @@ func TestManageRepoImpl_GetPrompt(t *testing.T) { SpaceID: 100, PromptKey: "test_key", PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, DisplayName: "test_name", Description: "test_description", CreatedBy: "test_creator", @@ -931,10 +960,11 @@ func TestManageRepoImpl_GetPrompt(t *testing.T) { }, nil) return fields{ - db: mockDB, - promptBasicDAO: mockBasicDAO, - promptCommitDAO: mockCommitDAO, - promptDraftDAO: mockDraftDAO, + db: mockDB, + promptBasicDAO: mockBasicDAO, + promptCommitDAO: mockCommitDAO, + promptDraftDAO: mockDraftDAO, + promptRelationDAO: nil, } }, args: args{ @@ -952,6 +982,7 @@ func TestManageRepoImpl_GetPrompt(t *testing.T) { SpaceID: 100, PromptKey: "test_key", PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, DisplayName: "test_name", Description: "test_description", CreatedBy: "test_creator", @@ -1108,10 +1139,11 @@ func TestManageRepoImpl_GetPrompt(t *testing.T) { ttFields := tt.fieldsGetter(ctrl) d := &ManageRepoImpl{ - db: ttFields.db, - promptBasicDAO: ttFields.promptBasicDAO, - promptCommitDAO: ttFields.promptCommitDAO, - promptDraftDAO: ttFields.promptDraftDAO, + db: ttFields.db, + promptBasicDAO: ttFields.promptBasicDAO, + promptCommitDAO: ttFields.promptCommitDAO, + promptDraftDAO: ttFields.promptDraftDAO, + promptRelationDAO: ttFields.promptRelationDAO, } got, err := d.GetPrompt(tt.args.ctx, tt.args.param) @@ -1271,6 +1303,32 @@ func TestManageRepoImpl_ListCommitInfo(t *testing.T) { CommittedAt: time.Unix(2000, 0), }, }, + CommitDOs: []*entity.PromptCommit{ + { + CommitInfo: &entity.CommitInfo{ + Version: "1.0.0", + BaseVersion: "0.9.0", + Description: "test commit 1", + CommittedBy: "test_user", + CommittedAt: time.Unix(1000, 0), + }, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{}, + }, + }, + { + CommitInfo: &entity.CommitInfo{ + Version: "1.1.0", + BaseVersion: "1.0.0", + Description: "test commit 2", + CommittedBy: "test_user", + CommittedAt: time.Unix(2000, 0), + }, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{}, + }, + }, + }, }, wantErr: nil, }, @@ -1338,6 +1396,32 @@ func TestManageRepoImpl_ListCommitInfo(t *testing.T) { }, }, NextPageToken: 3, + CommitDOs: []*entity.PromptCommit{ + { + CommitInfo: &entity.CommitInfo{ + Version: "1.0.0", + BaseVersion: "0.9.0", + Description: "test commit 1", + CommittedBy: "test_user", + CommittedAt: time.Unix(1000, 0), + }, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{}, + }, + }, + { + CommitInfo: &entity.CommitInfo{ + Version: "1.1.0", + BaseVersion: "1.0.0", + Description: "test commit 2", + CommittedBy: "test_user", + CommittedAt: time.Unix(2000, 0), + }, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{}, + }, + }, + }, }, wantErr: nil, }, @@ -1399,6 +1483,32 @@ func TestManageRepoImpl_ListCommitInfo(t *testing.T) { CommittedAt: time.Unix(4000, 0), }, }, + CommitDOs: []*entity.PromptCommit{ + { + CommitInfo: &entity.CommitInfo{ + Version: "1.2.0", + BaseVersion: "1.1.0", + Description: "test commit 3", + CommittedBy: "test_user", + CommittedAt: time.Unix(3000, 0), + }, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{}, + }, + }, + { + CommitInfo: &entity.CommitInfo{ + Version: "1.3.0", + BaseVersion: "1.2.0", + Description: "test commit 4", + CommittedBy: "test_user", + CommittedAt: time.Unix(4000, 0), + }, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{}, + }, + }, + }, }, wantErr: nil, }, @@ -1426,11 +1536,12 @@ func TestManageRepoImpl_ListCommitInfo(t *testing.T) { func TestManageRepoImpl_SaveDraft(t *testing.T) { type fields struct { - db db.Provider - promptBasicDAO mysql.IPromptBasicDAO - promptCommitDAO mysql.IPromptCommitDAO - promptDraftDAO mysql.IPromptUserDraftDAO - idgen idgen.IIDGenerator + db db.Provider + promptBasicDAO mysql.IPromptBasicDAO + promptCommitDAO mysql.IPromptCommitDAO + promptDraftDAO mysql.IPromptUserDraftDAO + promptRelationDAO mysql.IPromptRelationDAO + idgen idgen.IIDGenerator } type args struct { ctx context.Context @@ -1641,6 +1752,156 @@ func TestManageRepoImpl_SaveDraft(t *testing.T) { // }, // wantErr: errorx.New("Draft's base version is invalid, saving draft's base version = 1.0.0, original draft's base version = 0.9.0 "), //}, + { + name: "create draft with snippets", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockDB := dbmocks.NewMockProvider(ctrl) + mockDB.EXPECT().Transaction(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, fc func(*gorm.DB) error, opts ...db.Option) error { + return fc(nil) + }) + + mockBasicDAO := daomocks.NewMockIPromptBasicDAO(ctrl) + mockBasicDAO.EXPECT().Get(gomock.Any(), int64(1), gomock.Any()).Return(&model.PromptBasic{ + ID: 1, + SpaceID: 100, + }, nil) + + mockDraftDAO := daomocks.NewMockIPromptUserDraftDAO(ctrl) + mockDraftDAO.EXPECT().Get(gomock.Any(), int64(1), "test_user", gomock.Any()).Return(nil, nil) + + mockIDGen := idgenmocks.NewMockIIDGenerator(ctrl) + mockIDGen.EXPECT().GenID(gomock.Any()).Return(int64(1001), nil) + + mockDraftDAO.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, draft *model.PromptUserDraft, opts ...db.Option) error { + assert.Equal(t, int64(100), draft.SpaceID) + assert.Equal(t, "test_user", draft.UserID) + assert.True(t, draft.HasSnippets) + return nil + }) + + mockDraftDAO.EXPECT().GetByID(gomock.Any(), int64(1001), gomock.Any()).Return(&model.PromptUserDraft{ + ID: 1001, + UserID: "test_user", + }, nil) + + mockRelationDAO := daomocks.NewMockIPromptRelationDAO(ctrl) + mockRelationDAO.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any()).Return([]*model.PromptRelation{}, nil) + + mockIDGen.EXPECT().GenMultiIDs(gomock.Any(), 2).Return([]int64{4001, 4002}, nil) + mockRelationDAO.EXPECT().BatchCreate(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, relations []*model.PromptRelation, opts ...db.Option) error { + assert.Len(t, relations, 2) + for _, relation := range relations { + assert.Equal(t, int64(1), relation.MainPromptID) + assert.Equal(t, "", relation.MainPromptVersion) + assert.Equal(t, "test_user", relation.MainDraftUserID) + assert.Equal(t, int64(100), relation.SpaceID) + } + assert.Equal(t, int64(200), relations[0].SubPromptID) + assert.Equal(t, "v1", relations[0].SubPromptVersion) + assert.Equal(t, int64(201), relations[1].SubPromptID) + assert.Equal(t, "", relations[1].SubPromptVersion) + return nil + }) + + return fields{ + db: mockDB, + promptBasicDAO: mockBasicDAO, + promptDraftDAO: mockDraftDAO, + promptRelationDAO: mockRelationDAO, + idgen: mockIDGen, + } + }, + args: args{ + ctx: context.Background(), + promptDO: &entity.Prompt{ + ID: 1, + SpaceID: 100, + PromptDraft: &entity.PromptDraft{ + DraftInfo: &entity.DraftInfo{ + UserID: "test_user", + }, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: true, + Snippets: []*entity.Prompt{ + { + ID: 200, + PromptCommit: &entity.PromptCommit{ + CommitInfo: &entity.CommitInfo{ + Version: "v1", + }, + }, + }, + { + ID: 201, + }, + }, + }, + }, + }, + }, + }, + wantErr: nil, + }, + { + name: "create draft with snippets relation error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockDB := dbmocks.NewMockProvider(ctrl) + mockDB.EXPECT().Transaction(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, fc func(*gorm.DB) error, opts ...db.Option) error { + return fc(nil) + }) + + mockBasicDAO := daomocks.NewMockIPromptBasicDAO(ctrl) + mockBasicDAO.EXPECT().Get(gomock.Any(), int64(1), gomock.Any()).Return(&model.PromptBasic{ + ID: 1, + SpaceID: 100, + }, nil) + + mockDraftDAO := daomocks.NewMockIPromptUserDraftDAO(ctrl) + mockDraftDAO.EXPECT().Get(gomock.Any(), int64(1), "test_user", gomock.Any()).Return(nil, nil) + + mockIDGen := idgenmocks.NewMockIIDGenerator(ctrl) + mockIDGen.EXPECT().GenID(gomock.Any()).Return(int64(1001), nil) + + mockDraftDAO.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockDraftDAO.EXPECT().GetByID(gomock.Any(), int64(1001), gomock.Any()).Return(&model.PromptUserDraft{ + ID: 1001, + UserID: "test_user", + }, nil) + + mockRelationDAO := daomocks.NewMockIPromptRelationDAO(ctrl) + mockRelationDAO.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errorx.New("relation list error")) + + return fields{ + db: mockDB, + promptBasicDAO: mockBasicDAO, + promptDraftDAO: mockDraftDAO, + promptRelationDAO: mockRelationDAO, + idgen: mockIDGen, + } + }, + args: args{ + ctx: context.Background(), + promptDO: &entity.Prompt{ + ID: 1, + SpaceID: 100, + PromptDraft: &entity.PromptDraft{ + DraftInfo: &entity.DraftInfo{ + UserID: "test_user", + }, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: true, + Snippets: []*entity.Prompt{ + {ID: 200}, + }, + }, + }, + }, + }, + }, + wantErr: errorx.New("relation list error"), + }, { name: "update draft with no changes", fieldsGetter: func(ctrl *gomock.Controller) fields { @@ -1668,10 +1929,11 @@ func TestManageRepoImpl_SaveDraft(t *testing.T) { }, nil) return fields{ - db: mockDB, - promptBasicDAO: mockBasicDAO, - promptCommitDAO: mockCommitDAO, - promptDraftDAO: mockDraftDAO, + db: mockDB, + promptBasicDAO: mockBasicDAO, + promptCommitDAO: mockCommitDAO, + promptDraftDAO: mockDraftDAO, + promptRelationDAO: nil, } }, args: args{ @@ -1727,11 +1989,16 @@ func TestManageRepoImpl_SaveDraft(t *testing.T) { return nil }) + mockRelationDAO := daomocks.NewMockIPromptRelationDAO(ctrl) + mockRelationDAO.EXPECT().DeleteByMainPrompt(gomock.Any(), int64(1), "", "test_user", gomock.Any()).Return(nil) + return fields{ - db: mockDB, - promptBasicDAO: mockBasicDAO, - promptCommitDAO: mockCommitDAO, - promptDraftDAO: mockDraftDAO, + db: mockDB, + promptBasicDAO: mockBasicDAO, + promptCommitDAO: mockCommitDAO, + promptDraftDAO: mockDraftDAO, + promptRelationDAO: mockRelationDAO, + idgen: nil, } }, args: args{ @@ -1760,19 +2027,134 @@ func TestManageRepoImpl_SaveDraft(t *testing.T) { wantErr: nil, }, { - name: "db error", + name: "update draft with snippets", fieldsGetter: func(ctrl *gomock.Controller) fields { mockDB := dbmocks.NewMockProvider(ctrl) - mockDB.EXPECT().Transaction(gomock.Any(), gomock.Any()).Return(errorx.New("db error")) + mockDB.EXPECT().Transaction(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, fc func(*gorm.DB) error, opts ...db.Option) error { + return fc(nil) + }) + + mockBasicDAO := daomocks.NewMockIPromptBasicDAO(ctrl) + mockBasicDAO.EXPECT().Get(gomock.Any(), int64(1), gomock.Any()).Return(&model.PromptBasic{ + ID: 1, + SpaceID: 100, + }, nil) + + mockCommitDAO := daomocks.NewMockIPromptCommitDAO(ctrl) + mockCommitDAO.EXPECT().Get(gomock.Any(), int64(1), "v1", gomock.Any()).Return(&model.PromptCommit{ + Version: "v1", + }, nil) + + mockDraftDAO := daomocks.NewMockIPromptUserDraftDAO(ctrl) + mockDraftDAO.EXPECT().Get(gomock.Any(), int64(1), "test_user", gomock.Any()).Return(&model.PromptUserDraft{ + ID: 1001, + BaseVersion: "v1", + }, nil) + mockDraftDAO.EXPECT().Update(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, draft *model.PromptUserDraft, opts ...db.Option) error { + assert.Equal(t, int64(1001), draft.ID) + assert.True(t, draft.HasSnippets) + return nil + }) + mockDraftDAO.EXPECT().GetByID(gomock.Any(), int64(1001), gomock.Any()).Return(&model.PromptUserDraft{ + ID: 1001, + UserID: "test_user", + }, nil) + + mockRelationDAO := daomocks.NewMockIPromptRelationDAO(ctrl) + mockRelationDAO.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any()).Return([]*model.PromptRelation{ + { + ID: 2001, + SpaceID: 100, + MainPromptID: 1, + MainPromptVersion: "", + MainDraftUserID: "test_user", + SubPromptID: 200, + SubPromptVersion: "v1", + }, + { + ID: 2002, + SpaceID: 100, + MainPromptID: 1, + MainPromptVersion: "", + MainDraftUserID: "test_user", + SubPromptID: 201, + SubPromptVersion: "old", + }, + }, nil) + mockRelationDAO.EXPECT().BatchDeleteByIDs(gomock.Any(), []int64{2002}, gomock.Any()).Return(nil) + + mockIDGen := idgenmocks.NewMockIIDGenerator(ctrl) + mockIDGen.EXPECT().GenMultiIDs(gomock.Any(), 1).Return([]int64{4001}, nil) + mockRelationDAO.EXPECT().BatchCreate(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, relations []*model.PromptRelation, opts ...db.Option) error { + assert.Len(t, relations, 1) + assert.Equal(t, int64(1), relations[0].MainPromptID) + assert.Equal(t, "test_user", relations[0].MainDraftUserID) + assert.Equal(t, int64(202), relations[0].SubPromptID) + assert.Equal(t, "v2", relations[0].SubPromptVersion) + return nil + }) return fields{ - db: mockDB, + db: mockDB, + promptBasicDAO: mockBasicDAO, + promptCommitDAO: mockCommitDAO, + promptDraftDAO: mockDraftDAO, + promptRelationDAO: mockRelationDAO, + idgen: mockIDGen, } }, args: args{ ctx: context.Background(), promptDO: &entity.Prompt{ - ID: 1, + ID: 1, + SpaceID: 100, + PromptDraft: &entity.PromptDraft{ + DraftInfo: &entity.DraftInfo{ + UserID: "test_user", + BaseVersion: "v1", + }, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: true, + Snippets: []*entity.Prompt{ + { + ID: 200, + PromptCommit: &entity.PromptCommit{ + CommitInfo: &entity.CommitInfo{ + Version: "v1", + }, + }, + }, + { + ID: 202, + PromptCommit: &entity.PromptCommit{ + CommitInfo: &entity.CommitInfo{ + Version: "v2", + }, + }, + }, + }, + }, + }, + }, + }, + }, + wantErr: nil, + }, + { + name: "db error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockDB := dbmocks.NewMockProvider(ctrl) + mockDB.EXPECT().Transaction(gomock.Any(), gomock.Any()).Return(errorx.New("db error")) + + return fields{ + db: mockDB, + } + }, + args: args{ + ctx: context.Background(), + promptDO: &entity.Prompt{ + ID: 1, PromptDraft: &entity.PromptDraft{ DraftInfo: &entity.DraftInfo{ UserID: "test_user", @@ -2000,10 +2382,11 @@ func TestManageRepoImpl_SaveDraft(t *testing.T) { mockDraftDAO.EXPECT().Update(gomock.Any(), gomock.Any(), gomock.Any()).Return(errorx.New("update error")) return fields{ - db: mockDB, - promptBasicDAO: mockBasicDAO, - promptCommitDAO: mockCommitDAO, - promptDraftDAO: mockDraftDAO, + db: mockDB, + promptBasicDAO: mockBasicDAO, + promptCommitDAO: mockCommitDAO, + promptDraftDAO: mockDraftDAO, + promptRelationDAO: nil, } }, args: args{ @@ -2041,11 +2424,12 @@ func TestManageRepoImpl_SaveDraft(t *testing.T) { ttFields := tt.fieldsGetter(ctrl) d := &ManageRepoImpl{ - db: ttFields.db, - promptBasicDAO: ttFields.promptBasicDAO, - promptCommitDAO: ttFields.promptCommitDAO, - promptDraftDAO: ttFields.promptDraftDAO, - idgen: ttFields.idgen, + db: ttFields.db, + promptBasicDAO: ttFields.promptBasicDAO, + promptCommitDAO: ttFields.promptCommitDAO, + promptDraftDAO: ttFields.promptDraftDAO, + promptRelationDAO: ttFields.promptRelationDAO, + idgen: ttFields.idgen, } _, err := d.SaveDraft(tt.args.ctx, tt.args.promptDO) @@ -2056,10 +2440,11 @@ func TestManageRepoImpl_SaveDraft(t *testing.T) { func TestManageRepoImpl_CreatePrompt(t *testing.T) { type fields struct { - db db.Provider - promptBasicDAO mysql.IPromptBasicDAO - promptDraftDAO mysql.IPromptUserDraftDAO - idgen idgen.IIDGenerator + db db.Provider + promptBasicDAO mysql.IPromptBasicDAO + promptDraftDAO mysql.IPromptUserDraftDAO + promptRelationDAO mysql.IPromptRelationDAO + idgen idgen.IIDGenerator } type args struct { ctx context.Context @@ -2111,7 +2496,9 @@ func TestManageRepoImpl_CreatePrompt(t *testing.T) { args: args{ ctx: context.Background(), promptDO: &entity.Prompt{ - PromptBasic: &entity.PromptBasic{}, + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, }, }, want: 0, @@ -2131,7 +2518,9 @@ func TestManageRepoImpl_CreatePrompt(t *testing.T) { args: args{ ctx: context.Background(), promptDO: &entity.Prompt{ - PromptBasic: &entity.PromptBasic{}, + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, PromptDraft: &entity.PromptDraft{ DraftInfo: &entity.DraftInfo{ UserID: "test_user", @@ -2159,7 +2548,9 @@ func TestManageRepoImpl_CreatePrompt(t *testing.T) { args: args{ ctx: context.Background(), promptDO: &entity.Prompt{ - PromptBasic: &entity.PromptBasic{}, + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, }, }, want: 0, @@ -2191,7 +2582,9 @@ func TestManageRepoImpl_CreatePrompt(t *testing.T) { args: args{ ctx: context.Background(), promptDO: &entity.Prompt{ - PromptBasic: &entity.PromptBasic{}, + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, }, }, want: 0, @@ -2232,7 +2625,9 @@ func TestManageRepoImpl_CreatePrompt(t *testing.T) { args: args{ ctx: context.Background(), promptDO: &entity.Prompt{ - PromptBasic: &entity.PromptBasic{}, + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, PromptDraft: &entity.PromptDraft{ DraftInfo: &entity.DraftInfo{ UserID: "test_user", @@ -2269,7 +2664,9 @@ func TestManageRepoImpl_CreatePrompt(t *testing.T) { args: args{ ctx: context.Background(), promptDO: &entity.Prompt{ - PromptBasic: &entity.PromptBasic{}, + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, }, }, want: 1001, @@ -2310,11 +2707,257 @@ func TestManageRepoImpl_CreatePrompt(t *testing.T) { args: args{ ctx: context.Background(), promptDO: &entity.Prompt{ - PromptBasic: &entity.PromptBasic{}, + PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, + }, + PromptDraft: &entity.PromptDraft{ + DraftInfo: &entity.DraftInfo{ + UserID: "test_user", + }, + }, + }, + }, + want: 1001, + wantErr: nil, + }, + { + name: "create prompt with snippets success", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockIDGen := idgenmocks.NewMockIIDGenerator(ctrl) + mockIDGen.EXPECT().GenID(gomock.Any()).Return(int64(1001), nil) + mockIDGen.EXPECT().GenID(gomock.Any()).Return(int64(2001), nil) + mockIDGen.EXPECT().GenMultiIDs(gomock.Any(), 2).Return([]int64{3001, 3002}, nil) + + mockDB := dbmocks.NewMockProvider(ctrl) + mockDB.EXPECT().Transaction(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, fc func(*gorm.DB) error, opts ...db.Option) error { + return fc(nil) + }) + + mockBasicDAO := daomocks.NewMockIPromptBasicDAO(ctrl) + mockBasicDAO.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, basic *model.PromptBasic, opts ...db.Option) error { + assert.Equal(t, int64(1001), basic.ID) + return nil + }) + + mockDraftDAO := daomocks.NewMockIPromptUserDraftDAO(ctrl) + mockDraftDAO.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, draft *model.PromptUserDraft, opts ...db.Option) error { + assert.Equal(t, int64(2001), draft.ID) + assert.Equal(t, int64(1001), draft.PromptID) + return nil + }) + + mockRelationDAO := daomocks.NewMockIPromptRelationDAO(ctrl) + mockRelationDAO.EXPECT().BatchCreate(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, relations []*model.PromptRelation, opts ...db.Option) error { + assert.Len(t, relations, 2) + assert.Equal(t, int64(1001), relations[0].MainPromptID) + assert.Equal(t, int64(200), relations[0].SubPromptID) + assert.Equal(t, int64(1001), relations[1].MainPromptID) + assert.Equal(t, int64(201), relations[1].SubPromptID) + assert.Equal(t, "test_user", relations[0].MainDraftUserID) + assert.Equal(t, "test_user", relations[1].MainDraftUserID) + return nil + }) + + return fields{ + db: mockDB, + promptBasicDAO: mockBasicDAO, + promptDraftDAO: mockDraftDAO, + promptRelationDAO: mockRelationDAO, + idgen: mockIDGen, + } + }, + args: args{ + ctx: context.Background(), + promptDO: &entity.Prompt{ + SpaceID: 123, + PromptBasic: &entity.PromptBasic{ + DisplayName: "Test Prompt", + Description: "Test Description", + }, + PromptDraft: &entity.PromptDraft{ + DraftInfo: &entity.DraftInfo{ + UserID: "test_user", + }, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: true, + Snippets: []*entity.Prompt{ + {ID: 200}, + {ID: 201}, + }, + }, + }, + }, + }, + }, + want: 1001, + wantErr: nil, + }, + { + name: "create prompt with snippets - skip nil snippets", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockIDGen := idgenmocks.NewMockIIDGenerator(ctrl) + mockIDGen.EXPECT().GenID(gomock.Any()).Return(int64(1001), nil) + mockIDGen.EXPECT().GenID(gomock.Any()).Return(int64(2001), nil) + mockIDGen.EXPECT().GenMultiIDs(gomock.Any(), 3).Return([]int64{3001, 3002, 3003}, nil) + + mockDB := dbmocks.NewMockProvider(ctrl) + mockDB.EXPECT().Transaction(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, fc func(*gorm.DB) error, opts ...db.Option) error { + return fc(nil) + }) + + mockBasicDAO := daomocks.NewMockIPromptBasicDAO(ctrl) + mockBasicDAO.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + + mockDraftDAO := daomocks.NewMockIPromptUserDraftDAO(ctrl) + mockDraftDAO.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + + // No relation creation expected since valid snippets are 0 + mockRelationDAO := daomocks.NewMockIPromptRelationDAO(ctrl) + mockRelationDAO.EXPECT().BatchCreate(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, relations []*model.PromptRelation, opts ...db.Option) error { + assert.Len(t, relations, 2) // 2 valid snippets (ID: 0 and ID: 202) + return nil + }) + + return fields{ + db: mockDB, + promptBasicDAO: mockBasicDAO, + promptDraftDAO: mockDraftDAO, + promptRelationDAO: mockRelationDAO, + idgen: mockIDGen, + } + }, + args: args{ + ctx: context.Background(), + promptDO: &entity.Prompt{ + SpaceID: 123, + PromptBasic: &entity.PromptBasic{ + DisplayName: "Test Prompt", + Description: "Test Description", + }, + PromptDraft: &entity.PromptDraft{ + DraftInfo: &entity.DraftInfo{ + UserID: "test_user", + }, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: true, + Snippets: []*entity.Prompt{ + nil, + {ID: 0}, // Invalid ID + {ID: 202}, // Valid + }, + }, + }, + }, + }, + }, + want: 1001, + wantErr: nil, + }, + { + name: "create prompt with snippets - relation dao error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockIDGen := idgenmocks.NewMockIIDGenerator(ctrl) + mockIDGen.EXPECT().GenID(gomock.Any()).Return(int64(1001), nil) + mockIDGen.EXPECT().GenID(gomock.Any()).Return(int64(2001), nil) + mockIDGen.EXPECT().GenMultiIDs(gomock.Any(), 1).Return([]int64{3001}, nil) + + mockDB := dbmocks.NewMockProvider(ctrl) + mockDB.EXPECT().Transaction(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, fc func(*gorm.DB) error, opts ...db.Option) error { + return fc(nil) + }) + + mockBasicDAO := daomocks.NewMockIPromptBasicDAO(ctrl) + mockBasicDAO.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + + mockDraftDAO := daomocks.NewMockIPromptUserDraftDAO(ctrl) + mockDraftDAO.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + + mockRelationDAO := daomocks.NewMockIPromptRelationDAO(ctrl) + mockRelationDAO.EXPECT().BatchCreate(gomock.Any(), gomock.Any(), gomock.Any()).Return(errorx.New("relation dao error")) + + return fields{ + db: mockDB, + promptBasicDAO: mockBasicDAO, + promptDraftDAO: mockDraftDAO, + promptRelationDAO: mockRelationDAO, + idgen: mockIDGen, + } + }, + args: args{ + ctx: context.Background(), + promptDO: &entity.Prompt{ + SpaceID: 123, + PromptBasic: &entity.PromptBasic{ + DisplayName: "Test Prompt", + Description: "Test Description", + }, + PromptDraft: &entity.PromptDraft{ + DraftInfo: &entity.DraftInfo{ + UserID: "test_user", + }, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: true, + Snippets: []*entity.Prompt{ + {ID: 200}, + }, + }, + }, + }, + }, + }, + want: 0, + wantErr: errorx.New("relation dao error"), + }, + { + name: "create prompt without snippets - no relation creation", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockIDGen := idgenmocks.NewMockIIDGenerator(ctrl) + mockIDGen.EXPECT().GenID(gomock.Any()).Return(int64(1001), nil) + mockIDGen.EXPECT().GenID(gomock.Any()).Return(int64(2001), nil) + + mockDB := dbmocks.NewMockProvider(ctrl) + mockDB.EXPECT().Transaction(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, fc func(*gorm.DB) error, opts ...db.Option) error { + return fc(nil) + }) + + mockBasicDAO := daomocks.NewMockIPromptBasicDAO(ctrl) + mockBasicDAO.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + + mockDraftDAO := daomocks.NewMockIPromptUserDraftDAO(ctrl) + mockDraftDAO.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + + // No relation creation expected + mockRelationDAO := daomocks.NewMockIPromptRelationDAO(ctrl) + + return fields{ + db: mockDB, + promptBasicDAO: mockBasicDAO, + promptDraftDAO: mockDraftDAO, + promptRelationDAO: mockRelationDAO, + idgen: mockIDGen, + } + }, + args: args{ + ctx: context.Background(), + promptDO: &entity.Prompt{ + SpaceID: 123, + PromptBasic: &entity.PromptBasic{ + DisplayName: "Test Prompt", + Description: "Test Description", + }, PromptDraft: &entity.PromptDraft{ DraftInfo: &entity.DraftInfo{ UserID: "test_user", }, + PromptDetail: &entity.PromptDetail{ + PromptTemplate: &entity.PromptTemplate{ + HasSnippets: false, + Snippets: []*entity.Prompt{}, + }, + }, }, }, }, @@ -2331,10 +2974,11 @@ func TestManageRepoImpl_CreatePrompt(t *testing.T) { ttFields := tt.fieldsGetter(ctrl) d := &ManageRepoImpl{ - db: ttFields.db, - promptBasicDAO: ttFields.promptBasicDAO, - promptDraftDAO: ttFields.promptDraftDAO, - idgen: ttFields.idgen, + db: ttFields.db, + promptBasicDAO: ttFields.promptBasicDAO, + promptDraftDAO: ttFields.promptDraftDAO, + promptRelationDAO: ttFields.promptRelationDAO, + idgen: ttFields.idgen, } got, err := d.CreatePrompt(tt.args.ctx, tt.args.promptDO) @@ -2356,6 +3000,7 @@ func TestManageRepoImpl_CommitDraft(t *testing.T) { commitLabelMappingDAO mysql.ICommitLabelMappingDAO promptBasicCacheDAO redis.IPromptBasicDAO promptCacheDAO redis.IPromptDAO + promptRelationDAO mysql.IPromptRelationDAO } type args struct { ctx context.Context @@ -2372,6 +3017,7 @@ func TestManageRepoImpl_CommitDraft(t *testing.T) { fieldsGetter: func(ctrl *gomock.Controller) fields { return fields{ commitLabelMappingDAO: daomocks.NewMockICommitLabelMappingDAO(ctrl), + promptRelationDAO: daomocks.NewMockIPromptRelationDAO(ctrl), } }, args: args{ @@ -2386,7 +3032,10 @@ func TestManageRepoImpl_CommitDraft(t *testing.T) { { name: "invalid user id", fieldsGetter: func(ctrl *gomock.Controller) fields { - return fields{commitLabelMappingDAO: daomocks.NewMockICommitLabelMappingDAO(ctrl)} + return fields{ + commitLabelMappingDAO: daomocks.NewMockICommitLabelMappingDAO(ctrl), + promptRelationDAO: daomocks.NewMockIPromptRelationDAO(ctrl), + } }, args: args{ ctx: context.Background(), @@ -2400,7 +3049,10 @@ func TestManageRepoImpl_CommitDraft(t *testing.T) { { name: "invalid commit version", fieldsGetter: func(ctrl *gomock.Controller) fields { - return fields{commitLabelMappingDAO: daomocks.NewMockICommitLabelMappingDAO(ctrl)} + return fields{ + commitLabelMappingDAO: daomocks.NewMockICommitLabelMappingDAO(ctrl), + promptRelationDAO: daomocks.NewMockIPromptRelationDAO(ctrl), + } }, args: args{ ctx: context.Background(), @@ -3366,6 +4018,177 @@ func TestManageRepoImpl_CommitDraft(t *testing.T) { }, wantErr: nil, }, + { + name: "commit with snippets - hasSnippets true", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockIDGen := idgenmocks.NewMockIIDGenerator(ctrl) + mockIDGen.EXPECT().GenID(gomock.Any()).Return(int64(1001), nil) + mockIDGen.EXPECT().GenMultiIDs(gomock.Any(), 0).Return([]int64{}, nil) + + mockDB := dbmocks.NewMockProvider(ctrl) + mockDB.EXPECT().Transaction(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, fc func(*gorm.DB) error, opts ...db.Option) error { + return fc(nil) + }) + + nilDB, _ := gorm.Open(nil) + mockDB.EXPECT().NewSession(gomock.Any(), gomock.Any()).Return(nilDB) + + mockBasicDAO := daomocks.NewMockIPromptBasicDAO(ctrl) + mockBasicDAO.EXPECT().Get(gomock.Any(), int64(1), gomock.Any()).Return(&model.PromptBasic{ + ID: 1, + SpaceID: 100, + PromptKey: "test_key", + LatestVersion: "1.0.0", + }, nil) + + // 创建包含snippet的草稿 + draftWithSnippets := &model.PromptUserDraft{ + ID: 1001, + BaseVersion: "1.0.0", + HasSnippets: true, + } + + mockDraftDAO := daomocks.NewMockIPromptUserDraftDAO(ctrl) + mockDraftDAO.EXPECT().Get(gomock.Any(), int64(1), "test_user", gomock.Any()).Return(draftWithSnippets, nil) + + mockCommitDAO := daomocks.NewMockIPromptCommitDAO(ctrl) + mockCommitDAO.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, commit *model.PromptCommit, timeNow time.Time, opts ...db.Option) error { + assert.Equal(t, int64(1001), commit.ID) + assert.Equal(t, "2.0.0", commit.Version) + assert.Equal(t, "1.0.0", commit.BaseVersion) + assert.Equal(t, "test_user", commit.CommittedBy) + assert.True(t, commit.HasSnippets, "HasSnippets should be true for drafts with snippets") + return nil + }) + + mockDraftDAO.EXPECT().Delete(gomock.Any(), int64(1001), gomock.Any()).Return(nil) + + mockBasicDAO.EXPECT().Update(gomock.Any(), int64(1), gomock.Any(), gomock.Any()).Return(nil) + + mockPromptBasicCacheDAO := redismocks.NewMockIPromptBasicDAO(ctrl) + mockPromptBasicCacheDAO.EXPECT().DelByPromptKey(gomock.Any(), int64(100), "test_key").Return(nil) + + mockCommitLabelMappingDAO := daomocks.NewMockICommitLabelMappingDAO(ctrl) + mockCommitLabelMappingDAO.EXPECT().ListByPromptIDAndLabelKeys(gomock.Any(), int64(1), gomock.Any(), gomock.Any()).Return(nil, nil) + + // Mock snippet relation operations + mockRelationDAO := daomocks.NewMockIPromptRelationDAO(ctrl) + mockRelationDAO.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any()).Return([]*model.PromptRelation{ + { + ID: 2001, + SpaceID: 100, + MainPromptID: 1, + MainPromptVersion: "", + MainDraftUserID: "test_user", + SubPromptID: 2, + SubPromptVersion: "1.0.0", + }, + }, nil) + mockIDGen.EXPECT().GenMultiIDs(gomock.Any(), 1).Return([]int64{3001}, nil) + mockRelationDAO.EXPECT().BatchCreate(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockRelationDAO.EXPECT().BatchDeleteByIDs(gomock.Any(), []int64{2001}, gomock.Any()).Return(nil) + + return fields{ + db: mockDB, + idgen: mockIDGen, + promptBasicDAO: mockBasicDAO, + promptCommitDAO: mockCommitDAO, + promptDraftDAO: mockDraftDAO, + commitLabelMappingDAO: mockCommitLabelMappingDAO, + promptBasicCacheDAO: mockPromptBasicCacheDAO, + promptCacheDAO: redismocks.NewMockIPromptDAO(ctrl), + promptRelationDAO: mockRelationDAO, + } + }, + args: args{ + ctx: context.Background(), + param: repo.CommitDraftParam{ + PromptID: 1, + UserID: "test_user", + CommitVersion: "2.0.0", + CommitDescription: "commit with snippets", + }, + }, + wantErr: nil, + }, + { + name: "commit without snippets - hasSnippets false", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockIDGen := idgenmocks.NewMockIIDGenerator(ctrl) + mockIDGen.EXPECT().GenID(gomock.Any()).Return(int64(1001), nil) + mockIDGen.EXPECT().GenMultiIDs(gomock.Any(), 0).Return([]int64{}, nil) + + mockDB := dbmocks.NewMockProvider(ctrl) + mockDB.EXPECT().Transaction(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, fc func(*gorm.DB) error, opts ...db.Option) error { + return fc(nil) + }) + + nilDB, _ := gorm.Open(nil) + mockDB.EXPECT().NewSession(gomock.Any(), gomock.Any()).Return(nilDB) + + mockBasicDAO := daomocks.NewMockIPromptBasicDAO(ctrl) + mockBasicDAO.EXPECT().Get(gomock.Any(), int64(1), gomock.Any()).Return(&model.PromptBasic{ + ID: 1, + SpaceID: 100, + PromptKey: "test_key", + LatestVersion: "1.0.0", + }, nil) + + // 创建不包含snippet的草稿 + draftWithoutSnippets := &model.PromptUserDraft{ + ID: 1001, + BaseVersion: "1.0.0", + HasSnippets: false, + } + + mockDraftDAO := daomocks.NewMockIPromptUserDraftDAO(ctrl) + mockDraftDAO.EXPECT().Get(gomock.Any(), int64(1), "test_user", gomock.Any()).Return(draftWithoutSnippets, nil) + + mockCommitDAO := daomocks.NewMockIPromptCommitDAO(ctrl) + mockCommitDAO.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, commit *model.PromptCommit, timeNow time.Time, opts ...db.Option) error { + assert.Equal(t, int64(1001), commit.ID) + assert.Equal(t, "2.0.0", commit.Version) + assert.Equal(t, "1.0.0", commit.BaseVersion) + assert.Equal(t, "test_user", commit.CommittedBy) + assert.False(t, commit.HasSnippets, "HasSnippets should be false for drafts without snippets") + return nil + }) + + mockDraftDAO.EXPECT().Delete(gomock.Any(), int64(1001), gomock.Any()).Return(nil) + + mockBasicDAO.EXPECT().Update(gomock.Any(), int64(1), gomock.Any(), gomock.Any()).Return(nil) + + mockPromptBasicCacheDAO := redismocks.NewMockIPromptBasicDAO(ctrl) + mockPromptBasicCacheDAO.EXPECT().DelByPromptKey(gomock.Any(), int64(100), "test_key").Return(nil) + + mockRelationDAO := daomocks.NewMockIPromptRelationDAO(ctrl) + + mockCommitLabelMappingDAO := daomocks.NewMockICommitLabelMappingDAO(ctrl) + mockCommitLabelMappingDAO.EXPECT().ListByPromptIDAndLabelKeys(gomock.Any(), int64(1), gomock.Any(), gomock.Any()).Return(nil, nil) + + return fields{ + db: mockDB, + idgen: mockIDGen, + promptBasicDAO: mockBasicDAO, + promptCommitDAO: mockCommitDAO, + promptDraftDAO: mockDraftDAO, + commitLabelMappingDAO: mockCommitLabelMappingDAO, + promptBasicCacheDAO: mockPromptBasicCacheDAO, + promptCacheDAO: redismocks.NewMockIPromptDAO(ctrl), + promptRelationDAO: mockRelationDAO, + } + }, + args: args{ + ctx: context.Background(), + param: repo.CommitDraftParam{ + PromptID: 1, + UserID: "test_user", + CommitVersion: "2.0.0", + CommitDescription: "commit without snippets", + }, + }, + wantErr: nil, + }, } for _, tt := range tests { @@ -3384,6 +4207,7 @@ func TestManageRepoImpl_CommitDraft(t *testing.T) { commitLabelMappingDAO: ttFields.commitLabelMappingDAO, promptBasicCacheDAO: ttFields.promptBasicCacheDAO, promptCacheDAO: ttFields.promptCacheDAO, + promptRelationDAO: ttFields.promptRelationDAO, } err := d.CommitDraft(tt.args.ctx, tt.args.param) @@ -3406,6 +4230,7 @@ func TestNewManageRepo(t *testing.T) { mockCommitLabelMappingDAO := daomocks.NewMockICommitLabelMappingDAO(ctrl) mockPromptBasicCacheDAO := redismocks.NewMockIPromptBasicDAO(ctrl) mockPromptCacheDAO := redismocks.NewMockIPromptDAO(ctrl) + mockPromptRelationDAO := daomocks.NewMockIPromptRelationDAO(ctrl) // 调用构造函数 // 调用构造函数 @@ -3417,6 +4242,7 @@ func TestNewManageRepo(t *testing.T) { mockPromptCommitDAO, mockPromptDraftDAO, mockCommitLabelMappingDAO, + mockPromptRelationDAO, mockPromptBasicCacheDAO, mockPromptCacheDAO, ) @@ -3468,7 +4294,7 @@ func TestManageRepoImpl_ListPrompt(t *testing.T) { }, }, want: nil, - wantErr: errorx.New("param(SpaceID or PageNum or PageSize) is invalid, param = {\"SpaceID\":0,\"KeyWord\":\"\",\"CreatedBys\":null,\"UserID\":\"\",\"CommittedOnly\":false,\"PageNum\":1,\"PageSize\":10,\"OrderBy\":0,\"Asc\":false}"), + wantErr: errorx.New("param(SpaceID or PageNum or PageSize) is invalid, param = {\"SpaceID\":0,\"KeyWord\":\"\",\"CreatedBys\":null,\"UserID\":\"\",\"CommittedOnly\":false,\"FilterPromptTypes\":null,\"PromptIDs\":null,\"PageNum\":1,\"PageSize\":10,\"OrderBy\":0,\"Asc\":false}"), }, { name: "invalid page num", @@ -3484,7 +4310,7 @@ func TestManageRepoImpl_ListPrompt(t *testing.T) { }, }, want: nil, - wantErr: errorx.New("param(SpaceID or PageNum or PageSize) is invalid, param = {\"SpaceID\":123,\"KeyWord\":\"\",\"CreatedBys\":null,\"UserID\":\"\",\"CommittedOnly\":false,\"PageNum\":0,\"PageSize\":10,\"OrderBy\":0,\"Asc\":false}"), + wantErr: errorx.New("param(SpaceID or PageNum or PageSize) is invalid, param = {\"SpaceID\":123,\"KeyWord\":\"\",\"CreatedBys\":null,\"UserID\":\"\",\"CommittedOnly\":false,\"FilterPromptTypes\":null,\"PromptIDs\":null,\"PageNum\":0,\"PageSize\":10,\"OrderBy\":0,\"Asc\":false}"), }, { name: "invalid page size", @@ -3500,7 +4326,7 @@ func TestManageRepoImpl_ListPrompt(t *testing.T) { }, }, want: nil, - wantErr: errorx.New("param(SpaceID or PageNum or PageSize) is invalid, param = {\"SpaceID\":123,\"KeyWord\":\"\",\"CreatedBys\":null,\"UserID\":\"\",\"CommittedOnly\":false,\"PageNum\":1,\"PageSize\":0,\"OrderBy\":0,\"Asc\":false}"), + wantErr: errorx.New("param(SpaceID or PageNum or PageSize) is invalid, param = {\"SpaceID\":123,\"KeyWord\":\"\",\"CreatedBys\":null,\"UserID\":\"\",\"CommittedOnly\":false,\"FilterPromptTypes\":null,\"PromptIDs\":null,\"PageNum\":1,\"PageSize\":0,\"OrderBy\":0,\"Asc\":false}"), }, { name: "empty result", @@ -3595,7 +4421,7 @@ func TestManageRepoImpl_ListPrompt(t *testing.T) { ID: 1001, SpaceID: 123, PromptKey: "test_key_1", - PromptBasic: &entity.PromptBasic{DisplayName: "Test Prompt 1"}, + PromptBasic: &entity.PromptBasic{PromptType: entity.PromptTypeNormal, DisplayName: "Test Prompt 1"}, PromptDraft: &entity.PromptDraft{ PromptDetail: &entity.PromptDetail{ PromptTemplate: &entity.PromptTemplate{}, @@ -3609,7 +4435,7 @@ func TestManageRepoImpl_ListPrompt(t *testing.T) { ID: 1002, SpaceID: 123, PromptKey: "test_key_2", - PromptBasic: &entity.PromptBasic{DisplayName: "Test Prompt 2"}, + PromptBasic: &entity.PromptBasic{PromptType: entity.PromptTypeNormal, DisplayName: "Test Prompt 2"}, }, }, }, @@ -3663,7 +4489,7 @@ func TestManageRepoImpl_ListPrompt(t *testing.T) { ID: 1001, SpaceID: 123, PromptKey: "test_key_1", - PromptBasic: &entity.PromptBasic{DisplayName: "Test search_term Prompt"}, + PromptBasic: &entity.PromptBasic{PromptType: entity.PromptTypeNormal, DisplayName: "Test search_term Prompt"}, }, }, }, @@ -3718,7 +4544,7 @@ func TestManageRepoImpl_ListPrompt(t *testing.T) { ID: 1001, SpaceID: 123, PromptKey: "test_key_1", - PromptBasic: &entity.PromptBasic{DisplayName: "Test Prompt 1", CreatedBy: "user1"}, + PromptBasic: &entity.PromptBasic{PromptType: entity.PromptTypeNormal, DisplayName: "Test Prompt 1", CreatedBy: "user1"}, }, }, }, @@ -3774,7 +4600,7 @@ func TestManageRepoImpl_ListPrompt(t *testing.T) { ID: 1001, SpaceID: 123, PromptKey: "test_key_1", - PromptBasic: &entity.PromptBasic{DisplayName: "Test Prompt 1"}, + PromptBasic: &entity.PromptBasic{PromptType: entity.PromptTypeNormal, DisplayName: "Test Prompt 1"}, }, }, }, @@ -3869,3 +4695,197 @@ func TestManageRepoImpl_ListPrompt(t *testing.T) { }) } } + +func TestManageRepoImpl_ListParentPrompt(t *testing.T) { + t.Parallel() + type fields struct { + promptRelationDAO mysql.IPromptRelationDAO + promptBasicDAO mysql.IPromptBasicDAO + } + type args struct { + ctx context.Context + param repo.ListParentPromptParam + } + tests := []struct { + name string + fieldsGetter func(ctrl *gomock.Controller) fields + args args + wantErr error + wantErrMsg string + check func(t *testing.T, got map[string][]*repo.PromptCommitVersions) + }{ + { + name: "invalid sub prompt id", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{} + }, + args: args{ + ctx: context.Background(), + param: repo.ListParentPromptParam{ + SubPromptID: 0, + }, + }, + wantErrMsg: "param(SubPromptID) is invalid", + }, + { + name: "relation dao error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockRelationDAO := daomocks.NewMockIPromptRelationDAO(ctrl) + mockRelationDAO.EXPECT().List(gomock.Any(), gomock.Any()).Return(nil, errorx.New("list error")) + return fields{ + promptRelationDAO: mockRelationDAO, + } + }, + args: args{ + ctx: context.Background(), + param: repo.ListParentPromptParam{ + SubPromptID: 200, + }, + }, + wantErr: errorx.New("list error"), + }, + { + name: "no relations", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockRelationDAO := daomocks.NewMockIPromptRelationDAO(ctrl) + mockRelationDAO.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*model.PromptRelation{}, nil) + return fields{ + promptRelationDAO: mockRelationDAO, + } + }, + args: args{ + ctx: context.Background(), + param: repo.ListParentPromptParam{ + SubPromptID: 200, + }, + }, + check: func(t *testing.T, got map[string][]*repo.PromptCommitVersions) { + assert.Nil(t, got) + }, + }, + { + name: "list success with versions", + fieldsGetter: func(ctrl *gomock.Controller) fields { + relations := []*model.PromptRelation{ + { + ID: 1, + SpaceID: 1, + MainPromptID: 101, + MainPromptVersion: "1.0.0", + SubPromptID: 200, + SubPromptVersion: "v1", + }, + { + ID: 2, + SpaceID: 1, + MainPromptID: 101, + MainPromptVersion: "", + SubPromptID: 200, + SubPromptVersion: "v1", + }, + { + ID: 3, + SpaceID: 1, + MainPromptID: 102, + MainPromptVersion: "2.0.0", + SubPromptID: 200, + SubPromptVersion: "v2", + }, + { + ID: 4, + SpaceID: 1, + MainPromptID: 102, + MainPromptVersion: "2.0.1", + SubPromptID: 200, + SubPromptVersion: "v2", + }, + } + mockRelationDAO := daomocks.NewMockIPromptRelationDAO(ctrl) + mockRelationDAO.EXPECT().List(gomock.Any(), gomock.Any()).Return(relations, nil) + + mockBasicDAO := daomocks.NewMockIPromptBasicDAO(ctrl) + mockBasicDAO.EXPECT().MGet(gomock.Any(), gomock.Any()).Return(map[int64]*model.PromptBasic{ + 101: { + ID: 101, + SpaceID: 1, + PromptKey: "parent_a", + PromptType: string(entity.PromptTypeNormal), + }, + 102: { + ID: 102, + SpaceID: 1, + PromptKey: "parent_b", + PromptType: string(entity.PromptTypeSnippet), + }, + }, nil) + + return fields{ + promptRelationDAO: mockRelationDAO, + promptBasicDAO: mockBasicDAO, + } + }, + args: args{ + ctx: context.Background(), + param: repo.ListParentPromptParam{ + SubPromptID: 200, + SubPromptVersions: []string{"v1", "v2"}, + }, + }, + check: func(t *testing.T, got map[string][]*repo.PromptCommitVersions) { + assert.Len(t, got, 2) + v1List, ok := got["v1"] + assert.True(t, ok) + assert.Len(t, v1List, 1) + v1 := v1List[0] + assert.Equal(t, int64(101), v1.PromptID) + assert.Equal(t, int64(1), v1.SpaceID) + assert.Equal(t, "parent_a", v1.PromptKey) + if assert.NotNil(t, v1.PromptBasic) { + assert.Equal(t, entity.PromptTypeNormal, v1.PromptBasic.PromptType) + } + assert.Equal(t, []string{"1.0.0"}, v1.CommitVersions) + + v2List, ok := got["v2"] + assert.True(t, ok) + assert.Len(t, v2List, 1) + v2 := v2List[0] + assert.Equal(t, int64(102), v2.PromptID) + assert.Equal(t, "parent_b", v2.PromptKey) + if assert.NotNil(t, v2.PromptBasic) { + assert.Equal(t, entity.PromptTypeSnippet, v2.PromptBasic.PromptType) + } + assert.ElementsMatch(t, []string{"2.0.0", "2.0.1"}, v2.CommitVersions) + }, + }, + } + + for _, tt := range tests { + ttt := tt + t.Run(ttt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + fields := ttt.fieldsGetter(ctrl) + repoImpl := &ManageRepoImpl{ + promptRelationDAO: fields.promptRelationDAO, + promptBasicDAO: fields.promptBasicDAO, + } + + got, err := repoImpl.ListParentPrompt(ttt.args.ctx, ttt.args.param) + if ttt.wantErr != nil { + unittest.AssertErrorEqual(t, ttt.wantErr, err) + return + } + if ttt.wantErrMsg != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), ttt.wantErrMsg) + return + } + assert.NoError(t, err) + if ttt.check != nil { + ttt.check(t, got) + } + }) + } +} diff --git a/backend/modules/prompt/infra/repo/mysql/convertor/debug_context_test.go b/backend/modules/prompt/infra/repo/mysql/convertor/debug_context_test.go index 9086cf1a5..329afbfdc 100644 --- a/backend/modules/prompt/infra/repo/mysql/convertor/debug_context_test.go +++ b/backend/modules/prompt/infra/repo/mysql/convertor/debug_context_test.go @@ -100,6 +100,7 @@ func TestDebugContextDO2PO(t *testing.T) { PromptDetail: &entity.PromptDetail{ PromptTemplate: &entity.PromptTemplate{ TemplateType: entity.TemplateTypeNormal, + HasSnippets: false, Messages: []*entity.Message{ { Role: entity.RoleSystem, @@ -115,7 +116,7 @@ func TestDebugContextDO2PO(t *testing.T) { expected: &model.PromptDebugContext{ PromptID: 1, UserID: "test_user", - CompareConfig: ptr.Of(`{"groups":[{"prompt_detail":{"prompt_template":{"template_type":"normal","messages":[{"role":"system","content":"test content"}]}}}]}`), + CompareConfig: ptr.Of(`{"groups":[{"prompt_detail":{"prompt_template":{"template_type":"normal","messages":[{"role":"system","content":"test content"}],"has_snippets":false}}}]}`), }, wantErr: false, }, @@ -245,7 +246,7 @@ func TestDebugContextPO2DO(t *testing.T) { po: &model.PromptDebugContext{ PromptID: 1, UserID: "test_user", - CompareConfig: ptr.Of(`{"groups":[{"prompt_detail":{"prompt_template":{"template_type":"normal","messages":[{"role":"system","content":"test content"}]}}}]}`), + CompareConfig: ptr.Of(`{"groups":[{"prompt_detail":{"prompt_template":{"template_type":"normal","messages":[{"role":"system","content":"test content"}],"has_snippets":false}}}]}`), }, expected: &entity.DebugContext{ PromptID: 1, @@ -257,6 +258,7 @@ func TestDebugContextPO2DO(t *testing.T) { PromptDetail: &entity.PromptDetail{ PromptTemplate: &entity.PromptTemplate{ TemplateType: entity.TemplateTypeNormal, + HasSnippets: false, Messages: []*entity.Message{ { Role: entity.RoleSystem, diff --git a/backend/modules/prompt/infra/repo/mysql/convertor/manage.go b/backend/modules/prompt/infra/repo/mysql/convertor/manage.go index 6aea313fa..988a969e9 100644 --- a/backend/modules/prompt/infra/repo/mysql/convertor/manage.go +++ b/backend/modules/prompt/infra/repo/mysql/convertor/manage.go @@ -71,6 +71,7 @@ func BasicPO2DO(promptPO *model.PromptBasic) *entity.PromptBasic { return nil } return &entity.PromptBasic{ + PromptType: PromptTypePO2DO(promptPO.PromptType), DisplayName: promptPO.Name, Description: promptPO.Description, LatestVersion: promptPO.LatestVersion, @@ -147,6 +148,7 @@ func PromptDO2BasicPO(do *entity.Prompt) *model.PromptBasic { LatestVersion: do.PromptBasic.LatestVersion, CreatedAt: do.PromptBasic.CreatedAt, UpdatedAt: do.PromptBasic.UpdatedAt, + PromptType: PromptTypeDO2PO(do.PromptBasic.PromptType), } } @@ -191,6 +193,8 @@ func PromptDO2CommitPO(do *entity.Prompt) *model.PromptCommit { if do.PromptCommit.PromptDetail.PromptTemplate.Metadata != nil { po.Metadata = ptr.Of(json.Jsonify(do.PromptCommit.PromptDetail.PromptTemplate.Metadata)) } + // 设置has_snippets标志 + po.HasSnippets = do.PromptCommit.PromptDetail.PromptTemplate.HasSnippets } // 序列化ExtInfos到ExtInfo字段 if do.PromptCommit.PromptDetail.ExtInfos != nil { @@ -225,6 +229,7 @@ func PromptDO2DraftPO(promptDO *entity.Prompt) *model.PromptUserDraft { if detailDO.PromptTemplate.Metadata != nil { po.Metadata = ptr.Of(json.Jsonify(detailDO.PromptTemplate.Metadata)) } + po.HasSnippets = detailDO.PromptTemplate.HasSnippets } if detailDO.ModelConfig != nil { po.ModelConfig = ptr.Of(json.Jsonify(detailDO.ModelConfig)) @@ -276,6 +281,7 @@ func PromptUserDraftPO2PromptDetailDO(draftPO *model.PromptUserDraft) *entity.Pr VariableDefs: UnmarshalVariableDefDOs(draftPO.VariableDefs), TemplateType: UnmarshalTemplateType(draftPO.TemplateType), Metadata: UnmarshalMetadata(draftPO.Metadata), + HasSnippets: draftPO.HasSnippets, }, Tools: UnmarshalToolDOs(draftPO.Tools), ToolCallConfig: UnmarshalToolCallConfig(draftPO.ToolCallConfig), @@ -294,6 +300,7 @@ func PromptCommitPO2PromptDetailDO(commitPO *model.PromptCommit) *entity.PromptD VariableDefs: UnmarshalVariableDefDOs(commitPO.VariableDefs), TemplateType: UnmarshalTemplateType(commitPO.TemplateType), Metadata: UnmarshalMetadata(commitPO.Metadata), + HasSnippets: commitPO.HasSnippets, }, Tools: UnmarshalToolDOs(commitPO.Tools), ToolCallConfig: UnmarshalToolCallConfig(commitPO.ToolCallConfig), @@ -379,3 +386,25 @@ func UnmarshalBool(val int32) bool { func MarshalBool(val bool) int32 { return int32(lo.Ternary(val, 1, 0)) } + +func PromptTypePO2DO(po string) entity.PromptType { + switch po { + case string(entity.PromptTypeSnippet): + return entity.PromptTypeSnippet + case string(entity.PromptTypeNormal): + return entity.PromptTypeNormal + default: + return entity.PromptTypeNormal + } +} + +func PromptTypeDO2PO(do entity.PromptType) string { + switch do { + case entity.PromptTypeSnippet: + return string(entity.PromptTypeSnippet) + case entity.PromptTypeNormal: + return string(entity.PromptTypeNormal) + default: + return string(entity.PromptTypeNormal) + } +} diff --git a/backend/modules/prompt/infra/repo/mysql/convertor/manage_test.go b/backend/modules/prompt/infra/repo/mysql/convertor/manage_test.go index 003b072bd..78d400c64 100644 --- a/backend/modules/prompt/infra/repo/mysql/convertor/manage_test.go +++ b/backend/modules/prompt/infra/repo/mysql/convertor/manage_test.go @@ -41,6 +41,7 @@ func TestPromptDO2BasicPO(t *testing.T) { SpaceID: 100, PromptKey: "test_key", PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, DisplayName: "test_name", Description: "test_description", CreatedBy: "test_creator", @@ -61,6 +62,7 @@ func TestPromptDO2BasicPO(t *testing.T) { LatestVersion: "1.0.0", CreatedAt: time.Unix(1000, 0), UpdatedAt: time.Unix(2000, 0), + PromptType: "normal", }, }, } @@ -214,6 +216,7 @@ func TestBatchBasicPO2PromptDO(t *testing.T) { SpaceID: 100, PromptKey: "test_key", PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, DisplayName: "test_name", Description: "test_description", CreatedBy: "test_creator", @@ -280,6 +283,7 @@ func TestPromptPO2DO(t *testing.T) { SpaceID: 100, PromptKey: "test_key", PromptBasic: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, DisplayName: "test_name", Description: "test_description", CreatedBy: "test_creator", @@ -344,6 +348,7 @@ func TestBasicPO2DO(t *testing.T) { UpdatedAt: time.Unix(2000, 0), }, expected: &entity.PromptBasic{ + PromptType: entity.PromptTypeNormal, DisplayName: "test_name", Description: "test_description", CreatedBy: "test_creator", diff --git a/backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_basic.gen.go b/backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_basic.gen.go index 661fc5ab1..3845b063f 100644 --- a/backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_basic.gen.go +++ b/backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_basic.gen.go @@ -14,18 +14,19 @@ const TableNamePromptBasic = "prompt_basic" // PromptBasic Prompt基础表 type PromptBasic struct { - ID int64 `gorm:"column:id;type:bigint(20) unsigned;primaryKey;autoIncrement:true;comment:主键ID" json:"id"` // 主键ID - SpaceID int64 `gorm:"column:space_id;type:bigint(20) unsigned;not null;uniqueIndex:uniq_space_id_prompt_key_deleted_at,priority:1;comment:空间ID" json:"space_id"` // 空间ID - PromptKey string `gorm:"column:prompt_key;type:varchar(128);not null;uniqueIndex:uniq_space_id_prompt_key_deleted_at,priority:2;comment:Prompt key" json:"prompt_key"` // Prompt key - Name string `gorm:"column:name;type:varchar(128);not null;comment:Prompt名称" json:"name"` // Prompt名称 - Description string `gorm:"column:description;type:varchar(1024);not null;comment:描述" json:"description"` // 描述 - CreatedBy string `gorm:"column:created_by;type:varchar(128);not null;comment:创建人" json:"created_by"` // 创建人 - UpdatedBy string `gorm:"column:updated_by;type:varchar(128);not null;comment:更新人" json:"updated_by"` // 更新人 - LatestVersion string `gorm:"column:latest_version;type:varchar(128);not null;comment:最新版本" json:"latest_version"` // 最新版本 - LatestCommitTime *time.Time `gorm:"column:latest_commit_time;type:datetime;comment:最新提交时间" json:"latest_commit_time"` // 最新提交时间 - CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null;index:idx_created_at,priority:1;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间 - UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间 - DeletedAt soft_delete.DeletedAt `gorm:"column:deleted_at;type:bigint(20);not null;uniqueIndex:uniq_space_id_prompt_key_deleted_at,priority:3;column:deleted_at;not null;default:0;softDelete:milli;comment:删除时间" json:"deleted_at"` // 删除时间 + ID int64 `gorm:"column:id;type:bigint(20) unsigned;primaryKey;autoIncrement:true;comment:主键ID" json:"id"` // 主键ID + SpaceID int64 `gorm:"column:space_id;type:bigint(20) unsigned;not null;uniqueIndex:uniq_space_id_prompt_key_deleted_at,priority:1;index:idx_pid_ptype_delat,priority:1;comment:空间ID" json:"space_id"` // 空间ID + PromptKey string `gorm:"column:prompt_key;type:varchar(128);not null;uniqueIndex:uniq_space_id_prompt_key_deleted_at,priority:2;comment:Prompt key" json:"prompt_key"` // Prompt key + Name string `gorm:"column:name;type:varchar(128);not null;comment:Prompt名称" json:"name"` // Prompt名称 + Description string `gorm:"column:description;type:varchar(1024);not null;comment:描述" json:"description"` // 描述 + CreatedBy string `gorm:"column:created_by;type:varchar(128);not null;comment:创建人" json:"created_by"` // 创建人 + UpdatedBy string `gorm:"column:updated_by;type:varchar(128);not null;comment:更新人" json:"updated_by"` // 更新人 + LatestVersion string `gorm:"column:latest_version;type:varchar(128);not null;comment:最新版本" json:"latest_version"` // 最新版本 + LatestCommitTime *time.Time `gorm:"column:latest_commit_time;type:datetime;comment:最新提交时间" json:"latest_commit_time"` // 最新提交时间 + CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null;index:idx_created_at,priority:1;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间 + UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间 + DeletedAt soft_delete.DeletedAt `gorm:"column:deleted_at;type:bigint(20);not null;uniqueIndex:uniq_space_id_prompt_key_deleted_at,priority:3;index:idx_pid_ptype_delat,priority:3;column:deleted_at;not null;default:0;softDelete:milli;comment:删除时间" json:"deleted_at"` // 删除时间 + PromptType string `gorm:"column:prompt_type;type:varchar(64);not null;index:idx_pid_ptype_delat,priority:2;default:normal;comment:Prompt类型" json:"prompt_type"` // Prompt类型 } // TableName PromptBasic's table name diff --git a/backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_commit.gen.go b/backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_commit.gen.go index ab30250ed..65d5c5a00 100644 --- a/backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_commit.gen.go +++ b/backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_commit.gen.go @@ -30,6 +30,7 @@ type PromptCommit struct { ExtInfo *string `gorm:"column:ext_info;type:text;comment:扩展字段" json:"ext_info"` // 扩展字段 CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间 UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间 + HasSnippets bool `gorm:"column:has_snippets;type:tinyint(1);not null;comment:是否包含prompt片段" json:"has_snippets"` // 是否包含prompt片段 } // TableName PromptCommit's table name diff --git a/backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_relation.gen.go b/backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_relation.gen.go new file mode 100644 index 000000000..cfea422b9 --- /dev/null +++ b/backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_relation.gen.go @@ -0,0 +1,29 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNamePromptRelation = "prompt_relation" + +// PromptRelation Prompt关联表 +type PromptRelation struct { + ID int64 `gorm:"column:id;type:bigint(20) unsigned;primaryKey;autoIncrement:true;comment:主键ID" json:"id"` // 主键ID + SpaceID int64 `gorm:"column:space_id;type:bigint(20) unsigned;not null;comment:空间ID" json:"space_id"` // 空间ID + MainPromptID int64 `gorm:"column:main_prompt_id;type:bigint(20) unsigned;not null;index:idx_main_prompt_id_version,priority:1;index:idx_main_prompt_id_user,priority:1;comment:主Prompt ID" json:"main_prompt_id"` // 主Prompt ID + MainPromptVersion string `gorm:"column:main_prompt_version;type:varchar(128);not null;index:idx_main_prompt_id_version,priority:2;comment:主Prompt版本" json:"main_prompt_version"` // 主Prompt版本 + MainDraftUserID string `gorm:"column:main_draft_user_id;type:varchar(128);not null;index:idx_main_prompt_id_user,priority:2;comment:主Prompt草稿Owner" json:"main_draft_user_id"` // 主Prompt草稿Owner + SubPromptID int64 `gorm:"column:sub_prompt_id;type:bigint(20) unsigned;not null;index:idx_sub_prompt_id_version_create_time,priority:1;comment:子Prompt ID" json:"sub_prompt_id"` // 子Prompt ID + SubPromptVersion string `gorm:"column:sub_prompt_version;type:varchar(128);not null;index:idx_sub_prompt_id_version_create_time,priority:2;comment:子Prompt版本" json:"sub_prompt_version"` // 子Prompt版本 + CreateTime time.Time `gorm:"column:create_time;type:datetime;not null;index:idx_sub_prompt_id_version_create_time,priority:3;default:CURRENT_TIMESTAMP;comment:创建时间" json:"create_time"` // 创建时间 + UpdateTime time.Time `gorm:"column:update_time;type:datetime;not null;default:CURRENT_TIMESTAMP;comment:更新时间" json:"update_time"` // 更新时间 +} + +// TableName PromptRelation's table name +func (*PromptRelation) TableName() string { + return TableNamePromptRelation +} diff --git a/backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_user_draft.gen.go b/backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_user_draft.gen.go index 14d1faeb8..9ca4211ec 100644 --- a/backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_user_draft.gen.go +++ b/backend/modules/prompt/infra/repo/mysql/gorm_gen/model/prompt_user_draft.gen.go @@ -31,6 +31,7 @@ type PromptUserDraft struct { CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间 UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间 DeletedAt soft_delete.DeletedAt `gorm:"column:deleted_at;type:bigint(20);not null;uniqueIndex:uniq_prompt_id_user_id_deleted_at,priority:3;column:deleted_at;not null;default:0;softDelete:milli;comment:删除时间" json:"deleted_at"` // 删除时间 + HasSnippets bool `gorm:"column:has_snippets;type:tinyint(1);not null;comment:是否包含prompt片段" json:"has_snippets"` // 是否包含prompt片段 } // TableName PromptUserDraft's table name diff --git a/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/gen.go b/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/gen.go index a76a70af0..f757cbe12 100644 --- a/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/gen.go +++ b/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/gen.go @@ -24,6 +24,7 @@ func Use(db *gorm.DB, opts ...gen.DOOption) *Query { PromptDebugContext: newPromptDebugContext(db, opts...), PromptDebugLog: newPromptDebugLog(db, opts...), PromptLabel: newPromptLabel(db, opts...), + PromptRelation: newPromptRelation(db, opts...), PromptUserDraft: newPromptUserDraft(db, opts...), } } @@ -37,6 +38,7 @@ type Query struct { PromptDebugContext promptDebugContext PromptDebugLog promptDebugLog PromptLabel promptLabel + PromptRelation promptRelation PromptUserDraft promptUserDraft } @@ -51,6 +53,7 @@ func (q *Query) clone(db *gorm.DB) *Query { PromptDebugContext: q.PromptDebugContext.clone(db), PromptDebugLog: q.PromptDebugLog.clone(db), PromptLabel: q.PromptLabel.clone(db), + PromptRelation: q.PromptRelation.clone(db), PromptUserDraft: q.PromptUserDraft.clone(db), } } @@ -72,6 +75,7 @@ func (q *Query) ReplaceDB(db *gorm.DB) *Query { PromptDebugContext: q.PromptDebugContext.replaceDB(db), PromptDebugLog: q.PromptDebugLog.replaceDB(db), PromptLabel: q.PromptLabel.replaceDB(db), + PromptRelation: q.PromptRelation.replaceDB(db), PromptUserDraft: q.PromptUserDraft.replaceDB(db), } } @@ -83,6 +87,7 @@ type queryCtx struct { PromptDebugContext *promptDebugContextDo PromptDebugLog *promptDebugLogDo PromptLabel *promptLabelDo + PromptRelation *promptRelationDo PromptUserDraft *promptUserDraftDo } @@ -94,6 +99,7 @@ func (q *Query) WithContext(ctx context.Context) *queryCtx { PromptDebugContext: q.PromptDebugContext.WithContext(ctx), PromptDebugLog: q.PromptDebugLog.WithContext(ctx), PromptLabel: q.PromptLabel.WithContext(ctx), + PromptRelation: q.PromptRelation.WithContext(ctx), PromptUserDraft: q.PromptUserDraft.WithContext(ctx), } } diff --git a/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_basic.gen.go b/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_basic.gen.go index 1ae2cf42a..c9a43121c 100644 --- a/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_basic.gen.go +++ b/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_basic.gen.go @@ -39,6 +39,7 @@ func newPromptBasic(db *gorm.DB, opts ...gen.DOOption) promptBasic { _promptBasic.CreatedAt = field.NewTime(tableName, "created_at") _promptBasic.UpdatedAt = field.NewTime(tableName, "updated_at") _promptBasic.DeletedAt = field.NewField(tableName, "deleted_at") + _promptBasic.PromptType = field.NewString(tableName, "prompt_type") _promptBasic.fillFieldMap() @@ -62,6 +63,7 @@ type promptBasic struct { CreatedAt field.Time // 创建时间 UpdatedAt field.Time // 更新时间 DeletedAt field.Field // 删除时间 + PromptType field.String // Prompt类型 fieldMap map[string]field.Expr } @@ -90,6 +92,7 @@ func (p *promptBasic) updateTableName(table string) *promptBasic { p.CreatedAt = field.NewTime(table, "created_at") p.UpdatedAt = field.NewTime(table, "updated_at") p.DeletedAt = field.NewField(table, "deleted_at") + p.PromptType = field.NewString(table, "prompt_type") p.fillFieldMap() @@ -116,7 +119,7 @@ func (p *promptBasic) GetFieldByName(fieldName string) (field.OrderExpr, bool) { } func (p *promptBasic) fillFieldMap() { - p.fieldMap = make(map[string]field.Expr, 12) + p.fieldMap = make(map[string]field.Expr, 13) p.fieldMap["id"] = p.ID p.fieldMap["space_id"] = p.SpaceID p.fieldMap["prompt_key"] = p.PromptKey @@ -129,6 +132,7 @@ func (p *promptBasic) fillFieldMap() { p.fieldMap["created_at"] = p.CreatedAt p.fieldMap["updated_at"] = p.UpdatedAt p.fieldMap["deleted_at"] = p.DeletedAt + p.fieldMap["prompt_type"] = p.PromptType } func (p promptBasic) clone(db *gorm.DB) promptBasic { diff --git a/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_commit.gen.go b/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_commit.gen.go index eb6bdb327..bac1f4823 100644 --- a/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_commit.gen.go +++ b/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_commit.gen.go @@ -45,6 +45,7 @@ func newPromptCommit(db *gorm.DB, opts ...gen.DOOption) promptCommit { _promptCommit.ExtInfo = field.NewString(tableName, "ext_info") _promptCommit.CreatedAt = field.NewTime(tableName, "created_at") _promptCommit.UpdatedAt = field.NewTime(tableName, "updated_at") + _promptCommit.HasSnippets = field.NewBool(tableName, "has_snippets") _promptCommit.fillFieldMap() @@ -74,6 +75,7 @@ type promptCommit struct { ExtInfo field.String // 扩展字段 CreatedAt field.Time // 创建时间 UpdatedAt field.Time // 更新时间 + HasSnippets field.Bool // 是否包含prompt片段 fieldMap map[string]field.Expr } @@ -108,6 +110,7 @@ func (p *promptCommit) updateTableName(table string) *promptCommit { p.ExtInfo = field.NewString(table, "ext_info") p.CreatedAt = field.NewTime(table, "created_at") p.UpdatedAt = field.NewTime(table, "updated_at") + p.HasSnippets = field.NewBool(table, "has_snippets") p.fillFieldMap() @@ -136,7 +139,7 @@ func (p *promptCommit) GetFieldByName(fieldName string) (field.OrderExpr, bool) } func (p *promptCommit) fillFieldMap() { - p.fieldMap = make(map[string]field.Expr, 18) + p.fieldMap = make(map[string]field.Expr, 19) p.fieldMap["id"] = p.ID p.fieldMap["space_id"] = p.SpaceID p.fieldMap["prompt_id"] = p.PromptID @@ -155,6 +158,7 @@ func (p *promptCommit) fillFieldMap() { p.fieldMap["ext_info"] = p.ExtInfo p.fieldMap["created_at"] = p.CreatedAt p.fieldMap["updated_at"] = p.UpdatedAt + p.fieldMap["has_snippets"] = p.HasSnippets } func (p promptCommit) clone(db *gorm.DB) promptCommit { diff --git a/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_relation.gen.go b/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_relation.gen.go new file mode 100644 index 000000000..b537bf15a --- /dev/null +++ b/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_relation.gen.go @@ -0,0 +1,364 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package query + +import ( + "context" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + + "gorm.io/gen" + "gorm.io/gen/field" + + "gorm.io/plugin/dbresolver" + + "github.com/coze-dev/coze-loop/backend/modules/prompt/infra/repo/mysql/gorm_gen/model" +) + +func newPromptRelation(db *gorm.DB, opts ...gen.DOOption) promptRelation { + _promptRelation := promptRelation{} + + _promptRelation.promptRelationDo.UseDB(db, opts...) + _promptRelation.promptRelationDo.UseModel(&model.PromptRelation{}) + + tableName := _promptRelation.promptRelationDo.TableName() + _promptRelation.ALL = field.NewAsterisk(tableName) + _promptRelation.ID = field.NewInt64(tableName, "id") + _promptRelation.SpaceID = field.NewInt64(tableName, "space_id") + _promptRelation.MainPromptID = field.NewInt64(tableName, "main_prompt_id") + _promptRelation.MainPromptVersion = field.NewString(tableName, "main_prompt_version") + _promptRelation.MainDraftUserID = field.NewString(tableName, "main_draft_user_id") + _promptRelation.SubPromptID = field.NewInt64(tableName, "sub_prompt_id") + _promptRelation.SubPromptVersion = field.NewString(tableName, "sub_prompt_version") + _promptRelation.CreateTime = field.NewTime(tableName, "create_time") + _promptRelation.UpdateTime = field.NewTime(tableName, "update_time") + + _promptRelation.fillFieldMap() + + return _promptRelation +} + +// promptRelation Prompt关联表 +type promptRelation struct { + promptRelationDo promptRelationDo + + ALL field.Asterisk + ID field.Int64 // 主键ID + SpaceID field.Int64 // 空间ID + MainPromptID field.Int64 // 主Prompt ID + MainPromptVersion field.String // 主Prompt版本 + MainDraftUserID field.String // 主Prompt草稿Owner + SubPromptID field.Int64 // 子Prompt ID + SubPromptVersion field.String // 子Prompt版本 + CreateTime field.Time // 创建时间 + UpdateTime field.Time // 更新时间 + + fieldMap map[string]field.Expr +} + +func (p promptRelation) Table(newTableName string) *promptRelation { + p.promptRelationDo.UseTable(newTableName) + return p.updateTableName(newTableName) +} + +func (p promptRelation) As(alias string) *promptRelation { + p.promptRelationDo.DO = *(p.promptRelationDo.As(alias).(*gen.DO)) + return p.updateTableName(alias) +} + +func (p *promptRelation) updateTableName(table string) *promptRelation { + p.ALL = field.NewAsterisk(table) + p.ID = field.NewInt64(table, "id") + p.SpaceID = field.NewInt64(table, "space_id") + p.MainPromptID = field.NewInt64(table, "main_prompt_id") + p.MainPromptVersion = field.NewString(table, "main_prompt_version") + p.MainDraftUserID = field.NewString(table, "main_draft_user_id") + p.SubPromptID = field.NewInt64(table, "sub_prompt_id") + p.SubPromptVersion = field.NewString(table, "sub_prompt_version") + p.CreateTime = field.NewTime(table, "create_time") + p.UpdateTime = field.NewTime(table, "update_time") + + p.fillFieldMap() + + return p +} + +func (p *promptRelation) WithContext(ctx context.Context) *promptRelationDo { + return p.promptRelationDo.WithContext(ctx) +} + +func (p promptRelation) TableName() string { return p.promptRelationDo.TableName() } + +func (p promptRelation) Alias() string { return p.promptRelationDo.Alias() } + +func (p promptRelation) Columns(cols ...field.Expr) gen.Columns { + return p.promptRelationDo.Columns(cols...) +} + +func (p *promptRelation) GetFieldByName(fieldName string) (field.OrderExpr, bool) { + _f, ok := p.fieldMap[fieldName] + if !ok || _f == nil { + return nil, false + } + _oe, ok := _f.(field.OrderExpr) + return _oe, ok +} + +func (p *promptRelation) fillFieldMap() { + p.fieldMap = make(map[string]field.Expr, 9) + p.fieldMap["id"] = p.ID + p.fieldMap["space_id"] = p.SpaceID + p.fieldMap["main_prompt_id"] = p.MainPromptID + p.fieldMap["main_prompt_version"] = p.MainPromptVersion + p.fieldMap["main_draft_user_id"] = p.MainDraftUserID + p.fieldMap["sub_prompt_id"] = p.SubPromptID + p.fieldMap["sub_prompt_version"] = p.SubPromptVersion + p.fieldMap["create_time"] = p.CreateTime + p.fieldMap["update_time"] = p.UpdateTime +} + +func (p promptRelation) clone(db *gorm.DB) promptRelation { + p.promptRelationDo.ReplaceConnPool(db.Statement.ConnPool) + return p +} + +func (p promptRelation) replaceDB(db *gorm.DB) promptRelation { + p.promptRelationDo.ReplaceDB(db) + return p +} + +type promptRelationDo struct{ gen.DO } + +func (p promptRelationDo) Debug() *promptRelationDo { + return p.withDO(p.DO.Debug()) +} + +func (p promptRelationDo) WithContext(ctx context.Context) *promptRelationDo { + return p.withDO(p.DO.WithContext(ctx)) +} + +func (p promptRelationDo) ReadDB() *promptRelationDo { + return p.Clauses(dbresolver.Read) +} + +func (p promptRelationDo) WriteDB() *promptRelationDo { + return p.Clauses(dbresolver.Write) +} + +func (p promptRelationDo) Session(config *gorm.Session) *promptRelationDo { + return p.withDO(p.DO.Session(config)) +} + +func (p promptRelationDo) Clauses(conds ...clause.Expression) *promptRelationDo { + return p.withDO(p.DO.Clauses(conds...)) +} + +func (p promptRelationDo) Returning(value interface{}, columns ...string) *promptRelationDo { + return p.withDO(p.DO.Returning(value, columns...)) +} + +func (p promptRelationDo) Not(conds ...gen.Condition) *promptRelationDo { + return p.withDO(p.DO.Not(conds...)) +} + +func (p promptRelationDo) Or(conds ...gen.Condition) *promptRelationDo { + return p.withDO(p.DO.Or(conds...)) +} + +func (p promptRelationDo) Select(conds ...field.Expr) *promptRelationDo { + return p.withDO(p.DO.Select(conds...)) +} + +func (p promptRelationDo) Where(conds ...gen.Condition) *promptRelationDo { + return p.withDO(p.DO.Where(conds...)) +} + +func (p promptRelationDo) Order(conds ...field.Expr) *promptRelationDo { + return p.withDO(p.DO.Order(conds...)) +} + +func (p promptRelationDo) Distinct(cols ...field.Expr) *promptRelationDo { + return p.withDO(p.DO.Distinct(cols...)) +} + +func (p promptRelationDo) Omit(cols ...field.Expr) *promptRelationDo { + return p.withDO(p.DO.Omit(cols...)) +} + +func (p promptRelationDo) Join(table schema.Tabler, on ...field.Expr) *promptRelationDo { + return p.withDO(p.DO.Join(table, on...)) +} + +func (p promptRelationDo) LeftJoin(table schema.Tabler, on ...field.Expr) *promptRelationDo { + return p.withDO(p.DO.LeftJoin(table, on...)) +} + +func (p promptRelationDo) RightJoin(table schema.Tabler, on ...field.Expr) *promptRelationDo { + return p.withDO(p.DO.RightJoin(table, on...)) +} + +func (p promptRelationDo) Group(cols ...field.Expr) *promptRelationDo { + return p.withDO(p.DO.Group(cols...)) +} + +func (p promptRelationDo) Having(conds ...gen.Condition) *promptRelationDo { + return p.withDO(p.DO.Having(conds...)) +} + +func (p promptRelationDo) Limit(limit int) *promptRelationDo { + return p.withDO(p.DO.Limit(limit)) +} + +func (p promptRelationDo) Offset(offset int) *promptRelationDo { + return p.withDO(p.DO.Offset(offset)) +} + +func (p promptRelationDo) Scopes(funcs ...func(gen.Dao) gen.Dao) *promptRelationDo { + return p.withDO(p.DO.Scopes(funcs...)) +} + +func (p promptRelationDo) Unscoped() *promptRelationDo { + return p.withDO(p.DO.Unscoped()) +} + +func (p promptRelationDo) Create(values ...*model.PromptRelation) error { + if len(values) == 0 { + return nil + } + return p.DO.Create(values) +} + +func (p promptRelationDo) CreateInBatches(values []*model.PromptRelation, batchSize int) error { + return p.DO.CreateInBatches(values, batchSize) +} + +// Save : !!! underlying implementation is different with GORM +// The method is equivalent to executing the statement: db.Clauses(clause.OnConflict{UpdateAll: true}).Create(values) +func (p promptRelationDo) Save(values ...*model.PromptRelation) error { + if len(values) == 0 { + return nil + } + return p.DO.Save(values) +} + +func (p promptRelationDo) First() (*model.PromptRelation, error) { + if result, err := p.DO.First(); err != nil { + return nil, err + } else { + return result.(*model.PromptRelation), nil + } +} + +func (p promptRelationDo) Take() (*model.PromptRelation, error) { + if result, err := p.DO.Take(); err != nil { + return nil, err + } else { + return result.(*model.PromptRelation), nil + } +} + +func (p promptRelationDo) Last() (*model.PromptRelation, error) { + if result, err := p.DO.Last(); err != nil { + return nil, err + } else { + return result.(*model.PromptRelation), nil + } +} + +func (p promptRelationDo) Find() ([]*model.PromptRelation, error) { + result, err := p.DO.Find() + return result.([]*model.PromptRelation), err +} + +func (p promptRelationDo) FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*model.PromptRelation, err error) { + buf := make([]*model.PromptRelation, 0, batchSize) + err = p.DO.FindInBatches(&buf, batchSize, func(tx gen.Dao, batch int) error { + defer func() { results = append(results, buf...) }() + return fc(tx, batch) + }) + return results, err +} + +func (p promptRelationDo) FindInBatches(result *[]*model.PromptRelation, batchSize int, fc func(tx gen.Dao, batch int) error) error { + return p.DO.FindInBatches(result, batchSize, fc) +} + +func (p promptRelationDo) Attrs(attrs ...field.AssignExpr) *promptRelationDo { + return p.withDO(p.DO.Attrs(attrs...)) +} + +func (p promptRelationDo) Assign(attrs ...field.AssignExpr) *promptRelationDo { + return p.withDO(p.DO.Assign(attrs...)) +} + +func (p promptRelationDo) Joins(fields ...field.RelationField) *promptRelationDo { + for _, _f := range fields { + p = *p.withDO(p.DO.Joins(_f)) + } + return &p +} + +func (p promptRelationDo) Preload(fields ...field.RelationField) *promptRelationDo { + for _, _f := range fields { + p = *p.withDO(p.DO.Preload(_f)) + } + return &p +} + +func (p promptRelationDo) FirstOrInit() (*model.PromptRelation, error) { + if result, err := p.DO.FirstOrInit(); err != nil { + return nil, err + } else { + return result.(*model.PromptRelation), nil + } +} + +func (p promptRelationDo) FirstOrCreate() (*model.PromptRelation, error) { + if result, err := p.DO.FirstOrCreate(); err != nil { + return nil, err + } else { + return result.(*model.PromptRelation), nil + } +} + +func (p promptRelationDo) FindByPage(offset int, limit int) (result []*model.PromptRelation, count int64, err error) { + result, err = p.Offset(offset).Limit(limit).Find() + if err != nil { + return + } + + if size := len(result); 0 < limit && 0 < size && size < limit { + count = int64(size + offset) + return + } + + count, err = p.Offset(-1).Limit(-1).Count() + return +} + +func (p promptRelationDo) ScanByPage(result interface{}, offset int, limit int) (count int64, err error) { + count, err = p.Count() + if err != nil { + return + } + + err = p.Offset(offset).Limit(limit).Scan(result) + return +} + +func (p promptRelationDo) Scan(result interface{}) (err error) { + return p.DO.Scan(result) +} + +func (p promptRelationDo) Delete(models ...*model.PromptRelation) (result gen.ResultInfo, err error) { + return p.DO.Delete(models) +} + +func (p *promptRelationDo) withDO(do gen.Dao) *promptRelationDo { + p.DO = *do.(*gen.DO) + return p +} diff --git a/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_user_draft.gen.go b/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_user_draft.gen.go index b625d760f..5d8d9fbb6 100644 --- a/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_user_draft.gen.go +++ b/backend/modules/prompt/infra/repo/mysql/gorm_gen/query/prompt_user_draft.gen.go @@ -44,6 +44,7 @@ func newPromptUserDraft(db *gorm.DB, opts ...gen.DOOption) promptUserDraft { _promptUserDraft.CreatedAt = field.NewTime(tableName, "created_at") _promptUserDraft.UpdatedAt = field.NewTime(tableName, "updated_at") _promptUserDraft.DeletedAt = field.NewField(tableName, "deleted_at") + _promptUserDraft.HasSnippets = field.NewBool(tableName, "has_snippets") _promptUserDraft.fillFieldMap() @@ -72,6 +73,7 @@ type promptUserDraft struct { CreatedAt field.Time // 创建时间 UpdatedAt field.Time // 更新时间 DeletedAt field.Field // 删除时间 + HasSnippets field.Bool // 是否包含prompt片段 fieldMap map[string]field.Expr } @@ -105,6 +107,7 @@ func (p *promptUserDraft) updateTableName(table string) *promptUserDraft { p.CreatedAt = field.NewTime(table, "created_at") p.UpdatedAt = field.NewTime(table, "updated_at") p.DeletedAt = field.NewField(table, "deleted_at") + p.HasSnippets = field.NewBool(table, "has_snippets") p.fillFieldMap() @@ -133,7 +136,7 @@ func (p *promptUserDraft) GetFieldByName(fieldName string) (field.OrderExpr, boo } func (p *promptUserDraft) fillFieldMap() { - p.fieldMap = make(map[string]field.Expr, 17) + p.fieldMap = make(map[string]field.Expr, 18) p.fieldMap["id"] = p.ID p.fieldMap["space_id"] = p.SpaceID p.fieldMap["prompt_id"] = p.PromptID @@ -151,6 +154,7 @@ func (p *promptUserDraft) fillFieldMap() { p.fieldMap["created_at"] = p.CreatedAt p.fieldMap["updated_at"] = p.UpdatedAt p.fieldMap["deleted_at"] = p.DeletedAt + p.fieldMap["has_snippets"] = p.HasSnippets } func (p promptUserDraft) clone(db *gorm.DB) promptUserDraft { diff --git a/backend/modules/prompt/infra/repo/mysql/mocks/prompt_commit_dao.go b/backend/modules/prompt/infra/repo/mysql/mocks/prompt_commit_dao.go index 530f37e2a..45f92198b 100644 --- a/backend/modules/prompt/infra/repo/mysql/mocks/prompt_commit_dao.go +++ b/backend/modules/prompt/infra/repo/mysql/mocks/prompt_commit_dao.go @@ -122,3 +122,23 @@ func (mr *MockIPromptCommitDAOMockRecorder) MGet(ctx, pairs any, opts ...any) *g varargs := append([]any{ctx, pairs}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGet", reflect.TypeOf((*MockIPromptCommitDAO)(nil).MGet), varargs...) } + +// MGetVersionsByPromptID mocks base method. +func (m *MockIPromptCommitDAO) MGetVersionsByPromptID(ctx context.Context, promptID int64, opts ...db.Option) ([]string, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, promptID} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "MGetVersionsByPromptID", varargs...) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MGetVersionsByPromptID indicates an expected call of MGetVersionsByPromptID. +func (mr *MockIPromptCommitDAOMockRecorder) MGetVersionsByPromptID(ctx, promptID any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, promptID}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetVersionsByPromptID", reflect.TypeOf((*MockIPromptCommitDAO)(nil).MGetVersionsByPromptID), varargs...) +} diff --git a/backend/modules/prompt/infra/repo/mysql/mocks/prompt_relation_dao.go b/backend/modules/prompt/infra/repo/mysql/mocks/prompt_relation_dao.go new file mode 100644 index 000000000..f6d0f9dde --- /dev/null +++ b/backend/modules/prompt/infra/repo/mysql/mocks/prompt_relation_dao.go @@ -0,0 +1,140 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/coze-dev/coze-loop/backend/modules/prompt/infra/repo/mysql (interfaces: IPromptRelationDAO) +// +// Generated by this command: +// +// mockgen -destination=mocks/prompt_relation_dao.go -package=mocks . IPromptRelationDAO +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + db "github.com/coze-dev/coze-loop/backend/infra/db" + mysql "github.com/coze-dev/coze-loop/backend/modules/prompt/infra/repo/mysql" + model "github.com/coze-dev/coze-loop/backend/modules/prompt/infra/repo/mysql/gorm_gen/model" + gomock "go.uber.org/mock/gomock" +) + +// MockIPromptRelationDAO is a mock of IPromptRelationDAO interface. +type MockIPromptRelationDAO struct { + ctrl *gomock.Controller + recorder *MockIPromptRelationDAOMockRecorder + isgomock struct{} +} + +// MockIPromptRelationDAOMockRecorder is the mock recorder for MockIPromptRelationDAO. +type MockIPromptRelationDAOMockRecorder struct { + mock *MockIPromptRelationDAO +} + +// NewMockIPromptRelationDAO creates a new mock instance. +func NewMockIPromptRelationDAO(ctrl *gomock.Controller) *MockIPromptRelationDAO { + mock := &MockIPromptRelationDAO{ctrl: ctrl} + mock.recorder = &MockIPromptRelationDAOMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockIPromptRelationDAO) EXPECT() *MockIPromptRelationDAOMockRecorder { + return m.recorder +} + +// BatchCreate mocks base method. +func (m *MockIPromptRelationDAO) BatchCreate(ctx context.Context, relationPOs []*model.PromptRelation, opts ...db.Option) error { + m.ctrl.T.Helper() + varargs := []any{ctx, relationPOs} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "BatchCreate", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// BatchCreate indicates an expected call of BatchCreate. +func (mr *MockIPromptRelationDAOMockRecorder) BatchCreate(ctx, relationPOs any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, relationPOs}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchCreate", reflect.TypeOf((*MockIPromptRelationDAO)(nil).BatchCreate), varargs...) +} + +// BatchDeleteByIDs mocks base method. +func (m *MockIPromptRelationDAO) BatchDeleteByIDs(ctx context.Context, relationIDs []int64, opts ...db.Option) error { + m.ctrl.T.Helper() + varargs := []any{ctx, relationIDs} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "BatchDeleteByIDs", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// BatchDeleteByIDs indicates an expected call of BatchDeleteByIDs. +func (mr *MockIPromptRelationDAOMockRecorder) BatchDeleteByIDs(ctx, relationIDs any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, relationIDs}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchDeleteByIDs", reflect.TypeOf((*MockIPromptRelationDAO)(nil).BatchDeleteByIDs), varargs...) +} + +// Create mocks base method. +func (m *MockIPromptRelationDAO) Create(ctx context.Context, relationPO *model.PromptRelation, opts ...db.Option) error { + m.ctrl.T.Helper() + varargs := []any{ctx, relationPO} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Create", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Create indicates an expected call of Create. +func (mr *MockIPromptRelationDAOMockRecorder) Create(ctx, relationPO any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, relationPO}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockIPromptRelationDAO)(nil).Create), varargs...) +} + +// DeleteByMainPrompt mocks base method. +func (m *MockIPromptRelationDAO) DeleteByMainPrompt(ctx context.Context, mainPromptID int64, mainPromptVersion, mainDraftUserID string, opts ...db.Option) error { + m.ctrl.T.Helper() + varargs := []any{ctx, mainPromptID, mainPromptVersion, mainDraftUserID} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "DeleteByMainPrompt", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteByMainPrompt indicates an expected call of DeleteByMainPrompt. +func (mr *MockIPromptRelationDAOMockRecorder) DeleteByMainPrompt(ctx, mainPromptID, mainPromptVersion, mainDraftUserID any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, mainPromptID, mainPromptVersion, mainDraftUserID}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteByMainPrompt", reflect.TypeOf((*MockIPromptRelationDAO)(nil).DeleteByMainPrompt), varargs...) +} + +// List mocks base method. +func (m *MockIPromptRelationDAO) List(ctx context.Context, param mysql.ListPromptRelationParam, opts ...db.Option) ([]*model.PromptRelation, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, param} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "List", varargs...) + ret0, _ := ret[0].([]*model.PromptRelation) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// List indicates an expected call of List. +func (mr *MockIPromptRelationDAOMockRecorder) List(ctx, param any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, param}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockIPromptRelationDAO)(nil).List), varargs...) +} diff --git a/backend/modules/prompt/infra/repo/mysql/prompt_basic.go b/backend/modules/prompt/infra/repo/mysql/prompt_basic.go index 3480fa4f6..e71efed09 100644 --- a/backend/modules/prompt/infra/repo/mysql/prompt_basic.go +++ b/backend/modules/prompt/infra/repo/mysql/prompt_basic.go @@ -44,6 +44,8 @@ type ListPromptBasicParam struct { KeyWord string CreatedBys []string CommittedOnly bool + PromptTypes []string // Add prompt type filtering + PromptIDs []int64 Offset int Limit int @@ -175,6 +177,9 @@ func (d *PromptBasicDAOImpl) List(ctx context.Context, param ListPromptBasicPara if len(param.CreatedBys) > 0 { tx = tx.Where(q.PromptBasic.CreatedBy.In(param.CreatedBys...)) } + if len(param.PromptIDs) > 0 { + tx = tx.Where(q.PromptBasic.ID.In(param.PromptIDs...)) + } if !lo.IsEmpty(param.KeyWord) { likeExpr := field.Or( q.PromptBasic.PromptKey.Like(fmt.Sprintf("%%%s%%", param.KeyWord)), @@ -185,6 +190,9 @@ func (d *PromptBasicDAOImpl) List(ctx context.Context, param ListPromptBasicPara if param.CommittedOnly { tx = tx.Where(q.PromptBasic.LatestVersion.Neq("")) } + if len(param.PromptTypes) > 0 { + tx = tx.Where(q.PromptBasic.PromptType.In(param.PromptTypes...)) + } total, err = tx.Count() if err != nil { return nil, 0, errorx.WrapByCode(err, prompterr.CommonMySqlErrorCode) diff --git a/backend/modules/prompt/infra/repo/mysql/prompt_commit.go b/backend/modules/prompt/infra/repo/mysql/prompt_commit.go index b17eb049d..14069af29 100644 --- a/backend/modules/prompt/infra/repo/mysql/prompt_commit.go +++ b/backend/modules/prompt/infra/repo/mysql/prompt_commit.go @@ -27,6 +27,7 @@ type IPromptCommitDAO interface { Get(ctx context.Context, promptID int64, commitVersion string, opts ...db.Option) (promptCommitPO *model.PromptCommit, err error) MGet(ctx context.Context, pairs []PromptIDCommitVersionPair, opts ...db.Option) (pairCommitPOMap map[PromptIDCommitVersionPair]*model.PromptCommit, err error) List(ctx context.Context, param ListCommitParam, opts ...db.Option) (commitPOs []*model.PromptCommit, err error) + MGetVersionsByPromptID(ctx context.Context, promptID int64, opts ...db.Option) (versions []string, err error) } type ListCommitParam struct { @@ -154,3 +155,32 @@ func (d *PromptCommitDAOImpl) List(ctx context.Context, param ListCommitParam, o } return commitPOs, nil } + +func (d *PromptCommitDAOImpl) MGetVersionsByPromptID(ctx context.Context, promptID int64, opts ...db.Option) (versions []string, err error) { + if promptID <= 0 { + return nil, errorx.New("promptID is invalid, promptID = %d", promptID) + } + if d.writeTracker.CheckWriteFlagByID(ctx, platestwrite.ResourceTypePromptCommit, promptID) { + opts = append(opts, db.WithMaster()) + } + + q := query.Use(d.db.NewSession(ctx, opts...)) + tx := q.WithContext(ctx).PromptCommit + tx = tx.Select(q.PromptCommit.Version) + tx = tx.Where(q.PromptCommit.PromptID.Eq(promptID)) + commitPOs, err := tx.Find() + if err != nil { + return nil, errorx.WrapByCode(err, prompterr.CommonMySqlErrorCode) + } + if len(commitPOs) == 0 { + return nil, nil + } + versions = make([]string, 0, len(commitPOs)) + for _, po := range commitPOs { + if po == nil || po.Version == "" { + continue + } + versions = append(versions, po.Version) + } + return versions, nil +} diff --git a/backend/modules/prompt/infra/repo/mysql/prompt_relation.go b/backend/modules/prompt/infra/repo/mysql/prompt_relation.go new file mode 100644 index 000000000..5c46cbf06 --- /dev/null +++ b/backend/modules/prompt/infra/repo/mysql/prompt_relation.go @@ -0,0 +1,174 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package mysql + +import ( + "context" + "strconv" + + "github.com/coze-dev/coze-loop/backend/infra/db" + "github.com/coze-dev/coze-loop/backend/infra/platestwrite" + "github.com/coze-dev/coze-loop/backend/infra/redis" + "github.com/coze-dev/coze-loop/backend/modules/prompt/infra/repo/mysql/gorm_gen/model" + "github.com/coze-dev/coze-loop/backend/modules/prompt/infra/repo/mysql/gorm_gen/query" + "github.com/coze-dev/coze-loop/backend/pkg/errorx" +) + +//go:generate mockgen -destination=mocks/prompt_relation_dao.go -package=mocks . IPromptRelationDAO +type IPromptRelationDAO interface { + // BatchCreate creates multiple prompt relations in batch + BatchCreate(ctx context.Context, relationPOs []*model.PromptRelation, opts ...db.Option) error + + // DeleteByMainPrompt deletes all relations for a main prompt + DeleteByMainPrompt(ctx context.Context, mainPromptID int64, mainPromptVersion string, mainDraftUserID string, opts ...db.Option) error + + // BatchDeleteByIDs deletes relations by their IDs + BatchDeleteByIDs(ctx context.Context, relationIDs []int64, opts ...db.Option) error + + // List lists prompt relations with optional filtering parameters + List(ctx context.Context, param ListPromptRelationParam, opts ...db.Option) ([]*model.PromptRelation, error) +} + +// ListPromptRelationParam unified parameter for listing prompt relations +// All fields are optional - only non-zero values will be used for filtering +type ListPromptRelationParam struct { + // Main prompt filtering + MainPromptID *int64 + MainPromptVersions []string + MainDraftUserID *string + + // Sub prompt filtering + SubPromptID *int64 + SubPromptVersions []string + + // Pagination + Limit *int + Offset *int +} + +func NewPromptRelationDAO(db db.Provider, redisCli redis.Cmdable) IPromptRelationDAO { + return &PromptRelationDAOImpl{ + db: db, + writeTracker: platestwrite.NewLatestWriteTracker(redisCli), + } +} + +type PromptRelationDAOImpl struct { + db db.Provider + writeTracker platestwrite.ILatestWriteTracker +} + +func (d *PromptRelationDAOImpl) BatchCreate(ctx context.Context, relationPOs []*model.PromptRelation, opts ...db.Option) error { + if len(relationPOs) == 0 { + return nil + } + + q := query.Use(d.db.NewSession(ctx, opts...)).WithContext(ctx) + err := q.PromptRelation.CreateInBatches(relationPOs, len(relationPOs)) + if err != nil { + return err + } + + // 批量设置写标志,处理主从延迟 + mainPromptIDs := make(map[int64]bool) + subPromptIDs := make(map[int64]bool) + for _, relationPO := range relationPOs { + if relationPO != nil { + mainPromptIDs[relationPO.MainPromptID] = true + subPromptIDs[relationPO.SubPromptID] = true + d.writeTracker.SetWriteFlag(ctx, platestwrite.ResourceTypePromptRelation, relationPO.ID) + } + } + for mainPromptID := range mainPromptIDs { + d.writeTracker.SetWriteFlag(ctx, platestwrite.ResourceTypePromptRelation, 0, + platestwrite.SetWithSearchParam(strconv.FormatInt(mainPromptID, 10))) + } + for subPromptID := range subPromptIDs { + d.writeTracker.SetWriteFlag(ctx, platestwrite.ResourceTypePromptRelation, 0, + platestwrite.SetWithSearchParam(strconv.FormatInt(subPromptID, 10))) + } + + return nil +} + +func (d *PromptRelationDAOImpl) DeleteByMainPrompt(ctx context.Context, mainPromptID int64, mainPromptVersion string, mainDraftUserID string, opts ...db.Option) error { + if mainPromptID <= 0 { + return errorx.New("mainPromptID is invalid, mainPromptID = %d", mainPromptID) + } + + q := query.Use(d.db.NewSession(ctx, opts...)) + tx := q.WithContext(ctx).PromptRelation + tx = tx.Where( + q.PromptRelation.MainPromptID.Eq(mainPromptID), + q.PromptRelation.MainPromptVersion.Eq(mainPromptVersion), + q.PromptRelation.MainDraftUserID.Eq(mainDraftUserID), + ) + _, err := tx.Delete() + if err != nil { + return err + } + + // 设置写标志,处理主从延迟 + d.writeTracker.SetWriteFlag(ctx, platestwrite.ResourceTypePromptRelation, 0, + platestwrite.SetWithSearchParam(strconv.FormatInt(mainPromptID, 10))) + + return nil +} + +func (d *PromptRelationDAOImpl) BatchDeleteByIDs(ctx context.Context, relationIDs []int64, opts ...db.Option) error { + if len(relationIDs) == 0 { + return nil + } + + q := query.Use(d.db.NewSession(ctx, opts...)) + _, err := q.PromptRelation.WithContext(ctx).Where(q.PromptRelation.ID.In(relationIDs...)).Delete() + if err != nil { + return err + } + + // 批量设置写标志,处理主从延迟 + for _, relationID := range relationIDs { + d.writeTracker.SetWriteFlag(ctx, platestwrite.ResourceTypePromptRelation, relationID) + } + + return nil +} + +func (d *PromptRelationDAOImpl) List(ctx context.Context, param ListPromptRelationParam, opts ...db.Option) ([]*model.PromptRelation, error) { + // 检查主从延迟写标志 + if param.MainPromptID != nil && d.writeTracker.CheckWriteFlagBySearchParam(ctx, + platestwrite.ResourceTypePromptRelation, strconv.FormatInt(*param.MainPromptID, 10)) { + opts = append(opts, db.WithMaster()) + } + + q := query.Use(d.db.NewSession(ctx, opts...)) + tx := q.WithContext(ctx).PromptRelation + + // Apply filters only when parameters are provided + if param.MainPromptID != nil { + tx = tx.Where(q.PromptRelation.MainPromptID.Eq(*param.MainPromptID)) + } + if len(param.MainPromptVersions) > 0 { + tx = tx.Where(q.PromptRelation.MainPromptVersion.In(param.MainPromptVersions...)) + } + if param.MainDraftUserID != nil { + tx = tx.Where(q.PromptRelation.MainDraftUserID.Eq(*param.MainDraftUserID)) + } + if param.SubPromptID != nil { + tx = tx.Where(q.PromptRelation.SubPromptID.Eq(*param.SubPromptID)) + } + if len(param.SubPromptVersions) > 0 { + tx = tx.Where(q.PromptRelation.SubPromptVersion.In(param.SubPromptVersions...)) + } + + // Apply pagination if provided + if param.Limit != nil && *param.Limit > 0 { + tx = tx.Limit(*param.Limit) + } + if param.Offset != nil && *param.Offset > 0 { + tx = tx.Offset(*param.Offset) + } + + return tx.Find() +} diff --git a/backend/modules/prompt/infra/repo/mysql/prompt_user_draft.go b/backend/modules/prompt/infra/repo/mysql/prompt_user_draft.go index ff0df1e1e..b21e9653e 100644 --- a/backend/modules/prompt/infra/repo/mysql/prompt_user_draft.go +++ b/backend/modules/prompt/infra/repo/mysql/prompt_user_draft.go @@ -153,6 +153,7 @@ func (d *PromptUserDraftDAOImpl) Update(ctx context.Context, promptDraftPO *mode q.PromptUserDraft.VariableDefs.ColumnName().String(): promptDraftPO.VariableDefs, q.PromptUserDraft.Metadata.ColumnName().String(): promptDraftPO.Metadata, q.PromptUserDraft.IsDraftEdited.ColumnName().String(): promptDraftPO.IsDraftEdited, + q.PromptUserDraft.HasSnippets.ColumnName().String(): promptDraftPO.HasSnippets, }) if err != nil { return errorx.WrapByCode(err, prompterr.CommonMySqlErrorCode) diff --git a/backend/script/gorm_gen/generate.go b/backend/script/gorm_gen/generate.go index 4f2738273..a26dc2d5e 100644 --- a/backend/script/gorm_gen/generate.go +++ b/backend/script/gorm_gen/generate.go @@ -68,7 +68,7 @@ func generateForPrompt(db *gorm.DB) { }))) } - for _, table := range []string{"prompt_commit"} { + for _, table := range []string{"prompt_commit", "prompt_relation"} { models = append(models, g.GenerateModel(table, gen.FieldGORMTag("*", func(tag field.GormTag) field.GormTag { return tag.Set("charset=utf8mb4") diff --git a/idl/thrift/coze/loop/prompt/coze.loop.prompt.manage.thrift b/idl/thrift/coze/loop/prompt/coze.loop.prompt.manage.thrift index 0acec9b87..9608fe966 100644 --- a/idl/thrift/coze/loop/prompt/coze.loop.prompt.manage.thrift +++ b/idl/thrift/coze/loop/prompt/coze.loop.prompt.manage.thrift @@ -20,6 +20,8 @@ service PromptManageService { GetPromptResponse GetPrompt(1: GetPromptRequest request) (api.get = '/api/prompt/v1/prompts/:prompt_id') BatchGetPromptResponse BatchGetPrompt(1: BatchGetPromptRequest request) ListPromptResponse ListPrompt(1: ListPromptRequest request) (api.post = '/api/prompt/v1/prompts/list') + // 查询片段的引用记录 + ListParentPromptResponse ListParentPrompt (1: ListParentPromptRequest request) (api.post = '/api/prompt/v1/prompts/list_parent') // 改 UpdatePromptResponse UpdatePrompt(1: UpdatePromptRequest request) (api.put = '/api/prompt/v1/prompts/:prompt_id') @@ -49,6 +51,7 @@ struct CreatePromptRequest { 11: optional string prompt_name (vt.not_nil="true", vt.min_size="1") 12: optional string prompt_key (vt.not_nil="true", vt.min_size="1") 13: optional string prompt_description + 14: optional prompt.PromptType prompt_type 21: optional prompt.PromptDetail draft_detail @@ -95,6 +98,7 @@ struct GetPromptRequest { 21: optional bool with_draft (api.query="with_draft") 31: optional bool with_default_config (api.query="with_default_config") + 32: optional bool expand_snippet (api.query="expand_snippet") // 是否展开子片段,true:展开 255: optional base.Base Base } @@ -102,6 +106,7 @@ struct GetPromptResponse { 1: optional prompt.Prompt prompt 11: optional prompt.PromptDetail default_config + 12: optional i32 total_parent_references // [片段]被引用的总数 255: optional base.BaseResp BaseResp } @@ -136,6 +141,7 @@ struct ListPromptRequest { 11: optional string key_word 12: optional list created_bys 13: optional bool committed_only + 14: optional list filter_prompt_types // 向前兼容,如果不传,默认查询normal类型的Prompt 127: optional i32 page_num (vt.not_nil="true", vt.gt="0") 128: optional i32 page_size (vt.not_nil="true", vt.gt="0", vt.le="100") @@ -198,6 +204,7 @@ struct CommitDraftResponse { // 搜索Prompt提交版本 struct ListCommitRequest { 1: optional i64 prompt_id (api.path='prompt_id', api.js_conv='true', vt.not_nil='true', vt.gt='0', go.tag='json:"prompt_id"') + 2: optional bool with_commit_detail (api.query="with_commit_detail") // 是否查询详情 127: optional i32 page_size (vt.not_nil="true", vt.gt="0") 128: optional string page_token @@ -208,6 +215,8 @@ struct ListCommitRequest { struct ListCommitResponse { 1: optional list prompt_commit_infos 2: optional map> commit_version_label_mapping + 3: optional map parent_references_mapping // key: version, value:被引用数 + 4: optional map prompt_commit_detail_mapping // key:version, value:PromptDetail 11: optional list users @@ -288,3 +297,17 @@ struct UpdateCommitLabelsRequest { struct UpdateCommitLabelsResponse { 255: optional base.BaseResp BaseResp } + +struct ListParentPromptRequest { + 1: optional i64 workspace_id (api.js_conv='true', vt.not_nil='true', vt.gt='0', go.tag='json:"workspace_id"') + 2: optional i64 prompt_id (api.js_conv='true', vt.not_nil='true', vt.gt='0', go.tag='json:"prompt_id"') + 3: optional list commit_versions // 片段版本,不传则表示查询所有版本的引用记录 + + 255: optional base.Base Base +} + +struct ListParentPromptResponse { + 1: optional map> parent_prompts // 不同片段版本被引用的父prompt记录 + + 255: optional base.BaseResp BaseResp +} \ No newline at end of file diff --git a/idl/thrift/coze/loop/prompt/domain/prompt.thrift b/idl/thrift/coze/loop/prompt/domain/prompt.thrift index 34a6bc443..1d798f3b8 100644 --- a/idl/thrift/coze/loop/prompt/domain/prompt.thrift +++ b/idl/thrift/coze/loop/prompt/domain/prompt.thrift @@ -18,9 +18,14 @@ struct PromptBasic { 6: optional i64 created_at (api.js_conv="true", go.tag='json:"created_at"') 7: optional i64 updated_at (api.js_conv="true", go.tag='json:"updated_at"') 8: optional i64 latest_committed_at (api.js_conv="true", go.tag='json:"latest_committed_at"') + 9: optional PromptType prompt_type } +typedef string PromptType (ts.enum="true") +const PromptType PromptType_Normal = "normal" +const PromptType PromptType_Snippet = "snippet" + struct PromptCommit { 1: optional PromptDetail detail 2: optional CommitInfo commit_info @@ -61,6 +66,8 @@ struct PromptTemplate { 1: optional TemplateType template_type 2: optional list messages 3: optional list variable_defs + 4: optional bool has_snippet + 5: optional list snippets 100: optional map metadata } @@ -283,3 +290,11 @@ const Scenario Scenario_EvalTarget = "eval_target" struct OverridePromptParams { 1: optional ModelConfig model_config } + +struct PromptCommitVersions { + 1: optional i64 id (api.js_conv="true", go.tag='json:"id"') + 2: optional i64 workspace_id (api.js_conv="true", go.tag='json:"workspace_id"') + 3: optional string prompt_key + 4: optional PromptBasic prompt_basic + 5: optional list commit_versions +} \ No newline at end of file diff --git a/release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_basic.sql b/release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_basic.sql index f78a7eabc..55a27c93d 100644 --- a/release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_basic.sql +++ b/release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_basic.sql @@ -13,9 +13,11 @@ CREATE TABLE IF NOT EXISTS `prompt_basic` `created_at` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', `updated_at` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', `deleted_at` bigint NOT NULL DEFAULT '0' COMMENT '删除时间', + `prompt_type` varchar(64) NOT NULL DEFAULT 'normal' COMMENT 'Prompt类型', PRIMARY KEY (`id`), UNIQUE KEY `uniq_space_id_prompt_key_deleted_at` (`space_id`, `prompt_key`, `deleted_at`), - KEY `idx_created_at` (`created_at`) USING BTREE + KEY `idx_created_at` (`created_at`) USING BTREE, + KEY `idx_pid_ptype_delat` (`space_id`, `prompt_type`, `deleted_at`) USING BTREE ) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4 COLLATE = utf8mb4_general_ci COMMENT ='Prompt基础表'; \ No newline at end of file diff --git a/release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_commit.sql b/release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_commit.sql index a8b4623c8..22046c5e8 100644 --- a/release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_commit.sql +++ b/release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_commit.sql @@ -18,6 +18,7 @@ CREATE TABLE IF NOT EXISTS `prompt_commit` `ext_info` text COLLATE utf8mb4_general_ci COMMENT '扩展字段', `created_at` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', `updated_at` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', + `has_snippets` tinyint(1) NOT NULL DEFAULT 0 COMMENT '是否包含prompt片段', PRIMARY KEY (`id`), UNIQUE KEY `uniq_prompt_id_version` (`prompt_id`, `version`), KEY `idx_prompt_key_version` (`prompt_key`, `version`) USING BTREE diff --git a/release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_relation.sql b/release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_relation.sql new file mode 100644 index 000000000..60844d238 --- /dev/null +++ b/release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_relation.sql @@ -0,0 +1,15 @@ +CREATE TABLE IF NOT EXISTS `prompt_relation` ( + `id` bigint unsigned NOT NULL AUTO_INCREMENT COMMENT '主键ID', + `space_id` bigint unsigned NOT NULL COMMENT '空间ID', + `main_prompt_id` bigint unsigned NOT NULL COMMENT '主Prompt ID', + `main_prompt_version` varchar(128) NOT NULL DEFAULT '' COMMENT '主Prompt版本', + `main_draft_user_id` varchar(128) NOT NULL DEFAULT '' COMMENT '主Prompt草稿Owner', + `sub_prompt_id` bigint unsigned NOT NULL COMMENT '子Prompt ID', + `sub_prompt_version` varchar(128) NOT NULL DEFAULT '' COMMENT '子Prompt版本', + `create_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', + `update_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', + PRIMARY KEY (`id`), + KEY `idx_main_prompt_id_version` (`main_prompt_id`,`main_prompt_version`) COMMENT '主prompt_id_版本', + KEY `idx_main_prompt_id_user` (`main_prompt_id`,`main_draft_user_id`) COMMENT '主prompt_id_user', + KEY `idx_sub_prompt_id_version_create_time` (`sub_prompt_id`,`sub_prompt_version`, `create_time`) COMMENT '子prompt_id_版本' +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='Prompt关联表'; \ No newline at end of file diff --git a/release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_user_draft.sql b/release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_user_draft.sql index 804bf143e..cf9d1ed7f 100644 --- a/release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_user_draft.sql +++ b/release/deployment/docker-compose/bootstrap/mysql-init/init-sql/prompt_user_draft.sql @@ -17,6 +17,7 @@ CREATE TABLE IF NOT EXISTS `prompt_user_draft` `created_at` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', `updated_at` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', `deleted_at` bigint NOT NULL DEFAULT '0' COMMENT '删除时间', + `has_snippets` tinyint(1) NOT NULL DEFAULT 0 COMMENT '是否包含prompt片段', PRIMARY KEY (`id`), UNIQUE KEY `uniq_prompt_id_user_id_deleted_at` (`prompt_id`, `user_id`, `deleted_at`), KEY `idx_prompt_id_user_id` (`prompt_id`, `user_id`) diff --git a/release/deployment/docker-compose/bootstrap/mysql-init/patch-sql/prompt_basic_alter.sql b/release/deployment/docker-compose/bootstrap/mysql-init/patch-sql/prompt_basic_alter.sql new file mode 100644 index 000000000..a51a2becf --- /dev/null +++ b/release/deployment/docker-compose/bootstrap/mysql-init/patch-sql/prompt_basic_alter.sql @@ -0,0 +1,2 @@ +ALTER TABLE `prompt_basic` ADD COLUMN `prompt_type` varchar(64) NOT NULL DEFAULT 'normal' COMMENT 'Prompt类型'; +ALTER TABLE `prompt_basic` ADD KEY `idx_pid_ptype_delat` (`space_id`, `prompt_type`, `deleted_at`) USING BTREE; \ No newline at end of file diff --git a/release/deployment/docker-compose/bootstrap/mysql-init/patch-sql/prompt_commit_alter.sql b/release/deployment/docker-compose/bootstrap/mysql-init/patch-sql/prompt_commit_alter.sql index 4b622b352..d4919bfee 100644 --- a/release/deployment/docker-compose/bootstrap/mysql-init/patch-sql/prompt_commit_alter.sql +++ b/release/deployment/docker-compose/bootstrap/mysql-init/patch-sql/prompt_commit_alter.sql @@ -1,2 +1,3 @@ ALTER TABLE `prompt_commit` ADD COLUMN `ext_info` text COLLATE utf8mb4_general_ci COMMENT 'Extended information field'; ALTER TABLE `prompt_commit` ADD COLUMN `metadata` text COLLATE utf8mb4_general_ci COMMENT 'Template metadata field'; +ALTER TABLE `prompt_commit` ADD COLUMN `has_snippets` tinyint(1) NOT NULL DEFAULT 0 COMMENT '是否包含prompt片段'; \ No newline at end of file diff --git a/release/deployment/docker-compose/bootstrap/mysql-init/patch-sql/prompt_user_draft_alter.sql b/release/deployment/docker-compose/bootstrap/mysql-init/patch-sql/prompt_user_draft_alter.sql index 54323e1e0..8ecb3d12e 100644 --- a/release/deployment/docker-compose/bootstrap/mysql-init/patch-sql/prompt_user_draft_alter.sql +++ b/release/deployment/docker-compose/bootstrap/mysql-init/patch-sql/prompt_user_draft_alter.sql @@ -1,2 +1,3 @@ ALTER TABLE `prompt_user_draft` ADD COLUMN `ext_info` text COLLATE utf8mb4_general_ci COMMENT 'Extended information field'; ALTER TABLE `prompt_user_draft` ADD COLUMN `metadata` text COLLATE utf8mb4_general_ci COMMENT 'Template metadata field'; +ALTER TABLE `prompt_user_draft` ADD COLUMN `has_snippets` tinyint(1) NOT NULL DEFAULT 0 COMMENT '是否包含prompt片段'; \ No newline at end of file diff --git a/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_basic.sql b/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_basic.sql index f78a7eabc..d461043f5 100644 --- a/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_basic.sql +++ b/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_basic.sql @@ -13,9 +13,11 @@ CREATE TABLE IF NOT EXISTS `prompt_basic` `created_at` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', `updated_at` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', `deleted_at` bigint NOT NULL DEFAULT '0' COMMENT '删除时间', + `prompt_type` varchar(64) NOT NULL DEFAULT 'normal' COMMENT 'Prompt类型', PRIMARY KEY (`id`), UNIQUE KEY `uniq_space_id_prompt_key_deleted_at` (`space_id`, `prompt_key`, `deleted_at`), - KEY `idx_created_at` (`created_at`) USING BTREE -) ENGINE = InnoDB - DEFAULT CHARSET = utf8mb4 - COLLATE = utf8mb4_general_ci COMMENT ='Prompt基础表'; \ No newline at end of file + KEY `idx_created_at` (`created_at`) USING BTREE, + KEY `idx_pid_ptype_delat` (`space_id`, `prompt_type`, `deleted_at`) USING BTREE + ) ENGINE = InnoDB + DEFAULT CHARSET = utf8mb4 + COLLATE = utf8mb4_general_ci COMMENT ='Prompt基础表'; \ No newline at end of file diff --git a/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_basic_alter.sql b/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_basic_alter.sql new file mode 100644 index 000000000..a51a2becf --- /dev/null +++ b/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_basic_alter.sql @@ -0,0 +1,2 @@ +ALTER TABLE `prompt_basic` ADD COLUMN `prompt_type` varchar(64) NOT NULL DEFAULT 'normal' COMMENT 'Prompt类型'; +ALTER TABLE `prompt_basic` ADD KEY `idx_pid_ptype_delat` (`space_id`, `prompt_type`, `deleted_at`) USING BTREE; \ No newline at end of file diff --git a/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_commit.sql b/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_commit.sql index a8b4623c8..bbe022bad 100644 --- a/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_commit.sql +++ b/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_commit.sql @@ -18,9 +18,10 @@ CREATE TABLE IF NOT EXISTS `prompt_commit` `ext_info` text COLLATE utf8mb4_general_ci COMMENT '扩展字段', `created_at` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', `updated_at` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', + `has_snippets` tinyint(1) NOT NULL DEFAULT 0 COMMENT '是否包含prompt片段', PRIMARY KEY (`id`), UNIQUE KEY `uniq_prompt_id_version` (`prompt_id`, `version`), KEY `idx_prompt_key_version` (`prompt_key`, `version`) USING BTREE -) ENGINE = InnoDB - DEFAULT CHARSET = utf8mb4 - COLLATE = utf8mb4_general_ci COMMENT ='Commit表'; \ No newline at end of file + ) ENGINE = InnoDB + DEFAULT CHARSET = utf8mb4 + COLLATE = utf8mb4_general_ci COMMENT ='Commit表'; \ No newline at end of file diff --git a/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_commit_alter.sql b/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_commit_alter.sql index 4b622b352..d4919bfee 100644 --- a/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_commit_alter.sql +++ b/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_commit_alter.sql @@ -1,2 +1,3 @@ ALTER TABLE `prompt_commit` ADD COLUMN `ext_info` text COLLATE utf8mb4_general_ci COMMENT 'Extended information field'; ALTER TABLE `prompt_commit` ADD COLUMN `metadata` text COLLATE utf8mb4_general_ci COMMENT 'Template metadata field'; +ALTER TABLE `prompt_commit` ADD COLUMN `has_snippets` tinyint(1) NOT NULL DEFAULT 0 COMMENT '是否包含prompt片段'; \ No newline at end of file diff --git a/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_relation.sql b/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_relation.sql new file mode 100644 index 000000000..60844d238 --- /dev/null +++ b/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_relation.sql @@ -0,0 +1,15 @@ +CREATE TABLE IF NOT EXISTS `prompt_relation` ( + `id` bigint unsigned NOT NULL AUTO_INCREMENT COMMENT '主键ID', + `space_id` bigint unsigned NOT NULL COMMENT '空间ID', + `main_prompt_id` bigint unsigned NOT NULL COMMENT '主Prompt ID', + `main_prompt_version` varchar(128) NOT NULL DEFAULT '' COMMENT '主Prompt版本', + `main_draft_user_id` varchar(128) NOT NULL DEFAULT '' COMMENT '主Prompt草稿Owner', + `sub_prompt_id` bigint unsigned NOT NULL COMMENT '子Prompt ID', + `sub_prompt_version` varchar(128) NOT NULL DEFAULT '' COMMENT '子Prompt版本', + `create_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', + `update_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', + PRIMARY KEY (`id`), + KEY `idx_main_prompt_id_version` (`main_prompt_id`,`main_prompt_version`) COMMENT '主prompt_id_版本', + KEY `idx_main_prompt_id_user` (`main_prompt_id`,`main_draft_user_id`) COMMENT '主prompt_id_user', + KEY `idx_sub_prompt_id_version_create_time` (`sub_prompt_id`,`sub_prompt_version`, `create_time`) COMMENT '子prompt_id_版本' +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='Prompt关联表'; \ No newline at end of file diff --git a/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_user_draft.sql b/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_user_draft.sql index 804bf143e..be460369f 100644 --- a/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_user_draft.sql +++ b/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_user_draft.sql @@ -17,9 +17,10 @@ CREATE TABLE IF NOT EXISTS `prompt_user_draft` `created_at` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', `updated_at` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', `deleted_at` bigint NOT NULL DEFAULT '0' COMMENT '删除时间', + `has_snippets` tinyint(1) NOT NULL DEFAULT 0 COMMENT '是否包含prompt片段', PRIMARY KEY (`id`), UNIQUE KEY `uniq_prompt_id_user_id_deleted_at` (`prompt_id`, `user_id`, `deleted_at`), KEY `idx_prompt_id_user_id` (`prompt_id`, `user_id`) -) ENGINE = InnoDB - DEFAULT CHARSET = utf8mb4 - COLLATE = utf8mb4_general_ci COMMENT ='Draft表'; \ No newline at end of file + ) ENGINE = InnoDB + DEFAULT CHARSET = utf8mb4 + COLLATE = utf8mb4_general_ci COMMENT ='Draft表'; \ No newline at end of file diff --git a/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_user_draft_alter.sql b/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_user_draft_alter.sql index 54323e1e0..8ecb3d12e 100644 --- a/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_user_draft_alter.sql +++ b/release/deployment/helm-chart/charts/app/bootstrap/init/mysql/init-sql/prompt_user_draft_alter.sql @@ -1,2 +1,3 @@ ALTER TABLE `prompt_user_draft` ADD COLUMN `ext_info` text COLLATE utf8mb4_general_ci COMMENT 'Extended information field'; ALTER TABLE `prompt_user_draft` ADD COLUMN `metadata` text COLLATE utf8mb4_general_ci COMMENT 'Template metadata field'; +ALTER TABLE `prompt_user_draft` ADD COLUMN `has_snippets` tinyint(1) NOT NULL DEFAULT 0 COMMENT '是否包含prompt片段'; \ No newline at end of file