Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ This repo will primarily contain:
- `JIRA_WEBHOOK_SECRET` – Jira webhook secret you got in **Create Jira Webhook** step.

4. **Create a `.env` file for Jira Configuration:**
- Create a `.env` file called `jira_config.env` and add the following:
- Create a `.env` file in the folder called `jira-verification` and add the following:
```env
JIRA_WEBHOOK_SECRET=<<JIRA_WEBHOOK_SECRET>>
```
Expand Down
3 changes: 2 additions & 1 deletion docker-compose.gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ services:
- FIND_FINAL_FOLDER_ID_URL=http://hierarchy-validation:8009/find-folder-id
- UPDATE_DATAMODEL_PROGRESS_URL=http://ruuter-private:8088/classifier/datamodel/progress/update
- UPDATE_MODEL_TRAINING_STATUS_ENDPOINT=http://ruuter-private:8088/classifier/datamodel/update/training/status
- GET_DATASET_METADATA_ENDPOINT=http://ruuter-private:8088/classifier/datasetgroup/group/metadata
ports:
- "8003:8003"
networks:
Expand Down Expand Up @@ -491,7 +492,7 @@ services:
ports:
- "3008:3008"
env_file:
- jira_config.env
- ./jira-verification/.env
environment:
RUUTER_PUBLIC_JIRA_URL: http://ruuter-public:8086/internal/jira/accept
networks:
Expand Down
3 changes: 2 additions & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ services:
- FIND_FINAL_FOLDER_ID_URL=http://hierarchy-validation:8009/find-folder-id
- UPDATE_DATAMODEL_PROGRESS_URL=http://ruuter-private:8088/classifier/datamodel/progress/update
- UPDATE_MODEL_TRAINING_STATUS_ENDPOINT=http://ruuter-private:8088/classifier/datamodel/update/training/status
- GET_DATASET_METADATA_ENDPOINT=http://ruuter-private:8088/classifier/datasetgroup/group/metadata
ports:
- "8003:8003"
networks:
Expand Down Expand Up @@ -473,7 +474,7 @@ services:
ports:
- "3008:3008"
env_file:
- jira_config.env
- ./jira-verification/.env
environment:
RUUTER_PUBLIC_JIRA_URL: http://ruuter-public:8086/internal/jira/accept
networks:
Expand Down
1 change: 1 addition & 0 deletions model-inference/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class UpdateRequest(BaseModel):
bestBaseModel:str
updateType: Optional[str] = None
progressSessionId: int
dgId:Optional[int]= None

class OutlookInferenceRequest(BaseModel):
inputId:str
Expand Down
7 changes: 4 additions & 3 deletions model-inference/inference_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def find_missing_classes(self, main_classes, uploaded_classes):



def predict_class(self,text_input):
def predict_class(self,text_input, platform):

logger.info("ENTERING PREDICT CLASS")

Expand All @@ -117,10 +117,11 @@ def predict_class(self,text_input):
self.base_model.to(self.device)

logger.info(f"CLASS HIERARCHY FILE {self.hierarchy_file}")
logger.info(f"PLATFORM IN PREDICT CLASS {platform}")



data = self.hierarchy_file
if platform == 'jira':
data = data['classHierarchy']
parent = 1

logger.info(f"DATA - {data}")
Expand Down
4 changes: 2 additions & 2 deletions model-inference/inference_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ def inference(self, text:str, deployment_platform:str):
if(deployment_platform == "jira" and self.active_jira_model):

logger.info("ENTERING JIRA INFERENCE")
predicted_labels, probabilities = self.active_jira_model.predict_class(text)
predicted_labels, probabilities = self.active_jira_model.predict_class(text, deployment_platform)

logger.info(f"PREDICTED LABELS INSIDE .inference() FUNCTION - {predicted_labels}")
logger.info(f"PROBABILITIES INSIDE .inference() FUNCTION - {probabilities}")


if(deployment_platform == "outlook" and self.active_outlook_model):
logger.info("ENTERING OUTLOOK INFERENCE")
predicted_labels, probabilities = self.active_outlook_model.predict_class(text)
predicted_labels, probabilities = self.active_outlook_model.predict_class(text, deployment_platform)


