Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions candle-examples/examples/quantized-qwen3-moe/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# candle-quantized-qwen3

[Qwen3](<(https://qwenlm.github.io/blog/qwen3/)>) is an upgraded version of Qwen2.5, released by Alibaba Cloud.
Here is the MoE version of Qwen3, but quantized.

## Running the example

```bash
cargo run --example quantized-qwen3-moe --release -- --prompt "Write a function to count prime numbers up to N."
```

30b is used by default, 235b model is available via `--which` argument.

```bash
cargo run --example quantized-qwen3-moe --release -- --which 235b --prompt "A train is travelling at 120mph, how far does it travel in 3 minutes 30 seconds?"
```

To run on cuda(gpu).

```bash
cargo run --example quantized-qwen3-moe --release --features cuda -- --prompt "Write a function to count prime numbers up to N."
```
306 changes: 306 additions & 0 deletions candle-examples/examples/quantized-qwen3-moe/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

#[cfg(feature = "accelerate")]
extern crate accelerate_src;

use clap::{Parser, ValueEnum};
use std::io::Write;
use tokenizers::Tokenizer;

use candle::quantized::gguf_file;
use candle::Tensor;
use candle_transformers::generation::{LogitsProcessor, Sampling};

use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::models::quantized_qwen3_moe::ModelWeights as Qwen3_MoE;

const DEFAULT_PROMPT: &str = "Write a Rust function to calculate the factorial of a given number.";

#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
#[value(name = "30b")]
W3_MoE_30b,
#[value(name = "235b")]
W3_MoE_235b,
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// GGUF file to load, typically a .gguf file generated by the quantize command from llama.cpp
#[arg(long)]
model: Option<String>,

/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
/// is preserved.
#[arg(long)]
prompt: Option<String>,

/// The length of the sample to generate (in tokens).
#[arg(short = 'n', long, default_value_t = 1000)]
sample_len: usize,

/// The tokenizer config in json format.
#[arg(long)]
tokenizer: Option<String>,

/// The temperature used to generate samples, use 0 for greedy sampling.
#[arg(long, default_value_t = 0.8)]
temperature: f64,

/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,

/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,

/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,

/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,

/// Process prompt elements separately.
#[arg(long)]
split_prompt: bool,

/// Run on CPU rather than GPU even if a GPU is available.
#[arg(long)]
cpu: bool,

/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,

/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,

/// The model size to use.
#[arg(long, default_value = "30b")]
which: Which,
}

impl Args {
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
let tokenizer_path = match &self.tokenizer {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
let repo = match self.which {
Which::W3_MoE_30b => "Qwen/Qwen3-30B-A3B",
Which::W3_MoE_235b => "Qwen/Qwen3-235B-A22B",
};
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
}
};
Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
}

fn model(&self) -> anyhow::Result<std::path::PathBuf> {
let model_path = match &self.model {
Some(config) => std::path::PathBuf::from(config),
None => {
let (repo, filename, revision) = match self.which {
Which::W3_MoE_30b => (
"unsloth/Qwen3-30B-A3B-GGUF",
"Qwen3-30B-A3B-Q4_K_M.gguf",
"main",
),
Which::W3_MoE_235b => (
"unsloth/Qwen3-235B-A22B-GGUF",
"Qwen3-235B-A22B-Q4_K_M-00001-of-00003.gguf",
"main",
),
};
let api = hf_hub::api::sync::Api::new()?;
api.repo(hf_hub::Repo::with_revision(
repo.to_string(),
hf_hub::RepoType::Model,
revision.to_string(),
))
.get(filename)?
}
};
Ok(model_path)
}
}

fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 {
format!("{size_in_bytes}B")
} else if size_in_bytes < 1_000_000 {
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
} else if size_in_bytes < 1_000_000_000 {
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
} else {
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
}
}

fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;

let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};

println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature, args.repeat_penalty, args.repeat_last_n
);

let model_path = args.model()?;
let mut file = std::fs::File::open(&model_path)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;

let mut model = {
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
}
println!(
"loaded {:?} tensors ({}) in {:.2}s",
model.tensor_infos.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
Qwen3_MoE::from_gguf(model, &mut file, &device, None)?
};
println!("model built");

let tokenizer = args.tokenizer()?;
let mut tos = TokenOutputStream::new(tokenizer);
let prompt_str = args
.prompt
.clone()
.unwrap_or_else(|| DEFAULT_PROMPT.to_string());

let prompt_str = format!("<|im_start|>user\n{prompt_str}<|im_end|>\n<|im_start|>assistant\n");
print!("formatted prompt: {}", &prompt_str);

let tokens = tos
.tokenizer()
.encode(prompt_str, true)
.map_err(anyhow::Error::msg)?;

let tokens = tokens.get_ids();

let to_sample = args.sample_len.saturating_sub(1);

let mut all_tokens = vec![];

let mut logits_processor = {
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};

let start_prompt_processing = std::time::Instant::now();

let mut next_token = if !args.split_prompt {
let input = Tensor::new(tokens, &device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
} else {
let mut next_token = 0;
for (pos, token) in tokens.iter().enumerate() {
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, pos)?;
let logits = logits.squeeze(0)?;
next_token = logits_processor.sample(&logits)?
}
next_token
};

let prompt_dt = start_prompt_processing.elapsed();

all_tokens.push(next_token);

if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}

let eos_token = *tos.tokenizer().get_vocab(true).get("<|im_end|>").unwrap();

let start_post_prompt = std::time::Instant::now();

let mut sampled = 0;
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&all_tokens[start_at..],
)?
};
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
sampled += 1;
if next_token == eos_token {
break;
};
}

if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
print!("{rest}");
}

std::io::stdout().flush()?;
let dt = start_post_prompt.elapsed();
println!(
"\n\n{:4} prompt tokens processed: {:.2} token/s",
tokens.len(),
tokens.len() as f64 / prompt_dt.as_secs_f64(),
);
println!(
"{sampled:4} tokens generated: {:.2} token/s",
sampled as f64 / dt.as_secs_f64(),
);
Ok(())
}
1 change: 1 addition & 0 deletions candle-transformers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ pub mod quantized_phi;
pub mod quantized_phi3;
pub mod quantized_qwen2;
pub mod quantized_qwen3;
pub mod quantized_qwen3_moe;
pub mod quantized_recurrent_gemma;
pub mod quantized_rwkv_v5;
pub mod quantized_rwkv_v6;
Expand Down
Loading