Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions DSL/CronManager/script/train_script_starter.sh
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,68 @@ response_update_job_status=$(curl -s -X POST "$UPDATE_JOB_STATUS" \
-d "{\"jobId\": $job_id, \"jobStatus\": \"training-in-progress\"}")
echo "[DEBUG] Update job status response: '$response_update_job_status'"

# Create training progress session
echo "[SESSION] Creating training progress session..."
CREATE_PROGRESS_SESSION_ENDPOINT="http://ruuter-public:8086/global-classifier/datamodels/progress/create"

response_create_session=$(curl -s -X POST "$CREATE_PROGRESS_SESSION_ENDPOINT" \
-H "Content-Type: application/json" \
-d "{
\"modelId\": $model_id,
\"modelName\": \"$model_name\",
\"majorVersion\": $major_version,
\"minorVersion\": $minor_version,
\"latest\": $latest
}")

echo "[DEBUG] Create session response: '$response_create_session'"

# Extract session ID from response
if [ -z "$response_create_session" ]; then
echo "[ERROR] Failed to create training progress session - empty response"
exit 1
fi

# Check if session creation was successful
if echo "$response_create_session" | grep -q '"operationSuccessful":true'; then
session_id=$(echo "$response_create_session" | sed -E 's/.*"sessionId":"?([0-9]+)"?.*/\1/')

if [ -z "$session_id" ] || [ "$session_id" = "$response_create_session" ]; then
echo "[ERROR] Failed to extract session ID from response"
echo "[DEBUG] Raw response: '$response_create_session'"
exit 1
fi

echo "[SESSION] Training progress session created successfully with ID: $session_id"
else
echo "[ERROR] Training progress session creation failed"
echo "[DEBUG] Raw response: '$response_create_session'"
exit 1
fi

# Update initial training progress
echo "[PROGRESS] Updating initial training progress..."
UPDATE_PROGRESS_SESSION_ENDPOINT="http://ruuter-public:8086/global-classifier/datamodels/progress/update"

response_update_progress=$(curl -s -X POST "$UPDATE_PROGRESS_SESSION_ENDPOINT" \
-H "Content-Type: application/json" \
-d "{
\"sessionId\": $session_id,
\"trainingStatus\": \"Initiating Training\",
\"trainingMessage\": \"Download and preparing dataset\",
\"progressPercentage\": 20,
\"processComplete\": false
}")

echo "[DEBUG] Update progress response: '$response_update_progress'"

# Check if progress update was successful
if [ -z "$response_update_progress" ]; then
echo "[WARNING] Failed to update initial training progress - empty response"
else
echo "[PROGRESS] Initial training progress updated successfully"
fi

# Get dataset ID
response_get_dataset_id=$(curl -s -X POST "$GET_DATA_MODEL_BY_MODEL_ID_SQL" \
-H "Content-Type: application/json" \
Expand Down Expand Up @@ -242,6 +304,7 @@ python3 "$TRAINING_SCRIPT" \
--minor_version "$minor_version" \
--latest "$latest" \
--deployment_environment "$deployment_environment" \
--session_id "$session_id" \

training_exit_code=$?

Expand All @@ -265,6 +328,33 @@ else
-H "Content-Type: application/json" \
-d "{\"jobId\": $job_id, \"jobStatus\": \"training-failed\"}")

echo "[MODEL] Updating model training status to failed..."
UPDATE_MODEL_TRAINING_STATUS_FAILED="http://resql:8082/global-classifier/update-training_status-failed"
response_update_model_status=$(curl -s -X POST "$UPDATE_MODEL_TRAINING_STATUS_FAILED" \
-H "Content-Type: application/json" \
-d "{\"model_id\": $model_id}")

echo "[DEBUG] Update model training status response: '$response_update_model_status'"

echo "[PROGRESS] Updating progress session to show training failure..."
response_update_progress_failure=$(curl -s -X POST "$UPDATE_PROGRESS_SESSION_ENDPOINT" \
-H "Content-Type: application/json" \
-d "{
\"sessionId\": $session_id,
\"trainingStatus\": \"Training Failed\",
\"trainingMessage\": \"Model training has failed\",
\"progressPercentage\": 100,
\"processComplete\": false
}")

echo "[DEBUG] Update progress failure response: '$response_update_progress_failure'"

if [ -z "$response_update_progress_failure" ]; then
echo "[WARNING] Failed to update progress session with failure status"
else
echo "[PROGRESS] Progress session updated with failure status successfully"
fi

exit 1
fi

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
UPDATE public.data_models
SET
training_status = 'training_failed',
updated_timestamp = NOW()
WHERE model_id = :model_id;
6 changes: 4 additions & 2 deletions src/model-training/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@
TRAINING_FAILED_STATUS_MESSAGE = "Model training has failed"


INITIATING_TRAINING_PROGRESS_PERCENTAGE = 30
INITIATING_TRAINING_PROGRESS_PERCENTAGE = 20

TRAINING_IN_PROGRESS_PROGRESS_PERCENTAGE = 50
TRAINING_IN_PROGRESS_PROGRESS_PERCENTAGE = 30

TRAINING_IN_PROGRESS_PROGRESS_PERCENTAGE_AFTER_DATA_PREPARATION = 40

DEPLOYING_MODEL_PROGRESS_PERCENTAGE = 80

