Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions ir/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# IR 模型库

该目录实现了 PolyLLM 使用的中间表示(IR)层,核心目标是提供供应商无关的请求、响应与流式事件模型,以便在不同 LLM API 之间做适配。

## 功能概览

- `request.go`:定义指令、消息、块、采样参数、约束、结构化输出等数据结构。
- `events.go`:描述流式事件模型,支持块级增量。
- `builder.go`:提供构造请求的便捷函数。
- `validate.go`:对请求与事件进行校验,包含可配置的范围检查。
- `normalize.go`:规范化请求,提供去重、合并文本等能力。
- `aggregate.go`:将流式事件聚合为非流式响应,可选择直写文本。
- `meta.go`、`errors.go`:通用元数据及错误定义。

库本身不依赖外部包,可单独在其他项目中复用。

## 示例

```go
req := ir.NewRequest("example-model")
req.AppendDirective(ir.NewDirective(ir.DirectiveSystem, ir.NewTextBlock("d1", "You are helpful")))
req.Messages = []ir.Message{
ir.NewMessage(ir.RoleUser, "m1", ir.NewTextBlock("b1", "Hello")),
}
if err := ir.ValidateRequest(req, ir.WithTemperatureRange(0, 2)); err != nil {
panic(err)
}
```

## 许可

与仓库其他部分一致,遵循 MIT 许可证。
276 changes: 276 additions & 0 deletions ir/aggregate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
package ir

import (
"bytes"
"errors"
"fmt"
"io"
)

type aggregatorOptions struct {
writer io.Writer
maxBuffer int
onDelta func(Ref, string)
keepBlocks bool
}

// AggregatorOption 配置聚合器。
type AggregatorOption func(*aggregatorOptions)

// WithWriter 指定增量写入目标。
func WithWriter(w io.Writer) AggregatorOption {
return func(o *aggregatorOptions) {
o.writer = w
}
}

// WithMaxBufferBytes 限制内存缓冲大小。
func WithMaxBufferBytes(n int) AggregatorOption {
return func(o *aggregatorOptions) {
if n < 0 {
n = 0
}
o.maxBuffer = n
}
}

// WithOnDelta 注册文本增量回调。
func WithOnDelta(cb func(ref Ref, s string)) AggregatorOption {
return func(o *aggregatorOptions) {
o.onDelta = cb
}
}

// WithKeepBlocks 控制是否保留块信息。
func WithKeepBlocks(keep bool) AggregatorOption {
return func(o *aggregatorOptions) {
o.keepBlocks = keep
}
}

// Aggregator 将事件序列聚合为 Response。
type Aggregator struct {
opts aggregatorOptions

started bool
finished bool
streamFinished bool
errored bool

currentMessage *Message
currentBlock *Block
blockStack []string

messages []Message

textBuffer bytes.Buffer
bufferLimitHit bool

usage Usage
finishReason *FinishReason
}

// NewAggregator 构造聚合器。
func NewAggregator(opts ...AggregatorOption) *Aggregator {
cfg := aggregatorOptions{keepBlocks: true}
for _, opt := range opts {
opt(&cfg)
}
return &Aggregator{opts: cfg}
}

// Feed 消化单个事件。
func (a *Aggregator) Feed(ev Event) error {
if a.errored {
return ErrStateConflict
}
switch ev.Type {
case EventStreamStart:
if a.started {
return ErrStateConflict
}
a.started = true
case EventStreamError:
if !a.started || a.streamFinished || a.finished {
return ErrStateConflict
}
a.errored = true
return nil
case EventStreamFinish:
if !a.started || a.streamFinished || !a.finished {
return ErrStateConflict
}
a.streamFinished = true
case EventMessageStart:
if !a.started || a.finished || a.currentMessage != nil {
return ErrStateConflict
}
if ev.MsgHdr == nil {
return ErrInvalidArgument
}
msg := *ev.MsgHdr
msg.Meta = msg.Meta.Clone()
msg.Blocks = nil
a.currentMessage = &msg
case EventMessageEnd:
if a.currentMessage == nil || len(a.blockStack) != 0 {
return ErrEventOrder
}
a.messages = append(a.messages, *a.currentMessage)
a.currentMessage = nil
case EventBlockStart:
if a.currentMessage == nil {
return ErrEventOrder
}
if ev.Block == nil {
return ErrInvalidArgument
}
cp := copyBlockHeader(ev.Block)
a.currentBlock = cp
a.blockStack = append(a.blockStack, cp.ID)
case EventBlockEnd:
if len(a.blockStack) == 0 {
return ErrEventOrder
}
last := a.blockStack[len(a.blockStack)-1]
if ev.Ref.BlockID != "" && ev.Ref.BlockID != last {
return ErrEventOrder
}
a.blockStack = a.blockStack[:len(a.blockStack)-1]
if a.currentMessage == nil || a.currentBlock == nil || a.currentBlock.ID != last {
return ErrStateConflict
}
if a.opts.keepBlocks {
a.currentMessage.Blocks = append(a.currentMessage.Blocks, *a.currentBlock)
}
a.currentBlock = nil
case EventTextDelta:
if a.currentBlock == nil || a.currentBlock.Text == nil {
return ErrEventOrder
}
if cb := a.opts.onDelta; cb != nil {
cb(ev.Ref, ev.Text)
}
if a.opts.writer != nil && ev.Text != "" {
if _, err := io.WriteString(a.opts.writer, ev.Text); err != nil {
return fmt.Errorf("write delta: %w", err)
}
}
if a.opts.keepBlocks {
a.currentBlock.Text.Data += ev.Text
}
if !a.bufferLimitHit {
if a.opts.maxBuffer > 0 && a.textBuffer.Len()+len(ev.Text) > a.opts.maxBuffer {
remain := a.opts.maxBuffer - a.textBuffer.Len()
if remain > 0 {
a.textBuffer.WriteString(ev.Text[:remain])
}
a.bufferLimitHit = true
} else {
a.textBuffer.WriteString(ev.Text)
}
}
case EventJSONDelta:
return ErrUnsupported
case EventUsageDelta:
if ev.Usage != nil {
a.usage.PromptTokens += ev.Usage.PromptTokens
a.usage.CompletionTokens += ev.Usage.CompletionTokens
a.usage.TotalTokens += ev.Usage.TotalTokens
a.usage.BilledTokens += ev.Usage.BilledTokens
}
case EventFinish:
if a.finished {
return ErrStateConflict
}
if ev.Finish == nil {
return ErrInvalidArgument
}
if a.currentMessage != nil || len(a.blockStack) != 0 {
return ErrEventOrder
}
fr := *ev.Finish
a.finishReason = &fr
a.finished = true
default:
return ErrUnsupported
}
return nil
}

