Skip to content

Conversation

@sfc-gh-ajiang
Copy link
Collaborator

No description provided.

Copy link
Collaborator

@sfc-gh-dhung sfc-gh-dhung left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remember this is a public facing sample, please be sure the code quality is high. It's especially important for the code to be simple and readable, with self documenting variable/function names and sufficient comments for non-experts to understand

Comment on lines -147 to -149
# NOTE: Remove `target_instances=2` to run training on a single node
# See https://docs.snowflake.com/en/developer-guide/snowflake-ml/ml-jobs/distributed-ml-jobs
@remote(COMPUTE_POOL, stage_name=JOB_STAGE, target_instances=2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the main points of this sample is to demonstrate how easy it is to convert a local pipeline to pushing certain steps down into ML Jobs. Needing to write a separate script file which we submit_file() just for this conversion severely weakens this story. Why can't we just keep using a @remote() decorated function? @remote(...) should convert the function into an MLJobDefinition which we can directly use in pipeline_dag without needing an explicit MLJobDefinition.register() call

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is currently @remote does not create job definition and it creates a job directly. Currently, we only merged the PR for phase one and phase 2 is in review.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's hold off on merging this until @remote is ready then

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the @remote change is now available, can we now call this as an ML Job directly from pipeline_dag?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am little confused here. Do you mean we create a job inside the task directly?

Comment on lines -147 to -149
# NOTE: Remove `target_instances=2` to run training on a single node
# See https://docs.snowflake.com/en/developer-guide/snowflake-ml/ml-jobs/distributed-ml-jobs
@remote(COMPUTE_POOL, stage_name=JOB_STAGE, target_instances=2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the @remote change is now available, can we now call this as an ML Job directly from pipeline_dag?

Comment on lines 145 to 190
@remote(COMPUTE_POOL, stage_name=JOB_STAGE, database=DB_NAME, schema=SCHEMA_NAME)
def train_model(dataset_info: Optional[str] = None) -> Optional[str]:
'''
ML Job to train a model on the training dataset and register it in the model registry.

def train_model(session: Session) -> str:
"""
DAG task to train a machine learning model.
This function trains an XGBoost classifier on the provided training data and registers it in the model registry.
This function is executed remotely on Snowpark Container Services.

Args:
dataset_info (Optional[str]): JSON string containing serialized dataset information for training. If this function is called in a DAG task,
this argument is passed from the previous DAG task, otherwise it is passed manually.

Returns:
Optional[str]: JSON string containing serialized model information for registration. If this function is called in a DAG task,
this return value is passed to the next DAG task, otherwise it is as ML Job result.
'''
session = Session.builder.getOrCreate()
ctx = None
config = None

if dataset_info:
dataset_info_dicts = json.loads(dataset_info)
try:
ctx = TaskContext(session)
config = run_config.RunConfig.from_task_context(ctx)
dataset_info_dicts = json.loads(ctx.get_predecessor_return_value("PREPARE_DATA"))
except SnowparkSQLException:
print("there is no predecessor return value, fallback to local mode")

This function is executed as part of the DAG workflow to train a model using the prepared datasets.
It retrieves dataset information from the previous task, trains the model, evaluates it on both
training and test sets, and saves the model to a stage for later use.
datasets = {
key: DatasetInfo(**info_dict) for key, info_dict in dataset_info_dicts.items()
}
train_ds=load_dataset(
session,
datasets["full"].fully_qualified_name,
datasets["full"].version,
)
model_obj = modeling.train_model(session, datasets["train"])
train_metrics = modeling.evaluate_model(
session, model_obj, train_ds.read.data_sources[0], prefix="train"
)
version = f"v{uuid.uuid4().hex}"
mv = modeling.register_model(session, model_obj, config.model_name if config and config.model_name else "mortgage_model", version, train_ds, metrics={}) if config else modeling.register_model(session, model_obj, "mortgage_model", version, train_ds, metrics=train_metrics)
if ctx and config:
ctx.set_return_value(json.dumps({"model_name": mv.fully_qualified_model_name, "version_name": mv.version_name}))
return json.dumps({"model_name": mv.fully_qualified_model_name, "version_name": mv.version_name})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the gap preventing us from not needing this?

https://github.com/Snowflake-Labs/sf-samples/pull/250/files#r2695685818

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused here. I think we should use @Remote right?
create a job definition -> integrate the definition into task SDK?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just have pipeline_dag and pipeline_local use the same function with no extra wrapping?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently, pipeline_dag and pipeline_local use the same function pipeline_dag.train_model, what do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this as a separate file? Looks like it's only used in pipeline_dag currently

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking that pipeline_local.py and pipeline_day.py should focus on orchestration logic—like creating jobs or tasks. Since this class is more about handling task configuration, it might make sense to move it into a separate file for better separation of concerns.

For now, I’ve reverted the changes.


```bash
python src/pipeline_local.py
python src/pipeline_local.py --no-register # Skip model registration for faster experimentation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove?

Copy link
Collaborator Author

@sfc-gh-ajiang sfc-gh-ajiang Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is because we always register the model. But we do not push it to production.
The reason I do like this is that I got this error #250 (comment) when I save the model to a file.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like there is a different bug then, saving the model to a file should definitely work. Please be sure to fix that

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the root cause is that the memory synchronization gap between the head node and worker nodes. The script runs only on the Head Node, which acts as a coordinator that sends training instructions to the Worker Nodes, but these workers keep the resulting model weights in their own local memory once training is complete. It seems like the Snowflake XGBEstimator is designed to be "lazy", the Head Node does not automatically pull those heavy weights into its own local process. Consequently, when the script immediately tries to evaluate or serialize the model, the Head Node looks at its own empty memory and crashes because the "brain" of the model is still physically trapped on the separate worker nodes.

Do you have any recommend suggestions to fix it? I am not sure if it is a bug or not.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you able to reproduce this behavior in Notebooks? Training then inferencing immediately with the same model is a very common use case, so you should be able to quickly validate that with a multi-node Notebook

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I even cannot reproduce it in ML Job. If the ML Job only trains the model and returns the model, everything looks good.

```python
@remote(COMPUTE_POOL, stage_name=JOB_STAGE, target_instances=2)
def train_model(session: Session, input_data: DataSource) -> XGBClassifier:
def train_model(input_data: DataSource) -> Optional[str]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why we return a string in the README?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will update it.

Comment on lines +201 to +203
@remote(COMPUTE_POOL, stage_name=JOB_STAGE, target_instances=2)
def train_model(input_data: DataSource) -> Optional[str]:
...
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why repeat this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, forgot to delete it. will delete it


```python
mv = register_model(session, model, model_name, version, train_ds, metrics)
# get model version from train model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is because we always register the model. But we do not push it to production.
The reason I do like this is that I got this error #250 (comment) when I save the model to a file.

from snowflake.snowpark import Session

import modeling
import pipeline_dag
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pipeline_local should not have a dependency on pipeline_dag. Ideally the two don't know about each other, if necessary then dag can depend on local

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure — I’ll update it. One thought is to add @remote to modeling.py. However, inside the job payload we rely on the run config, which would cause pipeline_dag.py to import modeling.py and modeling.py to import pipeline_dag.py. That introduces a circular import.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I add @remote to pipeline_local.py and pipeline_local.py imports pipeline_dag.py to use the run config, then pipeline_dag.py needs to import pipeline_local.py to use the train_model function. That introduces a circular import.

What do you think about moving the run config out of pipeline_dag.py into a separate file?

Copy link
Collaborator

@sfc-gh-dhung sfc-gh-dhung Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds fine. It's also okay for both pipeline definitions to define their own @remote run_train_model functions, which just handle arguments then pass them to modeling.train_model. In this case, each @remote function should just have a few lines of code (maybe 3-4 at most); e.g. the local pipeline just accepts args and directly passes them to modeling.train_model, while the DAG pipeline reads from RunConfig before calling modeling.train_model

session = session_builder.getOrCreate()
modeling.ensure_environment(session)
pipeline_dag._ensure_environment(session)
cp.register_pickle_by_value(pipeline_dag)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

Copy link
Collaborator Author

@sfc-gh-ajiang sfc-gh-ajiang Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because it imports pipeline_dag.py to use train_model

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does that mean we need to pickle it?

}
return json.dumps(dataset_info)

@remote(COMPUTE_POOL, stage_name=JOB_STAGE, database=DB_NAME, schema=SCHEMA_NAME)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multi-node?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, will update it

config = RunConfig.from_task_context(ctx)
dataset_info_dicts = json.loads(ctx.get_predecessor_return_value("PREPARE_DATA"))
except SnowparkSQLException:
print("there is no predecessor return value, fallback to local mode")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure errors/warnings are meaningful to users who aren't already familiar with tasks and ml jobs. In this case, predecessor return value and local mode are meaningless/unknown terms

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will update it

Comment on lines +197 to +222
if dataset_info:
dataset_info_dicts = json.loads(dataset_info)
try:
ctx = TaskContext(session)
config = RunConfig.from_task_context(ctx)
dataset_info_dicts = json.loads(ctx.get_predecessor_return_value("PREPARE_DATA"))
except SnowparkSQLException:
print("there is no predecessor return value, fallback to local mode")

datasets = {
key: DatasetInfo(**info_dict) for key, info_dict in dataset_info_dicts.items()
}
train_ds=load_dataset(
session,
datasets["full"].fully_qualified_name,
datasets["full"].version,
)
model_obj = modeling.train_model(session, datasets["train"])
train_metrics = modeling.evaluate_model(
session, model_obj, train_ds.read.data_sources[0], prefix="train"
)
version = f"v{uuid.uuid4().hex}"
mv = modeling.register_model(session, model_obj, config.model_name if config and config.model_name else "mortgage_model", version, train_ds, metrics={}) if config else modeling.register_model(session, model_obj, "mortgage_model", version, train_ds, metrics=train_metrics)
if ctx and config:
ctx.set_return_value(json.dumps({"model_name": mv.fully_qualified_model_name, "version_name": mv.version_name}))
return json.dumps({"model_name": mv.fully_qualified_model_name, "version_name": mv.version_name})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comments and whitespace for readability please. Remember this is a public sample/tutorial

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants