diff --git a/include/mppi/controllers/controller.cuh b/include/mppi/controllers/controller.cuh index c356d107..9beb8b0d 100644 --- a/include/mppi/controllers/controller.cuh +++ b/include/mppi/controllers/controller.cuh @@ -326,8 +326,9 @@ public: * @param rel_time * @return */ - virtual control_array getCurrentControl(state_array& state, double rel_time, state_array& target_nominal_state, - control_trajectory& c_traj, TEMPLATED_FEEDBACK_STATE& fb_state) + virtual control_array getCurrentControl(Eigen::Ref state, double rel_time, + Eigen::Ref target_nominal_state, + Eigen::Ref c_traj, TEMPLATED_FEEDBACK_STATE& fb_state) { // MPPI control control_array u_ff = interpolateControls(rel_time, c_traj); @@ -348,7 +349,8 @@ public: * @param steps - number of dt's to slide control sequence forward * Slide the control sequence forwards by 'steps' */ - virtual void slideControlSequence(int steps) { + virtual void slideControlSequence(int steps) + { // Save the control history this->saveControlHistoryHelper(steps, this->control_, this->control_history_); @@ -360,7 +362,7 @@ public: * @param rel_time time since the solution was calculated * @return */ - virtual control_array interpolateControls(double rel_time, control_trajectory& c_traj) + virtual control_array interpolateControls(double rel_time, Eigen::Ref c_traj) { int lower_idx = (int)(rel_time / getDt()); int upper_idx = lower_idx + 1; @@ -370,14 +372,10 @@ public: control_array prev_cmd = c_traj.col(lower_idx); control_array next_cmd = c_traj.col(upper_idx); interpolated_control = (1 - alpha) * prev_cmd + alpha * next_cmd; - - // printf("prev: %d %f, %f\n", lower_idx, prev_cmd[0], prev_cmd[1]); - // printf("next: %d %f, %f\n", upper_idx, next_cmd[0], next_cmd[1]); - // printf("smoother: %f\n", alpha); return interpolated_control; } - virtual state_array interpolateState(state_trajectory& s_traj, double rel_time) + virtual state_array interpolateState(Eigen::Ref s_traj, double rel_time) { int lower_idx = (int)(rel_time / getDt()); int upper_idx = lower_idx + 1; @@ -392,8 +390,8 @@ public: * @param rel_time * @return */ - virtual control_array interpolateFeedback(state_array& state, state_array& target_nominal_state, double rel_time, - TEMPLATED_FEEDBACK_STATE& fb_state) + virtual control_array interpolateFeedback(Eigen::Ref state, Eigen::Ref target_nominal_state, + double rel_time, TEMPLATED_FEEDBACK_STATE& fb_state) { return fb_controller_->interpolateFeedback_(state, target_nominal_state, rel_time, fb_state); } diff --git a/include/mppi/core/base_plant.hpp b/include/mppi/core/base_plant.hpp index 5203f97e..5a9ca886 100644 --- a/include/mppi/core/base_plant.hpp +++ b/include/mppi/core/base_plant.hpp @@ -65,7 +65,7 @@ class BasePlant std::atomic has_new_dynamics_params_{ false }; std::atomic has_new_cost_params_{ false }; std::atomic has_new_controller_params_{ false }; - std::atomic enabled_{ false }; + std::atomic has_received_state_{ false }; // Values needed s_array init_state_ = s_array::Zero(); @@ -286,7 +286,7 @@ class BasePlant * @param state the most recent state from state estimator * @param time the time of the most recent state from the state estimator */ - virtual void updateState(s_array& state, double time) + virtual void updateState(Eigen::Ref state, double time) { // calculate and update all timing variables double temp_last_state_update_time = last_used_state_update_time_; @@ -295,8 +295,9 @@ class BasePlant state_ = state; state_time_ = time; + has_received_state_ = true; - if (last_used_state_update_time_ < 0) + if (num_iter_ == 0) { // we have not optimized yet so no reason to publish controls return; @@ -446,13 +447,16 @@ class BasePlant double temp_last_state_time = getStateTime(); double temp_last_used_state_update_time = last_used_state_update_time_; + // If it is the first iteration and we have received state, we should not wait for timestamps to differ + bool skip_first_loop = num_iter_ == 0 && has_received_state_; + // wait for a new state to compute control sequence from - int counter = 0; - while (temp_last_used_state_update_time == temp_last_state_time && is_alive->load()) + while (temp_last_used_state_update_time == temp_last_state_time && !skip_first_loop && is_alive->load()) { usleep(50); temp_last_state_time = getStateTime(); - counter++; + // In case when runControlIteration is ran before getting state and state time is specifically 0 + skip_first_loop = num_iter_ == 0 && has_received_state_; } if (!is_alive->load()) { @@ -487,7 +491,7 @@ class BasePlant // calculate how much we should slide the control sequence double dt = temp_last_state_time - temp_last_used_state_update_time; - if (temp_last_used_state_update_time == -1) + if (num_iter_ == 0) { // // should only happen on the first iteration dt = 0; @@ -518,21 +522,21 @@ class BasePlant { std::cerr << "ERROR: Nan in control inside plant" << std::endl; std::cerr << control_traj << std::endl; - exit(-1); + throw std::runtime_error("Control Trajectory inside plant has a NaN"); } s_traj state_traj = controller_->getTargetStateSeq(); if (!state_traj.allFinite()) { std::cerr << "ERROR: Nan in state inside plant" << std::endl; std::cerr << state_traj << std::endl; - exit(-1); + throw std::runtime_error("State Trajectory inside plant has a NaN"); } o_traj output_traj = controller_->getTargetOutputSeq(); - if (!state_traj.allFinite()) + if (!output_traj.allFinite()) { - std::cerr << "ERROR: Nan in state inside plant" << std::endl; - std::cerr << state_traj << std::endl; - exit(-1); + std::cerr << "ERROR: Nan in output inside plant" << std::endl; + std::cerr << output_traj << std::endl; + throw std::runtime_error("Output Trajectory inside plant has a NaN"); } optimization_duration_ = mppi::math::timeDiffms(std::chrono::steady_clock::now(), optimization_start); diff --git a/include/mppi/core/buffered_plant.hpp b/include/mppi/core/buffered_plant.hpp index 7039961e..fe4ec405 100644 --- a/include/mppi/core/buffered_plant.hpp +++ b/include/mppi/core/buffered_plant.hpp @@ -75,6 +75,16 @@ class BufferedPlant : public BasePlant buffer_.clearBuffers(); } + double getBufferDt() const + { + return buffer_dt_; + } + + void setBufferDt(const double buff_dt) + { + buffer_dt_ = buff_dt; + } + protected: Buffer buffer_; @@ -83,8 +93,4 @@ class BufferedPlant : public BasePlant double buffer_dt_ = 0.02; // the spacing between well sampled buffer positions }; -template class BufferMessage; -template class BufferMessage; -template class BufferMessage; - #endif // MPPIGENERIC_BUFFERED_PLANT_H diff --git a/tests/include/mppi_test/mock_classes/mock_controller.h b/tests/include/mppi_test/mock_classes/mock_controller.h index 3d6f423e..ff1029ce 100644 --- a/tests/include/mppi_test/mock_classes/mock_controller.h +++ b/tests/include/mppi_test/mock_classes/mock_controller.h @@ -19,11 +19,12 @@ class MockController MOCK_METHOD0(resetControls, void()); MOCK_METHOD1(computeFeedback, void(const Eigen::Ref& state)); MOCK_METHOD1(slideControlSequence, void(int stride)); - MOCK_METHOD5(getCurrentControl, - control_array(state_array&, double, state_array&, control_trajectory&, TEMPLATED_FEEDBACK_STATE&)); + MOCK_METHOD5(getCurrentControl, control_array(Eigen::Ref, double, Eigen::Ref, + Eigen::Ref, TEMPLATED_FEEDBACK_STATE&)); MOCK_METHOD2(computeControl, void(const Eigen::Ref& state, int optimization_stride)); MOCK_METHOD(control_trajectory, getControlSeq, (), (const, override)); MOCK_METHOD(state_trajectory, getTargetStateSeq, (), (const, override)); + // MOCK_METHOD(output_trajectory, getTargetOutputSeq, (), (const, override)); MOCK_METHOD(TEMPLATED_FEEDBACK_STATE, getFeedbackState, (), (const, override)); MOCK_METHOD(control_array, getFeedbackControl, (const Eigen::Ref&, const Eigen::Ref&, int), (override)); diff --git a/tests/mppi_core/base_plant_tester.cu b/tests/mppi_core/base_plant_tester.cu index ef189666..8e5c908f 100644 --- a/tests/mppi_core/base_plant_tester.cu +++ b/tests/mppi_core/base_plant_tester.cu @@ -181,7 +181,7 @@ TEST_F(BasePlantTest, Constructor) EXPECT_EQ(plant->getHz(), 20); EXPECT_EQ(plant->getTargetOptimizationStride(), 1); EXPECT_EQ(plant->getNumIter(), 0); - EXPECT_EQ(plant->getLastUsedPoseUpdateTime(), -1); + EXPECT_EQ(plant->getLastUsedPoseUpdateTime(), 0); EXPECT_EQ(plant->getStatus(), 1); EXPECT_EQ(mockController->getFeedbackEnabled(), false); EXPECT_EQ(plant->hasNewCostParams(), false); @@ -275,7 +275,6 @@ TEST_F(BasePlantTest, updateParametersAllTrue) TEST_F(BasePlantTest, updateStateOutsideTimeTest) { - mockController->setDt(DT); plant->setLastTime(0); EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0); @@ -293,7 +292,6 @@ TEST_F(BasePlantTest, updateStateOutsideTimeTest) TEST_F(BasePlantTest, updateStateTest) { - mockController->setDt(DT); plant->setLastTime(0); MockController::state_array state = MockController::state_array::Zero(); @@ -303,16 +301,73 @@ TEST_F(BasePlantTest, updateStateTest) EXPECT_EQ(plant->pubControlCalled, 0); EXPECT_EQ(plant->pubNominalStateCalled, 0); - EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(1); + // Calling updateState() should not pub controls when none have been calculated as of yet + EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0); plant->setLastUsedTime(11); plant->updateState(state, 12); EXPECT_EQ(plant->getState(), state); - EXPECT_EQ(plant->pubControlCalled, 1); + EXPECT_EQ(plant->pubControlCalled, 0); EXPECT_EQ(plant->pubNominalStateCalled, 0); // TODO in debug should pub nominal state } +TEST_F(BasePlantTest, pubControlOnlyAfterControlAreCalculatedTest) +{ + ::testing::Sequence s1; + // Step 1: calling updateState() before controls are calculated should not call controller->getCurrentControl() + MockController::state_array state = MockController::state_array::Zero(); + EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0).InSequence(s1); + double curr_time = 0.0; + plant->updateState(state, curr_time); + EXPECT_EQ(plant->pubControlCalled, 0); + EXPECT_EQ(plant->pubNominalStateCalled, 0); + // ::testing::Mock::VerifyAndClearExpectations(mockController.get()); + + // Step 2: run control iteration inside plant + // Create valid outputs from gmock methods to prevent nan detection from triggering + MockController::control_trajectory valid_control_seq = MockController::control_trajectory::Zero(MockDynamics::CONTROL_DIM, NUM_TIMESTEPS); + MockController::state_trajectory valid_state_seq = MockController::state_trajectory::Zero(MockDynamics::STATE_DIM, NUM_TIMESTEPS); + EXPECT_CALL(*mockController, computeControl(testing::_, testing::_)).Times(1).InSequence(s1); + EXPECT_CALL(*mockController, getControlSeq()).Times(1).WillRepeatedly(testing::Return(valid_control_seq)); + EXPECT_CALL(*mockController, getTargetStateSeq()).Times(1).WillRepeatedly(testing::Return(valid_state_seq)); + // EXPECT_CALL(*mockController, getTargetOutputSeq()).Times(1); + std::atomic is_alive(true); + plant->runControlIteration(&is_alive); + + // Step 3: calling updateState() now should use controller->getCurrentControl() + EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(1).InSequence(s1); + curr_time++; + plant->updateState(state, curr_time); + EXPECT_EQ(plant->pubControlCalled, 1); + EXPECT_EQ(plant->pubNominalStateCalled, 0); +} + +TEST_F(BasePlantTest, EnsureReceivingStateCompletesRunControlIterationTest) +{ + std::atomic is_alive(true); + // Create valid outputs from gmock methods to prevent nan detection from triggering + MockController::control_trajectory valid_control_seq = MockController::control_trajectory::Zero(MockDynamics::CONTROL_DIM, NUM_TIMESTEPS); + MockController::state_trajectory valid_state_seq = MockController::state_trajectory::Zero(MockDynamics::STATE_DIM, NUM_TIMESTEPS); + EXPECT_CALL(*mockController, computeControl(testing::_, testing::_)).Times(1); + EXPECT_CALL(*mockController, getControlSeq()).Times(1).WillRepeatedly(testing::Return(valid_control_seq)); + EXPECT_CALL(*mockController, getTargetStateSeq()).Times(1).WillRepeatedly(testing::Return(valid_state_seq)); + // EXPECT_CALL(*mockController, getTargetOutputSeq()).Times(1); + std::thread new_thread(&MockTestPlant::runControlIteration, plant.get(), &is_alive); + // Wait some period of time and then call updateState() + std::cout << "Wait for new state" << std::endl; + std::this_thread::sleep_for(std::chrono::milliseconds(300)); + EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0); + MockController::state_array state = MockController::state_array::Zero(); + double curr_time = 0.0; + plant->updateState(state, curr_time); + std::cout << "State sent" << std::endl; + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + // is_alive.store(false); + new_thread.join(); + +} + TEST_F(BasePlantTest, runControlIterationStoppedTest) { EXPECT_CALL(*mockController, slideControlSequence(testing::_)).Times(0); diff --git a/tests/mppi_core/buffered_plant_tester.cu b/tests/mppi_core/buffered_plant_tester.cu index 2fc02c7f..17e3db05 100644 --- a/tests/mppi_core/buffered_plant_tester.cu +++ b/tests/mppi_core/buffered_plant_tester.cu @@ -122,10 +122,6 @@ public: { return this->buffer_tau_; } - double getBufferDt() - { - return this->buffer_dt_; - } void setLastUsedUpdateTime(double time) { this->last_used_state_update_time_ = time; @@ -198,6 +194,13 @@ TEST_F(BufferedPlantTest, Constructor) EXPECT_FLOAT_EQ(plant->getBufferDt(), 0.02); } +TEST_F(BufferedPlantTest, setBufferDt) +{ + double new_buffer_dt = 30.0; + plant->setBufferDt(new_buffer_dt); + EXPECT_FLOAT_EQ(plant->getBufferDt(), new_buffer_dt); +} + TEST_F(BufferedPlantTest, interpNew) { Eigen::Vector3f pos = Eigen::Vector3f::Ones(); @@ -208,7 +211,9 @@ TEST_F(BufferedPlantTest, interpNew) MockDynamics::state_array state = MockDynamics::state_array::Random(); EXPECT_CALL(mockDynamics, stateFromMap(testing::_)).Times(2).WillRepeatedly(testing::Return(state)); - EXPECT_CALL(*mockController, getDt()).Times(2); + // Controls never calculated so no calls to controller in updateState() + EXPECT_CALL(*mockController, getDt()).Times(0); + EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0); plant->setLastUsedUpdateTime(0); plant->updateOdometry(pos, quat, vel, omega, 0.0); @@ -380,7 +385,9 @@ TEST_F(BufferedPlantTest, updateOdometry) MockDynamics::state_array state = MockDynamics::state_array::Random(); EXPECT_CALL(mockDynamics, stateFromMap(testing::_)).Times(2).WillRepeatedly(testing::Return(state)); - EXPECT_CALL(*mockController, getDt()).Times(2); + // Controls never calculated so no calls to controller in updateState() + EXPECT_CALL(*mockController, getDt()).Times(0); + EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0); plant->setLastUsedUpdateTime(0.0); plant->updateOdometry(pos, quat, vel, omega, 0.0); @@ -448,7 +455,9 @@ TEST_F(BufferedPlantTest, getInterpState) plant->setLastUsedUpdateTime(0.0); EXPECT_CALL(mockDynamics, stateFromMap(testing::_)).Times(2).WillRepeatedly(testing::Return(state)); - EXPECT_CALL(*mockController, getDt()).Times(2); + // Controls never calculated so no calls to controller in updateState() + EXPECT_CALL(*mockController, getDt()).Times(0); + EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0); plant->updateOdometry(pos, quat, vel, omega, 0.0); plant->updateControls(u, 0.0); @@ -509,7 +518,9 @@ TEST_F(BufferedPlantTest, getInterpBuffer) plant->setLastUsedUpdateTime(0.0); EXPECT_CALL(mockDynamics, stateFromMap(testing::_)).Times(2).WillRepeatedly(testing::Return(state)); - EXPECT_CALL(*mockController, getDt()).Times(2); + // Controls never calculated so no calls to controller in updateState() + EXPECT_CALL(*mockController, getDt()).Times(0); + EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0); plant->updateOdometry(pos, quat, vel, omega, 0.0); plant->updateControls(u, 0.0);