Expand Down
90 changes: 16 additions & 74 deletions src/model-training/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,20 @@
SEQUENCE_LENGTH,
MODEL_TRAINING_SOURCE_PATH,
DEPLOYMENT_ENDPOINT,
CREATE_TRAINING_PROGRESS_SESSION_ENDPOINT,
UPDATE_TRAINING_PROGRESS_SESSION_ENDPOINT,
UPDATE_MODEL_TRAINING_STATUS_ENDPOINT,
INITIATING_TRAINING_PROGRESS_STATUS,
TRAINING_IN_PROGRESS_PROGRESS_STATUS,
DEPLOYING_MODEL_PROGRESS_STATUS,
MODEL_TRAINED_AND_DEPLOYED_PROGRESS_STATUS,
TRAINING_FAILED_STATUS,
INITIATING_TRAINING_PROGRESS_PERCENTAGE,
TRAINING_IN_PROGRESS_PROGRESS_PERCENTAGE,
MODEL_TRAINED_AND_DEPLOYED_PROGRESS_PERCENTAGE,
INITIATING_TRAINING_PROGRESS_MESSAGE,
TRAINING_IN_PROGRESS_PROGRESS_MESSAGE,
DEPLOYING_MODEL_PROGRESS_MESSAGE,
MODEL_TRAINED_AND_DEPLOYED_PROGRESS_MESSAGE,
TRAINING_FAILED_STATUS_MESSAGE,
TRAINING_FAILED_PROGRESS_PERCENTAGE,
TRAINING_IN_PROGRESS_PROGRESS_PERCENTAGE_AFTER_DATA_PREPARATION,
)

import requests
Expand Down Expand Up @@ -69,7 +66,7 @@ def __init__(
self.current_deployment_platform = current_deployment_env
self.target_deployment_platform = target_deployment_platform

self.progress_session_id = ""
self.progress_session_id = progress_session_id

except Exception as e:
logger.error(f"EXCEPTION IN MODEL_TRAINER INIT : {e}")
Expand All @@ -90,68 +87,6 @@ def get_current_timestamp(self):
current_timestamp = int(datetime.now(timezone.utc).timestamp())
return current_timestamp

def create_training_progress_session(self):
"""
Create a training progress session in the database.
This function should be implemented to create a training progress session in the database.
"""
logger.info("Creating training progress session")

payload = {
"modelId": int(self.model_id),
"modelName": self.model_name,
"majorVersion": self.major_version,
"minorVersion": self.minor_version,
"latest": self.latest,
}

logger.info(f"Prepared training progress session payload {payload}")

try:
# Make request to create training progress session endpoint
response = requests.post(
url=CREATE_TRAINING_PROGRESS_SESSION_ENDPOINT,
json=payload,
headers={"Content-Type": "application/json"},
timeout=300, # 5 minute timeout for creating progress session
)

logger.info(
f"Create training progress session response - {response.status_code} - {response.text}"
)

# Check if request was successful

logger.info("Training progress session created successfully")

session_data = response.json()
session_id = session_data["response"]["sessionId"]

self.progress_session_id = session_id

return response.json()

except requests.HTTPError as e:
error_msg = f"HTTP error during creating training progress session: {e.response.status_code} - {e.response.text}"
logger.error(
error_msg, model_id=self.model_id, status_code=e.response.status_code
)
raise

except requests.RequestException as e:
error_msg = (
f"Network error during creating training progress session: {str(e)}"
)
logger.error(error_msg, model_id=self.model_id)
raise

except Exception as e:
error_msg = (
f"Unexpected error during creating training progress session: {str(e)}"
)
logger.error(error_msg, model_id=self.model_id)
raise

def update_training_progression_session(
self,
training_status: str,
Expand All @@ -175,7 +110,7 @@ def update_training_progression_session(

else:
payload = {
"sessionId": self.progress_session_id,
"sessionId": int(self.progress_session_id),
"trainingStatus": training_status,
"trainingMessage": training_message,
"progressPercentage": progress_percentage,
Expand Down Expand Up @@ -318,10 +253,12 @@ def train(self):
logger.info("ENTERING UNIFIED TRAINING FUNCTION")
logger.info(f"DEPLOYMENT PLATFORM - {self.current_deployment_platform}")

# Initial progress is now handled in bash script
# Start with data preparation progress update
trainer.update_training_progression_session(
training_status=INITIATING_TRAINING_PROGRESS_STATUS,
training_message=INITIATING_TRAINING_PROGRESS_MESSAGE,
progress_percentage=INITIATING_TRAINING_PROGRESS_PERCENTAGE,
training_status=TRAINING_IN_PROGRESS_PROGRESS_STATUS,
training_message=TRAINING_IN_PROGRESS_PROGRESS_MESSAGE,
progress_percentage=TRAINING_IN_PROGRESS_PROGRESS_PERCENTAGE,
process_complete=False,
)

Expand All @@ -343,7 +280,7 @@ def train(self):
trainer.update_training_progression_session(
training_status=TRAINING_IN_PROGRESS_PROGRESS_STATUS,
training_message=TRAINING_IN_PROGRESS_PROGRESS_MESSAGE,
progress_percentage=TRAINING_IN_PROGRESS_PROGRESS_PERCENTAGE,
progress_percentage=TRAINING_IN_PROGRESS_PROGRESS_PERCENTAGE_AFTER_DATA_PREPARATION,
process_complete=False,
)

Expand Down Expand Up @@ -735,6 +672,12 @@ def parse_args():
required=True,
help="Deployment Environment",
)
parser.add_argument(
"--session_id",
type=str,
required=True,
help="Training Progress Session ID",
)
return parser.parse_args()


Expand All @@ -752,7 +695,7 @@ def parse_args():
minor_version = args.minor_version
latest = args.latest.lower() == "true"
current_deployment_env = "undeployed"
progress_session_id = args.job_id
progress_session_id = args.session_id
target_deployment_platform = args.deployment_environment

trainer = ModelTrainer(
Expand All @@ -768,6 +711,5 @@ def parse_args():
target_deployment_platform=target_deployment_platform,
)

trainer.create_training_progress_session()
trainer.train()
trainer.deploy()
Loading