Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 166 additions & 2 deletions benchmarks/configs/learn_compositional_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@

from benchmarks.configs.names import CompositionalLearningExperiments
from benchmarks.configs.pretraining_experiments import supervised_pre_training_base
from tbp.monty.frameworks.actions.action_samplers import ConstantSampler
from tbp.monty.frameworks.config_utils.config_args import (
MontyArgs,
MotorSystemConfigInformedGoalStateDriven,
MotorSystemConfigNaiveScanSpiral,
TwoLMStackedMontyConfig,
get_cube_face_and_corner_views_rotations,
Expand All @@ -28,6 +30,7 @@
get_object_names_by_idx,
)
from tbp.monty.frameworks.config_utils.policy_setup_utils import (
make_informed_policy_config,
make_naive_scan_policy_config,
)
from tbp.monty.frameworks.environments import embodied_data as ED
Expand All @@ -43,7 +46,15 @@
from tbp.monty.frameworks.models.evidence_matching.learning_module import (
EvidenceGraphLM,
)
from tbp.monty.frameworks.models.motor_policies import NaiveScanPolicy
from tbp.monty.frameworks.models.evidence_matching.model import (
MontyForEvidenceGraphMatching,
)
from tbp.monty.frameworks.models.goal_state_generation import EvidenceGoalStateGenerator
from tbp.monty.frameworks.models.motor_policies import InformedPolicy, NaiveScanPolicy
from tbp.monty.frameworks.models.no_reset_evidence_matching import (
MontyForNoResetEvidenceGraphMatching,
NoResetEvidenceGraphLM,
)
from tbp.monty.simulators.habitat.configs import (
EnvInitArgsTwoLMDistantStackedMount,
TwoLMStackedDistantMountHabitatDatasetArgs,
Expand Down Expand Up @@ -72,6 +83,7 @@
# Note graph-delta-thresholds are not used for grid-based models
feature_weights={},
max_graph_size=0.3,
use_multithreading=False,
num_model_voxels_per_dim=200,
max_nodes_per_graph=2000,
object_evidence_threshold=20, # TODO - C: is this reasonable?
Expand All @@ -92,6 +104,7 @@
# real similarity measure.
"learning_module_0": {"object_id": 1},
},
use_multithreading=False,
feature_weights={"learning_module_0": {"object_id": 1}},
max_graph_size=0.4,
num_model_voxels_per_dim=200,
Expand All @@ -112,7 +125,6 @@
experiment_args=ExperimentArgs(
do_eval=False,
n_train_epochs=len(train_rotations_all),
show_sensor_output=False,
),
monty_config=TwoLMStackedMontyConfig(
monty_args=MontyArgs(num_exploratory_steps=1000),
Expand All @@ -139,12 +151,30 @@
),
)

# MINIMAL_3D_OBJECTS = ["016_sphere", "023_mug"]

# supervised_pre_training_minimal_3d_objects = copy.deepcopy(
# supervised_pre_training_flat_objects_wo_logos
# )
# supervised_pre_training_minimal_3d_objects.update(
# train_dataloader_args=EnvironmentDataloaderPerObjectArgs(
# object_names=get_object_names_by_idx(
# 0, len(MINIMAL_3D_OBJECTS), object_list=MINIMAL_3D_OBJECTS
# ),
# object_init_sampler=PredefinedObjectInitializer(
# rotations=train_rotations_all,
# ),
# ),
# )

# For learning the logos, we present them in a single rotation, but at multiple
# positions, as the naive scan policy otherwise samples peripheral points on the models
# poorly. This must be run after supervised_pre_training_flat_objects_wo_logos.
LOGO_POSITIONS = [[0.0, 1.5, 0.0], [-0.03, 1.5, 0.0], [0.03, 1.5, 0.0]]
LOGO_ROTATIONS = [[0.0, 0.0, 0.0]]

MINIMAL_LOGOS = ["021_logo_tbp"]