logger.info(f"PREDICTED LABELS INSIDE .inference() FUNCTION - {predicted_labels}")
Expand Down
38 changes: 34 additions & 4 deletions model-inference/model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
UPDATE_DATAMODEL_PROGRESS_URL = os.getenv("UPDATE_DATAMODEL_PROGRESS_URL")
UPDATE_MODEL_TRAINING_STATUS_ENDPOINT = os.getenv("UPDATE_MODEL_TRAINING_STATUS_ENDPOINT")
RUUTER_PRIVATE_URL = os.getenv("RUUTER_PRIVATE_URL")
GET_DATASET_METADATA_ENDPOINT=os.getenv("GET_DATASET_METADATA_ENDPOINT")

class ModelInference:
def __init__(self):
pass

def get_class_hierarchy_by_model_id(self, model_id):
def get_outlook_class_hierarchy_by_model_id(self, model_id):

try:
logger.info(f"get_class_hierarchy_by_model_id - {model_id}")
Expand Down Expand Up @@ -123,9 +124,9 @@ def validate_class_hierarchy(self, class_hierarchy, model_id):



def get_class_hierarchy_and_validate(self, model_id):
def get_outlook_class_hierarchy_and_validate(self, model_id):
try:
class_hierarchy = self.get_class_hierarchy_by_model_id(model_id)
class_hierarchy = self.get_outlook_class_hierarchy_by_model_id(model_id)
if class_hierarchy:
is_valid = self.validate_class_hierarchy(class_hierarchy, model_id)
return is_valid, class_hierarchy
Expand Down Expand Up @@ -252,4 +253,33 @@ def create_inference(self, payload):
raise RuntimeError(f"Failed to call create inference. Reason: {e}")



def get_class_hierarchy_by_dg_id(self, cookies, dg_id):
logger.info("********************************************************************")
logger.info(f"****** Calling function get_class_hierarchy_by_dg_id ******")
try:
logger.info(f"get_class_hierarchy_by_dg_id - {dg_id}")
logger.info(f"cookie : {cookies}")
cookies_updated = {"customJwtCookie":cookies}
logger.info(f"cookie_updated : {cookies_updated}")
logger.info(f"GET_DATASET_METADATA_ENDPOINT : {GET_DATASET_METADATA_ENDPOINT}")

response_hierarchy = requests.get(GET_DATASET_METADATA_ENDPOINT, params={'groupId': dg_id}, cookies=cookies_updated)

logger.info(f"response_hierarchy : {response_hierarchy}")

if response_hierarchy.status_code == 200:
logger.info("DATASET HIERARCHY RETREIVAL SUCCESSFUL")
hierarchy = response_hierarchy.json()
logger.info(f"DATASET HIERARCHY - {hierarchy}")
class_hierarchy = hierarchy['response']['data'][0]
logger.info(f"CLASS HIERARCHY - {class_hierarchy}")
return class_hierarchy

else:
logger.error(f"DATASET HIERARCHY RETRIEVAL FAILED: {response_hierarchy.status_code}")
raise RuntimeError(f"ERROR RESPONSE\n {response_hierarchy.text}")


except Exception as e:
logger.error(f"Failed to retrieve the class hierarchy Reason: {e}")
raise RuntimeError(f"Failed to retrieve the class hierarchy Reason: {e}")
8 changes: 3 additions & 5 deletions model-inference/model_inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def download_outlook_model(request: Request, model_data:UpdateRequest):
model_progress_session_id = model_data.progressSessionId

## Get class hierarchy and validate it
is_valid, class_hierarchy = model_inference.get_class_hierarchy_and_validate(model_data.modelId)
is_valid, class_hierarchy = model_inference.get_outlook_class_hierarchy_and_validate(model_data.modelId)

logger.info(f"IS VALID VALUE : {is_valid}")
logger.info(f"CLASS HIERARCHY VALUE : {class_hierarchy}")
Expand Down Expand Up @@ -230,7 +230,7 @@ async def download_jira_model(request: Request, model_data:UpdateRequest):

logger.info("JUST ABOUT TO ENTER get_class_hierarchy_by_model_id")

class_hierarchy = model_inference.get_class_hierarchy_by_model_id(model_data.modelId)
class_hierarchy = model_inference.get_class_hierarchy_by_dg_id(cookies=cookie, dg_id=model_data.dgId)

logger.info(f"JIRA UPDATE CLASS HIERARCHY - {class_hierarchy}")

Expand Down Expand Up @@ -354,7 +354,7 @@ async def download_test_model(request: Request, model_data:UpdateRequest):

logger.info("JUST ABOUT TO ENTER get_class_hierarchy_by_model_id")

class_hierarchy = model_inference.get_class_hierarchy_by_model_id(model_data.modelId)
class_hierarchy = model_inference.get_class_hierarchy_by_dg_id(cookies=cookie, dg_id=model_data.dgId)

