diff --git a/candle-examples/examples/quantized-qwen3-moe/README.md b/candle-examples/examples/quantized-qwen3-moe/README.md new file mode 100644 index 0000000000..b0f25727c0 --- /dev/null +++ b/candle-examples/examples/quantized-qwen3-moe/README.md @@ -0,0 +1,22 @@ +# candle-quantized-qwen3 + +[Qwen3](<(https://qwenlm.github.io/blog/qwen3/)>) is an upgraded version of Qwen2.5, released by Alibaba Cloud. +Here is the MoE version of Qwen3, but quantized. + +## Running the example + +```bash +cargo run --example quantized-qwen3-moe --release -- --prompt "Write a function to count prime numbers up to N." +``` + +30b is used by default, 235b model is available via `--which` argument. + +```bash +cargo run --example quantized-qwen3-moe --release -- --which 235b --prompt "A train is travelling at 120mph, how far does it travel in 3 minutes 30 seconds?" +``` + +To run on cuda(gpu). + +```bash +cargo run --example quantized-qwen3-moe --release --features cuda -- --prompt "Write a function to count prime numbers up to N." +``` diff --git a/candle-examples/examples/quantized-qwen3-moe/main.rs b/candle-examples/examples/quantized-qwen3-moe/main.rs new file mode 100644 index 0000000000..dd4bcdb0d5 --- /dev/null +++ b/candle-examples/examples/quantized-qwen3-moe/main.rs @@ -0,0 +1,306 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::{Parser, ValueEnum}; +use std::io::Write; +use tokenizers::Tokenizer; + +use candle::quantized::gguf_file; +use candle::Tensor; +use candle_transformers::generation::{LogitsProcessor, Sampling}; + +use candle_examples::token_output_stream::TokenOutputStream; +use candle_transformers::models::quantized_qwen3_moe::ModelWeights as Qwen3_MoE; + +const DEFAULT_PROMPT: &str = "Write a Rust function to calculate the factorial of a given number."; + +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum Which { + #[value(name = "30b")] + W3_MoE_30b, + #[value(name = "235b")] + W3_MoE_235b, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// GGUF file to load, typically a .gguf file generated by the quantize command from llama.cpp + #[arg(long)] + model: Option, + + /// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way + /// and 'chat' for an interactive model where history of previous prompts and generated tokens + /// is preserved. + #[arg(long)] + prompt: Option, + + /// The length of the sample to generate (in tokens). + #[arg(short = 'n', long, default_value_t = 1000)] + sample_len: usize, + + /// The tokenizer config in json format. + #[arg(long)] + tokenizer: Option, + + /// The temperature used to generate samples, use 0 for greedy sampling. + #[arg(long, default_value_t = 0.8)] + temperature: f64, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// Only sample among the top K samples. + #[arg(long)] + top_k: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// Process prompt elements separately. + #[arg(long)] + split_prompt: bool, + + /// Run on CPU rather than GPU even if a GPU is available. + #[arg(long)] + cpu: bool, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, + + /// The model size to use. + #[arg(long, default_value = "30b")] + which: Which, +} + +impl Args { + fn tokenizer(&self) -> anyhow::Result { + let tokenizer_path = match &self.tokenizer { + Some(config) => std::path::PathBuf::from(config), + None => { + let api = hf_hub::api::sync::Api::new()?; + let repo = match self.which { + Which::W3_MoE_30b => "Qwen/Qwen3-30B-A3B", + Which::W3_MoE_235b => "Qwen/Qwen3-235B-A22B", + }; + let api = api.model(repo.to_string()); + api.get("tokenizer.json")? + } + }; + Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg) + } + + fn model(&self) -> anyhow::Result { + let model_path = match &self.model { + Some(config) => std::path::PathBuf::from(config), + None => { + let (repo, filename, revision) = match self.which { + Which::W3_MoE_30b => ( + "unsloth/Qwen3-30B-A3B-GGUF", + "Qwen3-30B-A3B-Q4_K_M.gguf", + "main", + ), + Which::W3_MoE_235b => ( + "unsloth/Qwen3-235B-A22B-GGUF", + "Qwen3-235B-A22B-Q4_K_M-00001-of-00003.gguf", + "main", + ), + }; + let api = hf_hub::api::sync::Api::new()?; + api.repo(hf_hub::Repo::with_revision( + repo.to_string(), + hf_hub::RepoType::Model, + revision.to_string(), + )) + .get(filename)? + } + }; + Ok(model_path) + } +} + +fn format_size(size_in_bytes: usize) -> String { + if size_in_bytes < 1_000 { + format!("{size_in_bytes}B") + } else if size_in_bytes < 1_000_000 { + format!("{:.2}KB", size_in_bytes as f64 / 1e3) + } else if size_in_bytes < 1_000_000_000 { + format!("{:.2}MB", size_in_bytes as f64 / 1e6) + } else { + format!("{:.2}GB", size_in_bytes as f64 / 1e9) + } +} + +fn main() -> anyhow::Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature, args.repeat_penalty, args.repeat_last_n + ); + + let model_path = args.model()?; + let mut file = std::fs::File::open(&model_path)?; + let start = std::time::Instant::now(); + let device = candle_examples::device(args.cpu)?; + + let mut model = { + let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?; + let mut total_size_in_bytes = 0; + for (_, tensor) in model.tensor_infos.iter() { + let elem_count = tensor.shape.elem_count(); + total_size_in_bytes += + elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size(); + } + println!( + "loaded {:?} tensors ({}) in {:.2}s", + model.tensor_infos.len(), + &format_size(total_size_in_bytes), + start.elapsed().as_secs_f32(), + ); + Qwen3_MoE::from_gguf(model, &mut file, &device, None)? + }; + println!("model built"); + + let tokenizer = args.tokenizer()?; + let mut tos = TokenOutputStream::new(tokenizer); + let prompt_str = args + .prompt + .clone() + .unwrap_or_else(|| DEFAULT_PROMPT.to_string()); + + let prompt_str = format!("<|im_start|>user\n{prompt_str}<|im_end|>\n<|im_start|>assistant\n"); + print!("formatted prompt: {}", &prompt_str); + + let tokens = tos + .tokenizer() + .encode(prompt_str, true) + .map_err(anyhow::Error::msg)?; + + let tokens = tokens.get_ids(); + + let to_sample = args.sample_len.saturating_sub(1); + + let mut all_tokens = vec![]; + + let mut logits_processor = { + let temperature = args.temperature; + let sampling = if temperature <= 0. { + Sampling::ArgMax + } else { + match (args.top_k, args.top_p) { + (None, None) => Sampling::All { temperature }, + (Some(k), None) => Sampling::TopK { k, temperature }, + (None, Some(p)) => Sampling::TopP { p, temperature }, + (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, + } + }; + LogitsProcessor::from_sampling(args.seed, sampling) + }; + + let start_prompt_processing = std::time::Instant::now(); + + let mut next_token = if !args.split_prompt { + let input = Tensor::new(tokens, &device)?.unsqueeze(0)?; + let logits = model.forward(&input, 0)?; + let logits = logits.squeeze(0)?; + logits_processor.sample(&logits)? + } else { + let mut next_token = 0; + for (pos, token) in tokens.iter().enumerate() { + let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?; + let logits = model.forward(&input, pos)?; + let logits = logits.squeeze(0)?; + next_token = logits_processor.sample(&logits)? + } + next_token + }; + + let prompt_dt = start_prompt_processing.elapsed(); + + all_tokens.push(next_token); + + if let Some(t) = tos.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + + let eos_token = *tos.tokenizer().get_vocab(true).get("<|im_end|>").unwrap(); + + let start_post_prompt = std::time::Instant::now(); + + let mut sampled = 0; + for index in 0..to_sample { + let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?; + let logits = model.forward(&input, tokens.len() + index)?; + let logits = logits.squeeze(0)?; + let logits = if args.repeat_penalty == 1. { + logits + } else { + let start_at = all_tokens.len().saturating_sub(args.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + args.repeat_penalty, + &all_tokens[start_at..], + )? + }; + next_token = logits_processor.sample(&logits)?; + all_tokens.push(next_token); + if let Some(t) = tos.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + sampled += 1; + if next_token == eos_token { + break; + }; + } + + if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? { + print!("{rest}"); + } + + std::io::stdout().flush()?; + let dt = start_post_prompt.elapsed(); + println!( + "\n\n{:4} prompt tokens processed: {:.2} token/s", + tokens.len(), + tokens.len() as f64 / prompt_dt.as_secs_f64(), + ); + println!( + "{sampled:4} tokens generated: {:.2} token/s", + sampled as f64 / dt.as_secs_f64(), + ); + Ok(()) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index e77ba4a36f..2d93833581 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -94,6 +94,7 @@ pub mod quantized_phi; pub mod quantized_phi3; pub mod quantized_qwen2; pub mod quantized_qwen3; +pub mod quantized_qwen3_moe; pub mod quantized_recurrent_gemma; pub mod quantized_rwkv_v5; pub mod quantized_rwkv_v6; diff --git a/candle-transformers/src/models/quantized_qwen3.rs b/candle-transformers/src/models/quantized_qwen3.rs index 5d9f414658..e616ef03a5 100644 --- a/candle-transformers/src/models/quantized_qwen3.rs +++ b/candle-transformers/src/models/quantized_qwen3.rs @@ -14,38 +14,38 @@ use candle_nn::{kv_cache::ConcatKvCache, Activation, Embedding, Module}; use std::io::{Read, Seek}; use std::sync::Arc; -struct Gguf { +pub(crate) struct Gguf { ct: gguf_file::Content, reader: R, device: Device, } impl Gguf { - fn new(ct: gguf_file::Content, reader: R, device: Device) -> Self { + pub(crate) fn new(ct: gguf_file::Content, reader: R, device: Device) -> Self { Self { ct, reader, device } } - fn qmatmul(&mut self, name: &str) -> Result { + pub(crate) fn qmatmul(&mut self, name: &str) -> Result { let ws = self.ct.tensor(&mut self.reader, name, &self.device)?; QMatMul::from_weights(ws.into()) } - fn rms_norm(&mut self, name: &str, eps: f64) -> Result { + pub(crate) fn rms_norm(&mut self, name: &str, eps: f64) -> Result { let ws = self.ct.tensor(&mut self.reader, name, &self.device)?; RmsNorm::from_qtensor(ws, eps) } - fn metadata(&self) -> &std::collections::HashMap { + pub(crate) fn metadata(&self) -> &std::collections::HashMap { &self.ct.metadata } - fn tensor(&mut self, name: &str) -> Result { + pub(crate) fn tensor(&mut self, name: &str) -> Result { self.ct.tensor(&mut self.reader, name, &self.device) } } #[derive(Debug, Clone)] -struct MlpWeights { +pub(crate) struct MlpWeights { gate_proj: QMatMul, up_proj: QMatMul, down_proj: QMatMul, @@ -54,7 +54,7 @@ struct MlpWeights { } impl MlpWeights { - fn new(gg: &mut Gguf, prefix: &str) -> Result { + pub(crate) fn new(gg: &mut Gguf, prefix: &str) -> Result { let gate_proj = gg.qmatmul(&format!("{prefix}.ffn_gate.weight"))?; let up_proj = gg.qmatmul(&format!("{prefix}.ffn_up.weight"))?; let down_proj = gg.qmatmul(&format!("{prefix}.ffn_down.weight"))?; @@ -81,13 +81,13 @@ impl Module for MlpWeights { } #[derive(Debug, Clone)] -struct RotaryEmbedding { +pub(crate) struct RotaryEmbedding { sin: Tensor, cos: Tensor, } impl RotaryEmbedding { - fn new( + pub(crate) fn new( dtype: DType, head_dim: usize, max_position_embeddings: usize, @@ -113,7 +113,7 @@ impl RotaryEmbedding { } /// Apply RoPE (q, k shape: B x H x L x D) - fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> { + pub(crate) fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> { let (_, _, seq_len, _) = q.dims4()?; let cos = self.cos.narrow(0, offset, seq_len)?.to_dtype(q.dtype())?; let sin = self.sin.narrow(0, offset, seq_len)?.to_dtype(q.dtype())?; @@ -124,7 +124,7 @@ impl RotaryEmbedding { } #[derive(Debug, Clone)] -struct AttentionWeights { +pub(crate) struct AttentionWeights { q_proj: QMatMul, k_proj: QMatMul, v_proj: QMatMul, @@ -141,7 +141,7 @@ struct AttentionWeights { } impl AttentionWeights { - fn new( + pub(crate) fn new( gg: &mut Gguf, num_heads: usize, num_kv_heads: usize, @@ -181,7 +181,12 @@ impl AttentionWeights { }) } - fn forward(&mut self, x: &Tensor, attn_mask: Option<&Tensor>, offset: usize) -> Result { + pub(crate) fn forward( + &mut self, + x: &Tensor, + attn_mask: Option<&Tensor>, + offset: usize, + ) -> Result { let _enter = self.span_attn.enter(); let (b, l, _) = x.dims3()?; @@ -234,7 +239,7 @@ impl AttentionWeights { self.o_proj.forward(&reshaped_ctx) } - fn clear_kv_cache(&mut self) { + pub(crate) fn clear_kv_cache(&mut self) { self.kv_cache.reset(); } } diff --git a/candle-transformers/src/models/quantized_qwen3_moe.rs b/candle-transformers/src/models/quantized_qwen3_moe.rs new file mode 100644 index 0000000000..47a83cc7e5 --- /dev/null +++ b/candle-transformers/src/models/quantized_qwen3_moe.rs @@ -0,0 +1,332 @@ +//! Quantized Qwen3 MoE implementation. +//! +//! +//! References: +//! - [Qwen3 MoE Models](https://huggingface.co/docs/transformers/model_doc/qwen3_moe) (architecture based on official implementations) +//! +use super::with_tracing::QMatMul; +use crate::models::quantized_qwen3::{AttentionWeights, Gguf, MlpWeights, RotaryEmbedding}; +use crate::quantized_nn::RmsNorm; +use candle::quantized::gguf_file; +use candle::{DType, Device, Result, Tensor, D}; +use candle_nn::{Activation, Embedding, Module}; +use std::io::{Read, Seek}; +use std::sync::Arc; + +// Sparse MoE block stores the concatenated weights of all experts, no split! +#[derive(Debug, Clone)] +struct SparseMoeBlockWeights { + gate: QMatMul, + experts_gate: QMatMul, + experts_up: QMatMul, + experts_down: QMatMul, + act: candle_nn::Activation, + norm_topk_prob: bool, + num_experts_per_tok: usize, + span: tracing::Span, +} + +impl SparseMoeBlockWeights { + fn new( + gg: &mut Gguf, + prefix: &str, + act: Activation, + norm_topk_prob: bool, + num_experts_per_tok: usize, + ) -> Result { + let gate = gg.qmatmul(&format!("{prefix}.ffn_gate_inp.weight"))?; + let experts_gate = gg.qmatmul(&format!("{prefix}.ffn_gate_exps.weight"))?; + let experts_up = gg.qmatmul(&format!("{prefix}.ffn_up_exps.weight"))?; + let experts_down = gg.qmatmul(&format!("{prefix}.ffn_down_exps.weight"))?; + let span = tracing::span!(tracing::Level::TRACE, "MoEBlock"); + Ok(Self { + gate, + experts_gate, + experts_up, + experts_down, + act, + norm_topk_prob, + num_experts_per_tok, + span, + }) + } +} + +impl Module for SparseMoeBlockWeights { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + let (b_size, seq_len, hidden_dim) = xs.dims3()?; + //[b_size * seq_len, hidden_dim] + let xs = xs.reshape(((), hidden_dim))?; + let original_dtype = xs.dtype(); + let (num_tokens, hidden_dim) = xs.dims2()?; + + let router_logits = self.gate.forward(&xs.to_dtype(DType::F32)?)?; + let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?; + + // Extract topk experts per token + let experts_per_tok = routing_weights + .arg_sort_last_dim(false)? + .narrow(D::Minus1, 0, self.num_experts_per_tok)? + .contiguous()?; + let mut routing_weights = routing_weights.gather(&experts_per_tok, D::Minus1)?; + + if self.norm_topk_prob { + routing_weights = + routing_weights.broadcast_div(&routing_weights.sum_keepdim(D::Minus1)?)?; + } + + let ys = { + let xs = xs.reshape((num_tokens, 1, hidden_dim))?; + let gate = self + .experts_gate + .indexed_moe_forward(&xs, &experts_per_tok)?; + let up = self.experts_up.indexed_moe_forward(&xs, &experts_per_tok)?; + self.experts_down + .indexed_moe_forward(&(up * gate.apply(&self.act)?)?, &experts_per_tok)? + }; + + ys.broadcast_mul(&routing_weights.unsqueeze(D::Minus1)?)? + .sum(D::Minus2)? + .reshape((b_size, seq_len, hidden_dim))? + .to_dtype(original_dtype) + } +} + +#[derive(Debug, Clone)] +enum MoeOrMlpWeights { + Moe(SparseMoeBlockWeights), + Mlp(MlpWeights), +} + +impl Module for MoeOrMlpWeights { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::Moe(m) => m.forward(xs), + Self::Mlp(m) => m.forward(xs), + } + } +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + self_attn: AttentionWeights, + feed_forward: MoeOrMlpWeights, + ln1: RmsNorm, + ln2: RmsNorm, +} + +impl DecoderLayer { + fn new( + gg: &mut Gguf, + num_attention_heads: usize, + num_key_value_heads: usize, + head_dim: usize, + rms_norm_eps: f64, + rotary: Arc, + layer_idx: usize, + num_experts: usize, + decoder_sparse_step: usize, + norm_topk_prob: bool, + num_experts_per_tok: usize, + ) -> Result { + let prefix = format!("blk.{layer_idx}"); + + let ln1 = gg.rms_norm(&format!("{prefix}.attn_norm.weight"), rms_norm_eps)?; + let ln2 = gg.rms_norm(&format!("{prefix}.ffn_norm.weight"), rms_norm_eps)?; + let self_attn = AttentionWeights::new( + gg, + num_attention_heads, + num_key_value_heads, + head_dim, + rms_norm_eps, + rotary, + &prefix, + )?; + let feed_forward = if num_experts > 0 && (layer_idx + 1).is_multiple_of(decoder_sparse_step) + { + MoeOrMlpWeights::Moe(SparseMoeBlockWeights::new( + gg, + &prefix, + candle_nn::Activation::Silu, + norm_topk_prob, + num_experts_per_tok, + )?) + } else { + MoeOrMlpWeights::Mlp(MlpWeights::new(gg, &prefix)?) + }; + Ok(Self { + self_attn, + feed_forward, + ln1, + ln2, + }) + } + + fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result { + let h = self.ln1.forward(x)?; + let h = self.self_attn.forward(&h, mask, offset)?; + let x = (x + h)?; + let h2 = self.ln2.forward(&x)?; + let h2 = h2.apply(&self.feed_forward)?; + x + h2 + } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache(); + } +} + +#[derive(Debug, Clone)] +pub struct ModelWeights { + embed_tokens: Embedding, + layers: Vec, + norm: RmsNorm, + lm_head: QMatMul, + device: Device, + dtype: DType, + span: tracing::Span, + span_output: tracing::Span, +} + +impl ModelWeights { + pub fn from_gguf( + ct: gguf_file::Content, + reader: &mut R, + device: &Device, + num_experts_per_tok: Option, + ) -> Result { + let mut gg = Gguf::new(ct, reader, device.clone()); + let md_get = |s: &str| match gg.metadata().get(s) { + None => candle::bail!("cannot find {s} in metadata"), + Some(v) => Ok(v), + }; + + let num_attention_heads = md_get("qwen3moe.attention.head_count")?.to_u32()? as usize; + let num_kv_heads = md_get("qwen3moe.attention.head_count_kv")?.to_u32()? as usize; + let head_dim = md_get("qwen3moe.attention.key_length")?.to_u32()? as usize; + let num_layers = md_get("qwen3moe.block_count")?.to_u32()? as usize; + let hidden_size = md_get("qwen3moe.embedding_length")?.to_u32()? as usize; + let max_position_embeddings = md_get("qwen3moe.context_length")?.to_u32()? as usize; + let rms_norm_eps = md_get("qwen3moe.attention.layer_norm_rms_epsilon")?.to_f32()? as f64; + let rope_freq_base = md_get("qwen3moe.rope.freq_base")?.to_f32()? as f64; + let decoder_sparse_step = 1; + let moe_intermediate_size = + md_get("qwen3moe.expert_feed_forward_length")?.to_u32()? as usize; + let num_experts_per_tok = if let Some(n) = num_experts_per_tok { + n + } else { + md_get("qwen3moe.expert_used_count")?.to_u32()? as usize + }; + let num_experts = md_get("qwen3moe.expert_count")?.to_u32()? as usize; + let norm_topk_prob = false; + + let dtype = match gg.metadata().get("general.dtype") { + Some(v) => match v.to_u32() { + Ok(0) => DType::F32, + Ok(1) => DType::F16, + _ => DType::F16, + }, + None => DType::F16, + }; + + let embed_tensor = gg.tensor("token_embd.weight")?; + let embed_tokens = Embedding::new(embed_tensor.dequantize(device)?, hidden_size); + + let rotary = Arc::new(RotaryEmbedding::new( + dtype, + head_dim, + max_position_embeddings, + rope_freq_base, + device, + )?); + + let mut layers = Vec::with_capacity(num_layers); + for i in 0..num_layers { + layers.push(DecoderLayer::new( + &mut gg, + num_attention_heads, + num_kv_heads, + head_dim, + rms_norm_eps, + rotary.clone(), + i, + num_experts, + decoder_sparse_step, + norm_topk_prob, + num_experts_per_tok, + )?); + } + + let norm = gg.rms_norm("output_norm.weight", rms_norm_eps)?; + // Load output projection tensor, falling back to tied embeddings like gemma3 + let lm_head_tensor = match gg.tensor("output.weight") { + Ok(tensor) => tensor, + Err(_) => gg.tensor("token_embd.weight")?, + }; + let lm_head = QMatMul::from_weights(lm_head_tensor.into())?; + let span = tracing::span!(tracing::Level::TRACE, "model"); + let span_output = tracing::span!(tracing::Level::TRACE, "output"); + Ok(Self { + embed_tokens, + layers, + norm, + lm_head, + device: device.clone(), + dtype, + span, + span_output, + }) + } + + fn causal_mask( + &self, + b: usize, + tgt: usize, + offset: usize, + sw: Option, + ) -> Result { + let minf = f32::NEG_INFINITY; + let mask: Vec<_> = (0..tgt) + .flat_map(|i| { + (0..(tgt + offset)).map(move |j| { + let past_ok = j <= i + offset; + let sw_ok = match sw { + Some(w) => (i + offset) as i64 - j as i64 <= w as i64, + None => true, + }; + if past_ok && sw_ok { + 0. + } else { + minf + } + }) + }) + .collect(); + Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype) + } + + pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result { + let _enter = self.span.enter(); + let (b, l) = input.dims2()?; + let mut h = self.embed_tokens.forward(input)?; + let causal_mask = if l == 1 { + None + } else { + Some(self.causal_mask(b, l, offset, None)?) + }; + for layer in &mut self.layers { + h = layer.forward(&h, causal_mask.as_ref(), offset)?; + } + let h = self.norm.forward(&h)?; + let _enter = self.span_output.enter(); + let last_hidden = h.narrow(1, l - 1, 1)?; + self.lm_head.forward(&last_hidden)?.squeeze(1) + } + + pub fn clear_kv_cache(&mut self) { + for layer in &mut self.layers { + layer.clear_kv_cache(); + } + } +} diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs index f4706c7e95..77887ff297 100644 --- a/candle-transformers/src/models/with_tracing.rs +++ b/candle-transformers/src/models/with_tracing.rs @@ -122,6 +122,11 @@ impl QMatMul { let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); Ok(Self { inner, span }) } + + pub fn indexed_moe_forward(&self, xs: &Tensor, ids: &Tensor) -> Result { + let _enter = self.span.enter(); + self.inner.indexed_moe_forward(xs, ids) + } } impl Module for QMatMul {