diff --git a/adapt/intent.py b/adapt/intent.py index 8abb662..805eceb 100644 --- a/adapt/intent.py +++ b/adapt/intent.py @@ -117,7 +117,7 @@ def resolve_one_of(tags, at_least_one): class Intent(object): - def __init__(self, name, requires, at_least_one, optional): + def __init__(self, name, requires, at_least_one, optional, excludes=None): """Create Intent object Args: @@ -130,6 +130,7 @@ def __init__(self, name, requires, at_least_one, optional): self.requires = requires self.at_least_one = at_least_one self.optional = optional + self.excludes = excludes or [] def validate(self, tags, confidence): """Using this method removes tags from the result of validate_with_tags @@ -160,6 +161,14 @@ def validate_with_tags(self, tags, confidence): local_tags = tags[:] used_tags = [] + # Check excludes first + for exclude_type in self.excludes: + exclude_tag, _canonical_form, _tag_confidence = \ + find_first_tag(local_tags, exclude_type) + if exclude_tag: + result['confidence'] = 0.0 + return result, [] + for require_type, attribute_name in self.requires: required_tag, canonical_form, tag_confidence = \ find_first_tag(local_tags, require_type) @@ -243,6 +252,7 @@ def __init__(self, intent_name): """ self.at_least_one = [] self.requires = [] + self.excludes = [] self.optional = [] self.name = intent_name @@ -277,6 +287,19 @@ def require(self, entity_type, attribute_name=None): self.requires += [(entity_type, attribute_name)] return self + def exclude(self, entity_type): + """ + The intent parser must not contain an entity of the provided type. + + Args: + entity_type(str): an entity type + + Returns: + self: to continue modifications. + """ + self.excludes.append(entity_type) + return self + def optionally(self, entity_type, attribute_name=None): """ Parsed intents from this parser can optionally include an entity of the @@ -302,4 +325,5 @@ def build(self): :return: an Intent instance. """ return Intent(self.name, self.requires, - self.at_least_one, self.optional) + self.at_least_one, self.optional, + self.excludes) diff --git a/test/IntentEngineTest.py b/test/IntentEngineTest.py index 7cd4203..77ea4e6 100644 --- a/test/IntentEngineTest.py +++ b/test/IntentEngineTest.py @@ -225,3 +225,30 @@ def testResultsAreSortedByConfidence(self): assert len(confidences) > 1 assert all(confidences[i] >= confidences[i+1] for i in range(len(confidences)-1)) + def testExclude(self): + parser1 = IntentBuilder("Parser1").require("Entity1").exclude("Entity2").build() + self.engine.register_intent_parser(parser1) + + parser2 = IntentBuilder("Parser2").require("Entity1").exclude("Entity3").build() + self.engine.register_intent_parser(parser2) + + self.engine.register_entity("go", "Entity1") + self.engine.register_entity("tree", "Entity2") + self.engine.register_entity("house", "Entity3") + + # Parser 1 cannot contain the word tree + utterance = "go to the tree" + intent = next(self.engine.determine_intent(utterance)) + assert intent + assert intent['intent_type'] == 'Parser2' + + # Parser 2 cannot contain the word house + utterance = "go to the house" + intent = next(self.engine.determine_intent(utterance)) + assert intent + assert intent['intent_type'] == 'Parser1' + + # Should fail because both excluded words are present + utterance = "go to the tree house" + with self.assertRaises(StopIteration): + intent = next(self.engine.determine_intent(utterance))