diff --git a/usersimcrs/simulator/neural/core/transformer.py b/usersimcrs/simulator/neural/core/transformer.py index 6727112b..4b91e354 100644 --- a/usersimcrs/simulator/neural/core/transformer.py +++ b/usersimcrs/simulator/neural/core/transformer.py @@ -1,4 +1,11 @@ -"""Encoder-only transformer model for neural user simulator.""" +"""Encoder-only transformer model for neural user simulator. + +Implementation inspired by PyTorch documentation and TUS's transformer model. + +Sources: +https://colab.research.google.com/github/pytorch/tutorials/blob/gh-pages/_downloads/dca13261bbb4e9809d1a3aa521d22dd7/transformer_tutorial.ipynb#scrollTo=R8veciavth40 +https://gitlab.cs.uni-duesseldorf.de/general/dsml/tus_public/-/blob/master/convlab2/policy/tus/multiwoz/transformer.py?ref_type=heads +""" import math @@ -54,7 +61,6 @@ def __init__( nhead: int, hidden_dim: int, num_encoder_layers: int, - num_token: int, dropout: float = 0.5, ) -> None: """Initializes a encoder-only transformer model. @@ -69,15 +75,15 @@ def __init__( dropout: Dropout rate. Defaults to 0.5. """ super(TransformerEncoderModel, self).__init__() - self.d_model = input_dim + self.d_model = hidden_dim - self.pos_encoder = PositionalEncoding(input_dim, dropout) - self.embedding = nn.Embedding(num_token, input_dim) + self.pos_encoder = PositionalEncoding(hidden_dim, dropout) + self.embedding = nn.Linear(input_dim, hidden_dim) # Encoder layers - norm_layer = nn.LayerNorm(input_dim) + norm_layer = nn.LayerNorm(hidden_dim) encoder_layer = nn.TransformerEncoderLayer( - d_model=input_dim, + d_model=self.d_model, nhead=nhead, dim_feedforward=hidden_dim, ) @@ -87,7 +93,7 @@ def __init__( norm=norm_layer, ) - self.linear = nn.Linear(input_dim, output_dim) + self.linear = nn.Linear(hidden_dim, output_dim) self.softmax = nn.Softmax(dim=-1) self.init_weights() @@ -113,6 +119,8 @@ def forward( """ src = self.embedding(src) * math.sqrt(self.d_model) src = self.pos_encoder(src) - output = self.encoder(src, mask=src_mask) + src = src.permute(1, 0, 2) + output = self.encoder(src, src_key_padding_mask=src_mask) output = self.linear(output) + output = output.permute(1, 0, 2) return output diff --git a/usersimcrs/simulator/neural/tus/tus.py b/usersimcrs/simulator/neural/tus/tus.py new file mode 100644 index 00000000..a668db5a --- /dev/null +++ b/usersimcrs/simulator/neural/tus/tus.py @@ -0,0 +1,235 @@ +"""Transformer-based User Simulator (TUS) + +Reference: Domain-independent User Simulation with Transformers for +Task-oriented Dialogue Systems, Lin et al., 2021. +See: https://arxiv.org/abs/2106.08838 + +Implementation is adapted from the description in the paper and the original +implementation by the authors: +https://gitlab.cs.uni-duesseldorf.de/general/dsml/tus_public +""" + +import logging +import random +from collections import defaultdict +from typing import Any, DefaultDict, Dict, List + +import torch + +from dialoguekit.core.dialogue_act import DialogueAct +from dialoguekit.core.slot_value_annotation import SlotValueAnnotation +from dialoguekit.core.utterance import Utterance +from dialoguekit.nlg.nlg_conditional import ConditionalNLG +from dialoguekit.nlu.nlu import NLU +from dialoguekit.participant import DialogueParticipant +from usersimcrs.core.simulation_domain import SimulationDomain +from usersimcrs.dialogue_management.dialogue_state_tracker import ( + DialogueStateTracker, +) +from usersimcrs.items.item_collection import ItemCollection +from usersimcrs.simulator.neural.core.feature_handler import ( + FeatureMask, + FeatureVector, +) +from usersimcrs.simulator.neural.core.transformer import ( + TransformerEncoderModel, +) +from usersimcrs.simulator.neural.tus.tus_feature_handler import ( + TUSFeatureHandler, +) +from usersimcrs.simulator.user_simulator import UserSimulator + +logger = logging.getLogger(__name__) + + +class TUS(UserSimulator): + def __init__( + self, + id: str, + domain: SimulationDomain, + item_collection: ItemCollection, + nlu: NLU, + feature_handler: TUSFeatureHandler, + dialogue_state_tracker: DialogueStateTracker, + network_config: Dict[str, Any], + nlg: ConditionalNLG, + ) -> None: + """Initializes the Transformer-based User Simulator (TUS). + + Args: + id: Simulator ID. + domain: Domain knowledge. + item_collection: Collection of items. + nlu: NLU module. + feature_handler: Feature handler. + dialogue_state_tracker: Dialogue state tracker. + network_config: Network configuration. + nlg: NLG module generating textual responses. + """ + super().__init__(id=id, domain=domain, item_collection=item_collection) + self._nlu = nlu + self._nlg = nlg + self._tus_feature_handler = feature_handler + self._user_policy_network = TransformerEncoderModel(**network_config) + self._dialogue_state_tracker = dialogue_state_tracker + self._last_user_actions: DefaultDict[str, torch.Tensor] = defaultdict( + lambda: torch.tensor([]) + ) + self._last_turn_input: torch.Tensor = None + + def initialize(self) -> None: + """Initializes the user simulator.""" + self._dialogue_state_tracker.reset_state() + self._last_user_actions.clear() + self._last_turn_input = None + + def _generate_response(self, agent_utterance: Utterance) -> Utterance: + """Generates response to the agent utterance. + + Args: + agent_utterance: Agent utterance. + + Returns: + User utterance. + """ + previous_state = self._dialogue_state_tracker.get_current_state() + # 1. Perform NLU on the agent utterance, i.e., extract dialogue acts. + agent_dialogue_acts = self._nlu.extract_dialogue_acts(agent_utterance) + + # 2. Update dialogue state based on the agent dialogue acts. + self._dialogue_state_tracker.update_state( + dialogue_acts=agent_dialogue_acts, + participant=DialogueParticipant.AGENT, + ) + + # 3. Extract features for the current turn. + turn_feature, mask = self._tus_feature_handler.build_input_vector( + agent_dialogue_acts=agent_dialogue_acts, + previous_state=previous_state, + state=self._dialogue_state_tracker.get_current_state(), + information_need=self.information_need, + user_action_vectors=self._last_user_actions, + ) + + # 5. Predict user dialogue acts based on the features. + user_dialogue_acts = self.predict_user_dialogue_acts( + turn_feature, mask, self._tus_feature_handler.action_slots + ) + + # 6. Generate user utterance based on the predicted dialogue acts. + response = self._nlg.generate_utterance_text(user_dialogue_acts) + response.participant = DialogueParticipant.USER + + # 7. Update dialogue state based on the user dialogue acts. + self._dialogue_state_tracker.update_state( + dialogue_acts=response.dialogue_acts, + participant=DialogueParticipant.USER, + ) + + return response + + def predict_user_dialogue_acts( + self, + features: List[FeatureVector], + mask: FeatureMask, + action_slots: List[str], + ) -> List[DialogueAct]: + """Predicts user dialogue acts based on the features. + + Args: + features: Feature vector. + mask: Mask vector. + action_slots: Action slots used to predict the user action per slot. + + Returns: + Predicted user dialogue acts. + """ + output = self._user_policy_network(features, mask) + # fmt: off + output = output[ + :, 1 : self._tus_feature_handler.max_turn_feature_length + 1, : # noqa: E203, E501 + ] + # fmt: on + + slot_outputs: Dict[str, int] = {} + for index, slot_name in enumerate(action_slots): + o = int(torch.argmax(output[0, index + 1, :]).item()) + assert o in range(6), f"Invalid output: {o}" + slot_outputs[slot_name] = o + # One-hot encoding of user action for the slot + o_i = torch.zeros(6) + o_i[o] = 1 + self._last_user_actions[slot_name] = o_i + + user_dialogue_acts = self._parse_policy_output( + action_slots, slot_outputs + ) + return user_dialogue_acts + + def _parse_policy_output( + self, action_slots: List[str], slot_outputs: Dict[str, int] + ) -> List[DialogueAct]: + """Parses the policy output to dialogue acts. + + Args: + action_slots: Action slots. + slot_outputs: Output per slot. + + Returns: + Dialogue acts. + """ + belief_state = ( + self._dialogue_state_tracker.get_current_state().belief_state + ) + dialogue_acts = [] + + for slot in action_slots: + o = slot_outputs[slot] + dialogue_act = DialogueAct() + + # Default intent is "inform" + dialogue_act.intent = "inform" + + # Determine the value of the slot + if o == 1: + # The slot's value is requested by the user + dialogue_act.intent = "request" + dialogue_act.annotations.append(SlotValueAnnotation(slot)) + elif o == 2: + # The slot's value is set to "dontcare" + dialogue_act.annotations.append( + SlotValueAnnotation(slot, "dontcare") + ) + elif o == 3: + # The slot's value is taken from the information need + if slot in self.information_need.constraints.keys(): + dialogue_act.annotations.append( + SlotValueAnnotation( + slot, self.information_need.constraints[slot] + ) + ) + elif o == 4: + # The slot's value was previously mentioned and is retrieved + # from the belief state + if slot in belief_state.keys(): + dialogue_act.annotations.append( + SlotValueAnnotation(slot, belief_state[slot]) + ) + elif o == 5: + # The slot's value in the information need is randomly modified + value = random.choice( + list( + self._item_collection.get_possible_property_values(slot) + ) + ) + self.information_need.constraints[slot] = value + dialogue_act.annotations.append( + SlotValueAnnotation(slot, value) + ) + else: + logger.warning(f"{slot} is not mentioned in this turn.") + continue + + dialogue_acts.append(dialogue_act) + + return dialogue_acts diff --git a/usersimcrs/simulator/neural/tus/tus_feature_handler.py b/usersimcrs/simulator/neural/tus/tus_feature_handler.py index 1cab3dca..89fabc40 100644 --- a/usersimcrs/simulator/neural/tus/tus_feature_handler.py +++ b/usersimcrs/simulator/neural/tus/tus_feature_handler.py @@ -14,8 +14,8 @@ import torch from dialoguekit.core.annotated_utterance import AnnotatedUtterance -from dialoguekit.core.annotation import Annotation from dialoguekit.core.dialogue_act import DialogueAct +from dialoguekit.core.slot_value_annotation import SlotValueAnnotation from usersimcrs.core.information_need import InformationNeed from usersimcrs.core.simulation_domain import SimulationDomain from usersimcrs.dialogue_management.dialogue_state import DialogueState @@ -342,7 +342,7 @@ def build_input_vector( a special token. Note that is inferred from the list of feature vectors provided. The input vector is structured as follows: - [CLS] $V^t$ [SEP] $V^{t-1}$ [SEP] ... [SEP] $V^{t-n}$ [SEP] {padding}, + [CLS] $V^t$ [SEP] $V^{t-1}$ [SEP] ... [SEP] $V^{t-n}$ [SEP] {padding} where {padding} indicates the padding to reach the maximum length. Args: @@ -429,7 +429,7 @@ def get_label_vector( def _get_label( self, - annotation: Annotation, + slot_value_annotation: SlotValueAnnotation, current_state: DialogueState, information_need: InformationNeed, ) -> int: @@ -445,42 +445,43 @@ def _get_label( 5: The slot's value is randomly chosen. Args: - annotation: Annotation. + slot_value_annotation: Slot value pair annotation. current_state: Current state. information_need: Information need. Returns: Label. """ - if annotation.value == "dontcare": + if slot_value_annotation.value == "dontcare": return 1 - elif annotation.value is None: + elif slot_value_annotation.value is None: # The value is requested by the user return 2 - elif annotation.value == information_need.get_constraint_value( - annotation.slot - ) or annotation.value == information_need.requested_slots.get( - annotation.slot + elif ( + slot_value_annotation.value + == information_need.get_constraint_value(slot_value_annotation.slot) + or slot_value_annotation.value + == information_need.requested_slots.get(slot_value_annotation.slot) ): # The value is taken from the information need return 3 - elif annotation.value == current_state.belief_state.get( - annotation.slot + elif slot_value_annotation.value == current_state.belief_state.get( + slot_value_annotation.slot ): # The value was previously mentioned and is # retrieved from the belief state return 4 elif ( - annotation.slot + slot_value_annotation.slot in [ information_need.constraints.keys(), information_need.requested_slots.keys(), current_state.belief_state.keys(), ] - ) and annotation.value not in [ - information_need.get_constraint_value(annotation.slot), - information_need.requested_slots.get(annotation.slot), - current_state.belief_state.get(annotation.slot), + ) and slot_value_annotation.value not in [ + information_need.get_constraint_value(slot_value_annotation.slot), + information_need.requested_slots.get(slot_value_annotation.slot), + current_state.belief_state.get(slot_value_annotation.slot), ]: # The slot's value is randomly chosen return 5 diff --git a/usersimcrs/simulator/user_simulator.py b/usersimcrs/simulator/user_simulator.py index af238c4e..d8ef359d 100644 --- a/usersimcrs/simulator/user_simulator.py +++ b/usersimcrs/simulator/user_simulator.py @@ -17,7 +17,13 @@ def __init__( domain: SimulationDomain, item_collection: ItemCollection, ) -> None: - """Initializes the user simulator.""" + """Initializes the user simulator. + + Args: + id: User ID. + domain: Domain knowledge. + item_collection: Collection of items. + """ super().__init__(id, UserType.SIMULATOR) self._domain = domain self._item_collection = item_collection