Skip to content
173 changes: 79 additions & 94 deletions Dash_interface/chart_section_n.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys

from dash import Dash, html, dcc, Input, Output, State, dash_table, Patch
from dash.exceptions import PreventUpdate
import pickle
Expand Down Expand Up @@ -295,22 +297,30 @@ def update_peaks(data): # , slider_value):
if data == None:
return {}, {"display": "none"}
peaksObj = pickle.loads(base64.b64decode(data))

main_compound_peaks = peaksObj["main_compound_peaks"]
mod_compound_peaks = peaksObj["mod_compound_peaks"]
matched_peaks = peaksObj["matched_peaks"]
args = peaksObj["args"]
main_precursor_mz = peaksObj["main_precursor_mz"]
mod_precursor_mz = peaksObj["mod_precursor_mz"]

# Convert m/z values back down from keys
main_compound_peaks = [(mz/1e6, intensity) for mz, intensity in main_compound_peaks]
mod_compound_peaks = [(mz/1e6, intensity) for mz, intensity in mod_compound_peaks]
matched_peaks = [(main_mz/1e6, mod_mz/1e6) for main_mz, mod_mz in matched_peaks]

fig = go.Figure()
typesInxMain = {"matched_shifted": [], "matched_unshifted": [], "unmatched": []}

### Assemble matched and unmatched peaks for main compound

x1 = []
y1 = []
for peak in main_compound_peaks:
x1.append(peak[0])
y1.append(peak[1])