supervised_pre_training_logos_after_flat_objects = copy.deepcopy(
supervised_pre_training_flat_objects_wo_logos
)
Expand Down Expand Up @@ -177,6 +207,40 @@
),
)

# supervised_pre_training_minimal_logos_after_minimal_3d_objects = copy.deepcopy(
# supervised_pre_training_minimal_3d_objects
# )
# supervised_pre_training_minimal_logos_after_minimal_3d_objects.update(
# experiment_args=ExperimentArgs(
# do_eval=False,
# n_train_epochs=len(LOGO_POSITIONS) * len(LOGO_ROTATIONS),
# show_sensor_output=False,
# model_name_or_path=os.path.join(
# fe_pretrain_dir,
# "supervised_pre_training_minimal_3d_objects/pretrained/",
# ),
# ),
# monty_config=TwoLMStackedMontyConfig(
# monty_args=MontyArgs(num_exploratory_steps=1000),
# learning_module_configs=two_stacked_constrained_lms_config,
# motor_system_config=MotorSystemConfigNaiveScanSpiral(
# motor_system_args=dict(
# policy_class=NaiveScanPolicy,
# policy_args=make_naive_scan_policy_config(step_size=1),
# )
# ), # use spiral policy for more even object coverage during learning
# ),
# train_dataloader_args=EnvironmentDataloaderPerObjectArgs(
# object_names=get_object_names_by_idx(
# 0, len(MINIMAL_LOGOS), object_list=MINIMAL_LOGOS
# ),
# object_init_sampler=PredefinedObjectInitializer(
# positions=LOGO_POSITIONS,
# rotations=LOGO_ROTATIONS,
# ),
# ),
# )

# NOTE: we load the model trained on flat objects and logos, but we inheret from the
# config used for 3D "flat" objects, since it is similar in step-size, rotations, etc.
supervised_pre_training_curved_objects_after_flat_and_logo = copy.deepcopy(
Expand Down Expand Up @@ -317,14 +381,114 @@
),
)


two_stacked_constrained_lms_config_with_resampling = copy.deepcopy(
two_stacked_constrained_lms_config
)

two_stacked_constrained_lms_config_with_resampling["learning_module_0"][
"learning_module_class"
] = NoResetEvidenceGraphLM
two_stacked_constrained_lms_config_with_resampling["learning_module_0"][
"learning_module_args"
]["evidence_threshold_config"] = "all"
two_stacked_constrained_lms_config_with_resampling["learning_module_0"][
"learning_module_args"
]["object_evidence_threshold"] = 1
two_stacked_constrained_lms_config_with_resampling["learning_module_0"][
"learning_module_args"
]["gsg_class"] = EvidenceGoalStateGenerator
two_stacked_constrained_lms_config_with_resampling["learning_module_0"][
"learning_module_args"
]["gsg_args"] = dict(
goal_tolerances=dict(
location=0.015, # distance in meters
), # Tolerance(s) when determining goal-state success
elapsed_steps_factor=10, # Factor that considers the number of elapsed
# steps as a possible condition for initiating a hypothesis-testing goal
# state; should be set to an integer reflecting a number of steps
min_post_goal_success_steps=20, # Number of necessary steps for a hypothesis
# goal-state to be considered
x_percent_scale_factor=0.75, # Scale x-percent threshold to decide
# when we should focus on pose rather than determining object ID; should
# be bounded between 0:1.0; "mod" for modifier
desired_object_distance=0.03, # Distance from the object to the
# agent that is considered "close enough" to the object
)


OBJECTS_DISK_WITH_LOGO_ONLY = ["007_disk_tbp_horz"]

supervised_pre_training_objects_disk_with_logo_only_and_resampling = copy.deepcopy(
supervised_pre_training_objects_with_logos_lvl2_comp_models
)

# Other improvements --> surface policy during learning --> but this isn't currently
# setup for stacked LMs
# Only know about the mug and the logo --> this would require a fair amount of
# retraining Use the hypothesis testing with random saccade policy

