Skip to content

Commit 044089e

Browse files
jefcodersamugit83
authored andcommitted
update
1 parent df51827 commit 044089e

File tree

2 files changed

+4
-13
lines changed

2 files changed

+4
-13
lines changed

reinforcement_learning/qlearn_agent.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import os
22
import pickle
33
import numpy as np
4-
# Added imports for the neural network mode
54
import tensorflow as tf
65
from tensorflow.keras.models import Sequential, load_model
76
from tensorflow.keras.layers import Dense
87
from tensorflow.keras.optimizers import Adam
9-
import logging # Added for cleaner messages
8+
import logging
109

1110
# Configure logging
1211
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -67,7 +66,7 @@ def __init__(
6766
self.error_list = []
6867
self.load_data()
6968

70-
# +++ Added: Helper to build the neural network +++
69+
7170
def _build_model(self):
7271
"""Builds the neural network model for DQN."""
7372

@@ -83,7 +82,6 @@ def _build_model(self):
8382
logging.info("Built neural network model.")
8483
# model.summary() # Optional: print model summary
8584
return model
86-
# +++ End Added +++
8785

8886
def load_data(self):
8987
"""

tools/rag/rl_meta_rag/rl_meta_rag_retrieve.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# Helper function to extract categories from prompt template
1212
def _extract_categories_from_prompt(prompt_template, feature_name):
1313
"""Extracts list items for a given feature from the prompt template."""
14-
# Find the line describing the feature
1514
match = re.search(rf"- {feature_name}:.*?\[(.*?)\]", prompt_template.template, re.DOTALL)
1615
if match:
1716
# Extract the content within the brackets and split into items
@@ -131,11 +130,6 @@ def _extract_query_features(self, query):
131130
# Adjust query length specifically if needed
132131
features[5] = len(query.split()) # A reasonable default for query_length
133132

134-
# Note: The original code appended query_length again here.
135-
# This was likely a bug as query_length is already included in the 9 features.
136-
# query_length_calc = len(query.split())
137-
# features.append(query_length_calc) # REMOVED THIS LINE
138-
139133
self.socketio.emit('reasoning_update', {
140134
"message": f"Extracted query features: {features}"
141135
})
@@ -240,7 +234,6 @@ def retrieve(self, query, session_id=None):
240234
if technique_id in self.rag_techniques:
241235
result = self.rag_techniques[technique_id](query)
242236
else:
243-
# Default or fallback technique
244237
self.socketio.emit('reasoning_update', {
245238
"message": f"Warning: Selected technique ID {technique_id} not found. Defaulting to technique 0."
246239
})
@@ -262,12 +255,12 @@ def retrieve(self, query, session_id=None):
262255
redis_client = redis.StrictRedis(host='redis', port=6379, db=0, decode_responses=True)
263256
# Store the original feature list before tuple conversion if needed
264257
data_to_store = {
265-
"state": state_features, # Store the raw feature list
258+
"state": state_features,
266259
"action": technique_id,
267260
"query": query
268261
}
269262
redis_key = f"rl_update:{session_id}"
270-
# Use json.dumps with a handler for non-serializable types if necessary, though list should be fine
263+
271264
try:
272265
redis_client.set(redis_key, json.dumps(data_to_store))
273266
self.socketio.emit('reasoning_update', {

0 commit comments

Comments
 (0)