logger.info(f"TEST UPDATE CLASS HIERARCHY - {class_hierarchy}")

Expand Down Expand Up @@ -556,12 +556,10 @@ async def outlook_inference(request:Request, inference_data:OutlookInferenceRequ
async def jira_inference(request:Request, inferenceData:JiraInferenceRequest):
try:


logger.info(f"INFERENCE DATA IN JIRA INFERENCE - {inferenceData}")

model_id = model_inference_wrapper.get_jira_model_id()


if(model_id):
# 1 . Check whether the if the Inference Exists
is_exist, inference_id = model_inference.check_inference_data_exists(input_id=inferenceData.inputId)
Expand Down
2 changes: 1 addition & 1 deletion model-inference/test_inference_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def inference(self, text: str, model_id: int):
predicted_labels = None
probabilities = None
model = self.model_dictionary[model_id]
predicted_labels, probabilities = model.predict_class(text_input=text)
predicted_labels, probabilities = model.predict_class(text_input=text, platform="test")
return predicted_labels, probabilities
else:
raise Exception(f"Model with ID {model_id} not found")
Expand Down
34 changes: 20 additions & 14 deletions model_trainer/datapipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,16 @@ def __init__(self, dg_id,cookie):

raise RuntimeError(f"ERROR RESPONSE {response.text}")

logger.info(f"###############################################################")
logger.info(f"****** Calling init function of DataPipeline Class ******")
logger.info(f"Endpoint : {GET_DATASET_METADATA_ENDPOINT}")
logger.info(f"Cookie : {cookies}")
logger.info(f"DGID : {dg_id}")

response_hierarchy = requests.get(GET_DATASET_METADATA_ENDPOINT, params={'groupId': dg_id}, cookies=cookies)

logger.info(f"response_hierarchy : {response_hierarchy}")

if response_hierarchy.status_code == 200:
logger.info("DATASET HIERARCHY RETREIVAL SUCCESSFUL")
hierarchy = response_hierarchy.json()
Expand All @@ -38,19 +45,6 @@ def __init__(self, dg_id,cookie):
logger.error(f"DATASET HIERARCHY RETRIEVAL FAILED: {response_hierarchy.status_code}")
logger.error(f"RESPONSE: {response.text}")
raise RuntimeError(f"ERROR RESPONSE\n {response_hierarchy.text}")


def find_target_column(self,df , filter_list):

value_set = set(filter_list)
columns_with_exact_or_subset_values = []

for column in df.columns:
unique_values = set(df[column].dropna().unique())
if unique_values and unique_values.issubset(value_set):
columns_with_exact_or_subset_values.append(column)

return columns_with_exact_or_subset_values

def extract_input_columns(self):

Expand Down Expand Up @@ -110,4 +104,16 @@ def create_dataframes(self):
filtered_df = filtered_df.dropna()
dfs.append(filtered_df)

return dfs
return dfs

def find_target_column(self,df , filter_list):

value_set = set(filter_list)
columns_with_exact_or_subset_values = []

for column in df.columns:
unique_values = set(df[column].dropna().unique())
if unique_values and unique_values.issubset(value_set):
columns_with_exact_or_subset_values.append(column)

return columns_with_exact_or_subset_values
5 changes: 3 additions & 2 deletions model_trainer/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def update_model_training_progress_session(self,training_status,
return session_id


def deploy_model(self, best_model_name, progress_session_id):
def deploy_model(self, best_model_name, progress_session_id, dg_id):

payload = {}
payload["modelId"] = self.new_model_id
Expand All @@ -167,6 +167,7 @@ def deploy_model(self, best_model_name, progress_session_id):
payload["bestBaseModel"] = best_model_name
payload["progressSessionId"] = progress_session_id
payload["updateType"] = self.update_type
payload["dgId"] = dg_id

if self.update_type == "retrain":
payload["replaceDeploymentPlatform"] = self.current_deployment_platform
Expand Down Expand Up @@ -411,7 +412,7 @@ def train(self):
else:

logger.info(f"INITIATING DEPLOYMENT TO {self.current_deployment_platform}")
self.deploy_model(best_model_name=best_model_name, progress_session_id=session_id)
self.deploy_model(best_model_name=best_model_name, progress_session_id=session_id,dg_id=dg_id)

except Exception as e:
self.send_error_progress_session(f"RUNTIME CRASHED - ERROR - {str(e)}")
Expand Down
Loading