Skip to content

Commit f9a9eb1

Browse files
Support FP16 as intermediate results in graph computation
This commit is a demo aimed at using FP16 as the data type for intermediate results in graph inference, reducing computation and improving inference speed. Verification was conducted with the CANN backend on Qwen2.5, Qwen3-MoE, and DeepSeek-Lite-V2, showing performance improvements of 3%–10% depending on the concurrency and model. The main changes include modifying operators involved in graph by replacing hardcoded FP32 data types with type inference based on input, adding FP16 support for GET_ROWS, and casting t_embd and t_logits back to FP32 at the end of inference. In fact, this is only a very basic validation. For full FP16 support, the following are still needed: 1. Modify all operators that currently hardcode FP32 to perform type inference based on the data type. 2. Add FP16 support to all backend operators. 3. Extend test cases to include FP16 data types. Co-authored-by: noemotiovon <[email protected]>
1 parent 3b337b0 commit f9a9eb1

File tree

6 files changed

+70
-9
lines changed

6 files changed

+70
-9
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4580,9 +4580,15 @@ static void ggml_compute_forward_get_rows_f16(
45804580

45814581
GGML_ASSERT(i01 >= 0 && i01 < ne01);
45824582

4583-
ggml_cpu_fp16_to_fp32(
4584-
(const ggml_fp16_t*) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
4585-
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
4583+
// Supports both F16 and F32 as dst type.
4584+
if (dst->type == GGML_TYPE_F16)
4585+
ggml_vec_cpy_f16(nc,
4586+
(ggml_fp16_t *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
4587+
(ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
4588+
else
4589+
ggml_cpu_fp16_to_fp32(
4590+
(const ggml_fp16_t*) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
4591+
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
45864592
}
45874593
}
45884594

@@ -4662,9 +4668,15 @@ static void ggml_compute_forward_get_rows_f32(
46624668

46634669
GGML_ASSERT(i01 >= 0 && i01 < ne01);
46644670

4665-
ggml_vec_cpy_f32(nc,
4671+
// Supports both F16 and F32 as dst type.
4672+
if (dst->type == GGML_TYPE_F32)
4673+
ggml_vec_cpy_f32(nc,
46664674
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
46674675
(float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
4676+
else
4677+
ggml_cpu_fp32_to_fp16(
4678+
(const float*) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
4679+
(ggml_fp16_t *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
46684680
}
46694681
}
46704682

ggml/src/ggml-cpu/vec.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ inline static void ggml_vec_sub_f16 (const int n, ggml_fp16_t * z, const ggml_fp
8787
}
8888
inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
8989
inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
90+
inline static void ggml_vec_cpy_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
9091
inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
9192
inline static void ggml_vec_neg_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
9293
for (int i = 0; i < n; ++i) {

ggml/src/ggml.c

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3024,7 +3024,10 @@ struct ggml_tensor * ggml_mul_mat(
30243024
GGML_ASSERT(!ggml_is_transposed(a));
30253025

30263026
const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
3027-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
3027+
// Tensor a is the weight, with its type determined by the model file.
3028+
// Tensor b is the activation, i.e., the intermediate computation result.
3029+
// Here, the destination type (dst) is kept the same as the input activation type.
3030+
struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
30283031

30293032
result->op = GGML_OP_MUL_MAT;
30303033
result->src[0] = a;
@@ -3073,7 +3076,9 @@ struct ggml_tensor * ggml_mul_mat_id(
30733076
GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast
30743077

30753078
const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 };
3076-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
3079+
// Tensor b is the activation, i.e., the intermediate computation result.
3080+
// Here, the destination type (dst) is kept the same as the input activation type.
3081+
struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
30773082

30783083
result->op = GGML_OP_MUL_MAT_ID;
30793084
result->src[0] = as;
@@ -3628,7 +3633,9 @@ struct ggml_tensor * ggml_get_rows(
36283633
GGML_ASSERT(b->type == GGML_TYPE_I32);
36293634

36303635
// TODO: implement non F32 return
3631-
enum ggml_type type = GGML_TYPE_F32;
3636+
// TODO: Automatically select the destination type based on parameters,
3637+
// environment variables, or backend support. Hard code F16 for example.
3638+
enum ggml_type type = GGML_TYPE_F16;
36323639
if (a->type == GGML_TYPE_I32) {
36333640
type = a->type;
36343641
}
@@ -3676,7 +3683,8 @@ struct ggml_tensor * ggml_set_rows(
36763683
GGML_ASSERT(b->ne[2] % c->ne[1] == 0);
36773684
GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
36783685
GGML_ASSERT(c->ne[3] == 1);
3679-
GGML_ASSERT(b->type == GGML_TYPE_F32);
3686+
// b->type also can be F16.
3687+
//GGML_ASSERT(b->type == GGML_TYPE_F32);
36803688
GGML_ASSERT(c->type == GGML_TYPE_I64 || c->type == GGML_TYPE_I32);
36813689

36823690
GGML_ASSERT(ggml_is_contiguous_rows(a));
@@ -5003,7 +5011,10 @@ struct ggml_tensor * ggml_flash_attn_ext(
50035011

50045012
// permute(0, 2, 1, 3)
50055013
int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] };
5006-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
5014+
// The types of k and v are the same as those in the KV cache,
5015+
// while q is an intermediate computation result.
5016+
// Here, the destination type (dst) is kept the same as the type of q.
5017+
struct ggml_tensor * result = ggml_new_tensor(ctx, q->type, 4, ne);
50075018

50085019
float params[] = { scale, max_bias, logit_softcap };
50095020
ggml_set_op_params(result, params, sizeof(params));

src/llama-graph.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1937,6 +1937,38 @@ void llm_graph_context::build_pooling(
19371937
ggml_build_forward_expand(gf, cur);
19381938
}
19391939

1940+
void llm_graph_context::cast_outputs() const {
1941+
ggml_tensor * ori_embd = res->t_embd;
1942+
if (cparams.embeddings && res->t_embd->type != GGML_TYPE_F32) {
1943+
ggml_tensor * embd = res->t_embd;
1944+
embd = ggml_cast(ctx0, embd, GGML_TYPE_F32);
1945+
cb(embd, "result_embd_cast", -1);
1946+
ggml_build_forward_expand(gf, embd);
1947+
res->t_embd = embd;
1948+
}
1949+
1950+
if (cparams.embeddings && res->t_embd_pooled->type != GGML_TYPE_F32) {
1951+
// if LLAMA_POOLING_TYPE_NONE, embd_pooled == embd
1952+
if (res->t_embd_pooled == ori_embd) {
1953+
res->t_embd_pooled = res->t_embd;
1954+
} else {
1955+
ggml_tensor * embd_pooled = res->t_embd_pooled;
1956+
embd_pooled = ggml_cast(ctx0, embd_pooled, GGML_TYPE_F32);
1957+
cb(embd_pooled, "result_embd_pooled_cast", -1);
1958+
ggml_build_forward_expand(gf, embd_pooled);
1959+
res->t_embd_pooled = embd_pooled;
1960+
}
1961+
}
1962+
1963+
if(res->t_logits->type != GGML_TYPE_F32) {
1964+
ggml_tensor * logits = res->t_logits;
1965+
logits = ggml_cast(ctx0, logits, GGML_TYPE_F32);
1966+
cb(logits, "result_logits_cast", -1);
1967+
ggml_build_forward_expand(gf, logits);
1968+
res->t_logits = logits;
1969+
}
1970+
}
1971+
19401972
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
19411973
// TODO move to hparams if a T5 variant appears that uses a different value
19421974
const int64_t max_distance = 128;

src/llama-graph.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,8 @@ struct llm_graph_context {
814814
ggml_tensor * cls_b,
815815
ggml_tensor * cls_out,
816816
ggml_tensor * cls_out_b) const;
817+
818+
void cast_outputs() const;
817819
};
818820

819821
// TODO: better name

src/llama-model.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19618,6 +19618,9 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1961819618
// add on pooling layer
1961919619
llm->build_pooling(cls, cls_b, cls_out, cls_out_b);
1962019620

19621+
// cast output to F32
19622+
llm->cast_outputs();
19623+
1962119624
return llm->res->get_gf();
1962219625
}
1962319626

0 commit comments

Comments
 (0)