Skip to content

Future-House/ether0

Repository files navigation

ether0 Reward Model

GitHub arXiv Project Status: Active License

Tests Code style: black python Model on HF Dataset on HF

ether0 logo

ether0: a scientific reasoning model, dataset, and reward functions for chemistry.

This repo contains the reward model for evaluating ether0 and similar models, along with utilities for working with the verifiable rewards in our benchmark.

Overview

ether0 is a reasoning language model post-trained through a loop of:

  1. Supervised fine-tuning (SFT) on long chain-of-thought reasoning traces, to elicit reasoning from a base model.
  2. Reinforcement learning with verifiable rewards (RLVR) to improve reasoning on focused task groups, at their own pace. These multitask learned models are referred to as 'specialists'.
  3. Rejection sampling to filter specialists' reasoning for correctness and quality.
  4. SFT on the base model again to make a 'generalist' reasoning model.
  5. RLVR to recover any lost performance and push further in an all-task setting.

ether0 training info

Repo Structure

This repo contains several packages:

  • ether0: reward functions, rdkit data utilities, dataset generation prompts, dataset data models, language model training prompts, and data models.
  • ether0.remotes: server code for ether0 reward functions involving exotic packages and/or third party models.

Note

This repo does not contain training code, although you can find open source repositories like NeMo-RL or Hugging Face TRL that can do the SFT and RL phases of training.

Open Weights

Please see our open-source weights on Hugging Face: https://huggingface.co/futurehouse/ether0

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("futurehouse/ether0")
tokenizer = AutoTokenizer.from_pretrained("futurehouse/ether0")

Open Test Set

Please see our open-source benchmark (test set) on Hugging Face: https://huggingface.co/datasets/futurehouse/ether0-benchmark

from datasets import load_dataset

test_ds = load_dataset("futurehouse/ether0-benchmark", split="test")

Usage

Installation

The easiest way to get started is a pip install from GitHub:

pip install git+https://github.com/Future-House/ether0.git

Or if you want the full set up, clone the repo and use uv:

git clone https://github.com/Future-House/ether0.git
cd ether0
uv sync

Reward Functions

Here is a basic example of how to use the reward functions:

from ether0.rewards import valid_mol_eval

# Task: provide a valid completion of this molecule
partial_smiles = "O=C(OC1C(OC(=O)C=2C=CC=CC2)C3(O)C(C)(C)CCCC3(C)C4CC=5OC=CC5C(C)C14"

# Here's two model-proposed SMILES completions
invalid_completion_smiles = "CCC"
valid_completion_smiles = ")C=6C=CC=CC6"

# Evaluate the completions
assert not valid_mol_eval(invalid_completion_smiles, partial_smiles)
assert valid_mol_eval(valid_completion_smiles, partial_smiles)

Visualization

If it helps, you can visualize molecules:

from ether0.data import draw_molecule

# See above reward functions demo for where these came from
partial_smiles = "O=C(OC1C(OC(=O)C=2C=CC=CC2)C3(O)C(C)(C)CCCC3(C)C4CC=5OC=CC5C(C)C14"
invalid_completion_smiles = "CCC"
valid_completion_smiles = ")C=6C=CC=CC6"

valid_mol_text = draw_molecule(partial_smiles + valid_completion_smiles)
with open("valid_molecule.svg", "w") as f:
    f.write(valid_mol_text)

The output of draw_molecule can also be easily visualized using IPython.display, or in your terminal via chafa valid_molecule.svg (chafa docs).

valid molecule

Similarly, one can visualize reactions:

from ether0.data import draw_reaction

# Source: ether0-benchmark's test split's question 41251b01-7291-56a0-9030-ea51bab03a4c
rxn_smiles = "CC1CNCC1c1nc2c(cnn2C(C)C)c(=O)[nH]1.COc1ccc(C=O)cn1>>"

rxn_text = draw_reaction(rxn_smiles)
with open("reaction.svg", "w") as f:
    f.write(rxn_text)

Third Party Visualizers

Benchmark

Here is a sample baseline of ether0-benchmark on gpt-4o using lmi. To install lmi, please install ether0 with the baselines extra (for example uv sync --extra baselines).

We also need to run our remote rewards server via ether0-serve (for more information, see ether0.remotes docs):

ETHER0_REMOTES_API_TOKEN=abc123 ether0-serve

Next, start ipython with the relevant environment variables set:

ETHER0_REMOTES_API_BASE_URL="http://127.0.0.1:8000" ETHER0_REMOTES_API_TOKEN=abc123 \
    ipython

And run the following Python code:

import itertools
import statistics
from collections import defaultdict

from aviary.core import Message
from datasets import load_dataset
from lmi import LiteLLMModel
from tqdm.asyncio import tqdm_asyncio as asyncio

from ether0.data import get_problem_category
from ether0.model_prompts import LOOSE_XML_ANSWER_USER_PROMPT, extract_answer_loose
from ether0.models import RewardFunctionInfo
from ether0.rewards import EVAL_FUNCTIONS

# Add LLM prompt of your making to the dataset
test_ds = load_dataset("futurehouse/ether0-benchmark", split="test").map(
    lambda x: {"prompt": "\n\n".join((LOOSE_XML_ANSWER_USER_PROMPT, x["problem"]))}
)

# Prompt to LLM
model = LiteLLMModel(name="gpt-4o")
results = await asyncio.gather(
    *(model.acompletion([Message(content=row["prompt"])]) for row in test_ds),
    desc="Running evaluation",
)

# Compute rewards
per_category_rewards = defaultdict(list)
for row, result in zip(test_ds, results, strict=True):
    # NOTE: you can also use `ether0.rewards.accuracy_reward`,
    # but we decided to go a bit "lower level" for this demo
    reward_info = RewardFunctionInfo.model_validate(row["solution"])
    yhat = extract_answer_loose(result[0].text)
    reward = EVAL_FUNCTIONS[reward_info.fxn_name](
        yhat=yhat, y=reward_info.answer_info, test=True
    )
    per_category_rewards[get_problem_category(reward_info.problem_type)].append(reward)

for category, rewards in sorted(per_category_rewards.items()):
    print(
        f"In category {category!r} of {len(rewards)} questions,"
        f" average reward was {statistics.mean(rewards):.3f}."
    )
accuracy = statistics.mean(itertools.chain.from_iterable(per_category_rewards.values()))
print(f"Cumulative average reward across {len(test_ds)} questions was {accuracy:.3f}.")

About

A scientific reasoning model, dataset, and reward functions for chemistry.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Contributors 3

  •  
  •  
  •  

Languages