Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions .github/workflows/bumpversion.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: Bump version

on:
push:
branches:
- main

jobs:
bump-version:
if: "!startsWith(github.event.head_commit.message, 'bump:')"
runs-on: ubuntu-latest
name: "Bump version and create changelog with commitizen"
steps:
- name: Check out
uses: actions/checkout@v3
with:
token: "${{ secrets.PERSONAL_ACCESS_TOKEN }}"
fetch-depth: 0
- name: Create bump and changelog
uses: commitizen-tools/commitizen-action@master
with:
github_token: ${{ secrets.PERSONAL_ACCESS_TOKEN }}
prerelease: beta
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,9 @@ repos:
language: system
pass_filenames: false
files: ^ui-ssr/

- repo: https://github.com/commitizen-tools/commitizen
rev: v1.17.0
hooks:
- id: commitizen
stages: [commit-msg]
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@
Use [pip](https://pypi.org/project/pip/) to install Language-Model-SAEs:

```bash
pip install lm-saes==2.0.0b4
pip install lm-saes==2.0.0b5
```

We also highly recommend using [uv](https://docs.astral.sh/uv/) to manage your own project dependencies. You can use

```bash
uv add lm-saes==2.0.0b4
uv add lm-saes==2.0.0b5
```

to add Language-Model-SAEs as your project dependency.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def parse_args():
activation_factory=ActivationFactoryConfig(
sources=[
ActivationFactoryActivationsSource(
path=str(args.activation_path),
path=os.path.expanduser(args.activation_path),
name="pythia-160m-2d",
device="cuda",
dtype=torch.float16,
Expand Down
9 changes: 4 additions & 5 deletions examples/generate_pythia_activation_1d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse
import os
from pathlib import Path

import torch

Expand Down Expand Up @@ -64,15 +63,15 @@ def parse_args():
dataset=DatasetConfig(dataset_name_or_path="Hzfinfdu/SlimPajama-3B"),
dataset_name="SlimPajama-3B",
hook_points=[f"blocks.{layer}.hook_resid_post" for layer in layers],
output_dir=Path(args.activation_path).expanduser(),
output_dir=os.path.expanduser(args.activation_path),
total_tokens=800_000_000,
context_size=2048,
n_samples_per_chunk=None,
model_batch_size=int(32),
model_batch_size=32,
num_workers=None,
target=ActivationFactoryTarget.ACTIVATIONS_1D,
batch_size=int(2048 * 64),
buffer_size=int(2048 * 200),
batch_size=2048 * 64,
buffer_size=2048 * 200,
buffer_shuffle=BufferShuffleConfig(
perm_seed=42,
generator_device="cuda",
Expand Down
3 changes: 1 addition & 2 deletions examples/generate_pythia_activation_2d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse
import os
from pathlib import Path

import torch

Expand Down Expand Up @@ -64,7 +63,7 @@ def parse_args():
dataset=DatasetConfig(dataset_name_or_path="Hzfinfdu/SlimPajama-3B"),
dataset_name="SlimPajama-3B",
hook_points=[f"blocks.{layer}.hook_resid_post" for layer in layers],
output_dir=Path(args.activation_path).expanduser(),
output_dir=os.path.expanduser(args.activation_path),
total_tokens=100_000_000,
context_size=2048,
model_batch_size=args.model_batch_size,
Expand Down
1 change: 1 addition & 0 deletions examples/train_pythia_clt_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
init_encoder_with_decoder_transpose=False,
),
trainer=TrainerConfig(
amp_dtype=torch.float32,
initial_k=768 * 12 * 8 // 2,
k_warmup_steps=1.0,
k_schedule_type="exponential",
Expand Down
1 change: 1 addition & 0 deletions examples/train_pythia_lorsa_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
encoder_uniform_bound=1 / math.sqrt(768 * 8),
),
trainer=TrainerConfig(
amp_dtype=torch.float32,
lr=2e-4,
initial_k=64,
k_warmup_steps=0.1,
Expand Down
1 change: 1 addition & 0 deletions examples/train_pythia_sae_batchtopk.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
grid_search_init_norm=True,
),
trainer=TrainerConfig(
amp_dtype=torch.float32,
lr=1e-4,
initial_k=50,
k_warmup_steps=0.1,
Expand Down
1 change: 1 addition & 0 deletions examples/train_pythia_sae_jumprelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
init_log_jumprelu_threshold_value=math.log(0.1),
),
trainer=TrainerConfig(
amp_dtype=torch.float32,
lr=5e-5,
l1_coefficient=0.3,
total_training_tokens=800_000_000,
Expand Down
1 change: 1 addition & 0 deletions examples/train_pythia_sae_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
grid_search_init_norm=True,
),
trainer=TrainerConfig(
amp_dtype=torch.float32,
lr=1e-4,
initial_k=50,
k_warmup_steps=0.1,
Expand Down
4 changes: 2 additions & 2 deletions examples/train_pythia_sae_with_pre_generated_activations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse
import math
import os
from pathlib import Path

import torch

Expand Down Expand Up @@ -56,6 +55,7 @@ def parse_args():
init_encoder_with_decoder_transpose_factor=1.0,
),
trainer=TrainerConfig(
amp_dtype=torch.float32,
lr=5e-5,
l1_coefficient=0.3,
total_training_tokens=800_000_000,
Expand All @@ -74,7 +74,7 @@ def parse_args():
activation_factory=ActivationFactoryConfig(
sources=[
ActivationFactoryActivationsSource(
path=str(Path(args.activation_path).expanduser()),
path=os.path.expanduser(args.activation_path),
name="pythia-160m-1d",
device="cuda",
dtype=torch.float32,
Expand Down
16 changes: 14 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "lm-saes"
version = "2.0.0b4"
version = "2.0.0b5"
description = "For OpenMOSS Mechanistic Interpretability Team's Sparse Autoencoder (SAE) research. Open-sourced and constantly updated."
dependencies = [
"transformer-lens",
Expand Down Expand Up @@ -95,6 +95,7 @@ dev = [
"gradio>=5.34.0",
"sqlalchemy>=2.0.44",
"apscheduler>=3.11.1",
"commitizen>=4.11.0",
]
docs = [
"mkdocs-gen-files>=0.5.0",
Expand Down Expand Up @@ -198,4 +199,15 @@ requires-dist = ["torch", "einops"]

[tool.uv.sources.transformer-lens]
path = "./TransformerLens"
editable = true
editable = true

[tool.commitizen]
name = "cz_conventional_commits"
tag_format = "v$version"
version_scheme = "pep440"
version_provider = "uv"
update_changelog_on_bump = true
version_files = [
"README.md:pip install lm-saes=={version}",
"README.md:uv add lm-saes=={version}",
]
Loading