Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions morph_utils/ccf.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def get_ccf_structure(voxel, name_map=None, annotation=None, coordinate_to_voxel
return name_map[structure_id]

def projection_matrix_for_swc(input_swc_file, mask_method = "tip_and_branch",
tip_count = False, annotation=None,
count_method = "node", annotation=None,
annotation_path = None, volume_shape=(1320, 800, 1140),
resolution=10, node_type_list=[2],
resample_spacing=None):
Expand All @@ -304,7 +304,8 @@ def projection_matrix_for_swc(input_swc_file, mask_method = "tip_and_branch",
'tip_and_branch' will return a projection matrix masking only structures with tip and branch nodes. If 'tip'
will only look at structures with tip nodes. And last, if 'branch' will only look at structures with
branch nodes.
tip_count (bool): if True, will count number of tips instead of the length of axon
count_method (str): ['node','tip','branch']. When 'node', will measure axon length by multiplying by internode spacing.
Otherwise will return the count of tip or branch nodes in each structure
annotation (array, optional): 3 dimensional ccf annotation array. Defaults to None.
annotation_path (str, optional): path to nrrd file to use (optional). Defaults to None.
volume_shape (tuple, optional): the size in voxels of the ccf atlas (annotation volume). Defaults to (1320, 800, 1140).
Expand All @@ -317,7 +318,7 @@ def projection_matrix_for_swc(input_swc_file, mask_method = "tip_and_branch",
filename (str)

specimen_projection_summary (dict): keys are strings of structures and values are the quantitiave projection
values. Either axon length, or number numbe of nodes depending on tip_count argument.
values. Either axon length, or number numbe of nodes depending on count_method argument.

"""

Expand All @@ -329,8 +330,11 @@ def projection_matrix_for_swc(input_swc_file, mask_method = "tip_and_branch",
print(f"WARNING: Annotation path provided does not exist, defaulting to 10um resolution, (1320,800, 1140) ccf.\n{annotation_path}")
annotation_path = None
annotation = open_ccf_annotation(with_nrrd=True, annotation_path=annotation_path)



if count_method not in ['node','tip','branch']:
msg = f"count_method must be 'node','tip', or 'branch'. You passed in: {count_method}"
raise ValueError(msg)

sg_df = load_structure_graph()
name_map = NAME_MAP
full_name_to_abbrev_dict = dict(zip(sg_df.name, sg_df.index))
Expand Down Expand Up @@ -358,7 +362,9 @@ def projection_matrix_for_swc(input_swc_file, mask_method = "tip_and_branch",

# annotate each node
if morph_df.empty:
print("Its empty")

msg = "morph_df is empty, possibly caused by `morph_df = morph_df[morph_df['type'].isin(node_type_list)]`"
warnings.warn(msg)
return input_swc_file, {}

morph_df['ccf_structure'] = morph_df.apply(lambda rw: full_name_to_abbrev_dict[get_ccf_structure( np.array([rw.x, rw.y, rw.z]) , name_map, annotation, True)], axis=1)
Expand All @@ -379,15 +385,12 @@ def node_ider(morph,i):

# determine ipsi/contra projections
morph_df["ccf_structure_sided"] = morph_df.apply(lambda row: "ipsi_{}".format(row.ccf_structure) if row.z<z_midline else "contra_{}".format(row.ccf_structure), axis=1)


# mask the morphology dataframe accordinagly
if mask_method!="None":

keep_structs = []
for struct, struct_df in morph_df.groupby("ccf_structure_sided"):
node_types_in_struct = struct_df.node_type.unique().tolist()

if (mask_method == 'tip') & ("tip" in node_types_in_struct):
keep_structs.append(struct)

Expand All @@ -403,23 +406,22 @@ def node_ider(morph,i):
morph_df_masked = morph_df[morph_df['ccf_structure_sided'].isin(keep_structs)]

else:
print("Not masking projection matrix...")
morph_df_masked = morph_df

# remove ventral targets and out of brain
ventral_targs = ["ipsi_{}".format(v) for v in vs_acronyms] + ["contra_{}".format(v) for v in vs_acronyms]
targets_to_remove = ["ipsi_Out Of Cortex", "ipsi_root","contra_Out Of Cortex", "contra_root"] + ventral_targs
morph_df_masked = morph_df_masked[~morph_df_masked['ccf_structure_sided'].isin(targets_to_remove)]

# accomodate tip counting instead of axon length
if tip_count:
morph_df_masked = morph_df_masked[morph_df_masked['node_type']=='tip']
if count_method != 'node':
morph_df_masked = morph_df_masked[morph_df_masked['node_type']==count_method]
spacing = 1

# qunatify projections per structure
n_nodes_per_structure = morph_df_masked.ccf_structure_sided.value_counts()
axon_length_per_structure = n_nodes_per_structure*spacing
specimen_projection_summary = axon_length_per_structure.to_dict()

return input_swc_file, specimen_projection_summary


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import os
from tqdm import tqdm
import pandas as pd
import argschema as ags
from morph_utils.ccf import projection_matrix_for_swc

class IO_Schema(ags.ArgSchema):
output_directory = ags.fields.OutputDir(description="output directory")
output_projection_csv = ags.fields.OutputFile(description="output projection csv")
mask_method = ags.fields.Str(default="tip_and_branch",description = " 'tip_and_branch', 'branch', 'tip', or 'tip_or_branch' ")
projection_threshold = ags.fields.Int(default=0)
normalize_proj_mat = ags.fields.Boolean(default=True)


def normalize_projection_columns_per_cell(input_df, projection_column_identifiers=['ipsi', 'contra']):
"""
:param input_df: input projection df
:param projection_column_identifiers: list of identifiers for projection columns. i.e. strings that identify projection columns from metadata columns
:return: normalized projection matrix
"""
proj_cols = [c for c in input_df.columns if any([ider in c for ider in projection_column_identifiers])]
input_df[proj_cols] = input_df[proj_cols].fillna(0)

res = input_df[proj_cols].T / input_df[proj_cols].sum(axis=1)
input_df[proj_cols] = res.T

return input_df


def main(output_directory,
output_projection_csv,
projection_threshold,
mask_method,
normalize_proj_mat,
**kwargs):

files_of_interest = [f for f in os.listdir(output_directory) if (f.endswith(".csv") and not f.endswith("_norm.csv")) ]
output_projection_csv = output_projection_csv.replace(".csv", f"_{mask_method}.csv")

projection_records = {}
# branch_and_tip_projection_records = {}
for fn in files_of_interest:
df = pd.read_csv(os.path.join(output_directory, fn),index_col=0)
src_file = df.index[0]
fn = os.path.abspath(src_file)

proj_records = df.loc[src_file].to_dict()
# brnch_tip_records = res[1]

projection_records[fn] = proj_records
# branch_and_tip_projection_records[fn] = brnch_tip_records

proj_df = pd.DataFrame(projection_records).T.fillna(0)
# proj_df_mask = pd.DataFrame(branch_and_tip_projection_records).T.fillna(0)

proj_df.to_csv(output_projection_csv)
# proj_df_mask.to_csv(output_projection_csv_tip_branch_mask)

if projection_threshold != 0:
output_projection_csv = output_projection_csv.replace(".csv",
"{}thresh.csv".format(projection_threshold))
# output_projection_csv_tip_branch_mask = output_projection_csv_tip_branch_mask.replace(".csv",
# "{}thresh.csv".format(
# projection_threshold))

proj_df_arr = proj_df.values
proj_df_arr[proj_df_arr < projection_threshold] = 0
proj_df = pd.DataFrame(proj_df_arr, columns=proj_df.columns, index=proj_df.index)
proj_df.to_csv(output_projection_csv)

# proj_df_mask_arr = proj_df_mask.values
# proj_df_mask_arr[proj_df_mask_arr < projection_threshold] = 0
# proj_df_mask = pd.DataFrame(proj_df_mask_arr, columns=proj_df_mask.columns, index=proj_df_mask.index)
# proj_df_mask.to_csv(output_projection_csv_tip_branch_mask)

if normalize_proj_mat:
output_projection_csv = output_projection_csv.replace(".csv", "_norm.csv")
# output_projection_csv_tip_branch_mask = output_projection_csv_tip_branch_mask.replace(".csv", "_norm.csv")

proj_df = normalize_projection_columns_per_cell(proj_df)
proj_df.to_csv(output_projection_csv)

# proj_df_mask = normalize_projection_columns_per_cell(proj_df_mask)
# proj_df_mask.to_csv(output_projection_csv_tip_branch_mask)

def console_script():
module = ags.ArgSchemaParser(schema_type=IO_Schema)
main(**module.args)

if __name__ == "__main__":
module = ags.ArgSchemaParser(schema_type=IO_Schema)
main(**module.args)
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class IO_Schema(ags.ArgSchema):
projection_threshold = ags.fields.Int(default=0)
normalize_proj_mat = ags.fields.Boolean(default=True)
mask_method = ags.fields.Str(default="tip_and_branch",description = " 'tip_and_branch', 'branch', 'tip', or 'tip_or_branch' ")
tip_count = ags.fields.Boolean(default=False, description="when true, this will measure a matrix of number of tips instead of number of nodes")
count_method = ags.fields.String(default="node", description="should be a member of ['node','tip','branch']")
annotation_path = ags.fields.Str(default="",description = "Optional. Path to annotation .nrrd file. Defaults to 10um ccf atlas")
resolution = ags.fields.Int(default=10, description="Optional. ccf resolution (micron/pixel")
volume_shape = ags.fields.List(ags.fields.Int, default=[1320, 800, 1140], description = "Optional. Size of input annotation")
Expand Down Expand Up @@ -38,18 +38,21 @@ def main(input_swc_file,
projection_threshold,
normalize_proj_mat,
mask_method,
tip_count,
count_method,
annotation_path,
volume_shape,
resample_spacing,
**kwargs):

if annotation_path == "":
annotation_path = None

if mask_method not in [None,'tip_and_branch', 'branch', 'tip', 'tip_or_branch']:
raise ValueError(f"Invalid mask_method provided {mask_method}")

results = []
res = projection_matrix_for_swc(input_swc_file=input_swc_file,
tip_count = tip_count,
count_method = count_method,
mask_method = mask_method,
annotation=None,
annotation_path = annotation_path,
Expand All @@ -74,7 +77,7 @@ def main(input_swc_file,

proj_df.to_csv(output_projection_csv)
# proj_df_mask.to_csv(output_projection_csv_tip_branch_mask)

print(proj_df.head())
if projection_threshold != 0:
output_projection_csv = output_projection_csv.replace(".csv",
"{}thresh.csv".format(projection_threshold))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ class IO_Schema(ags.ArgSchema):
projection_threshold = ags.fields.Int(default=0)
normalize_proj_mat = ags.fields.Boolean(default=True)
mask_method = ags.fields.Str(default="tip_and_branch",description = " 'tip_and_branch', 'branch', 'tip', or 'tip_or_branch' ")
tip_count = ags.fields.Boolean(default=False, description="when true, this will measure a matrix of number of tips instead of number of nodes")
count_method = ags.fields.String(default="node", description="should be a member of ['node','tip','branch']")
annotation_path = ags.fields.Str(default="",description = "Optional. Path to annotation .nrrd file. Defaults to 10um ccf atlas")
resolution = ags.fields.Int(default=10, description="Optional. ccf resolution (micron/pixel")
volume_shape = ags.fields.List(ags.fields.Int, default=[1320, 800, 1140], description = "Optional. Size of input annotation")
resample_spacing = ags.fields.Float(allow_none=True, default=None, description = 'internode spacing to resample input morphology with')
run_host = ags.fields.String(default='local',description='either ["local" or "hpc"]. Will run either locally or submit jobs to the HPC')
virtual_env_name = ags.fields.Str(default='skeleton_keys_4',description='Name of virtual conda env to activate on hpc. not needed if running local')
output_projection_csv = ags.fields.OutputFile(allow_none=True,default=None, description="output projection csv, when running local only")
output_projection_csv = ags.fields.OutputFile(description="output projection csv, when running local only")


def normalize_projection_columns_per_cell(input_df, projection_column_identifiers=['ipsi', 'contra']):
Expand All @@ -44,7 +44,7 @@ def main(ccf_swc_directory,
projection_threshold,
normalize_proj_mat,
mask_method,
tip_count,
count_method,
annotation_path,
volume_shape,
resample_spacing,
Expand All @@ -55,6 +55,8 @@ def main(ccf_swc_directory,

if run_host not in ['local','hpc']:
raise ValueError(f"Invalid run_host parameter entered ({run_host})")
if mask_method not in [None,'tip_and_branch', 'branch', 'tip', 'tip_or_branch']:
raise ValueError(f"Invalid mask_method provided {mask_method}")

if annotation_path == "":
annotation_path = None
Expand All @@ -69,13 +71,14 @@ def main(ccf_swc_directory,
os.mkdir(dd)

results = []
single_cell_job_ids = []
for swc_fn in tqdm([f for f in os.listdir(ccf_swc_directory) if ".swc" in f]):

swc_pth = os.path.abspath(os.path.join(ccf_swc_directory, swc_fn))

if run_host=='local':
res = projection_matrix_for_swc(input_swc_file=swc_pth,
tip_count = tip_count,
count_method= count_method,
mask_method = mask_method,
annotation=None,
annotation_path = annotation_path,
Expand All @@ -86,24 +89,25 @@ def main(ccf_swc_directory,

else:

output_projection_csv = os.path.join(single_sp_proj_dir, swc_fn.replace(".swc",".csv"))
this_output_projection_csv = os.path.join(single_sp_proj_dir, swc_fn.replace(".swc",".csv"))

if not os.path.exists(output_projection_csv):
if not os.path.exists(this_output_projection_csv):

job_file = os.path.join(job_dir,swc_fn.replace(".swc",".sh"))
log_file = os.path.join(job_dir,swc_fn.replace(".swc",".log"))

command = "morph_utils_extract_projection_matrix_single_cell "
command = command+ f" --input_swc_file '{swc_pth}'"
command = command+ f" --output_projection_csv {output_projection_csv}"
command = command+ f" --output_projection_csv {this_output_projection_csv}"
command = command+ f" --projection_threshold {projection_threshold}"
command = command+ f" --normalize_proj_mat {normalize_proj_mat}"
command = command+ f" --mask_method {mask_method}"
command = command+ f" --tip_count {tip_count}"
command = command+ f" --count_method {count_method}"
command = command+ f" --annotation_path {annotation_path}"
command = command+ f" --resolution {resolution}"
# command = command+ f" --volume_shape {volume_shape}"
command = command+ f" --resample_spacing {resample_spacing}"
if resample_spacing is not None:
command = command+ f" --resample_spacing {resample_spacing}"


activate_command = f"source activate {virtual_env_name}"
Expand Down Expand Up @@ -150,8 +154,70 @@ def main(ccf_swc_directory,
std_out = result.stdout.decode('utf-8')

job_id = std_out.split("Submitted batch job ")[-1].replace("\n", "")
time.sleep(0.1)
single_cell_job_ids.append(job_id)
# time.sleep(0.1)

if run_host!='local':
# aggregate single projection files into proj mat
job_file = os.path.join(job_dir,"Projection_Aggregation.sh")
log_file = os.path.join(job_dir,"Projection_Aggregation.log")
agg_outdir = os.path.join(output_directory,'SingleCellProjections')
command = "morph_utils_aggregate_single_cell_projs "
command = command+ f" --output_projection_csv {output_projection_csv}"
command = command+ f" --output_directory {agg_outdir}"
command = command+ f" --projection_threshold {projection_threshold}"
command = command+ f" --normalize_proj_mat {normalize_proj_mat}"
command = command+ f" --mask_method {mask_method}"

activate_command = f"source activate {virtual_env_name}"
command_list = [activate_command, command]

slurm_kwargs = {
"--job-name": "AggregateProjs",
"--mail-type": "NONE",
"--cpus-per-task": "1",
"--nodes": "1",
"--kill-on-invalid-dep": "yes",
"--mem": "4gb",
"--time": "1:00:00",
"--partition": "celltypes",
"--output": log_file
}

dag_node = {
"job_file":job_file,
"slurm_kwargs":slurm_kwargs,
"slurm_commands":command_list
}

job_file = dag_node["job_file"]
slurm_kwargs = dag_node["slurm_kwargs"]
command_list = dag_node["slurm_commands"]

job_string_list = [f"#SBATCH {k}={v}" for k, v in slurm_kwargs.items()]
job_string_list = job_string_list + command_list
job_string_list = ["#!/bin/bash"] + job_string_list

if os.path.exists(job_file):
os.remove(job_file)

with open(job_file, 'w') as job_f:
for val in job_string_list:
job_f.write(val)
job_f.write('\n')

command = "sbatch --dependency=afterany"
for p_jid in single_cell_job_ids:
command = command + f":{p_jid}"
command = command + " {}".format(job_file)
command_list = command.split(" ")
# print(command)
result = subprocess.run(command_list, stdout=subprocess.PIPE)
std_out = result.stdout.decode('utf-8')

job_id = std_out.split("Submitted batch job ")[-1].replace("\n", "")


if results != []:

output_projection_csv = output_projection_csv.replace(".csv", f"_{mask_method}.csv")
Expand Down
Loading