Skip to content

Commit f6a0ca6

Browse files
authored
Added baseline prompt and README example (#6)
1 parent 5620bc8 commit f6a0ca6

File tree

4 files changed

+406
-1
lines changed

4 files changed

+406
-1
lines changed

README.md

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,70 @@ or in your terminal via `chafa valid_molecule.svg`
128128
([chafa docs](https://hpjansson.org/chafa/)).
129129

130130
![valid molecule](docs/assets/valid_molecule.svg)
131+
132+
### Benchmark
133+
134+
Here is a sample baseline of
135+
[`ether0-benchmark`](https://huggingface.co/datasets/futurehouse/ether0-benchmark)
136+
on `gpt-4o` using [`lmi`](https://github.com/Future-House/ldp/tree/main/packages/lmi).
137+
To install `lmi`, please install `ether0` with the `baselines` extra
138+
(for example `uv sync --extra baselines`).
139+
140+
We also need to run our remote rewards server via `ether0-serve`
141+
(for more information, see [`ether0.remotes` docs](packages/remotes/README.md)):
142+
143+
```bash
144+
ETHER0_REMOTES_API_TOKEN=abc123 ether0-serve
145+
```
146+
147+
Next, start `ipython` with the relevant environment variables set:
148+
149+
```bash
150+
ETHER0_REMOTES_API_BASE_URL="http://127.0.0.1:8000" ETHER0_REMOTES_API_TOKEN=abc123 ipython
151+
```
152+
153+
And run the following Python code:
154+
155+
```python
156+
import itertools
157+
import statistics
158+
from collections import defaultdict
159+
160+
from aviary.core import Message
161+
from datasets import load_dataset
162+
from lmi import LiteLLMModel
163+
from tqdm.asyncio import tqdm_asyncio as asyncio
164+
165+
from ether0.data import get_problem_category
166+
from ether0.model_prompts import LOOSE_XML_ANSWER_USER_PROMPT, extract_answer_loose
167+
from ether0.models import RewardFunctionInfo
168+
from ether0.rewards import EVAL_FUNCTIONS
169+
170+
# Add LLM prompt of your making to the dataset
171+
test_ds = load_dataset("futurehouse/ether0-benchmark", split="test").map(
172+
lambda x: {"prompt": "\n\n".join((LOOSE_XML_ANSWER_USER_PROMPT, x["problem"]))}
173+
)
174+
175+
# Prompt to LLM
176+
model = LiteLLMModel(name="gpt-4o")
177+
results = await asyncio.gather(
178+
*(model.acompletion([Message(content=row["prompt"])]) for row in test_ds),
179+
desc="Running evaluation",
180+
)
181+
182+
# Compute rewards
183+
per_category_rewards = defaultdict(list)
184+
for row, result in zip(test_ds, results, strict=True):
185+
reward_info = RewardFunctionInfo.model_validate(row["solution"])
186+
yhat = extract_answer_loose(result[0].text)
187+
reward = EVAL_FUNCTIONS[reward_info.fxn_name](yhat=yhat, y=reward_info.answer_info)
188+
per_category_rewards[get_problem_category(reward_info.problem_type)].append(reward)
189+
190+
for category, rewards in sorted(per_category_rewards.items()):
191+
print(
192+
f"In category {category!r} of {len(rewards)} questions,"
193+
f" average reward was {statistics.mean(rewards):.3f}."
194+
)
195+
accuracy = statistics.mean(itertools.chain.from_iterable(per_category_rewards.values()))
196+
print(f"Cumulative average reward across {len(test_ds)} questions was {accuracy:.3f}.")
197+
```

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ add-tokens = [
4949
"ipywidgets>=8", # For Jupyter notebook support, and pin to keep recent
5050
"transformers>=4.49", # Pin to keep recent
5151
]
52+
baselines = [
53+
"fhaviary>=0.19", # Pin for Python 3.13 compatibility
54+
"fhlmi>=0.26", # Pin for Python 3.13 compatibility
55+
"ipython",
56+
]
5257
dev = [
5358
"ether0[add-tokens,typing]",
5459
"huggingface-hub[cli]", # For login inside of CI

src/ether0/model_prompts.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,14 @@ def extract_thought_answer_strict(
118118
return None, None # Consider nested answer as a failure
119119

120120

121+
LOOSE_XML_ANSWER_USER_PROMPT = (
122+
"When answering,"
123+
" be sure to place the final answer as"
124+
" SMILES notation into XML tags <answer></answer>."
125+
" An example is <answer>CCO</answer>."
126+
)
127+
128+
121129
def extract_answer_loose(text: str | None) -> str:
122130
"""
123131
Extract thought and answer from text using a loose XML pattern.

0 commit comments

Comments
 (0)