Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions RootInteractive/InteractiveDrawing/bokeh/bokehTools.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def bokehDrawArray(dataFrame, query, figureArray, histogramArray=[], parameterAr
evaluator = ColumnEvaluator(context, cdsDict, paramDict, jsFunctionDict, i["expr"], aliasDict)
result = evaluator.visit(exprTree.body)
if result["type"] == "javascript":
func = "return "+result["implementation"]
func = evaluator.make_vfunc(result["implementation"])
fields = list(evaluator.aliasDependencies.values())
parameters = [i for i in evaluator.paramDependencies if "options" not in paramDict[i]]
variablesParam = [i for i in evaluator.paramDependencies if "options" in paramDict[i]]
Expand All @@ -499,7 +499,7 @@ def bokehDrawArray(dataFrame, query, figureArray, histogramArray=[], parameterAr
variablesAlias.append(paramDict[j]["value"])
fields.append(j)
nvars_local = nvars_local+1
transform = CustomJSNAryFunction(parameters=customJsArgList, fields=fields.copy(), func=func)
transform = CustomJSNAryFunction(parameters=customJsArgList, fields=fields.copy(), v_func=func)
fields = variablesAlias
else:
aliasDict[i["name"]] = result["name"]
Expand Down Expand Up @@ -2245,12 +2245,12 @@ def make_transform(transform, paramDict, aliasDict, cdsDict, jsFunctionDict, par
if transform_parsed["type"] == "js_lambda":
if transform_parsed["n_args"] == 1:
transform_customjs = CustomJSTransform(args=transform_parameters, v_func=f"""
return xs.map({transform_parsed["implementation"]});
return xs.map({evaluator.make_scalar_func(transform_parsed["implementation"])});
""")
elif transform_parsed["n_args"] == 2:
transform_customjs = CustomJSTransform(args=transform_parameters, v_func=f"""
const ys = data_source.get_column(varY);
return xs.map((x, i) => ({transform_parsed["implementation"]}{"(x, ys[i])" if orientation==1 else "(ys[i],x)"}));
return xs.map((x, i) => ({evaluator.make_scalar_func(transform_parsed["implementation"])}{"(x, ys[i])" if orientation==1 else "(ys[i],x)"}));
""")
elif transform_parsed["type"] == "parameter":
if "options" not in paramDict[transform_parsed["name"]]:
Expand Down
80 changes: 59 additions & 21 deletions RootInteractive/InteractiveDrawing/bokeh/compileVarName.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@
"arctanh": np.arctanh
}

resize_out_boilerplate = """
const len = data_source.get_length();
if($output == null || $output.length !== len){
$output = new Array(len).fill(0, len);
}
"""

class ColumnEvaluator:
# This class walks the Python abstract syntax tree of the expressions to detect its dependencies
def __init__(self, context, cdsDict, paramDict, funcDict, code, aliasDict, firstGeneratedID=0):
Expand All @@ -74,6 +81,7 @@ def __init__(self, context, cdsDict, paramDict, funcDict, code, aliasDict, first
self.funcDict = funcDict
self.context = context
self.dependencies = set()
self.dependencies_table = set()
self.paramDependencies = set()
self.aliasDependencies = {}
self.firstGeneratedID = firstGeneratedID
Expand Down Expand Up @@ -148,7 +156,19 @@ def visit_Num(self, node: ast.Constant):
def visit_Subscript(self, node: ast.Subscript):
value = self.visit(node.value)
sliceValue = self.visit(node.slice)
return {}
if sliceValue["type"] != "constant":
raise NotImplementedError("Only constant subscripts are supported on the client")
if value["type"] == "parameter":
self.isSource = False
param_name = value["name"]
index = sliceValue["value"]
self.paramDependencies.add(param_name)
return {
"name": self.code,
"implementation": f"{param_name}[{index}]",
"type": "parameter_element"
}
raise NotImplementedError("Subscripted expressions are only supported for parameters on the client")

def visit_Attribute(self, node: ast.Attribute):
if self.context in self.aliasDict and node.attr in self.aliasDict[self.context]:
Expand Down Expand Up @@ -213,18 +233,20 @@ def visit_Attribute(self, node: ast.Attribute):
self.helper_idx += 1
return {
"name": node.attr,
"implementation": self.aliasDependencies[attrChainStr],
"implementation": self.aliasDependencies[attrChainStr]+"{index_suffix}",
"type": "column"
}
if not isinstance(node.value, ast.Name):
raise ValueError("Column data source name cannot be a function call")
if node.value.id != "self":
if self.context is not None:
if node.value.id != self.context:
raise ValueError("Incompatible data sources: " + node.value.id + "." + node.attr + ", " + self.context)
if node.value.id not in self.cdsDict:
raise KeyError("Data source not found: " + node.value.id)
self.context = node.value.id
if self.context is not None:
if node.value.id != self.context:
self.dependencies_table.add((node.value.id, node.attr))
# raise ValueError("Incompatible data sources: " + node.value.id + "." + node.attr + ", " + self.context)
else:
self.context = node.value.id
if self.context in self.cdsDict and self.cdsDict[self.context]["type"] == "stack":
self.isSource = False
if node.attr != "$source_index":
Expand All @@ -249,7 +271,7 @@ def visit_Attribute(self, node: ast.Attribute):
is_boolean = False
return {
"name": node.attr,
"implementation": node.attr,
"implementation": node.attr+"{index_suffix}",
"type": "column",
"is_boolean": is_boolean
}
Expand Down Expand Up @@ -295,7 +317,7 @@ def visit_Name(self, node: ast.Name):
# Detect if parameter is a lambda here?
return {
"name": node.id,
"implementation": node.id,
"implementation": node.id if "options" not in self.paramDict[node.id] else node.id+"{index_suffix}",
"type": "paramTensor" if isinstance(self.paramDict[node.id]["value"], list) else "parameter"
}
if node.id in [self.context, "self"]:
Expand All @@ -312,7 +334,7 @@ def visit_Name(self, node: ast.Name):
"attrChain": [node.id]
}
attrNode = ast.Attribute(value=ast.Name(id="self", ctx=ast.Load()), attr=node.id)
return self.visit(attrNode)
return self.visit_Attribute(attrNode)

def visit_Name_histogram(self, id: str):
self.isSource = False
Expand All @@ -339,14 +361,10 @@ def visit_Name_histogram(self, id: str):
isOK = True
if not isOK:
raise KeyError("Column " + id + " not found in histogram " + histogram["name"])
#return {
# "error": KeyError,
# "msg": "Column " + id + " not found in histogram " + histogram["name"]
#}
self.aliasDependencies[id] = id
return {
"name": id,
"implementation": id,
"implementation": id+"{index_suffix}",
"type": "column"
}

Expand Down Expand Up @@ -387,8 +405,6 @@ def visit_BinOp(self, node):
operator_infix = " << "
elif isinstance(op, ast.RShift):
operator_infix = " >> "
elif isinstance(op, ast.RShift):
operator_infix = " >> "
elif isinstance(op, ast.BitOr):
operator_infix = " | "
elif isinstance(op, ast.BitXor):
Expand Down Expand Up @@ -496,12 +512,30 @@ def visit_Lambda(self, node:ast.Lambda):
"n_args": len(args),
"implementation": f"(({impl_args})=>({impl_body}))"
}


def make_vfunc(self, body: str):
a = resize_out_boilerplate
a += "\n".join([f"const {i[1]} = {i[0]}.getColumn({i[1]});" for i in self.dependencies_table])
a += f"""
for(let $i=0; $i<$output.length; $i++){{
$output[$i] = {body.replace("{index_suffix}", "[$i]")}
}}
return $output;
"""
return a

def make_scalar_func(self, body: str):
a = "\n".join([f"const {i[1]} = {i[0]}.getColumn({i[1]});" for i in self.dependencies_table])
a += f"""
return {body.replace("{index_suffix}", "")};
"""
return a

def checkColumn(columnKey, tableKey, cdsDict):
return False

def getOrMakeColumns(variableNames, context = None, cdsDict: dict = {}, paramDict: dict = {}, funcDict: dict = {},
memoizedColumns: dict = None, aliasDict: dict = None, forbiddenColumns: set = set()):
memoizedColumns: dict = None, aliasDict: dict = None, forbiddenColumns: set = set(), vfunc_feature_flag: bool = True):
if variableNames is None or len(variableNames) == 0:
return variableNames, context, memoizedColumns, set()
if not isinstance(variableNames, list):
Expand Down Expand Up @@ -550,7 +584,6 @@ def getOrMakeColumns(variableNames, context = None, cdsDict: dict = {}, paramDic
if i_context not in aliasDict:
aliasDict[i_context] = {}
columnName = column["name"]
func = "return "+column["implementation"]
variablesAlias = list(evaluator.aliasDependencies.keys())
fieldsAlias = list(evaluator.aliasDependencies.values())
parameters = {i:paramDict[i]["value"] for i in evaluator.paramDependencies if "options" not in paramDict[i]}
Expand All @@ -570,7 +603,12 @@ def getOrMakeColumns(variableNames, context = None, cdsDict: dict = {}, paramDic
if is_customjs_func:
transform = funcDict[queryAST.body.func.id]
else:
transform = CustomJSNAryFunction(parameters=parameters, fields=fieldsAlias, func=func)
if vfunc_feature_flag:
func = evaluator.make_vfunc(column["implementation"])
transform = CustomJSNAryFunction(parameters=parameters, fields=fieldsAlias, v_func=func)
else:
func = evaluator.make_scalar_func(column["implementation"])
transform = CustomJSNAryFunction(parameters=parameters, fields=fieldsAlias, func=func)
for j in parameters:
if "subscribed_events" not in paramDict[j]:
paramDict[j]["subscribed_events"] = []
Expand All @@ -593,7 +631,7 @@ def getOrMakeColumns(variableNames, context = None, cdsDict: dict = {}, paramDic
dependency_columns = [i[1] for i in direct_dependencies]
dependency_tables = [i[0] for i in direct_dependencies]
_, _, memoizedColumns, sources_local = getOrMakeColumns(dependency_columns, dependency_tables, cdsDict, paramDict, funcDict,
memoizedColumns, aliasDict, forbiddenColumns | {(i_context, i_var)})
memoizedColumns, aliasDict, forbiddenColumns | {(i_context, i_var)}, vfunc_feature_flag)
used_names.update(sources_local)
if i_context in memoizedColumns:
memoizedColumns[i_context][i_var] = column
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pandas as pd
import base64
import timeit
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression, Ridge
Expand Down Expand Up @@ -76,6 +77,10 @@
s = onx_ridge.SerializeToString()
onx_ridge_b64 = base64.b64encode(s).decode('utf-8')

print(timeit.timeit("rfr.predict(df2[['A', 'B', 'C', 'D']])", globals=globals(), number=1))
print(timeit.timeit("rfr50.predict(df2[['A', 'B', 'C', 'D']])", globals=globals(), number=1))
print(timeit.timeit("ridgeReg.predict(df2[['A', 'B', 'C', 'D']])", globals=globals(), number=1))


def test_onnx_bokehDrawArray():
output_file("test_ort_web_bokehDrawSA.html", "Test ONNX runtime web")
Expand Down Expand Up @@ -136,7 +141,11 @@ def test_onnx_multimodels():
}
return $output
""", "parameters":{"intercept":ridgeReg.intercept_, "coefs":ridgeReg.coef_}, "fields":["A","B","C","D"]},
{"name":"y_pred_customjs_ridge_naive","func":"return intercept + coefs[0]*A + coefs[1]*B + coefs[2]*C + coefs[3]*D","parameters":{"intercept":ridgeReg.intercept_, "coefs":ridgeReg.coef_}, "variables":["A","B","C","D"]}
{"name":"y_pred_customjs_ridge_naive","expr":"intercept + coefs[0]*A + coefs[1]*B + coefs[2]*C + coefs[3]*D"}
]
parameterArray += [
{"name":"intercept","value":ridgeReg.intercept_},
{"name":"coefs","value":ridgeReg.coef_}
]
widgetParams = mergeFigureArrays(widgetParams, [["range", ["A"],{"name":"A"}],["range",["B"],{"name":"B"}]])
widgetLayoutDesc["Select"] += [["A","B"]]
Expand Down