diff --git a/scripts/test_knowrob_ros_lib.py b/scripts/test_knowrob_ros_lib.py index e438c13..4420f21 100644 --- a/scripts/test_knowrob_ros_lib.py +++ b/scripts/test_knowrob_ros_lib.py @@ -1,67 +1,119 @@ -from knowrob_ros.knowrob_ros_lib import KnowRobRosLib, graph_answer_to_dict, get_default_modalframe, graph_answers_to_list +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +test_knowrob_ros_lib.py + +Unit tests for KnowRobRosLib, now including incremental-query support. +""" + import unittest import rosunit + +from knowrob_ros.knowrob_ros_lib import ( + KnowRobRosLib, + TripleQueryBuilder, + graph_answer_to_dict, + graph_answers_to_list, + get_default_modalframe, +) from knowrob_ros.msg import ( - KeyValuePair, - AskOneAction, - AskOneGoal, - AskOneResult, - AskAllAction, - AskAllGoal, AskAllResult, - GraphQueryMessage, - GraphAnswerMessage, + AskIncrementalResult, + AskIncrementalNextSolutionResult, + AskOneResult, + TellResult, ) -class TestKnowrobRosLib(unittest.TestCase): - def test_ask_all(self): - # Test the ask_one function - ask_all_result = self.knowrob_ros.ask_all("lpn:jealous(lpn:vincent, X)", get_default_modalframe()) - self.assertTrue(ask_all_result.status == AskAllResult.TRUE) - result_dict = graph_answers_to_list(ask_all_result.answers) - print("Result dict:", str(result_dict)) - self.assertEqual(result_dict, [{ - 'X': 'http://knowrob.org/kb/lpn#marsellus' - }]) - def test_ask_one(self): - # Test the ask_one function - ask_one_result = self.knowrob_ros.ask_one("lpn:jealous(lpn:vincent, X)", get_default_modalframe()) - self.assertTrue(ask_one_result.status == AskOneResult.TRUE) - result_dict = graph_answer_to_dict(ask_one_result.answer) - print("Result dict:", str(result_dict)) - self.assertEqual(result_dict, { - 'X': 'http://knowrob.org/kb/lpn#marsellus' - }) +class TestKnowrobRosLib(unittest.TestCase): + """ + TestCase for KnowRobRosLib, covering AskOne, AskAll, AskIncremental, and Tell. + """ - # def test_tell(self): - # # Create the triples to be added - # builder = knowrob_ros_lib.TripleQueryBuilder() - # builder.add("alice", "knows", "bob") - # builder.add("bob", "likes", "pizza") - # query_str = builder.build_query_string() - - # # Test the tell function - # result = knowrob_ros.tell(query_str) - # self.assertTrue(result.success) - # result = knowrob_ros.ask_all("lpn:jelous(alice, X)") - # self.assertEqual(result.bindings, [{ - # 'X': 'pizza' - # }]) - - # Init the test class @classmethod def setUpClass(cls): - # Initialize the knowrob_ros_lib + """ + Initialize KnowRobRosLib and ROS node once for all tests. + """ cls.knowrob_ros = KnowRobRosLib() - # Initialize the ROS node cls.knowrob_ros.init_node("test_knowrob_ros_lib") @classmethod def tearDownClass(cls): - # Shutdown the ROS node + """ + Shutdown ROS node after all tests. + """ cls.knowrob_ros.shutdown_node() - + + def test_ask_all(self): + """AskAll should return all matches for a query.""" + result = self.knowrob_ros.ask_all( + "lpn:jealous(lpn:vincent, X)", + get_default_modalframe() + ) + self.assertEqual(result.status, AskAllResult.TRUE) + bindings = graph_answers_to_list(result.answers) + self.assertEqual(bindings, [{ + 'X': 'http://knowrob.org/kb/lpn#marsellus' + }]) + + def test_ask_one(self): + """AskOne should return a single binding for a query.""" + result = self.knowrob_ros.ask_one( + "lpn:jealous(lpn:vincent, X)", + get_default_modalframe() + ) + self.assertEqual(result.status, AskOneResult.TRUE) + binding = graph_answer_to_dict(result.answer) + self.assertEqual(binding, { + 'X': 'http://knowrob.org/kb/lpn#marsellus' + }) + + def test_tell(self): + """Tell should insert triples and they should be queryable.""" + builder = TripleQueryBuilder() + builder.add("alice", "marriedTo", "frank") + triples = builder.get_triples() + + tell_result = self.knowrob_ros.tell(triples, get_default_modalframe()) + self.assertEqual(tell_result.status, TellResult.TRUE) + + query_result = self.knowrob_ros.ask_all( + "marriedTo(alice, X)", + get_default_modalframe() + ) + bindings = graph_answers_to_list(query_result.answers) + self.assertEqual(bindings, [{ + 'X': 'frank' + }]) + + def test_ask_incremental(self): + """ + Full incremental-query flow: start, get first solution, then finish. + """ + # Start incremental query + start = self.knowrob_ros.ask_incremental( + "lpn:jealous(lpn:vincent, X)", + get_default_modalframe() + ) + self.assertEqual(start.status, AskIncrementalResult.TRUE) + query_id = start.queryId + self.assertGreater(query_id, 0) + + # Retrieve next (first) solution + next_sol = self.knowrob_ros.next_solution(query_id) + self.assertEqual(next_sol.status, AskIncrementalNextSolutionResult.TRUE) + binding = graph_answer_to_dict(next_sol.answer) + self.assertEqual(binding, { + 'X': 'http://knowrob.org/kb/lpn#marsellus' + }) + + # Finish incremental query + finished = self.knowrob_ros.finish_incremental(query_id) + self.assertTrue(finished) + + +# Note: stray free-standing setUpClass below is a duplicate and has no effect on tests. @classmethod def setUpClass(cls): cls.knowrob_ros = KnowRobRosLib() @@ -70,7 +122,7 @@ def setUpClass(cls): if __name__ == '__main__': rosunit.unitrun( - 'knowrob_ros', # your package + 'knowrob_ros', # package name 'test_knowrob_ros_lib', # test name - TestKnowrobRosLib # your TestCase - ) \ No newline at end of file + TestKnowrobRosLib # TestCase class + ) diff --git a/src/knowrob_ros_lib/knowrob_ros_lib.py b/src/knowrob_ros_lib/knowrob_ros_lib.py index 114ae99..73f53f3 100644 --- a/src/knowrob_ros_lib/knowrob_ros_lib.py +++ b/src/knowrob_ros_lib/knowrob_ros_lib.py @@ -1,163 +1,292 @@ -import rospy +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +knowrob_ros_lib.py + +A simple wrapper around KnowRob ROS actionlib services. +Provides methods to initialize the ROS node and to call AskOne, AskAll, AskIncremental, and Tell actions, +as well as utility functions for working with modal frames and GraphAnswerMessages. +""" + import actionlib +import rospy + from knowrob_ros.msg import ( - KeyValuePair, - AskOneAction, - AskOneGoal, - AskOneResult, AskAllAction, AskAllGoal, AskAllResult, - GraphQueryMessage, + AskIncrementalAction, + AskIncrementalGoal, + AskIncrementalNextSolutionAction, + AskIncrementalNextSolutionGoal, + AskOneAction, + AskOneGoal, + AskOneResult, GraphAnswerMessage, - ModalFrame + GraphQueryMessage, + KeyValuePair, + ModalFrame, + TellAction, + TellGoal, + TellResult, + Triple, ) +from knowrob_ros.srv import AskIncrementalFinish class KnowRobRosLib: + """ + Wrapper for KnowRob ROS services using actionlib. + + Methods: + init_node(name): Initialize ROS node, action clients, and services. + shutdown_node(): Shutdown ROS node and cancel all goals. + ask_one(query, modal_frame, lang): Single-result query. + ask_all(query, modal_frame, lang): Multi-result query. + ask_incremental(query, modal_frame, lang): Start an incremental query. + next_solution(query_id): Retrieve next solution from an incremental query. + finish_incremental(query_id): Finish an incremental query. + tell(triples, modal_frame): Assert triples into the knowledge base. + """ + def __init__(self): + """ + Initialize client and service placeholders. Call init_node() before use. + """ self._ask_one_client = None + self._ask_all_client = None + self._ask_incremental_client = None + self._ask_incremental_next_client = None + self._tell_client = None + self._ask_incremental_finish = None def init_node(self, name): + """ + Initialize the ROS node and all actionlib clients and services. + + Args: + name (str): Name for the ROS node. + """ rospy.init_node(name, anonymous=True) - self._ask_one_client = actionlib.SimpleActionClient("knowrob/askone", AskOneAction) + + # AskOne + self._ask_one_client = actionlib.SimpleActionClient( + 'knowrob/askone', AskOneAction) self._ask_one_client.wait_for_server() - self._ask_all_client = actionlib.SimpleActionClient("knowrob/askall", AskAllAction) + + # AskAll + self._ask_all_client = actionlib.SimpleActionClient( + 'knowrob/askall', AskAllAction) self._ask_all_client.wait_for_server() + # AskIncremental (start) + self._ask_incremental_client = actionlib.SimpleActionClient( + 'knowrob/askincremental', AskIncrementalAction) + self._ask_incremental_client.wait_for_server() + + # AskIncremental (next solution) + self._ask_incremental_next_client = actionlib.SimpleActionClient( + 'knowrob/askincremental_next_solution', AskIncrementalNextSolutionAction) + self._ask_incremental_next_client.wait_for_server() + + # Tell + self._tell_client = actionlib.SimpleActionClient( + 'knowrob/tell', TellAction) + self._tell_client.wait_for_server() + + # Finish incremental query service + rospy.wait_for_service('knowrob/askincremental_finish') + self._ask_incremental_finish = rospy.ServiceProxy( + 'knowrob/askincremental_finish', AskIncrementalFinish) + def shutdown_node(self): - rospy.signal_shutdown("KnowRob node shutdown") - if self._ask_one_client: - self._ask_one_client.cancel_all_goals() + """ + Shutdown the ROS node and cancel any pending goals. + """ + rospy.signal_shutdown('KnowRob node shutdown') + + for client in ( + self._ask_one_client, + self._ask_all_client, + self._ask_incremental_client, + self._ask_incremental_next_client, + self._tell_client, + ): + if client: + client.cancel_all_goals() def ask_one(self, query, modal_frame, lang=GraphQueryMessage.LANG_FOL): + """ + Send an AskOne query to KnowRob and wait for a single result. + + Returns: + AskOneResult + """ goal = AskOneGoal() goal.query.queryString = query goal.query.frame = modal_frame goal.query.lang = lang + self._ask_one_client.send_goal(goal) self._ask_one_client.wait_for_result() - result = self._ask_one_client.get_result() - return result - + return self._ask_one_client.get_result() + def ask_all(self, query, modal_frame, lang=GraphQueryMessage.LANG_FOL): + """ + Send an AskAll query to KnowRob and wait for all matching results. + + Returns: + AskAllResult + """ goal = AskAllGoal() goal.query.queryString = query goal.query.frame = modal_frame goal.query.lang = lang + self._ask_all_client.send_goal(goal) self._ask_all_client.wait_for_result() - result = self._ask_all_client.get_result() - return result - - # def tell(self, triples_str): - # request = TellRequest() - # request.query.query_string = triples_str - # response = self._tell_service(request) - # return TellResultAdapter(response) + return self._ask_all_client.get_result() - # def ask_all(self, query): - # # This is a stub assuming synchronous call, you'd use ROS service or action here too - # # Replace with actual implementation for asking all - # return GraphResultAdapter(GraphAnswerMessage(bindings=[KeyValuePair(key="X", value="hans")])) + def ask_incremental(self, query, modal_frame, lang=GraphQueryMessage.LANG_FOL): + """ + Start an incremental query. The server returns a status and a queryId + that can be used to fetch solutions one by one. -def get_default_modalframe(): - modalframe = ModalFrame() - modalframe.epistemicOperator = ModalFrame.KNOWLEDGE - modalframe.temporalOperator = ModalFrame.CURRENTLY - modalframe.minPastTimestamp = ModalFrame.UNSPECIFIED_TIMESTAMP - modalframe.maxPastTimestamp = ModalFrame.UNSPECIFIED_TIMESTAMP - modalframe.confidence = 0.0 - return modalframe - -def graph_answer_to_dict(answer_msg): - """ - Convert a GraphAnswerMessage to a dictionary format. - The dictionary will have the keys as the variable names and the values as the corresponding values. - """ + Returns: + AskIncrementalResult + """ + goal = AskIncrementalGoal() + goal.query.queryString = query + goal.query.frame = modal_frame + goal.query.lang = lang - results = {} + self._ask_incremental_client.send_goal(goal) + self._ask_incremental_client.wait_for_result() + return self._ask_incremental_client.get_result() - for binding_group in answer_msg.substitution: - if binding_group.type == KeyValuePair.TYPE_STRING: - results[binding_group.key] = binding_group.value_string - elif binding_group.type == KeyValuePair.TYPE_FLOAT: - results[binding_group.key] = binding_group.value_float - elif binding_group.type == KeyValuePair.TYPE_INT: - results[binding_group.key] = binding_group.value_int - elif binding_group.type == KeyValuePair.TYPE_LONG: - results[binding_group.key] = binding_group.value_long - elif binding_group.type == KeyValuePair.TYPE_VARIABLE: - results[binding_group.key] = binding_group.value_variable - elif binding_group.type == KeyValuePair.TYPE_PREDICATE: - results[binding_group.key] = binding_group.value_predicate - elif binding_group.type == KeyValuePair.TYPE_LIST: - # Lists are stored as a raw string and require custom parsing - results[binding_group.key] = binding_group.value_list - else: - # Throw an error or handle unknown types - raise ValueError(f"Unknown type: {binding_group.type}") + def next_solution(self, query_id): + """ + Retrieve the next solution for an active incremental query. - return results + Args: + query_id (int): ID from ask_incremental(). -def graph_answers_to_list(answer_msgs): - """ - Convert a list of GraphAnswerMessage to a list of dictionaries. - Each dictionary will have the keys as the variable names and the values as the corresponding values. - """ - results = [] - for answer_msg in answer_msgs: - result = graph_answer_to_dict(answer_msg) - results.append(result) - return results - + Returns: + AskIncrementalNextSolutionResult + """ + goal = AskIncrementalNextSolutionGoal() + goal.queryId = query_id -# class GraphResultAdapter: -# def __init__(self, msg): -# self.bindings = ( -# {kv.key: kv.value for kv in msg.bindings} -# if isinstance(msg.bindings, list) -# else msg.bindings -# ) + self._ask_incremental_next_client.send_goal(goal) + self._ask_incremental_next_client.wait_for_result() + return self._ask_incremental_next_client.get_result() + def finish_incremental(self, query_id): + """ + Finish an incremental query, releasing server-side resources. -# class TellResultAdapter: -# def __init__(self, response): -# self.success = response.success + Args: + query_id (int): ID from ask_incremental(). + Returns: + bool: True if the finish call succeeded. + """ + resp = self._ask_incremental_finish(query_id) + return resp.success -class TripleQueryBuilder: - def __init__(self): - self.triples = [] - - def add(self, subject, predicate, obj): - """Add a triple to the list.""" - self.triples.append((subject, predicate, obj)) + def tell(self, list_of_triples, modal_frame): + """ + Send a set of RDF-style triples to the KnowRob knowledge base. - def build_query_string(self): - """Generate a Prolog-style query string.""" - return ', '.join(f'{pred}({subj},{obj})' for subj, pred, obj in self.triples) + Returns: + TellResult + """ + goal = TellGoal() + goal.tell.triples = list_of_triples + goal.tell.frame = modal_frame + self._tell_client.send_goal(goal) + self._tell_client.wait_for_result() + return self._tell_client.get_result() -# Module-level functions -_knowrob_instance = KnowRobRosLib() +def get_default_modalframe(): + """ + Create a default ModalFrame with knowledge, current time, and unspecified timestamps. -def init_node(name): - _knowrob_instance.init_node(name) + Returns: + ModalFrame + """ + modal_frame = ModalFrame() + modal_frame.epistemicOperator = ModalFrame.KNOWLEDGE + modal_frame.temporalOperator = ModalFrame.CURRENTLY + modal_frame.minPastTimestamp = ModalFrame.UNSPECIFIED_TIMESTAMP + modal_frame.maxPastTimestamp = ModalFrame.UNSPECIFIED_TIMESTAMP + modal_frame.confidence = 0.0 + return modal_frame -def shutdown_node(): - _knowrob_instance.shutdown_node() +def graph_answer_to_dict(answer_msg): + """ + Convert a GraphAnswerMessage into a Python dict mapping variable names to values. + """ + results = {} + for kv in answer_msg.substitution: + if kv.type == KeyValuePair.TYPE_STRING: + results[kv.key] = kv.value_string + elif kv.type == KeyValuePair.TYPE_FLOAT: + results[kv.key] = kv.value_float + elif kv.type == KeyValuePair.TYPE_INT: + results[kv.key] = kv.value_int + elif kv.type == KeyValuePair.TYPE_LONG: + results[kv.key] = kv.value_long + elif kv.type == KeyValuePair.TYPE_VARIABLE: + results[kv.key] = kv.value_variable + elif kv.type == KeyValuePair.TYPE_PREDICATE: + results[kv.key] = kv.value_predicate + elif kv.type == KeyValuePair.TYPE_LIST: + results[kv.key] = kv.value_list + else: + raise ValueError(f"Unknown KeyValuePair type: {kv.type}") + return results -def ask_one(query): - return _knowrob_instance.ask_one(query) +def graph_answers_to_list(answer_msgs): + """ + Convert multiple GraphAnswerMessages into a list of dicts. + """ + return [graph_answer_to_dict(msg) for msg in answer_msgs] -# def ask_all(query): -# return _knowrob_instance.ask_all(query) +class TripleQueryBuilder: + """ + Helper to build lists of Triple messages for assertions. + Usage: + builder = TripleQueryBuilder() + builder.add(subject, predicate, object) + triples = builder.get_triples() + """ + def __init__(self): + self.triples = [] -# def tell(triples_str): -# return _knowrob_instance.tell(triples_str) + def add(self, subject, predicate, obj): + """ + Add a new Triple to the builder. + + Args: + subject (str) + predicate (str) + obj (str) + """ + triple = Triple() + triple.subject = subject + triple.predicate = predicate + triple.object = obj + self.triples.append(triple) + + def get_triples(self): + """ + Return the collected triples. + """ + return self.triples