Skip to content
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,4 @@ doc/_autosummary/
# Ignore notebooks directory in Docker folder
/docker/notebooks/
/docker/notebooks/*
pyrit/auxiliary_attacks/gcg/experiments/mlruns/*
142 changes: 42 additions & 100 deletions pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,107 +165,51 @@ def __init__(
self.test_new_toks = max(self.test_new_toks, len(self.tokenizer(prefix).input_ids))

self._update_ids()


# new _update_ids using apply_chat_template
def _update_ids(self):

messages = [
{"role": "user", "content": f"{self.goal} {self.control}"},
{"role": "assistant", "content": f"{self.target}"},
]

self.conv_template.append_message(self.conv_template.roles[0], f"{self.goal} {self.control}")
self.conv_template.append_message(self.conv_template.roles[1], f"{self.target}")
prompt = self.conv_template.get_prompt()


prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
print(f"_update_ids new version - {self.goal=}, {self.control=}, {self.target=}")
print(f"{self.conv_template.roles[0]=}, {self.conv_template.roles[1]=}")
print(f"checking prompt before encoding (new), {prompt=}, {self.goal=}")

encoding = self.tokenizer(prompt)
toks = encoding.input_ids

if self.conv_template.name == "llama-2" or self.conv_template.name == "llama-3":
self.conv_template.messages = []

self.conv_template.append_message(self.conv_template.roles[0], None)
toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
self._user_role_slice = slice(None, len(toks))

self.conv_template.update_last_message(f"{self.goal}")
toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
self._goal_slice = slice(self._user_role_slice.stop, max(self._user_role_slice.stop, len(toks)))

separator = " " if self.goal else ""
self.conv_template.update_last_message(f"{self.goal}{separator}{self.control}")
toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
self._control_slice = slice(self._goal_slice.stop, len(toks))

self.conv_template.append_message(self.conv_template.roles[1], None)
toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
self._assistant_role_slice = slice(self._control_slice.stop, len(toks))

self.conv_template.update_last_message(f"{self.target}")
toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
self._target_slice = slice(self._assistant_role_slice.stop, len(toks) - 2)
self._loss_slice = slice(self._assistant_role_slice.stop - 1, len(toks) - 3)

else:
python_tokenizer = False or self.conv_template.name == "oasst_pythia"
try:
encoding.char_to_token(len(prompt) - 1)
except Exception:
python_tokenizer = True
if python_tokenizer:
# This is specific to the vicuna and pythia tokenizer and conversation prompt.
# It will not work with other tokenizers or prompts.
self.conv_template.messages = []

self.conv_template.append_message(self.conv_template.roles[0], None)
toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
self._user_role_slice = slice(None, len(toks))

self.conv_template.update_last_message(f"{self.goal}")
toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
self._goal_slice = slice(self._user_role_slice.stop, max(self._user_role_slice.stop, len(toks) - 1))

separator = " " if self.goal else ""
self.conv_template.update_last_message(f"{self.goal}{separator}{self.control}")
toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
self._control_slice = slice(self._goal_slice.stop, len(toks) - 1)

self.conv_template.append_message(self.conv_template.roles[1], None)
toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
self._assistant_role_slice = slice(self._control_slice.stop, len(toks))

self.conv_template.update_last_message(f"{self.target}")
toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
self._target_slice = slice(self._assistant_role_slice.stop, len(toks) - 1)
self._loss_slice = slice(self._assistant_role_slice.stop - 1, len(toks) - 2)
else:
self._system_slice = slice(None, encoding.char_to_token(len(self.conv_template.system)))
self._user_role_slice = slice(
encoding.char_to_token(prompt.find(self.conv_template.roles[0])),
encoding.char_to_token(
prompt.find(self.conv_template.roles[0]) + len(self.conv_template.roles[0]) + 1
),
)
self._goal_slice = slice(
encoding.char_to_token(prompt.find(self.goal)),
encoding.char_to_token(prompt.find(self.goal) + len(self.goal)),
)
self._control_slice = slice(
encoding.char_to_token(prompt.find(self.control)),
encoding.char_to_token(prompt.find(self.control) + len(self.control)),
)
self._assistant_role_slice = slice(
encoding.char_to_token(prompt.find(self.conv_template.roles[1])),
encoding.char_to_token(
prompt.find(self.conv_template.roles[1]) + len(self.conv_template.roles[1]) + 1
),
)
self._target_slice = slice(
encoding.char_to_token(prompt.find(self.target)),
encoding.char_to_token(prompt.find(self.target) + len(self.target)),
)
self._loss_slice = slice(
encoding.char_to_token(prompt.find(self.target)) - 1,
encoding.char_to_token(prompt.find(self.target) + len(self.target)) - 1,
)

self.input_ids = torch.tensor(toks[: self._target_slice.stop], device="cpu")
self.conv_template.messages = []
print(f"{self.conv_template.roles[1]=}")
print(f"{prompt.find(self.conv_template.roles[1])=}")
self._goal_slice = slice(
encoding.char_to_token(prompt.find(self.goal)),
encoding.char_to_token(prompt.find(self.goal) + len(self.goal)),
)
self._control_slice = slice(
encoding.char_to_token(prompt.find(self.control)),
encoding.char_to_token(prompt.find(self.control) + len(self.control)),
)
self._assistant_role_slice = slice(
encoding.char_to_token(prompt.find(self.conv_template.roles[1])),
encoding.char_to_token(
prompt.find(self.conv_template.roles[1]) + len(self.conv_template.roles[1]) + 1
),
)
self._target_slice = slice(
encoding.char_to_token(prompt.find(self.target)),
encoding.char_to_token(prompt.find(self.target) + len(self.target)),
)
self._loss_slice = slice(
encoding.char_to_token(prompt.find(self.target)) - 1,
encoding.char_to_token(prompt.find(self.target) + len(self.target)) - 1,
)

self.input_ids = torch.tensor(toks[:self._target_slice.stop], device="cpu")

@torch.no_grad()
def generate(self, model, gen_config=None):
if gen_config is None:
Expand Down Expand Up @@ -961,7 +905,7 @@ def __init__(
self.test_workers = test_workers
self.progressive_goals = progressive_goals
self.progressive_models = progressive_models
self.control = control_init
self.control = control_init # conv template setting
self.test_prefixes = test_prefixes
self.logfile = logfile
self.managers = managers
Expand Down Expand Up @@ -1619,8 +1563,7 @@ def get_workers(params, eval=False):
tokenizers = []
for i in range(len(params.tokenizer_paths)):
tokenizer = AutoTokenizer.from_pretrained(
params.tokenizer_paths[i], token=params.token, trust_remote_code=False, **params.tokenizer_kwargs[i]
)
params.tokenizer_paths[i], token=params.token, trust_remote_code=False, use_fast = True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What motivated this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

char_to_token() is only supported by fast tokenizers in Hugging Face’s Transformers library, to resolve this switch to a fast tokenizer using use_fast=True

if "oasst-sft-6-llama-30b" in params.tokenizer_paths[i]:
tokenizer.bos_token_id = 1
tokenizer.unk_token_id = 0
Expand Down Expand Up @@ -1660,7 +1603,6 @@ def get_workers(params, eval=False):
raw_conv_templates.append(conv_template)
else:
raise ValueError("Conversation template not recognized")

conv_templates = []
for conv in raw_conv_templates:
if conv.name == "zero_shot":
Expand Down Expand Up @@ -1731,4 +1673,4 @@ def get_goals_and_targets(params):
logger.info("Loaded {} train goals".format(len(train_goals)))
logger.info("Loaded {} test goals".format(len(test_goals)))

return train_goals, train_targets, test_goals, test_targets
return train_goals, train_targets, test_goals, test_targets
8 changes: 4 additions & 4 deletions pyrit/auxiliary_attacks/gcg/experiments/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import time
from typing import Union

import torch
import mlflow
import numpy as np
import torch.multiprocessing as mp
Expand Down Expand Up @@ -57,9 +57,9 @@ def generate_suffix(
verbose: bool = True,
allow_non_ascii: bool = False,
num_train_models: int = 1,
devices: list = ["cuda:0"],
devices: list = ["mps"] if torch.backends.mps.is_available() else ["cuda:0"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just something you used because cuda wasn't working or do you think this would be an improvement in some cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its just an additional device mapping to also support torch running on macos devices - https://docs.pytorch.org/docs/stable/notes/mps.html

I originally added it to do a local test run on my macbook device. It can be removed, if we'll always run gcg attacks only on CUDA gpus via cloud.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no problem with support more than cuda, but I worry that it may now us mps even when cuda is available (?)

Maybe I don't know enough about when mps is available vs cuda or whether both can be available at the same time. If you know please reply here 🙂

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already removed the condition (this is marked as outdated version of train.py) - with just cuda mapping. However its mostly mutually exclusive - to have cuda or mps, it can't be both at the same time, as mps is seen as a cuda equivalent in mac device with MPS backend i.e. the apple silicon devices.

model_kwargs: list = [{"low_cpu_mem_usage": True, "use_cache": False}],
tokenizer_kwargs: list = [{"use_fast": False}],
tokenizer_kwargs: list = [{"use_fast": True}],
n_test_data: int = 0,
test_data: str = "",
lr: float = 0.01,
Expand Down Expand Up @@ -186,4 +186,4 @@ def process_fn2(s):
)

for worker in workers + test_workers:
worker.stop()
worker.stop()