diff --git a/src/software/thunderscope/constants.py b/src/software/thunderscope/constants.py index 217d87a302..1ba4d0e4a5 100644 --- a/src/software/thunderscope/constants.py +++ b/src/software/thunderscope/constants.py @@ -413,3 +413,18 @@ class RuntimeManagerConstants: RELEASES_URL = "https://api.github.com/repos/UBC-Thunderbots/Software/releases" DOWNLOAD_URL = "https://github.com/UBC-Thunderbots/Software/releases/download/" MAX_RELEASES_FETCHED = 5 + + +class PassResultsConstants: + PASS_RESULTS_DIRECTORY_PATH = "/tmp/tbots/ml" + PASS_RESULTS_FILE_NAME_TEMPLATE = "pass_results_{interval}.csv" + + FRIENDLY_GOAL_SCORE = 10 + ENEMY_GOAL_SCORE = -FRIENDLY_GOAL_SCORE + FRIENDLY_POSSESSION_SCORE = 2 + ENEMY_POSSESSION_SCORE = -FRIENDLY_POSSESSION_SCORE + NEUTRAL_SCORE = 0 + + # the time intervals to log results for after each pass + # so after a pass, wait X seconds and then log game state + INTERVALS_S = [1, 5, 10] diff --git a/src/software/thunderscope/log/stats/BUILD b/src/software/thunderscope/log/stats/BUILD index e1dbd7a41d..9b70246035 100644 --- a/src/software/thunderscope/log/stats/BUILD +++ b/src/software/thunderscope/log/stats/BUILD @@ -8,11 +8,20 @@ py_library( ], ) +py_library( + name = "pass_results", + srcs = ["pass_results.py"], + deps = [ + "//software/thunderscope/log/trackers:tracker_module", + ], +) + py_library( name = "stats", srcs = ["stats.py"], deps = [ ":fullsystem_stats", + ":pass_results", "//software/thunderscope:thread_safe_buffer", ], ) diff --git a/src/software/thunderscope/log/stats/pass_results.py b/src/software/thunderscope/log/stats/pass_results.py index f5e86fb552..88457eb3cb 100644 --- a/src/software/thunderscope/log/stats/pass_results.py +++ b/src/software/thunderscope/log/stats/pass_results.py @@ -1,17 +1,231 @@ +from software.thunderscope.log.trackers import ( + PossessionTracker, + PassTracker, + TrackerBuilder, + RefereeTracker, +) +from software.thunderscope.proto_unix_io import ProtoUnixIO +from datetime import datetime +from dataclasses import dataclass +from software.thunderscope.constants import PassResultsConstants +import os +from proto.import_all_protos import * + + +@dataclass +class PassLog: + pass_: Pass + timestamp: datetime + friendly_score: int + enemy_score: int + + class PassResultsTracker: """Class to track the results of any passes taken i.e looking at if our position in the game got better or worse after certain time intervals """ - def __init__(self, friendly_colour_yellow: bool, buffer_size: int = 5): + PASS_RESULTS_TEMPLATE = ( + "{pass_start_x},{pass_start_y}," + "{pass_end_x},{pass_end_y}," + "{speed}," + "{score}\n" + ) + + def __init__( + self, + proto_unix_io: ProtoUnixIO, + friendly_colour_yellow: bool, + buffer_size: int = 5, + ): + """Initializes the pass resuidxlts tracker + + :param proto_unix_io: the proto unix io to use + :param friendly_colour_yellow: if the friendly color is yellow or not + :param buffer_size: buffer size to use + """ self.friendly_colour_yellow = friendly_colour_yellow - pass + self.tracker = ( + TrackerBuilder(proto_unix_io=proto_unix_io) + .add_tracker(PassTracker, callback=self._add_pass_timestamp) + .add_tracker(PossessionTracker, callback=self._update_friendly_possession) + .add_tracker( + RefereeTracker, + callback=self._update_scores_friendly, + friendly_color_yellow=self.friendly_colour_yellow, + ) + .add_tracker( + RefereeTracker, + callback=self._update_scores_friendly, + friendly_color_yellow=(not self.friendly_colour_yellow), + ) + ) + + self.pass_times_map: dict[int, list[PassLog]] = { + interval: [] for interval in PassResultsConstants.INTERVALS_S + } + + self.is_friendly_possession: bool | None = False + self.friendly_score = 0 + self.enemy_score = 0 + + self.pass_results_file_map = {} + + self.world = None + + def setup(self): + """Creates the relevant directories and a csv file for each of the + intervals in INTERVALS + """ + pass_results_dir = PassResultsConstants.PASS_RESULTS_DIRECTORY_PATH + + # create all directories in path if they doesn't exist + os.makedirs(os.path.dirname(pass_results_dir), exist_ok=True) + + for interval in PassResultsConstants.INTERVALS_S: + file_path = os.path.join( + pass_results_dir, + PassResultsConstants.PASS_RESULTS_FILE_NAME_TEMPLATE.format( + interval=interval + ), + ) + + is_new_file = not os.path.exists(file_path) + + self.pass_results_file_map[interval] = open( + file_path, + "a", + ) + + # write the headers first if the file doesn't already exist + if is_new_file: + self.pass_results_file_map[interval].write( + self._get_pass_result_headers() + ) + self.pass_results_file_map[interval].flush() + + def cleanup(self): + """Flushes content and closes all the files for all intervals""" + for interval in PassResultsConstants.INTERVALS_S: + if self.pass_results_file_map[interval]: + self.pass_results_file_map[interval].flush() + self.pass_results_file_map[interval].close() + + def _update_friendly_possession(self, is_friendly_possession: bool | None) -> None: + self.is_friendly_possession = is_friendly_possession + + def _update_scores_friendly(self, friendly_score: int, *_) -> None: + self.friendly_score = friendly_score + + def _update_scores_enemy(self, enemy_score: int, *_) -> None: + self.enemy_score = enemy_score + + def _add_pass_timestamp(self, pass_: Pass) -> None: + """Adds the given pass, the current timestamp, and the current scores to the lowest interval's list + :param pass_: the pass to add + """ + # TODO: use world timestamp time instead of datetime time + self.pass_times_map[PassResultsConstants.INTERVALS_S[0]].append( + PassLog( + pass_=pass_, + timestamp=datetime.now(), + friendly_score=self.friendly_score, + enemy_score=self.enemy_score, + ) + ) def refresh(self) -> None: - """Refreshes the kick tracker so we stay up to date on new passes""" - pass + """Refreshes the tracker so we stay up to date on new passes + and checks to see if any passes are older than their interval + """ + self.tracker.refresh() + + self._update_pass_timestamps() + + def _log_pass_result(self, logged_pass: PassLog, interval_s: int) -> None: + """For an already recorded pass, calculates and logs its score for the given interval + i.e after seconds following the pass + + :param logged_pass: a pass that already occurred that we want to find the score for + :param interval_s: how long (in seconds) it has been after the pass + """ + pass_score = self._get_pass_score(logged_pass) + + self._log_pass_result_to_file( + file=self.pass_results_file_map[interval_s], + pass_=logged_pass.pass_, + score=pass_score, + ) + + def _get_pass_score(self, logged_pass: PassLog) -> int: + """For the given logged pass, get the score based on the current game state + If the friendly / enemy scores at the time of pass are different + Or if possession has changed, return the corresponding score + Else, return the neutral score + + :param logged_pass: the pass to score + :return: a single integer score for the pass + """ + if self.friendly_score > logged_pass.friendly_score: + return PassResultsConstants.FRIENDLY_GOAL_SCORE + + if self.enemy_score > logged_pass.enemy_score: + return PassResultsConstants.ENEMY_GOAL_SCORE + + if self.is_friendly_possession: + return PassResultsConstants.FRIENDLY_POSSESSION_SCORE + elif self.is_friendly_possession is False: + return PassResultsConstants.ENEMY_POSSESSION_SCORE + + return PassResultsConstants.NEUTRAL_SCORE + + def _get_pass_result_headers(self): + return self.PASS_RESULTS_TEMPLATE.replace("{", "").replace("}", "") + + def _log_pass_result_to_file(self, file, pass_: Pass, score: int) -> None: + """Logs a single pass's result to the given file handle + + :param file: the file handle to write to + :param pass_: the pass to log + :param score: the score for the given pass + """ + pass_result_string = self.PASS_RESULTS_TEMPLATE.format( + pass_start_x=pass_.passer_point.x_meters, + pass_start_y=pass_.passer_point.y_meters, + pass_end_x=pass_.receiver_point.x_meters, + pass_end_y=pass_.receiver_point.y_meters, + speed=pass_.pass_speed_m_per_s, + score=score, + ) + + file.write(pass_result_string) + file.flush() + + def _update_pass_timestamps(self): + """For all currently logged passes, check if the interval they belong to has passed + If so, log their score to the corresponding file + And move them to the next interval if exists + """ + for idx, interval in enumerate(PassResultsConstants.INTERVALS_S): + pass_timestamps = self.pass_times_map[interval] + + # TODO: use world timestamp time instead of datetime time + time_now = datetime.now() + + while ( + pass_timestamps + and (time_now - pass_timestamps[0].timestamp).total_seconds() > interval + ): + pass_with_timestamp = pass_timestamps.pop(0) + print( + f"Pass {pass_with_timestamp.pass_} is older than interval {interval}" + ) + + self._log_pass_result(pass_with_timestamp, interval) - def record_pass_taken(self, pass_taken: Pass): - pass + if idx < len(PassResultsConstants.INTERVALS_S) - 1: + self.pass_times_map[ + PassResultsConstants.INTERVALS_S[idx + 1] + ].append(pass_with_timestamp) diff --git a/src/software/thunderscope/log/stats/stats.py b/src/software/thunderscope/log/stats/stats.py index de7e9074fc..3dbac1077a 100644 --- a/src/software/thunderscope/log/stats/stats.py +++ b/src/software/thunderscope/log/stats/stats.py @@ -1,4 +1,5 @@ from software.thunderscope.log.stats.fullsystem_stats import FullSystemStats +from software.thunderscope.log.stats.pass_results import PassResultsTracker from software.thunderscope.proto_unix_io import ProtoUnixIO from proto.import_all_protos import * @@ -22,12 +23,19 @@ def __init__( record_enemy_stats=record_enemy_stats, ) + self.pass_results = PassResultsTracker( + proto_unix_io=proto_unix_io, friendly_colour_yellow=friendly_color_yellow + ) + def refresh(self): self.fs_stats.refresh() + self.pass_results.refresh() def __enter__(self): self.fs_stats.setup() + self.pass_results.setup() return self def __exit__(self, exc_type, exc_value, traceback): self.fs_stats.cleanup() + self.pass_results.cleanup()