Skip to content
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ This repo will primarily contain:
- `JIRA_WEBHOOK_SECRET` – Jira webhook secret you got in **Create Jira Webhook** step.

4. **Create a `.env` file for Jira Configuration:**
- Create a `.env` file called `jira_config.env` and add the following:
- Create a `.env` file in the folder called `jira-verification` and add the following:
```env
JIRA_WEBHOOK_SECRET=<<JIRA_WEBHOOK_SECRET>>
```
Expand Down
3 changes: 2 additions & 1 deletion docker-compose.gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ services:
- FIND_FINAL_FOLDER_ID_URL=http://hierarchy-validation:8009/find-folder-id
- UPDATE_DATAMODEL_PROGRESS_URL=http://ruuter-private:8088/classifier/datamodel/progress/update
- UPDATE_MODEL_TRAINING_STATUS_ENDPOINT=http://ruuter-private:8088/classifier/datamodel/update/training/status
- GET_DATASET_METADATA_ENDPOINT=http://ruuter-private:8088/classifier/datasetgroup/group/metadata
ports:
- "8003:8003"
networks:
Expand Down Expand Up @@ -491,7 +492,7 @@ services:
ports:
- "3008:3008"
env_file:
- jira_config.env
- ./jira-verification/.env
environment:
RUUTER_PUBLIC_JIRA_URL: http://ruuter-public:8086/internal/jira/accept
networks:
Expand Down
3 changes: 2 additions & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ services:
- FIND_FINAL_FOLDER_ID_URL=http://hierarchy-validation:8009/find-folder-id
- UPDATE_DATAMODEL_PROGRESS_URL=http://ruuter-private:8088/classifier/datamodel/progress/update
- UPDATE_MODEL_TRAINING_STATUS_ENDPOINT=http://ruuter-private:8088/classifier/datamodel/update/training/status
- GET_DATASET_METADATA_ENDPOINT=http://ruuter-private:8088/classifier/datasetgroup/group/metadata
ports:
- "8003:8003"
networks:
Expand Down Expand Up @@ -473,7 +474,7 @@ services:
ports:
- "3008:3008"
env_file:
- jira_config.env
- ./jira-verification/.env
environment:
RUUTER_PUBLIC_JIRA_URL: http://ruuter-public:8086/internal/jira/accept
networks:
Expand Down
1 change: 1 addition & 0 deletions model-inference/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class UpdateRequest(BaseModel):
bestBaseModel:str
updateType: Optional[str] = None
progressSessionId: int
dgId:Optional[int]= None

class OutlookInferenceRequest(BaseModel):
inputId:str
Expand Down
7 changes: 4 additions & 3 deletions model-inference/inference_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def find_missing_classes(self, main_classes, uploaded_classes):



def predict_class(self,text_input):
def predict_class(self,text_input, platform):

logger.info("ENTERING PREDICT CLASS")

Expand All @@ -117,10 +117,11 @@ def predict_class(self,text_input):
self.base_model.to(self.device)

logger.info(f"CLASS HIERARCHY FILE {self.hierarchy_file}")
logger.info(f"PLATFORM IN PREDICT CLASS {platform}")



data = self.hierarchy_file
if platform == 'jira':
data = data['classHierarchy']
parent = 1

logger.info(f"DATA - {data}")
Expand Down
4 changes: 2 additions & 2 deletions model-inference/inference_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ def inference(self, text:str, deployment_platform:str):
if(deployment_platform == "jira" and self.active_jira_model):

logger.info("ENTERING JIRA INFERENCE")
predicted_labels, probabilities = self.active_jira_model.predict_class(text)
predicted_labels, probabilities = self.active_jira_model.predict_class(text, deployment_platform)

logger.info(f"PREDICTED LABELS INSIDE .inference() FUNCTION - {predicted_labels}")
logger.info(f"PROBABILITIES INSIDE .inference() FUNCTION - {probabilities}")


if(deployment_platform == "outlook" and self.active_outlook_model):
logger.info("ENTERING OUTLOOK INFERENCE")
predicted_labels, probabilities = self.active_outlook_model.predict_class(text)
predicted_labels, probabilities = self.active_outlook_model.predict_class(text, deployment_platform)


