diff --git a/Sources/Tokenizers/Tokenizer.swift b/Sources/Tokenizers/Tokenizer.swift index 4880f5a9..ad820e06 100644 --- a/Sources/Tokenizers/Tokenizer.swift +++ b/Sources/Tokenizers/Tokenizer.swift @@ -787,48 +787,41 @@ public class PreTrainedTokenizer: @unchecked Sendable, Tokenizer { throw TokenizerError.missingChatTemplate } - let renderedTemplate: String - do { - let template = try compiledTemplate(for: selectedChatTemplate) - - var context: [String: Jinja.Value] = try [ - "messages": .array(messages.map { try Value(any: $0) }), - "add_generation_prompt": .boolean(addGenerationPrompt), - ] - if let tools { - context["tools"] = try .array(tools.map { try Value(any: $0) }) - } - if let additionalContext { - // Additional keys and values to be added to the context provided to the prompt templating engine. - // For example, the app could set "tools_in_user_message" to false for Llama 3.1 and 3.2 if a system message is provided. - // The default value is true in the Llama 3.1 and 3.2 chat templates, but these models will perform better if the tools are included in a system message. - for (key, value) in additionalContext { - context[key] = try Value(any: value) - } + let template = try compiledTemplate(for: selectedChatTemplate) + var context: [String: Jinja.Value] = try [ + "messages": .array(messages.map { try Value(any: $0) }), + "add_generation_prompt": .boolean(addGenerationPrompt), + ] + if let tools { + context["tools"] = try .array(tools.map { try Value(any: $0) }) + } + if let additionalContext { + // Additional keys and values to be added to the context provided to the prompt templating engine. + // For example, the app could set "tools_in_user_message" to false for Llama 3.1 and 3.2 if a system message is provided. + // The default value is true in the Llama 3.1 and 3.2 chat templates, but these models will perform better if the tools are included in a system message. + for (key, value) in additionalContext { + context[key] = try Value(any: value) } + } - for (key, value) in tokenizerConfig.dictionary(or: [:]) { - if specialTokenAttributes.contains(key.string), !value.isNull() { - if let stringValue = value.string() { - context[key.string] = .string(stringValue) - } else if let dictionary = value.dictionary() { - if let addedTokenString = addedTokenAsString(Config(dictionary)) { - context[key.string] = .string(addedTokenString) - } - } else if let array: [String] = value.get() { - context[key.string] = .array(array.map { .string($0) }) - } else { - context[key.string] = try Value(any: value) + for (key, value) in tokenizerConfig.dictionary(or: [:]) { + if specialTokenAttributes.contains(key.string), !value.isNull() { + if let stringValue = value.string() { + context[key.string] = .string(stringValue) + } else if let dictionary = value.dictionary() { + if let addedTokenString = addedTokenAsString(Config(dictionary)) { + context[key.string] = .string(addedTokenString) } + } else if let array: [String] = value.get() { + context[key.string] = .array(array.map { .string($0) }) + } else { + context[key.string] = try Value(any: value) } } - - renderedTemplate = try template.render(context) - } catch let error as JinjaError { - let description = (error as? LocalizedError)?.errorDescription ?? "\(error)" - throw TokenizerError.chatTemplate(description) } - var encodedTokens = encode(text: renderedTemplate, addSpecialTokens: false) + + let rendered = try template.render(context) + var encodedTokens = encode(text: rendered, addSpecialTokens: false) var maxLength = maxLength ?? encodedTokens.count maxLength = min(maxLength, tokenizerConfig.modelMaxLength.integer() ?? maxLength) if encodedTokens.count > maxLength {