# topPeakCount = slider_value
topPeakCount = max(
len(main_compound_peaks),
len(mod_compound_peaks),
Expand All @@ -319,29 +329,32 @@ def update_peaks(data): # , slider_value):
hoverData = {"main": [], "modified": []}
for i in topPeaksInxModif:
flag = False
for j in matched_peaks:
if j[0] == i:
for main_match_mz, mod_match_mz in matched_peaks:
if abs(main_compound_peaks[i][0] - main_match_mz) < 1e-6: # We have found a match for our specific peak
if (
abs(
main_compound_peaks[i][0]
- mod_compound_peaks[j[1]][0]
- mod_match_mz
)
> args["mz_tolerance"]
):
typesInxMain["matched_shifted"].append(i)
hoverData["main"].append(j[1])
typesInxMain["matched_shifted"].append([main_match_mz, y1[i], f"{mod_match_mz:.2f}:{main_compound_peaks[i][0]:.2f}"])
else:
typesInxMain["matched_unshifted"].append(i)
typesInxMain["matched_unshifted"].append([main_match_mz, y1[i], f"{mod_match_mz:.2f}:{main_compound_peaks[i][0]:.2f}"])
flag = True
break
break
if not flag:
typesInxMain["unmatched"].append(i)
typesInxMain["unmatched"].append([main_compound_peaks[i][0], y1[i], "Unmatched"])


typesInxModified = {
"matched_shifted": [],
"matched_unshifted": [],
"unmatched": [],
}

### Assemble matched and unmatched peaks for modified compound

x2 = []
y2 = []
for peak in mod_compound_peaks:
Expand All @@ -351,42 +364,45 @@ def update_peaks(data): # , slider_value):
topPeaksInxModif = sorted(range(len(y2)), key=lambda i: y2[i])[-topPeakCount:]
for i in topPeaksInxModif:
flag = False
for j in matched_peaks:
if j[1] == i:
for main_match_mz, mod_match_mz in matched_peaks:

if abs(mod_compound_peaks[i][0] - mod_match_mz) < 1e-6: # We have found a match for our specific peak
if (
abs(
main_compound_peaks[j[0]][0]
- mod_compound_peaks[j[1]][0]
mod_compound_peaks[i][0]
- main_match_mz
)
> 0.1
> args["mz_tolerance"]
):
typesInxModified["matched_shifted"].append([i, j[0]])
hoverData["modified"].append(j[0])
typesInxModified["matched_shifted"].append([mod_match_mz, -y2[i], f"{main_match_mz:.2f}:{mod_compound_peaks[i][0]:.2f}"])
# hoverData["modified"].append(j[0])
else:
typesInxModified["matched_unshifted"].append([i, j[0]])
typesInxModified["matched_unshifted"].append([mod_match_mz, -y2[i], f"{main_match_mz:.2f}:{mod_compound_peaks[i][0]:.2f}"])
flag = True
break
if not flag:
typesInxModified["unmatched"].append([i, -1])
typesInxModified["unmatched"].append([mod_compound_peaks[i][0], -y2[i], "Unmatched"])

minX = min(min(x1), min(x2))
maxX = max(max(x1), max(x2))
minX = min(minX, main_precursor_mz, mod_precursor_mz)
maxX = max(maxX, main_precursor_mz, mod_precursor_mz)

### Plotting

for inx_type in typesInxMain:
x_main = [round(x1[j], 4) for j in typesInxMain[inx_type]]
y1_ = [y1[j] for j in typesInxMain[inx_type]]
y_main = [y / max(y1_) * 100 for y in y1_]
x_modified = [round(x2[j[0]], 4) for j in typesInxModified[inx_type]]
y2_ = [y2[j[0]] for j in typesInxModified[inx_type]]
y_modified = [-j / max(y2_) * 100 for j in y2_]
indicis = typesInxMain[inx_type] + [
j[0] for j in typesInxModified[inx_type]
]
x_ = x_main + x_modified
y_ = y_main + y_modified
colors = [colorsInxMain[inx_type]] * len(x_)

x = [j[0] for j in typesInxMain[inx_type]] + [j[0] for j in typesInxModified[inx_type]]
y = [j[1] for j in typesInxMain[inx_type]] + [j[1] for j in typesInxModified[inx_type]]
# Separate norm constants for pos and neg y
if len(y) == 0:
continue

max_y = max(y) if max(y) > 0 else 1
min_y = min(y) if min(y) < 0 else -1
y = [y_i / max_y * 100 if y_i > 0 else -(y_i / min_y) * 100 for y_i in y]
hovertext = [j[2] for j in typesInxMain[inx_type]] + [j[2] for j in typesInxModified[inx_type]]
colors = [colorsInxMain[inx_type]] * len(x)
if inx_type == "unmatched":
visibility = "legendonly"
if len(typesInxModified["matched_shifted"]) == 0 and len(
Expand All @@ -396,36 +412,20 @@ def update_peaks(data): # , slider_value):

fig.add_trace(
go.Bar(
x=x_,
y=y_,
x=x,
y=y,
width=(maxX - minX) / 500,
hovertext=indicis,
hovertext=hovertext,
name=inx_type,
visible=visibility,
marker_color=colors,
)
)
elif inx_type == "matched_shifted":
hovertext = []
for i in range(len(x_main)):
hovertext.append(
str(indicis[i])
+ " "
+ "matched to:"
+ str(hoverData["main"][i])
)
for i in range(len(x_main), len(x_main) + len(x_modified)):
hovertext.append(
str(indicis[i])
+ " "
+ "matched to:"
+ str(hoverData["modified"][i - len(x_main)])
)

fig.add_trace(
go.Bar(
x=x_,
y=y_,
x=x,
y=y,
hovertext=hovertext,
name=inx_type,
width=(maxX - minX) / 500,
Expand All @@ -435,9 +435,9 @@ def update_peaks(data): # , slider_value):
else:
fig.add_trace(
go.Bar(
x=x_,
y=y_,
hovertext=indicis,
x=x,
y=y,
hovertext=hovertext,
name=inx_type,
width=(maxX - minX) / 500,
marker_color=colors,
Expand All @@ -452,8 +452,6 @@ def update_peaks(data): # , slider_value):
mode="lines",
line=go.scatter.Line(color="black", dash="dash", width= (maxX - minX) / 600),
name='known precursor m/z',
# showlegend=False,
# hoverinfo='skip'
)
)
fig.add_trace(
Expand All @@ -463,17 +461,9 @@ def update_peaks(data): # , slider_value):
mode="lines",
line=go.scatter.Line(color="black", dash="dot", width= (maxX - minX) / 600),
name='modified precursor m/z',
# showlegend=False,
# hoverinfo='skip'
)
)

# minX = min(minX, main_precursor_mz, mod_precursor_mz)
# maxX = max(maxX, main_precursor_mz, mod_precursor_mz)

# fig.update_traces(
# width=(maxX - minX) / 400,
# )
fig.update_layout(
title="Alignment of Peaks",
bargap=0,
Expand Down Expand Up @@ -501,18 +491,6 @@ def update_peaks(data): # , slider_value):
"zIndex": "1",
}


# @app.callback(
# Output("peak_info", "children", allow_duplicate=True),
# Input("siteLocatorObj", "data"),
# prevent_initial_call=True,
# )
# def clear_peak_info(data):
# if data == None:
# return ""
# else:
# return "Select a peak to see its fragments"

@app.callback(
Output("peak_info", "children", allow_duplicate=True),
Input("peaks", "clickData"),
Expand All @@ -533,17 +511,24 @@ def display_click_data(clickData, fragmentsObj):

structure = fragmentsObj["structure"]
frags_map = fragmentsObj["frags_map"]
peaks = fragmentsObj["peaks"]
peak_keys = [int(x[0]) for x in fragmentsObj["peaks"]]

peak_index = -1
for i, peak in enumerate(peaks):
if abs(peak[0]- clicked_peak_x)/clicked_peak_x*1000000 < 40:
peak_index = i
peak_key = None
for k in peak_keys:
if abs((k/1e6)- clicked_peak_x)/clicked_peak_x*1000000 < 40:
peak_key = k # Cast to int (numpy ints won't key)
break
if peak_index == -1:
return "error in finding peak index"
if peak_key is None:
raise ValueError(f"Clicked peak not found in peaks list "f"(clicked_peak_x: {clicked_peak_x}, peaks: {peak_keys})")

try:
fragments = list(frags_map[peak_key])
except KeyError:
# Check for the closest key
closest_key = min(frags_map.keys(), key=lambda k: abs(k/1e6 - clicked_peak_x))

raise ValueError(f"Fragment map does not contain peak key {peak_key} (type {type(peak_key)}), closest key is {closest_key} (type {type(closest_key)} with m/z {closest_key/1e6}, clicked m/z was {clicked_peak_x}")

fragments = list(frags_map[peak_index])
result_posibility_indicies = []
for fragment in fragments:
fragment_indicies = []
Expand All @@ -564,9 +549,9 @@ def display_click_data(clickData, fragmentsObj):
)
except:
import traceback

traceback.print_exc()
traceback.print_exc(file=sys.stderr)
return "siteLocator object not found"

return None

# change the color of the bar when clicked
Expand Down Expand Up @@ -599,9 +584,9 @@ def change_bar_color(clickData, figure):
# figure["data"][i]["marker"]["color"][j] = "green"
# if matched shifted peak, highlight the corresponding peak in the other bar
if figure["data"][i]["name"] == "matched_shifted":
index = figure["data"][i]["hovertext"][j].split(":")[1]
peak_x = str(figure["data"][i]["hovertext"][j].split(":")[0]).strip()
for l in range(len(figure["data"][i]["x"])):
if (figure["data"][i]["hovertext"][l].split(" ")[0] == index and figure["data"][i]["y"][l] < 0):
if (str(figure["data"][i]["hovertext"][l].split(':')[1]).strip() == peak_x and figure["data"][i]["y"][l] < 0):
patched_figure["data"][i]["marker"]["color"][l] = "olive"
break

Expand All @@ -617,7 +602,7 @@ def change_bar_color(clickData, figure):
@app.callback(
[Output("siteLocatorObj", "data", allow_duplicate=True),
Output("peak_info", "children", allow_duplicate=True),
Output('fragmentsObj', 'data', allow_duplicate=True)],
Output("fragmentsObj", "data", allow_duplicate=True)],
Input(FragmentsDisplayAIO.ids.fragment_data("fragmentDisplay"), "data"),
State("siteLocatorObj", "data"),
prevent_initial_call=True,
Expand All @@ -632,21 +617,21 @@ def apply_structure_filter(data, siteLocatorObj):
modified_compound_id = siteLocator._get_unknown()
main_compound_id = siteLocator._get_known_neighbor(modified_compound_id)
main_compound = siteLocator.network.nodes[main_compound_id]['compound']
main_compound_peaks = [(main_compound.spectrum.mz[i], main_compound.spectrum.intensity[i]) for i in range(len(main_compound.spectrum.mz))]
main_compound_peaks = [(main_compound.spectrum.mz_key[i], main_compound.spectrum.intensity[i]) for i in range(len(main_compound.spectrum.mz_key))]
modified_compound = siteLocator.network.nodes[modified_compound_id]['compound']

ind = main_compound.spectrum.get_peak_indexes(data["mz"])
main_compound.spectrum.peak_fragments_map[ind[0]] = [data["all_fragments"][i] for i in data["selected_fragments"]]
mzs = data["mz"]
main_compound.spectrum.peak_fragment_dict[int(mzs[0])] = [data["all_fragments"][i] for i in data["selected_fragments"]]

fragmentsObj = {
"frags_map": main_compound.spectrum.peak_fragments_map,
"frags_map": main_compound.spectrum.peak_fragment_dict,
"structure": main_compound.structure,
"peaks": main_compound_peaks,
"Precursor_MZ": main_compound.spectrum.precursor_mz,
}


fragments = list(main_compound.spectrum.peak_fragments_map[ind[0]])
fragments = list(main_compound.spectrum.peak_fragment_dict[int(mzs[0])])
result_posibility_indicies = []
for fragment in fragments:
fragment_indicies = []
Expand Down
Loading