logger.info(f"PREDICTED LABELS INSIDE .inference() FUNCTION - {predicted_labels}")
Expand Down
38 changes: 34 additions & 4 deletions model-inference/model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
UPDATE_DATAMODEL_PROGRESS_URL = os.getenv("UPDATE_DATAMODEL_PROGRESS_URL")
UPDATE_MODEL_TRAINING_STATUS_ENDPOINT = os.getenv("UPDATE_MODEL_TRAINING_STATUS_ENDPOINT")
RUUTER_PRIVATE_URL = os.getenv("RUUTER_PRIVATE_URL")
GET_DATASET_METADATA_ENDPOINT=os.getenv("GET_DATASET_METADATA_ENDPOINT")

class ModelInference:
def __init__(self):
pass

def get_class_hierarchy_by_model_id(self, model_id):
def get_outlook_class_hierarchy_by_model_id(self, model_id):

try:
logger.info(f"get_class_hierarchy_by_model_id - {model_id}")
Expand Down Expand Up @@ -123,9 +124,9 @@ def validate_class_hierarchy(self, class_hierarchy, model_id):



def get_class_hierarchy_and_validate(self, model_id):
def get_outlook_class_hierarchy_and_validate(self, model_id):
try:
class_hierarchy = self.get_class_hierarchy_by_model_id(model_id)
class_hierarchy = self.get_outlook_class_hierarchy_by_model_id(model_id)
if class_hierarchy:
is_valid = self.validate_class_hierarchy(class_hierarchy, model_id)
return is_valid, class_hierarchy
Expand Down Expand Up @@ -252,4 +253,33 @@ def create_inference(self, payload):
raise RuntimeError(f"Failed to call create inference. Reason: {e}")



def get_class_hierarchy_by_dg_id(self, cookies, dg_id):
logger.info("********************************************************************")
logger.info(f"****** Calling function get_class_hierarchy_by_dg_id ******")
try:
logger.info(f"get_class_hierarchy_by_dg_id - {dg_id}")
logger.info(f"cookie : {cookies}")
cookies_updated = {"customJwtCookie":cookies}
logger.info(f"cookie_updated : {cookies_updated}")
logger.info(f"GET_DATASET_METADATA_ENDPOINT : {GET_DATASET_METADATA_ENDPOINT}")

response_hierarchy = requests.get(GET_DATASET_METADATA_ENDPOINT, params={'groupId': dg_id}, cookies=cookies_updated)

logger.info(f"response_hierarchy : {response_hierarchy}")

if response_hierarchy.status_code == 200:
logger.info("DATASET HIERARCHY RETREIVAL SUCCESSFUL")
hierarchy = response_hierarchy.json()
logger.info(f"DATASET HIERARCHY - {hierarchy}")
class_hierarchy = hierarchy['response']['data'][0]
logger.info(f"CLASS HIERARCHY - {class_hierarchy}")
return class_hierarchy

else:
logger.error(f"DATASET HIERARCHY RETRIEVAL FAILED: {response_hierarchy.status_code}")
raise RuntimeError(f"ERROR RESPONSE\n {response_hierarchy.text}")


except Exception as e:
logger.error(f"Failed to retrieve the class hierarchy Reason: {e}")
raise RuntimeError(f"Failed to retrieve the class hierarchy Reason: {e}")
8 changes: 3 additions & 5 deletions model-inference/model_inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def download_outlook_model(request: Request, model_data:UpdateRequest):
model_progress_session_id = model_data.progressSessionId

## Get class hierarchy and validate it
is_valid, class_hierarchy = model_inference.get_class_hierarchy_and_validate(model_data.modelId)
is_valid, class_hierarchy = model_inference.get_outlook_class_hierarchy_and_validate(model_data.modelId)

logger.info(f"IS VALID VALUE : {is_valid}")
logger.info(f"CLASS HIERARCHY VALUE : {class_hierarchy}")
Expand Down Expand Up @@ -230,7 +230,7 @@ async def download_jira_model(request: Request, model_data:UpdateRequest):

logger.info("JUST ABOUT TO ENTER get_class_hierarchy_by_model_id")

class_hierarchy = model_inference.get_class_hierarchy_by_model_id(model_data.modelId)
class_hierarchy = model_inference.get_class_hierarchy_by_dg_id(cookies=cookie, dg_id=model_data.dgId)

logger.info(f"JIRA UPDATE CLASS HIERARCHY - {class_hierarchy}")

Expand Down Expand Up @@ -354,7 +354,7 @@ async def download_test_model(request: Request, model_data:UpdateRequest):

logger.info("JUST ABOUT TO ENTER get_class_hierarchy_by_model_id")

class_hierarchy = model_inference.get_class_hierarchy_by_model_id(model_data.modelId)
class_hierarchy = model_inference.get_class_hierarchy_by_dg_id(cookies=cookie, dg_id=model_data.dgId)