MODEL_PATH_WITH_MINIMAL_TRAINING = os.path.join(
fe_pretrain_dir,
"supervised_pre_training_logos_after_flat_objects/pretrained/",
)

# Debugging - shouldn't have any issues learning in this second stage, because the only
# LM that is being updaged does not use the hypothesis resampler...
# Except... Monty class is set globally (MontyForNoResetEvidenceGraphMatching) -->
# but actually we're not concerned with unsupervised inference, so probably not needed

# ?Related to resampling not reaching terminal condition, so essentially no outputs?

temp_few_rotations = [[0, 0, 0]]

supervised_pre_training_objects_disk_with_logo_only_and_resampling.update(
# The low-level LM should use hypothesis resampling during its inference
experiment_args=ExperimentArgs(
do_eval=False,
n_train_epochs=len(temp_few_rotations),
model_name_or_path=MODEL_PATH_WITH_MINIMAL_TRAINING,
supervised_lm_ids=["learning_module_1"],
min_lms_match=2,
show_sensor_output=True,
),
monty_config=TwoLMStackedMontyConfig(
monty_args=MontyArgs(num_exploratory_steps=0, min_train_steps=100),
monty_class=MontyForEvidenceGraphMatching,
learning_module_configs=two_stacked_constrained_lms_config_with_resampling,
motor_system_config=MotorSystemConfigInformedGoalStateDriven(
motor_system_args=dict(
policy_class=InformedPolicy,
policy_args=make_informed_policy_config(
action_space_type="distant_agent_no_translation",
action_sampler_class=ConstantSampler,
rotation_degrees=0.5,
use_goal_state_driven_actions=True,
),
),
),
),
train_dataloader_args=EnvironmentDataloaderPerObjectArgs(
object_names=get_object_names_by_idx(
0, len(OBJECTS_DISK_WITH_LOGO_ONLY), object_list=OBJECTS_DISK_WITH_LOGO_ONLY
),
object_init_sampler=PredefinedObjectInitializer(
rotations=temp_few_rotations,
),
),
)

experiments = CompositionalLearningExperiments(
supervised_pre_training_flat_objects_wo_logos=supervised_pre_training_flat_objects_wo_logos,
supervised_pre_training_logos_after_flat_objects=supervised_pre_training_logos_after_flat_objects,
supervised_pre_training_curved_objects_after_flat_and_logo=supervised_pre_training_curved_objects_after_flat_and_logo,
# supervised_pre_training_minimal_3d_objects=supervised_pre_training_minimal_3d_objects,
# supervised_pre_training_minimal_logos_after_minimal_3d_objects=supervised_pre_training_minimal_logos_after_minimal_3d_objects,
supervised_pre_training_objects_with_logos_lvl1_monolithic_models=supervised_pre_training_objects_with_logos_lvl1_monolithic_models,
supervised_pre_training_objects_with_logos_lvl1_comp_models=supervised_pre_training_objects_with_logos_lvl1_comp_models,
supervised_pre_training_objects_with_logos_lvl2_comp_models=supervised_pre_training_objects_with_logos_lvl2_comp_models,
supervised_pre_training_objects_with_logos_lvl3_comp_models=supervised_pre_training_objects_with_logos_lvl3_comp_models,
supervised_pre_training_objects_with_logos_lvl4_comp_models=supervised_pre_training_objects_with_logos_lvl4_comp_models,
supervised_pre_training_objects_disk_with_logo_only_and_resampling=supervised_pre_training_objects_disk_with_logo_only_and_resampling,
)
CONFIGS = asdict(experiments)
3 changes: 3 additions & 0 deletions benchmarks/configs/names.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,14 @@ class CompositionalLearningExperiments:
supervised_pre_training_flat_objects_wo_logos: dict
supervised_pre_training_logos_after_flat_objects: dict
supervised_pre_training_curved_objects_after_flat_and_logo: dict
# supervised_pre_training_minimal_3d_objects: dict
# supervised_pre_training_minimal_logos_after_minimal_3d_objects: dict
supervised_pre_training_objects_with_logos_lvl1_monolithic_models: dict
supervised_pre_training_objects_with_logos_lvl1_comp_models: dict
supervised_pre_training_objects_with_logos_lvl2_comp_models: dict
supervised_pre_training_objects_with_logos_lvl3_comp_models: dict
supervised_pre_training_objects_with_logos_lvl4_comp_models: dict
supervised_pre_training_objects_disk_with_logo_only_and_resampling: dict


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,9 @@ def get_output(self):
sender_id=self.learning_module_id,
sender_type="LM",
)
if use_state:
print(f"Sending state with object-ID {object_id_features}")

