-
Notifications
You must be signed in to change notification settings - Fork 173
Update chatbot/train.py and chatbot/app.py to improve model performance #200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughThe chatbot module is enhanced with NLTK-based NLP standardization (word_tokenize and PorterStemmer) integrated into both training and inference pipelines. Training configuration is externalized into constants, epoch-level loss monitoring is added, and the dataset now returns PyTorch tensors instead of raw numpy arrays for consistency. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Areas requiring attention:
Poem
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
chatbot/train.py (2)
84-91: Type mismatch:bag_of_wordsreturns tensors, but code expects numpy arrays.
bag_of_wordsreturns atorch.tensor(line 32), but these are collected into a list and converted tonp.array(line 90). This creates a numpy object array of tensors. Later,__getitem__(line 111) callstorch.from_numpy()on these tensor objects, which will fail becausefrom_numpyexpects a numpy ndarray, not a PyTorch tensor.Either convert
bag_of_wordsoutput to numpy before appending, or handle tensors directly:Option 1: Convert to numpy in bag_of_words (consistent with original design)
def bag_of_words(tokenized_sentence, words): sentence_words = [stem(word) for word in tokenized_sentence] bag = [1.0 if word in sentence_words else 0.0 for word in words] - return torch.tensor(bag, dtype=torch.float32) + return np.array(bag, dtype=np.float32)Option 2: Stack tensors directly (avoid numpy conversion)
-X_train = np.array(X_train) +X_train = torch.stack(X_train) y_train = np.array(y_train)And update
__getitem__:def __getitem__(self, index): - return torch.from_numpy(self.x_data[index]), torch.tensor(self.y_data[index]) + return self.x_data[index], torch.tensor(self.y_data[index])
117-121:num_workers=0does not meet the linked issue requirement.Issue #198 specifically requested "set DataLoader
num_workers > 0to enable parallel data fetching" for performance improvement. The current value is0, which disables parallel loading.train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, - num_workers=0) + num_workers=2)Note: If Windows compatibility is a concern, consider using a conditional:
import os num_workers = 0 if os.name == 'nt' else 2
🧹 Nitpick comments (2)
chatbot/train.py (2)
8-11: Consider guarding NLTK resource download.Calling
nltk.download('punkt')at module import time can cause unexpected network calls in production. Consider wrapping this in a try/except that checks if the resource already exists, or move it to a setup/initialization function.import nltk from nltk.stem.porter import PorterStemmer -nltk.download('punkt') +try: + nltk.data.find('tokenizers/punkt') +except LookupError: + nltk.download('punkt')
20-32: Code duplication withapp.py.The
tokenize,stem, andbag_of_wordsfunctions are duplicated inapp.py(lines 14-24). Consider extracting these into a shared module (e.g.,nlp_utils.py) to ensure consistency and reduce maintenance burden.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
chatbot/app.py(3 hunks)chatbot/train.py(7 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
chatbot/app.py (1)
chatbot/train.py (3)
stem(24-25)tokenize(21-22)bag_of_words(27-32)
chatbot/train.py (1)
chatbot/app.py (3)
stem(18-19)tokenize(15-16)bag_of_words(21-24)
🪛 Ruff (0.14.8)
chatbot/app.py
96-96: Possible binding to all interfaces
(S104)
96-96: Use of debug=True in Flask app detected
(S201)
🔇 Additional comments (6)
chatbot/train.py (3)
13-18: Good use of configuration constants.Extracting
INTENTS_FILEandMODEL_SAVE_FILEinto constants improves maintainability. The global stemmer instance aligns with the approach inapp.py.
132-159: Good implementation of epoch-averaged loss tracking.The weighted loss accumulation (
loss.item() * words.size(0)) and division bylen(dataset)correctly computes the average loss per sample across the epoch, fulfilling the linked issue requirement.
161-172: LGTM!Model saving correctly uses the
MODEL_SAVE_FILEconstant and includes all necessary metadata for inference.chatbot/app.py (3)
8-24: LGTM - NLP preprocessing aligns with training.The tokenization and stemming implementation correctly mirrors
train.py, ensuring consistent preprocessing between training and inference.
47-48: Verifytorch.loadbehavior with current PyTorch version.In PyTorch 2.6+,
torch.loaddefaults toweights_only=True, which would fail here since the saved data includes non-tensor objects (all_words,tags). Consider explicitly setting the parameter for forward compatibility.FILE = "data.pth" -data = torch.load(FILE) +data = torch.load(FILE, weights_only=False)Alternatively, for better security, use
torch.load(FILE, weights_only=True, map_location=device)and save metadata separately in a JSON file.
63-93: Chat endpoint logic is correct.The inference pipeline properly uses the updated NLP functions and handles uncertain predictions gracefully.
One minor consideration: returning raw exception messages (
str(e)) in line 93 could expose internal details. In production, consider logging the full error server-side and returning a generic message to clients.
| if __name__ == '__main__': | ||
| app.run(host="0.0.0.0", port=5000,debug=True) | ||
| app.run(host="0.0.0.0", port=5000, debug=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Development server configuration flagged by static analysis.
debug=True and binding to 0.0.0.0 are appropriate for development but pose security risks in production:
debug=Trueenables the interactive debugger, which can execute arbitrary code0.0.0.0exposes the service to all network interfaces
Consider using environment variables or a configuration flag:
+import os
+
if __name__ == '__main__':
- app.run(host="0.0.0.0", port=5000, debug=True)
+ debug_mode = os.environ.get('FLASK_DEBUG', 'false').lower() == 'true'
+ app.run(host="0.0.0.0", port=5000, debug=debug_mode)For production, use a WSGI server (e.g., Gunicorn) instead of the Flask development server.
🧰 Tools
🪛 ast-grep (0.40.0)
[warning] 95-95: Running flask app with host 0.0.0.0 could expose the server publicly.
Context: app.run(host="0.0.0.0", port=5000, debug=True)
Note: [CWE-668]: Exposure of Resource to Wrong Sphere [OWASP A01:2021]: Broken Access Control [REFERENCES]
https://owasp.org/Top10/A01_2021-Broken_Access_Control
(avoid_app_run_with_bad_host-python)
[warning] 95-95: Detected Flask app with debug=True. Do not deploy to production with this flag enabled as it will leak sensitive information. Instead, consider using Flask configuration variables or setting 'debug' using system environment variables.
Context: app.run(host="0.0.0.0", port=5000, debug=True)
Note: [CWE-489] Active Debug Code. [REFERENCES]
- https://labs.detectify.com/2015/10/02/how-patreon-got-hacked-publicly-exposed-werkzeug-debugger/
(debug-enabled-python)
🪛 Ruff (0.14.8)
96-96: Possible binding to all interfaces
(S104)
96-96: Use of debug=True in Flask app detected
(S201)
🤖 Prompt for AI Agents
In chatbot/app.py around lines 95-96, the app is started with debug=True and
host="0.0.0.0", which is unsafe for production; change the startup to read
configuration from environment variables (e.g., FLASK_ENV or APP_DEBUG and
APP_HOST/APP_PORT), default debug to False, and only enable debug and binding to
0.0.0.0 when an explicit development flag is set; for production remove the
built-in server invocation and run under a WSGI server like Gunicorn (document
example command) so the code no longer enables the interactive debugger or
indiscriminately binds all interfaces by default.
| @@ -1,23 +1,34 @@ | |||
| from operator import index | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused import.
The index import from operator is never used in this file.
-from operator import index📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from operator import index |
🤖 Prompt for AI Agents
In chatbot/train.py at line 1, the file imports `index` from the standard
library `operator` module but never uses it; remove the unused import line `from
operator import index` to clean up imports and avoid lint warnings, ensuring
there are no other references to `index` elsewhere in the file before committing
the change.
Description
In this PR, I have
chatbot/app.pyto work with the improvedchatbot/train.pyLet me know if there is anything else I need to take care of.
Fixes #198
Type of change
Please mark the options that are relevant.
Checklist:
Summary by CodeRabbit
New Features
Refactor
✏️ Tip: You can customize this high-level summary in your review settings.