diff --git a/DSL/Ruuter.private/DSL/POST/classifier/datamodel/re-train.yml b/DSL/Ruuter.private/DSL/POST/classifier/datamodel/retrain.yml similarity index 100% rename from DSL/Ruuter.private/DSL/POST/classifier/datamodel/re-train.yml rename to DSL/Ruuter.private/DSL/POST/classifier/datamodel/retrain.yml diff --git a/README.md b/README.md index bd18ca6b..70b2f4af 100644 --- a/README.md +++ b/README.md @@ -106,7 +106,7 @@ This repo will primarily contain: - `JIRA_WEBHOOK_SECRET` – Jira webhook secret you got in **Create Jira Webhook** step. 4. **Create a `.env` file for Jira Configuration:** - - Create a `.env` file called `jira_config.env` and add the following: + - Create a `.env` file in the folder called `jira-verification` and add the following: ```env JIRA_WEBHOOK_SECRET=<> ``` diff --git a/docker-compose.gpu.yml b/docker-compose.gpu.yml index 34a3e003..2c65845f 100644 --- a/docker-compose.gpu.yml +++ b/docker-compose.gpu.yml @@ -311,6 +311,7 @@ services: - FIND_FINAL_FOLDER_ID_URL=http://hierarchy-validation:8009/find-folder-id - UPDATE_DATAMODEL_PROGRESS_URL=http://ruuter-private:8088/classifier/datamodel/progress/update - UPDATE_MODEL_TRAINING_STATUS_ENDPOINT=http://ruuter-private:8088/classifier/datamodel/update/training/status + - GET_DATASET_METADATA_ENDPOINT=http://ruuter-private:8088/classifier/datasetgroup/group/metadata ports: - "8003:8003" networks: @@ -491,7 +492,7 @@ services: ports: - "3008:3008" env_file: - - jira_config.env + - ./jira-verification/.env environment: RUUTER_PUBLIC_JIRA_URL: http://ruuter-public:8086/internal/jira/accept networks: diff --git a/docker-compose.yml b/docker-compose.yml index 8b4394ec..2035573f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -305,6 +305,7 @@ services: - FIND_FINAL_FOLDER_ID_URL=http://hierarchy-validation:8009/find-folder-id - UPDATE_DATAMODEL_PROGRESS_URL=http://ruuter-private:8088/classifier/datamodel/progress/update - UPDATE_MODEL_TRAINING_STATUS_ENDPOINT=http://ruuter-private:8088/classifier/datamodel/update/training/status + - GET_DATASET_METADATA_ENDPOINT=http://ruuter-private:8088/classifier/datasetgroup/group/metadata ports: - "8003:8003" networks: @@ -473,7 +474,7 @@ services: ports: - "3008:3008" env_file: - - jira_config.env + - ./jira-verification/.env environment: RUUTER_PUBLIC_JIRA_URL: http://ruuter-public:8086/internal/jira/accept networks: diff --git a/model-inference/constants.py b/model-inference/constants.py index 77309872..7d2f5943 100644 --- a/model-inference/constants.py +++ b/model-inference/constants.py @@ -46,6 +46,7 @@ class UpdateRequest(BaseModel): bestBaseModel:str updateType: Optional[str] = None progressSessionId: int + dgId:Optional[int]= None class OutlookInferenceRequest(BaseModel): inputId:str diff --git a/model-inference/inference_pipeline.py b/model-inference/inference_pipeline.py index 9bbfe854..82b7e1fd 100644 --- a/model-inference/inference_pipeline.py +++ b/model-inference/inference_pipeline.py @@ -104,7 +104,7 @@ def find_missing_classes(self, main_classes, uploaded_classes): - def predict_class(self,text_input): + def predict_class(self,text_input, platform): logger.info("ENTERING PREDICT CLASS") @@ -117,10 +117,11 @@ def predict_class(self,text_input): self.base_model.to(self.device) logger.info(f"CLASS HIERARCHY FILE {self.hierarchy_file}") + logger.info(f"PLATFORM IN PREDICT CLASS {platform}") - - data = self.hierarchy_file + if platform == 'jira': + data = data['classHierarchy'] parent = 1 logger.info(f"DATA - {data}") diff --git a/model-inference/inference_wrapper.py b/model-inference/inference_wrapper.py index 19dcdc5b..6eab9397 100644 --- a/model-inference/inference_wrapper.py +++ b/model-inference/inference_wrapper.py @@ -62,7 +62,7 @@ def inference(self, text:str, deployment_platform:str): if(deployment_platform == "jira" and self.active_jira_model): logger.info("ENTERING JIRA INFERENCE") - predicted_labels, probabilities = self.active_jira_model.predict_class(text) + predicted_labels, probabilities = self.active_jira_model.predict_class(text, deployment_platform) logger.info(f"PREDICTED LABELS INSIDE .inference() FUNCTION - {predicted_labels}") logger.info(f"PROBABILITIES INSIDE .inference() FUNCTION - {probabilities}") @@ -70,7 +70,7 @@ def inference(self, text:str, deployment_platform:str): if(deployment_platform == "outlook" and self.active_outlook_model): logger.info("ENTERING OUTLOOK INFERENCE") - predicted_labels, probabilities = self.active_outlook_model.predict_class(text) + predicted_labels, probabilities = self.active_outlook_model.predict_class(text, deployment_platform) logger.info(f"PREDICTED LABELS INSIDE .inference() FUNCTION - {predicted_labels}") diff --git a/model-inference/model_inference.py b/model-inference/model_inference.py index 61c5f8b1..516a1cf9 100644 --- a/model-inference/model_inference.py +++ b/model-inference/model_inference.py @@ -19,12 +19,13 @@ UPDATE_DATAMODEL_PROGRESS_URL = os.getenv("UPDATE_DATAMODEL_PROGRESS_URL") UPDATE_MODEL_TRAINING_STATUS_ENDPOINT = os.getenv("UPDATE_MODEL_TRAINING_STATUS_ENDPOINT") RUUTER_PRIVATE_URL = os.getenv("RUUTER_PRIVATE_URL") +GET_DATASET_METADATA_ENDPOINT=os.getenv("GET_DATASET_METADATA_ENDPOINT") class ModelInference: def __init__(self): pass - def get_class_hierarchy_by_model_id(self, model_id): + def get_outlook_class_hierarchy_by_model_id(self, model_id): try: logger.info(f"get_class_hierarchy_by_model_id - {model_id}") @@ -123,9 +124,9 @@ def validate_class_hierarchy(self, class_hierarchy, model_id): - def get_class_hierarchy_and_validate(self, model_id): + def get_outlook_class_hierarchy_and_validate(self, model_id): try: - class_hierarchy = self.get_class_hierarchy_by_model_id(model_id) + class_hierarchy = self.get_outlook_class_hierarchy_by_model_id(model_id) if class_hierarchy: is_valid = self.validate_class_hierarchy(class_hierarchy, model_id) return is_valid, class_hierarchy @@ -252,4 +253,33 @@ def create_inference(self, payload): raise RuntimeError(f"Failed to call create inference. Reason: {e}") - + def get_class_hierarchy_by_dg_id(self, cookies, dg_id): + logger.info("********************************************************************") + logger.info(f"****** Calling function get_class_hierarchy_by_dg_id ******") + try: + logger.info(f"get_class_hierarchy_by_dg_id - {dg_id}") + logger.info(f"cookie : {cookies}") + cookies_updated = {"customJwtCookie":cookies} + logger.info(f"cookie_updated : {cookies_updated}") + logger.info(f"GET_DATASET_METADATA_ENDPOINT : {GET_DATASET_METADATA_ENDPOINT}") + + response_hierarchy = requests.get(GET_DATASET_METADATA_ENDPOINT, params={'groupId': dg_id}, cookies=cookies_updated) + + logger.info(f"response_hierarchy : {response_hierarchy}") + + if response_hierarchy.status_code == 200: + logger.info("DATASET HIERARCHY RETREIVAL SUCCESSFUL") + hierarchy = response_hierarchy.json() + logger.info(f"DATASET HIERARCHY - {hierarchy}") + class_hierarchy = hierarchy['response']['data'][0] + logger.info(f"CLASS HIERARCHY - {class_hierarchy}") + return class_hierarchy + + else: + logger.error(f"DATASET HIERARCHY RETRIEVAL FAILED: {response_hierarchy.status_code}") + raise RuntimeError(f"ERROR RESPONSE\n {response_hierarchy.text}") + + + except Exception as e: + logger.error(f"Failed to retrieve the class hierarchy Reason: {e}") + raise RuntimeError(f"Failed to retrieve the class hierarchy Reason: {e}") diff --git a/model-inference/model_inference_api.py b/model-inference/model_inference_api.py index 9627f33d..3b9a0673 100644 --- a/model-inference/model_inference_api.py +++ b/model-inference/model_inference_api.py @@ -70,7 +70,7 @@ async def download_outlook_model(request: Request, model_data:UpdateRequest): model_progress_session_id = model_data.progressSessionId ## Get class hierarchy and validate it - is_valid, class_hierarchy = model_inference.get_class_hierarchy_and_validate(model_data.modelId) + is_valid, class_hierarchy = model_inference.get_outlook_class_hierarchy_and_validate(model_data.modelId) logger.info(f"IS VALID VALUE : {is_valid}") logger.info(f"CLASS HIERARCHY VALUE : {class_hierarchy}") @@ -230,7 +230,7 @@ async def download_jira_model(request: Request, model_data:UpdateRequest): logger.info("JUST ABOUT TO ENTER get_class_hierarchy_by_model_id") - class_hierarchy = model_inference.get_class_hierarchy_by_model_id(model_data.modelId) + class_hierarchy = model_inference.get_class_hierarchy_by_dg_id(cookies=cookie, dg_id=model_data.dgId) logger.info(f"JIRA UPDATE CLASS HIERARCHY - {class_hierarchy}") @@ -354,7 +354,7 @@ async def download_test_model(request: Request, model_data:UpdateRequest): logger.info("JUST ABOUT TO ENTER get_class_hierarchy_by_model_id") - class_hierarchy = model_inference.get_class_hierarchy_by_model_id(model_data.modelId) + class_hierarchy = model_inference.get_class_hierarchy_by_dg_id(cookies=cookie, dg_id=model_data.dgId) logger.info(f"TEST UPDATE CLASS HIERARCHY - {class_hierarchy}") @@ -556,12 +556,10 @@ async def outlook_inference(request:Request, inference_data:OutlookInferenceRequ async def jira_inference(request:Request, inferenceData:JiraInferenceRequest): try: - logger.info(f"INFERENCE DATA IN JIRA INFERENCE - {inferenceData}") model_id = model_inference_wrapper.get_jira_model_id() - if(model_id): # 1 . Check whether the if the Inference Exists is_exist, inference_id = model_inference.check_inference_data_exists(input_id=inferenceData.inputId) diff --git a/model-inference/test_inference_wrapper.py b/model-inference/test_inference_wrapper.py index a8d0ca4b..652133e5 100644 --- a/model-inference/test_inference_wrapper.py +++ b/model-inference/test_inference_wrapper.py @@ -30,7 +30,7 @@ def inference(self, text: str, model_id: int): predicted_labels = None probabilities = None model = self.model_dictionary[model_id] - predicted_labels, probabilities = model.predict_class(text_input=text) + predicted_labels, probabilities = model.predict_class(text_input=text, platform="test") return predicted_labels, probabilities else: raise Exception(f"Model with ID {model_id} not found") diff --git a/model_trainer/datapipeline.py b/model_trainer/datapipeline.py index d8aa3959..0a273b0f 100644 --- a/model_trainer/datapipeline.py +++ b/model_trainer/datapipeline.py @@ -26,9 +26,16 @@ def __init__(self, dg_id,cookie): raise RuntimeError(f"ERROR RESPONSE {response.text}") + logger.info(f"###############################################################") + logger.info(f"****** Calling init function of DataPipeline Class ******") + logger.info(f"Endpoint : {GET_DATASET_METADATA_ENDPOINT}") + logger.info(f"Cookie : {cookies}") + logger.info(f"DGID : {dg_id}") response_hierarchy = requests.get(GET_DATASET_METADATA_ENDPOINT, params={'groupId': dg_id}, cookies=cookies) + logger.info(f"response_hierarchy : {response_hierarchy}") + if response_hierarchy.status_code == 200: logger.info("DATASET HIERARCHY RETREIVAL SUCCESSFUL") hierarchy = response_hierarchy.json() @@ -38,19 +45,6 @@ def __init__(self, dg_id,cookie): logger.error(f"DATASET HIERARCHY RETRIEVAL FAILED: {response_hierarchy.status_code}") logger.error(f"RESPONSE: {response.text}") raise RuntimeError(f"ERROR RESPONSE\n {response_hierarchy.text}") - - - def find_target_column(self,df , filter_list): - - value_set = set(filter_list) - columns_with_exact_or_subset_values = [] - - for column in df.columns: - unique_values = set(df[column].dropna().unique()) - if unique_values and unique_values.issubset(value_set): - columns_with_exact_or_subset_values.append(column) - - return columns_with_exact_or_subset_values def extract_input_columns(self): @@ -110,4 +104,16 @@ def create_dataframes(self): filtered_df = filtered_df.dropna() dfs.append(filtered_df) - return dfs \ No newline at end of file + return dfs + + def find_target_column(self,df , filter_list): + + value_set = set(filter_list) + columns_with_exact_or_subset_values = [] + + for column in df.columns: + unique_values = set(df[column].dropna().unique()) + if unique_values and unique_values.issubset(value_set): + columns_with_exact_or_subset_values.append(column) + + return columns_with_exact_or_subset_values \ No newline at end of file diff --git a/model_trainer/model_trainer.py b/model_trainer/model_trainer.py index d27840b4..af12e5d8 100644 --- a/model_trainer/model_trainer.py +++ b/model_trainer/model_trainer.py @@ -157,7 +157,7 @@ def update_model_training_progress_session(self,training_status, return session_id - def deploy_model(self, best_model_name, progress_session_id): + def deploy_model(self, best_model_name, progress_session_id, dg_id): payload = {} payload["modelId"] = self.new_model_id @@ -167,6 +167,7 @@ def deploy_model(self, best_model_name, progress_session_id): payload["bestBaseModel"] = best_model_name payload["progressSessionId"] = progress_session_id payload["updateType"] = self.update_type + payload["dgId"] = dg_id if self.update_type == "retrain": payload["replaceDeploymentPlatform"] = self.current_deployment_platform @@ -411,7 +412,7 @@ def train(self): else: logger.info(f"INITIATING DEPLOYMENT TO {self.current_deployment_platform}") - self.deploy_model(best_model_name=best_model_name, progress_session_id=session_id) + self.deploy_model(best_model_name=best_model_name, progress_session_id=session_id,dg_id=dg_id) except Exception as e: self.send_error_progress_session(f"RUNTIME CRASHED - ERROR - {str(e)}")