// Done 完成聚合并返回 Response。
func (a *Aggregator) Done(base Response) (Response, error) {
if !a.started {
return Response{}, ErrStateConflict
}
if a.errored {
return Response{}, ErrStateConflict
}
if !a.streamFinished {
return Response{}, errors.Join(ErrEventOrder, fmt.Errorf("stream_finish missing"))
}
if !a.finished {
return Response{}, errors.Join(ErrEventOrder, fmt.Errorf("finish missing"))
}
base.Output.Messages = append([]Message(nil), a.messages...)
base.Output.TextConcat = a.textBuffer.String()
base.Usage = a.usage
if a.finishReason != nil {
base.Finish = *a.finishReason
}
return base, nil
}

// CurrentText 返回已聚合的文本。
func (a *Aggregator) CurrentText() string {
return a.textBuffer.String()
}

// CurrentUsage 返回聚合的用量。
func (a *Aggregator) CurrentUsage() Usage {
return a.usage
}

// CurrentFinish 返回已记录的结束原因。
func (a *Aggregator) CurrentFinish() *FinishReason {
return a.finishReason
}

func copyBlockHeader(b *Block) *Block {
cp := *b
cp.Meta = cp.Meta.Clone()
if b.Text != nil {
tb := *b.Text
tb.Data = ""
cp.Text = &tb
}
if b.JSON != nil {
jb := *b.JSON
cp.JSON = &jb
}
if b.Image != nil {
ib := *b.Image
cp.Image = &ib
}
if b.Audio != nil {
ab := *b.Audio
cp.Audio = &ab
}
if b.FileRef != nil {
fb := *b.FileRef
cp.FileRef = &fb
}
if b.ToolCall != nil {
tc := *b.ToolCall
tc.Arguments = CloneArguments(b.ToolCall.Arguments)
cp.ToolCall = &tc
}
if b.ToolResult != nil {
tr := *b.ToolResult
if b.ToolResult.Output != nil {
tr.Output = CloneArguments(b.ToolResult.Output)
}
cp.ToolResult = &tr
}
return &cp
}
74 changes: 74 additions & 0 deletions ir/builder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package ir

import "encoding/json"

// NewRequest 创建一个指定模型的 Request。
func NewRequest(model string) *Request {
return &Request{Model: model}
}

// NewDirective 构造 Directive。
func NewDirective(kind DirectiveKind, blocks ...Block) Directive {
return Directive{Kind: kind, Blocks: append([]Block(nil), blocks...)}
}

// NewTextBlock 构造文本块。
func NewTextBlock(id, text string) Block {
return Block{
ID: id,
Type: BlockText,
Text: &TextBlock{Data: text},
}
}

// NewJSONBlock 构造 JSON 块。
func NewJSONBlock(id string, data any) Block {
return Block{
ID: id,
Type: BlockJSON,
JSON: &JSONBlock{Data: data},
}
}

// NewMessage 构造 Message。
func NewMessage(role Role, id string, blocks ...Block) Message {
return Message{
ID: id,
Role: role,
Blocks: append([]Block(nil), blocks...),
}
}

// AppendDirective 追加指令。
func (r *Request) AppendDirective(d Directive) {
if r == nil {
return
}
r.Directives = append(r.Directives, d)
}

// AppendUserText 追加用户文本消息。
func (r *Request) AppendUserText(id, text string) {
if r == nil {
return
}
r.Messages = append(r.Messages, NewMessage(RoleUser, id, NewTextBlock(id+"_b", text)))
}

// AppendAssistantText 追加助手文本消息。
func (r *Request) AppendAssistantText(id, text string) {
if r == nil {
return
}
r.Messages = append(r.Messages, NewMessage(RoleAssistant, id, NewTextBlock(id+"_b", text)))
}

// CloneArguments 返回工具参数的深拷贝。
func CloneArguments(raw *json.RawMessage) *json.RawMessage {
if raw == nil {
return nil
}
dup := make(json.RawMessage, len(*raw))
copy(dup, *raw)
return &dup
}
11 changes: 11 additions & 0 deletions ir/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Package ir 提供一个供应商无关的 LLM 请求/响应/事件中间表示。
//
// 该包包含如下能力:
// - 请求模型:指令、消息、内容块、采样与输出约束。
// - 流式事件:统一的块级增量模型,支持多模态与工具调用。
// - 校验与规范化:对请求与事件进行一致性检查与轻量规范化。
// - 聚合器:将流式事件恢复为非流式响应,支持边写边收缩缓冲。
//
// 包的设计参考了 README 中的总览,重点在于为上层 PolyLLM 提供稳定的
// 适配层,使不同厂商的 API 可以映射到统一的数据模型中。
package ir
Loading
Loading