logger.info(f"TEST UPDATE CLASS HIERARCHY - {class_hierarchy}")

Expand Down Expand Up @@ -556,12 +556,10 @@ async def outlook_inference(request:Request, inference_data:OutlookInferenceRequ
async def jira_inference(request:Request, inferenceData:JiraInferenceRequest):
try:


logger.info(f"INFERENCE DATA IN JIRA INFERENCE - {inferenceData}")

model_id = model_inference_wrapper.get_jira_model_id()


if(model_id):
# 1 . Check whether the if the Inference Exists
is_exist, inference_id = model_inference.check_inference_data_exists(input_id=inferenceData.inputId)
Expand Down
2 changes: 1 addition & 1 deletion model-inference/test_inference_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def inference(self, text: str, model_id: int):
predicted_labels = None
probabilities = None
model = self.model_dictionary[model_id]
predicted_labels, probabilities = model.predict_class(text_input=text)
predicted_labels, probabilities = model.predict_class(text_input=text, platform="test")
return predicted_labels, probabilities
else:
raise Exception(f"Model with ID {model_id} not found")
Expand Down
34 changes: 20 additions & 14 deletions model_trainer/datapipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,16 @@ def __init__(self, dg_id,cookie):

raise RuntimeError(f"ERROR RESPONSE {response.text}")

logger.info(f"###############################################################")
logger.info(f"****** Calling init function of DataPipeline Class ******")
logger.info(f"Endpoint : {GET_DATASET_METADATA_ENDPOINT}")
logger.info(f"Cookie : {cookies}")
logger.info(f"DGID : {dg_id}")

response_hierarchy = requests.get(GET_DATASET_METADATA_ENDPOINT, params={'groupId': dg_id}, cookies=cookies)

logger.info(f"response_hierarchy : {response_hierarchy}")

if response_hierarchy.status_code == 200:
logger.info("DATASET HIERARCHY RETREIVAL SUCCESSFUL")
hierarchy = response_hierarchy.json()
Expand All @@ -38,19 +45,6 @@ def __init__(self, dg_id,cookie):
logger.error(f"DATASET HIERARCHY RETRIEVAL FAILED: {response_hierarchy.status_code}")
logger.error(f"RESPONSE: {response.text}")
raise RuntimeError(f"ERROR RESPONSE\n {response_hierarchy.text}")


def find_target_column(self,df , filter_list):

value_set = set(filter_list)
columns_with_exact_or_subset_values = []

for column in df.columns:
unique_values = set(df[column].dropna().unique())
if unique_values and unique_values.issubset(value_set):
columns_with_exact_or_subset_values.append(column)

return columns_with_exact_or_subset_values

def extract_input_columns(self):

Expand Down Expand Up @@ -110,4 +104,16 @@ def create_dataframes(self):
filtered_df = filtered_df.dropna()
dfs.append(filtered_df)

return dfs
return dfs

def find_target_column(self,df , filter_list):

value_set = set(filter_list)
columns_with_exact_or_subset_values = []

for column in df.columns:
unique_values = set(df[column].dropna().unique())
if unique_values and unique_values.issubset(value_set):
columns_with_exact_or_subset_values.append(column)

return columns_with_exact_or_subset_values
5 changes: 3 additions & 2 deletions model_trainer/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def update_model_training_progress_session(self,training_status,
return session_id


def deploy_model(self, best_model_name, progress_session_id):
def deploy_model(self, best_model_name, progress_session_id, dg_id):

payload = {}
payload["modelId"] = self.new_model_id
Expand All @@ -167,6 +167,7 @@ def deploy_model(self, best_model_name, progress_session_id):
payload["bestBaseModel"] = best_model_name
payload["progressSessionId"] = progress_session_id
payload["updateType"] = self.update_type
payload["dgId"] = dg_id

if self.update_type == "retrain":
payload["replaceDeploymentPlatform"] = self.current_deployment_platform
Expand Down Expand Up @@ -411,7 +412,7 @@ def train(self):
else:

logger.info(f"INITIATING DEPLOYMENT TO {self.current_deployment_platform}")
self.deploy_model(best_model_name=best_model_name, progress_session_id=session_id)
self.deploy_model(best_model_name=best_model_name, progress_session_id=session_id,dg_id=dg_id)

except Exception as e:
self.send_error_progress_session(f"RUNTIME CRASHED - ERROR - {str(e)}")
Expand Down
Loading