@@ -32,10 +32,10 @@ class MLXService {
3232 ]
3333
3434 /// Cache to store loaded model containers to avoid reloading.
35- private var modelCache : [ String : ModelContainer ] = [ : ]
35+ private let modelCache = NSCache < NSString , ModelContainer > ( )
3636
3737 /// Stores a prompt cache for each loaded model
38- private var promptCache : [ String : PromptCache ] = [ : ]
38+ private let promptCache = NSCache < NSString , PromptCache > ( )
3939
4040 /// Tracks the current model download progress.
4141 /// Access this property to monitor model download status.
@@ -51,9 +51,10 @@ class MLXService {
5151 MLX . GPU. set ( cacheLimit: 20 * 1024 * 1024 )
5252
5353 // Return cached model if available to avoid reloading
54- if let container = modelCache [ model. name] {
54+ if let container = modelCache. object ( forKey : model. name as NSString ) {
5555 return container
5656 } else {
57+ print ( " Model not loaded \( model. name) , loading model... " )
5758 // Select appropriate factory based on model type
5859 let factory : ModelFactory =
5960 switch model. type {
@@ -71,9 +72,13 @@ class MLXService {
7172 self . modelDownloadProgress = progress
7273 }
7374 }
74-
75+
76+ // Clear out the promptCache
77+ promptCache. removeObject ( forKey: model. name as NSString )
78+
7579 // Cache the loaded model for future use
76- modelCache [ model. name] = container
80+ modelCache. setObject ( container, forKey: model. name as NSString )
81+
7782 return container
7883 }
7984 }
@@ -118,32 +123,41 @@ class MLXService {
118123
119124 let parameters = GenerateParameters ( temperature: 0.7 )
120125
121- // Get the prompt cache
122- let cache : PromptCache
123- if let existingCache = self . promptCache [ model. name] {
124- cache = existingCache
125- } else {
126- // Create cache if it doesn't exist yet
127- cache = PromptCache ( cache: context. model. newCache ( parameters: parameters) )
128- promptCache [ model. name] = cache
129- }
130-
131- let lmInput : LMInput
132-
133- /// Remove prefix from prompt that is already in cache
134- if let suffix = cache. getUncachedSuffix ( prompt: fullPrompt. text. tokens) {
135- lmInput = LMInput ( text: LMInput . Text ( tokens: suffix) )
136- } else {
137- // If suffix is nil, the cache is inconsistent with the new prompt
138- // and the cache doesn't support trimming so create a new one here.
139- self . promptCache [ model. name] = PromptCache ( cache: context. model. newCache ( parameters: parameters) )
140- lmInput = fullPrompt
141- }
126+ // TODO: Prompt cache access isn't isolated
127+ // Get the prompt cache and adjust new prompt to remove
128+ // prefix already in cache, trim cache if cache is
129+ // inconsistent with new prompt.
130+ let ( cache, lmInput) = getPromptCache ( fullPrompt: fullPrompt, parameters: parameters, context: context, modelName: model. name)
142131
143- // TODO: cache.perform ...
144132 // TODO: The generated tokens should be added to the prompt cache but not possible with AsyncStream
145133 return try MLXLMCommon . generate (
146134 input: lmInput, parameters: parameters, context: context, cache: cache. cache)
147135 }
148136 }
137+
138+ func getPromptCache( fullPrompt: LMInput , parameters: GenerateParameters , context: ModelContext , modelName: String ) -> ( PromptCache , LMInput ) {
139+ let cache : PromptCache
140+ if let existingCache = promptCache. object ( forKey: modelName as NSString ) {
141+ cache = existingCache
142+ } else {
143+ // Create cache if it doesn't exist yet
144+ cache = PromptCache ( cache: context. model. newCache ( parameters: parameters) )
145+ self . promptCache. setObject ( cache, forKey: modelName as NSString )
146+ }
147+
148+ let lmInput : LMInput
149+
150+ /// Remove prefix from prompt that is already in cache
151+ if let suffix = cache. getUncachedSuffix ( prompt: fullPrompt. text. tokens) {
152+ lmInput = LMInput ( text: LMInput . Text ( tokens: suffix) )
153+ } else {
154+ // If suffix is nil, the cache is inconsistent with the new prompt
155+ // and the cache doesn't support trimming so create a new one here.
156+ let newCache = PromptCache ( cache: context. model. newCache ( parameters: parameters) )
157+ self . promptCache. setObject ( newCache, forKey: modelName as NSString )
158+ lmInput = fullPrompt
159+ }
160+
161+ return ( cache, lmInput)
162+ }
149163}
0 commit comments