return hypothesized_state

# ------------------ Getters & Setters ---------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def update_hypotheses(
mapper=mapper,
tracker=tracker,
)
informed_hypotheses = self._sample_informed(
informed_hypotheses, prediction_error = self._sample_informed(
channel_features=features[input_channel],
graph_id=graph_id,
informed_count=informed_count,
Expand All @@ -288,7 +288,7 @@ def update_hypotheses(
# We only displace existing hypotheses since the newly resampled hypotheses
# should not be affected by the displacement from the last sensory input.
if existing_count > 0:
existing_hypotheses = (
existing_hypotheses, prediction_error = (
self.hypotheses_displacer.displace_hypotheses_and_compute_evidence(
channel_displacement=displacements[input_channel],
channel_features=features[input_channel],
Expand Down Expand Up @@ -331,9 +331,13 @@ def update_hypotheses(
# Update tracker evidence
tracker.update(channel_hypotheses.evidence, input_channel)

# TODO C add calculation of prediction error
prediction_error = 0.0

return (
hypotheses_updates,
resampling_telemetry if self.include_telemetry else None,
prediction_error,
)

def _num_hyps_per_node(self, channel_features: dict) -> int:
Expand Down Expand Up @@ -527,14 +531,15 @@ def _sample_informed(
The sampled informed hypotheses.

"""
prediction_error = 0.0
# Return empty arrays for no hypotheses to sample
if informed_count == 0:
return ChannelHypotheses(
input_channel=input_channel,
locations=np.zeros((0, 3)),
poses=np.zeros((0, 3, 3)),
evidence=np.zeros(0),
)
), prediction_error

num_hyps_per_node = self._num_hyps_per_node(channel_features)
# === Calculate selected evidence by top-k indices === #
Expand Down Expand Up @@ -619,4 +624,4 @@ def _sample_informed(
locations=selected_locations,
poses=selected_rotations,
evidence=selected_feature_evidence,
)
), prediction_error
12 changes: 12 additions & 0 deletions src/tbp/monty/frameworks/models/graph_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,7 +1159,9 @@ def update_memory(
if graph_id is None:
logger.info("no match found in time, not updating memory")
else:
print(f"Iterating over input channels: {features.keys()}")
for input_channel in features.keys():
print(f"Current input channel: {input_channel}")
(
input_channel_features,
input_channel_locations,
Expand Down Expand Up @@ -1512,7 +1514,17 @@ def _extract_entries_with_content(self, features, locations):
"""
# NOTE: Could use any feature here but using pose_fully_defined since it
# is one dimensional and a required feature in each State.
# print("Current features:")
# print(features)
missing_features = np.isnan(features["pose_fully_defined"]).flatten()
print(
f"Current pose fully defined features: {features['pose_fully_defined']} of shape {np.shape(features['pose_fully_defined'])}"
)
print(f"features in buffer: {features.keys()}")
print(f"Shape of missing features: {np.shape(missing_features)}")
print(f"Shape of locations: {np.shape(locations)}")
print(f"Sum of missing features: {np.sum(missing_features)}")
print(f"Sum of not-missing features: {np.sum(~missing_features)}")
# Remove missing features (contain nan values)
locations = locations[~missing_features]
for feature in features.keys():
Expand Down