Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
# [NeurIPS 2023] Reflexion: Language Agents with Verbal Reinforcement Learning
# [NeurIPS 2023] Reflexion: Language Agents with Verbal Reinforcement Learning (Replication + Extension)

This repo holds the code, demos, and log files for [Reflexion: Language Agents with Verbal Reinforcement Learning](https://arxiv.org/abs/2303.11366) by Noah Shinn, Federico Cassano, Edward Berman, Ashwin Gopinath, Karthik Narasimhan, Shunyu Yao.
<!-- This repo holds the code, demos, and log files for [Reflexion: Language Agents with Verbal Reinforcement Learning](https://arxiv.org/abs/2303.11366) by Noah Shinn, Federico Cassano, Edward Berman, Ashwin Gopinath, Karthik Narasimhan, Shunyu Yao. -->

![Reflexion RL diagram](./figures/reflexion_rl.png)
<!-- ![Reflexion RL diagram](./figures/reflexion_rl.png) -->

![Reflexion tasks](./figures/reflexion_tasks.png)
<!-- ![Reflexion tasks](./figures/reflexion_tasks.png) -->

We have released the LeetcodeHardGym [here](https://github.com/GammaTauAI/leetcode-hard-gym)
<!-- We have released the LeetcodeHardGym [here](https://github.com/GammaTauAI/leetcode-hard-gym) -->

## To Run: reasoning (HotPotQA)

We have provided a set of notebooks to easily run, explore, and interact with the results of the reasoning experiments. Each experiment consists of a random sample of 100 questions from the HotPotQA distractor dataset. Each question in the sample is attempted by an agent with a specific type and reflexion strategy.

### Setup
### Setup (for HotPotQA)

To get started:

Expand All @@ -22,9 +22,22 @@ To get started:
git clone https://github.com/noahshinn/reflexion && cd ./hotpotqa_runs
```

2. Install the module dependencies into your environment:
2. Use the right python and install module dependencies into your environment:
*Get pyenv here if needed*: <https://github.com/pyenv/pyenv>

```bash
# Use python version 3.11.9

pyenv install 3.11.9
pyenv local 3.11.9
python -V


python -m venv .venv
source .venv/bin/activate

# Install dependencies
python -m pip install –upgrade pip
pip install -r requirements.txt
```

Expand All @@ -40,7 +53,7 @@ Agent type is determined by the notebook you choose to run. The available agent

- `ReAct` - ReAct Agent

- `CoT_context` - CoT Agent given supporting context about the question
- `CoT_context` - CoT Agent given supporting context about the question

- `CoT_no_context` - CoT Agent given no supporting context about the question

Expand All @@ -50,14 +63,16 @@ The notebook for each agent type is located in the `./hotpot_runs/notebooks` dir

Each notebook allows you to specify the reflexion strategy to be used by the agents. The available reflexion strategies, which are defined in an `Enum`, include:

- `ReflexionStrategy.NONE` - The agent is not given any information about its last attempt.
- `ReflexionStrategy.NONE` - The agent is not given any information about its last attempt.

- `ReflexionStrategy.LAST_ATTEMPT` - The agent is given its reasoning trace from its last attempt on the question as context.

- `ReflexionStrategy.REFLEXION` - The agent is given its self-reflection on the last attempt as context.
- `ReflexionStrategy.REFLEXION` - The agent is given its self-reflection on the last attempt as context.

- `ReflexionStrategy.LAST_ATTEMPT_AND_REFLEXION` - The agent is given both its reasoning trace and self-reflection on the last attempt as context.

# Yuchen - The stuff below I have not changed and thus may not work correct

### To Run: decision-making (AlfWorld)

Clone this repo and move to the AlfWorld directory
Expand Down
2 changes: 1 addition & 1 deletion alfworld_runs/run_reflexion.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
python main.py \
python3 main.py \
--num_trials 10 \
--num_envs 134 \
--run_name "reflexion_run_logs" \
Expand Down
206 changes: 154 additions & 52 deletions hotpotqa_runs/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,50 @@
from typing import List, Union, Literal
from enum import Enum
import tiktoken
from langchain import OpenAI, Wikipedia
from langchain.llms.base import BaseLLM
from langchain.chat_models import ChatOpenAI
from langchain.chat_models.base import BaseChatModel
from langchain.schema import (
SystemMessage,
HumanMessage,
AIMessage,
)
from langchain.agents.react.base import DocstoreExplorer
from langchain.docstore.base import Docstore
from langchain.prompts import PromptTemplate

try:
from langchain_openai import OpenAI, ChatOpenAI
except ImportError:
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI

try:
from langchain_community.docstore.wikipedia import Wikipedia
except ImportError:
from langchain.docstore.wikipedia import Wikipedia

try:
from langchain_core.language_models.llms import BaseLLM
except ImportError:
from langchain.llms.base import BaseLLM

try:
from langchain_core.language_models.chat_models import BaseChatModel
except ImportError:
from langchain.chat_models.base import BaseChatModel

try:
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
except ImportError:
from langchain.schema import SystemMessage, HumanMessage, AIMessage

try:
from langchain_classic.agents.react.base import DocstoreExplorer
except ImportError:
try:
from langchain.agents.react.base import DocstoreExplorer
except ImportError:
from langchain_community.agent_toolkits.base import DocstoreExplorer

try:
from langchain_community.docstore.base import Docstore
except ImportError:
from langchain.docstore.base import Docstore

try:
from langchain_core.prompts import PromptTemplate
except ImportError:
from langchain.prompts import PromptTemplate
from llm import AnyOpenAILLM
from prompts import reflect_prompt, react_agent_prompt, react_reflect_agent_prompt, REFLECTION_HEADER, LAST_TRIAL_HEADER, REFLECTION_AFTER_LAST_TRIAL_HEADER
from prompts import cot_agent_prompt, cot_reflect_agent_prompt, cot_reflect_prompt, COT_INSTRUCTION, COT_REFLECT_INSTRUCTION
Expand Down Expand Up @@ -42,18 +74,8 @@ def __init__(self,
reflect_prompt: PromptTemplate = cot_reflect_prompt,
cot_examples: str = COT,
reflect_examples: str = COT_REFLECT,
self_reflect_llm: AnyOpenAILLM = AnyOpenAILLM(
temperature=0,
max_tokens=250,
model_name="gpt-3.5-turbo",
model_kwargs={"stop": "\n"},
openai_api_key=os.environ['OPENAI_API_KEY']),
action_llm: AnyOpenAILLM = AnyOpenAILLM(
temperature=0,
max_tokens=250,
model_name="gpt-3.5-turbo",
model_kwargs={"stop": "\n"},
openai_api_key=os.environ['OPENAI_API_KEY']),
self_reflect_llm = None,
action_llm = None,
) -> None:
self.question = question
self.context = context
Expand All @@ -62,8 +84,34 @@ def __init__(self,
self.reflect_prompt = reflect_prompt
self.cot_examples = cot_examples
self.reflect_examples = reflect_examples
self.self_reflect_llm = self_reflect_llm
self.action_llm = action_llm

# Initialize LLMs with defaults if not provided
if self_reflect_llm is None:
if 'OPENAI_API_KEY' in os.environ:
self.self_reflect_llm = AnyOpenAILLM(
temperature=0,
max_tokens=250,
model_name="gpt-3.5-turbo",
model_kwargs={"stop": "\n"},
openai_api_key=os.environ['OPENAI_API_KEY'])
else:
raise ValueError("self_reflect_llm must be provided or OPENAI_API_KEY must be set")
else:
self.self_reflect_llm = self_reflect_llm

if action_llm is None:
if 'OPENAI_API_KEY' in os.environ:
self.action_llm = AnyOpenAILLM(
temperature=0,
max_tokens=250,
model_name="gpt-3.5-turbo",
model_kwargs={"stop": "\n"},
openai_api_key=os.environ['OPENAI_API_KEY'])
else:
raise ValueError("action_llm must be provided or OPENAI_API_KEY must be set")
else:
self.action_llm = action_llm

self.reflections: List[str] = []
self.reflections_str = ''
self.answer = ''
Expand Down Expand Up @@ -158,24 +206,35 @@ def __init__(self,
key: str,
max_steps: int = 6,
agent_prompt: PromptTemplate = react_agent_prompt,
docstore: Docstore = Wikipedia(),
react_llm: AnyOpenAILLM = AnyOpenAILLM(
temperature=0,
max_tokens=100,
model_name="gpt-3.5-turbo",
model_kwargs={"stop": "\n"},
openai_api_key=os.environ['OPENAI_API_KEY']),
docstore = None,
react_llm = None,
) -> None:

self.question = question
self.answer = ''
self.key = key
self.max_steps = max_steps
self.agent_prompt = agent_prompt
self.react_examples = WEBTHINK_SIMPLE6

# Initialize docstore with default if not provided
if docstore is None:
docstore = Wikipedia()
self.docstore = DocstoreExplorer(docstore) # Search, Lookup
self.llm = react_llm

# Initialize LLM with default if not provided
if react_llm is None:
if 'OPENAI_API_KEY' in os.environ:
self.llm = AnyOpenAILLM(
temperature=0,
max_tokens=100,
model_name="gpt-3.5-turbo",
model_kwargs={"stop": "\n"},
openai_api_key=os.environ['OPENAI_API_KEY'])
else:
raise ValueError("react_llm must be provided or OPENAI_API_KEY must be set")
else:
self.llm = react_llm

self.enc = tiktoken.encoding_for_model("text-davinci-003")

Expand Down Expand Up @@ -216,7 +275,45 @@ def step(self) -> None:

if action_type == 'Search':
try:
self.scratchpad += format_step(self.docstore.search(argument))
result = self.docstore.search(argument)
if result.startswith('Could not find'):
self.scratchpad += format_step(result)
else:
if hasattr(self.docstore, 'document') and self.docstore.document is not None:
page_url = self.docstore.document.metadata.get('page', '')
page_title_raw = page_url.split('/')[-1].replace('_', ' ')
page_title = page_title_raw.replace(',', '').lower()
search_term = argument.replace(',', '').lower()

page_title_clean = page_title.replace('inc.', '').replace('ltd.', '').replace('(', '').replace(')', '').strip()
search_term_clean = search_term.replace('inc.', '').replace('ltd.', '').strip()

#fuzzy matching
search_words = [w for w in search_term_clean.split() if len(w) > 2] # Skip short words
title_words = [w for w in page_title_clean.split() if len(w) > 2]

def words_similar(w1, w2):
if w1 == w2:
return True
if w1.startswith(w2) or w2.startswith(w1):
return True
if len(w1) >= 4 and len(w2) >= 4:
common_chars = sum(1 for c in set(w1) if c in w2)
return common_chars >= min(len(w1), len(w2)) - 1
return False

if len(search_words) > 0:
matching_words = sum(1 for sw in search_words
if any(words_similar(sw, tw) for tw in title_words))
match_ratio = matching_words / len(search_words)
else:
match_ratio = 1.0
if match_ratio < 0.6:
self.scratchpad += f'Could not find [{argument}]. The search returned a different page ("{page_title_raw}"). Try searching for a related topic or more specific terms.'
else:
self.scratchpad += format_step(result)
else:
self.scratchpad += format_step(result)
except Exception as e:
print(e)
self.scratchpad += f'Could not find that page, please try again.'
Expand Down Expand Up @@ -268,22 +365,25 @@ def __init__(self,
max_steps: int = 6,
agent_prompt: PromptTemplate = react_reflect_agent_prompt,
reflect_prompt: PromptTemplate = reflect_prompt,
docstore: Docstore = Wikipedia(),
react_llm: AnyOpenAILLM = AnyOpenAILLM(
temperature=0,
max_tokens=100,
model_name="gpt-3.5-turbo",
model_kwargs={"stop": "\n"},
openai_api_key=os.environ['OPENAI_API_KEY']),
reflect_llm: AnyOpenAILLM = AnyOpenAILLM(
temperature=0,
max_tokens=250,
model_name="gpt-3.5-turbo",
openai_api_key=os.environ['OPENAI_API_KEY']),
docstore = None,
react_llm = None,
reflect_llm = None,
) -> None:

super().__init__(question, key, max_steps, agent_prompt, docstore, react_llm)
self.reflect_llm = reflect_llm

# Initialize reflect_llm with default if not provided
if reflect_llm is None:
if 'OPENAI_API_KEY' in os.environ:
self.reflect_llm = AnyOpenAILLM(
temperature=0,
max_tokens=250,
model_name="gpt-3.5-turbo",
openai_api_key=os.environ['OPENAI_API_KEY'])
else:
raise ValueError("reflect_llm must be provided or OPENAI_API_KEY must be set")
else:
self.reflect_llm = reflect_llm
self.reflect_prompt = reflect_prompt
self.reflect_examples = REFLECTIONS
self.reflections: List[str] = []
Expand All @@ -292,6 +392,8 @@ def __init__(self,
def run(self, reset = True, reflect_strategy: ReflexionStrategy = ReflexionStrategy.REFLEXION) -> None:
if (self.is_finished() or self.is_halted()) and not self.is_correct():
self.reflect(reflect_strategy)
# After reflection, always reset to start a fresh attempt with new reflections
reset = True

ReactAgent.run(self, reset)

Expand Down Expand Up @@ -341,9 +443,9 @@ def parse_action(string):
action_type = match.group(1)
argument = match.group(2)
return action_type, argument

else:
return None
return None, None

def format_step(step: str) -> str:
return step.strip('\n').strip().replace('\n', '')
Expand Down
Loading