Skip to content

Commit dc7a217

Browse files
[PRO-16170] Give ability to set keys, create submission with cluster report, and … (#99)
* Give ability to set keys, create submission with cluster report, and expose cluster events --------- Co-authored-by: Brandon Kaplan <Bkaplan31@gmail.com>
1 parent 102cd8a commit dc7a217

File tree

7 files changed

+219
-19
lines changed

7 files changed

+219
-19
lines changed

sync/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
"""Library for leveraging the power of Sync"""
2-
__version__ = "1.0.3"
2+
__version__ = "1.1.0"
33

44
TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"

sync/_databricks.py

Lines changed: 86 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,14 @@
1818
from sync.api import projects
1919
from sync.clients.databricks import get_default_client
2020
from sync.config import CONFIG # noqa F401
21-
from sync.models import DatabricksAPIError, DatabricksClusterReport, DatabricksError, Response
21+
from sync.models import (
22+
DatabricksAPIError,
23+
DatabricksClusterReport,
24+
DatabricksError,
25+
DatabricksComputeType,
26+
DatabricksPlanType,
27+
Response
28+
)
2229
from sync.utils.dbfs import format_dbfs_filepath, read_dbfs_file
2330

2431
logger = logging.getLogger(__name__)
@@ -56,6 +63,58 @@ def get_cluster(cluster_id: str) -> Response[dict]:
5663
return Response(result=cluster)
5764

5865

66+
def create_submission_with_cluster_info(
67+
run_id: str,
68+
project_id: str,
69+
cluster: Dict,
70+
cluster_info: Dict,
71+
cluster_activity_events: Dict,
72+
plan_type: DatabricksPlanType,
73+
compute_type: DatabricksComputeType,
74+
) -> Response[str]:
75+
"""Create a Submission for the specified Databricks run given a cluster report"""
76+
77+
run = get_default_client().get_run(run_id)
78+
79+
if "error_code" in run:
80+
return Response(error=DatabricksAPIError(**run))
81+
82+
project_response = projects.get_project(project_id)
83+
if project_response.error:
84+
return project_response
85+
cluster_path = project_response.result.get("cluster_path")
86+
87+
project_cluster_tasks = _get_project_cluster_tasks(run, project_id, cluster_path)
88+
89+
cluster_tasks = project_cluster_tasks.get(project_id)
90+
if not cluster_tasks:
91+
return Response(
92+
error=DatabricksError(
93+
message=f"Failed to locate cluster in run {run_id} for project {project_id}"
94+
)
95+
)
96+
97+
_, tasks = cluster_tasks
98+
99+
cluster_report = _create_cluster_report(
100+
cluster=cluster,
101+
cluster_info=cluster_info,
102+
cluster_activity_events=cluster_activity_events,
103+
tasks=tasks,
104+
plan_type=plan_type,
105+
compute_type=compute_type
106+
)
107+
eventlog = _get_event_log_from_cluster(cluster, tasks).result
108+
109+
return projects.create_project_submission_with_eventlog_bytes(
110+
get_default_client().get_platform(),
111+
cluster_report.dict(exclude_none=True),
112+
"eventlog.zip",
113+
eventlog,
114+
project_id,
115+
)
116+
117+
59118
def create_submission_for_run(
60119
run_id: str,
61120
plan_type: str,
@@ -160,19 +219,27 @@ def _get_run_information(
160219
cluster_report = cluster_report_response.result
161220
if cluster_report:
162221
cluster = cluster_report.cluster
163-
spark_context_id = _get_run_spark_context_id(tasks)
164-
end_time = max(task["end_time"] for task in tasks)
165-
eventlog_response = _get_eventlog(cluster, spark_context_id.result, end_time)
166-
222+
eventlog_response = _get_event_log_from_cluster(cluster, tasks)
167223
eventlog = eventlog_response.result
168224
if eventlog:
169-
# TODO - allow submissions w/out eventlog. Best way to make eventlog optional?..
170225
return Response(result=(cluster_report, eventlog))
171226

172-
return eventlog_response
173227
return cluster_report_response
174228

175229

230+
def _get_event_log_from_cluster(cluster: Dict, tasks: List[Dict]) -> Response[bytes]:
231+
spark_context_id = _get_run_spark_context_id(tasks)
232+
end_time = max(task["end_time"] for task in tasks)
233+
eventlog_response = _get_eventlog(cluster, spark_context_id.result, end_time)
234+
235+
eventlog = eventlog_response.result
236+
if eventlog:
237+
# TODO - allow submissions w/out eventlog. Best way to make eventlog optional?..
238+
return Response(result=eventlog)
239+
240+
return eventlog_response # return eventlog response with errors
241+
242+
176243
def get_cluster_report(
177244
run_id: str,
178245
plan_type: str,
@@ -240,6 +307,17 @@ def _get_cluster_report(
240307
raise NotImplementedError()
241308

242309

310+
def _create_cluster_report(
311+
cluster: dict,
312+
cluster_info: dict,
313+
cluster_activity_events: dict,
314+
tasks: List[dict],
315+
plan_type: DatabricksPlanType,
316+
compute_type: DatabricksComputeType
317+
) -> DatabricksClusterReport:
318+
raise NotImplementedError()
319+
320+
243321
def _get_cluster_instances_from_dbfs(filepath: str):
244322
filepath = format_dbfs_filepath(filepath)
245323
dbx_client = get_default_client()
@@ -1493,7 +1571,7 @@ def _get_eventlog(
14931571
return Response(error=DatabricksError(message=f"Unknown log destination: {filesystem}"))
14941572

14951573

1496-
def _get_all_cluster_events(cluster_id: str):
1574+
def get_all_cluster_events(cluster_id: str):
14971575
"""Fetches all ClusterEvents for a given Databricks cluster, optionally within a time window.
14981576
Pages will be followed and returned as 1 object
14991577
"""

sync/awsdatabricks.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import sync._databricks
1313
from sync._databricks import (
1414
_cluster_log_destination,
15-
_get_all_cluster_events,
15+
get_all_cluster_events,
1616
_get_cluster_instances_from_dbfs,
1717
_update_monitored_timelines,
1818
_wait_for_cluster_termination,
@@ -22,6 +22,7 @@
2222
create_cluster,
2323
create_run,
2424
create_submission_for_run,
25+
create_submission_with_cluster_info,
2526
get_cluster,
2627
get_cluster_report,
2728
get_project_cluster,
@@ -48,6 +49,8 @@
4849
AccessStatusCode,
4950
AWSDatabricksClusterReport,
5051
DatabricksError,
52+
DatabricksPlanType,
53+
DatabricksComputeType,
5154
Response,
5255
)
5356
from sync.utils.dbfs import format_dbfs_filepath, write_dbfs_file
@@ -57,7 +60,9 @@
5760
"get_access_report",
5861
"run_and_record_job",
5962
"create_submission_for_run",
63+
"create_submission_with_cluster_info",
6064
"get_cluster_report",
65+
"get_all_cluster_events",
6166
"monitor_cluster",
6267
"create_cluster",
6368
"get_cluster",
@@ -217,7 +222,7 @@ def _get_cluster_report(
217222
else:
218223
timelines = timeline_response.result
219224

220-
cluster_events = _get_all_cluster_events(cluster_id)
225+
cluster_events = get_all_cluster_events(cluster_id)
221226
return Response(
222227
result=AWSDatabricksClusterReport(
223228
plan_type=plan_type,
@@ -232,12 +237,33 @@ def _get_cluster_report(
232237
)
233238

234239

240+
def _create_cluster_report(
241+
cluster: dict,
242+
cluster_info: dict,
243+
cluster_activity_events: dict,
244+
tasks: List[dict],
245+
plan_type: DatabricksPlanType,
246+
compute_type: DatabricksComputeType,
247+
) -> AWSDatabricksClusterReport:
248+
return AWSDatabricksClusterReport(
249+
plan_type=plan_type,
250+
compute_type=compute_type,
251+
cluster=cluster,
252+
cluster_events=cluster_activity_events,
253+
tasks=tasks,
254+
instances=cluster_info.get("instances"),
255+
volumes=cluster_info.get("volumes"),
256+
instance_timelines=cluster_info.get("instance_timelines"),
257+
)
258+
259+
235260
if getattr(sync._databricks, "__claim", __name__) != __name__:
236261
raise RuntimeError(
237262
"Databricks modules for different cloud providers cannot be used in the same context"
238263
)
239264

240265
sync._databricks._get_cluster_report = _get_cluster_report
266+
sync._databricks._create_cluster_report = _create_cluster_report
241267
setattr(sync._databricks, "__claim", __name__)
242268

243269

@@ -328,6 +354,7 @@ def monitor_cluster(
328354
cluster_id: str,
329355
polling_period: int = 20,
330356
cluster_report_destination_override: dict = None,
357+
kill_on_termination: bool = False,
331358
) -> None:
332359
cluster = get_default_client().get_cluster(cluster_id)
333360
spark_context_id = cluster.get("spark_context_id")
@@ -350,6 +377,7 @@ def monitor_cluster(
350377
cluster_id,
351378
spark_context_id,
352379
polling_period,
380+
kill_on_termination,
353381
)
354382
else:
355383
logger.warning("Unable to monitor cluster due to missing cluster log destination - exiting")
@@ -360,6 +388,7 @@ def _monitor_cluster(
360388
cluster_id: str,
361389
spark_context_id: int,
362390
polling_period: int,
391+
kill_on_termination: bool = False,
363392
) -> None:
364393

365394
(log_url, filesystem, bucket, base_prefix) = cluster_log_destination
@@ -377,14 +406,16 @@ def _monitor_cluster(
377406
active_timelines_by_id = {}
378407
retired_timelines = []
379408
recorded_volumes_by_id = {}
380-
while True:
409+
410+
while_condition = True
411+
while while_condition:
381412
try:
382413
current_insts = _get_ec2_instances(cluster_id, ec2)
383414
recorded_volumes_by_id.update(
384415
{v["VolumeId"]: v for v in _get_ebs_volumes_for_instances(current_insts, ec2)}
385416
)
386417

387-
# Record new (or overrwite) existing instances.
418+
# Record new (or overwrite) existing instances.
388419
# Separately record the ids of those that are in the "running" state.
389420
running_inst_ids = set({})
390421
for inst in current_insts:
@@ -412,6 +443,11 @@ def _monitor_cluster(
412443
"utf-8",
413444
)
414445
)
446+
447+
if kill_on_termination:
448+
cluster_state = get_default_client().get_cluster(cluster_id).get("state")
449+
if cluster_state == "TERMINATED":
450+
while_condition = False
415451
except Exception as e:
416452
logger.error(f"Exception encountered while polling cluster: {e}")
417453

sync/azuredatabricks.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import sync._databricks
1616
from sync._databricks import (
1717
_cluster_log_destination,
18-
_get_all_cluster_events,
18+
get_all_cluster_events,
1919
_get_cluster_instances_from_dbfs,
2020
_update_monitored_timelines,
2121
_wait_for_cluster_termination,
@@ -25,6 +25,7 @@
2525
create_cluster,
2626
create_run,
2727
create_submission_for_run,
28+
create_submission_with_cluster_info,
2829
get_cluster,
2930
get_cluster_report,
3031
get_project_cluster,
@@ -50,6 +51,8 @@
5051
AccessStatusCode,
5152
AzureDatabricksClusterReport,
5253
DatabricksError,
54+
DatabricksPlanType,
55+
DatabricksComputeType,
5356
Response,
5457
)
5558
from sync.utils.dbfs import format_dbfs_filepath, write_dbfs_file
@@ -62,7 +65,9 @@
6265
"create_cluster",
6366
"get_cluster",
6467
"create_submission_for_run",
68+
"create_submission_with_cluster_info",
6569
"get_cluster_report",
70+
"get_all_cluster_events",
6671
"handle_successful_job_run",
6772
"record_run",
6873
"get_project_cluster",
@@ -209,7 +214,7 @@ def _get_cluster_report(
209214
else:
210215
return instances
211216

212-
cluster_events = _get_all_cluster_events(cluster_id)
217+
cluster_events = get_all_cluster_events(cluster_id)
213218
return Response(
214219
result=AzureDatabricksClusterReport(
215220
plan_type=plan_type,
@@ -223,6 +228,25 @@ def _get_cluster_report(
223228
)
224229

225230

231+
def _create_cluster_report(
232+
cluster: dict,
233+
cluster_info: dict,
234+
cluster_activity_events: dict,
235+
tasks: List[dict],
236+
plan_type: DatabricksPlanType,
237+
compute_type: DatabricksComputeType
238+
) -> AzureDatabricksClusterReport:
239+
return AzureDatabricksClusterReport(
240+
plan_type=plan_type,
241+
compute_type=compute_type,
242+
cluster=cluster,
243+
cluster_events=cluster_activity_events,
244+
tasks=tasks,
245+
instances=cluster_info.get("instances"),
246+
instance_timelines=cluster_info.get("timelines")
247+
)
248+
249+
226250
if getattr(sync._databricks, "__claim", __name__) != __name__:
227251
# Unless building documentation you can't load both databricks modules in the same program
228252
if not sys.argv[0].endswith("sphinx-build"):
@@ -231,6 +255,7 @@ def _get_cluster_report(
231255
)
232256

233257
sync._databricks._get_cluster_report = _get_cluster_report
258+
sync._databricks._create_cluster_report = _create_cluster_report
234259
setattr(sync._databricks, "__claim", __name__)
235260

236261

sync/config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,14 @@ def get_api_key() -> APIKey:
116116
return _api_key
117117

118118

119+
def set_api_key(api_key: APIKey):
120+
global _api_key
121+
if _api_key is not None:
122+
raise RuntimeError("Sync API key/secret has already been set and the library does not support resetting "
123+
"credentials")
124+
_api_key = api_key
125+
126+
119127
def get_config() -> Configuration:
120128
"""Gets configuration
121129
@@ -138,6 +146,14 @@ def get_databricks_config() -> DatabricksConf:
138146
return _db_config
139147

140148

149+
def set_databricks_config(db_config: DatabricksConf):
150+
global _db_config
151+
if _db_config is not None:
152+
raise RuntimeError("Databricks config has already been set and the library does not support resetting "
153+
"credentials")
154+
_db_config = db_config
155+
156+
141157
CONFIG: Configuration
142158
_config = None
143159
API_KEY: APIKey

0 commit comments

Comments
 (0)