Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eval bug: garbage output right after kv-cache defragmentation for CPU backend #12253

Closed
aviallon opened this issue Mar 7, 2025 · 24 comments
Closed
Labels
bug Something isn't working

Comments

@aviallon
Copy link
Contributor

aviallon commented Mar 7, 2025

Name and Version

$ llama-server --version
version: 4798 (1782cdfe)
built with cc (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0 for x86_64-linux-gnu

Operating systems

Linux

GGML backends

CPU

Hardware

Ryzen 5950X (configured with LLC as NUMA node)

Models

bartowski/microsoft_Phi-4-mini-instruct-GGUF:Q4_0

Problem description & steps to reproduce

When running llama-server with this exact command:

$ LLMA_ARG_N_PREDICT=-1 LLAMA_ARG_HF_REPO=bartowski/microsoft_Phi-4-mini-instruct-GGUF:Q4_0 LLAMA_ARG_ALIAS=myalias LLAMA_ARG_N_GPU_LAYERS=0 LLAMA_ARG_BATCH=1984 LLAMA_ARG_THREADS=8 LLAMA_ARG_NUMA=distribute llama-server --cache-type-k q8_0 --cache-type-v q8_0 -c 4096 --parallel 2 -fa --metrics

and using the API or the WebUI to make the model generate large outputs on two slots at once, I get garbage output (which stays until server restart) as soon as KV-cache defragmentation occurs once.
In fact, disabling it makes the issue disappear, and setting a breakpoint into llama_kv_cache_defrag_impl(lctx); allows pinpointing with absolute certitude that it is indeed caused by KV-cache defragmentation.

Setting #if 0 to #if 1 causes a segfault in (memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);)

Here is the stack trace:

libc.so.6!__memcpy_avx_unaligned_erms() (memmove-vec-unaligned-erms.S:265)
libllama.so!memcpy(const void * restrict __src, void * restrict __dest) (/usr/include/x86_64-linux-gnu/bits/string_fortified.h:29)
libllama.so!llama_kv_cache_defrag_impl(llama_context & lctx) (/workspaces/llama.cpp/src/llama.cpp:9210)
libllama.so!llama_kv_cache_update_impl(llama_context & lctx) (/workspaces/llama.cpp/src/llama.cpp:9271)
libllama.so!llama_kv_cache_update(llama_context * ctx) (/workspaces/llama.cpp/src/llama.cpp:9970)
libllama.so!llama_prepare_ubatch(const uint32_t n_tokens_all, llama_ubatch & ubatch, llama_kv_slot_restorer & kv_slot_restorer, llama_context & lctx) (/workspaces/llama.cpp/src/llama.cpp:8547)
libllama.so!llama_decode_impl(llama_context & lctx, llama_batch inp_batch) (/workspaces/llama.cpp/src/llama.cpp:8628)
libllama.so!llama_decode(llama_context * ctx, llama_batch batch) (/workspaces/llama.cpp/src/llama.cpp:9993)
server_context::update_slots(server_context * const this) (/workspaces/llama.cpp/examples/server/server.cpp:3131)
std::function<void ()>::operator()() const(const std::function<void()> * const this) (/usr/include/c++/13/bits/std_function.h:591)
server_queue::start_loop(server_queue * const this) (/workspaces/llama.cpp/examples/server/server.cpp:1622)
main(int argc, char ** argv) (/workspaces/llama.cpp/examples/server/server.cpp:4500)

Note: here are my compile options:

$ cmake -B build -S . -DCMAKE_INSTALL_PREFIX=out/install/default -DCMAKE_BUILD_TYPE=RelWithDebInfo -DGGML_NATIVE=OFF -DGGML_LTO=ON -DGGML_AARCH64_REPACK=ON -DLLAMA_BUILD_SERVER=ON -DCMAKE_CXX_FLAGS="-flto=auto -ggdb -fno-omit-frame-pointer" -DLLAMA_CURL=ON

First Bad Commit

No response

Relevant log output

$ LLMA_ARG_N_PREDICT=-1 LLAMA_ARG_HF_REPO=bartowski/microsoft_Phi-4-mini-instruct-GGUF:Q4_0 LLAMA_ARG_ALIAS=myalias LLAMA_ARG_N_GPU_LAYERS=0 LLAMA_ARG_BATCH=1984 LLAMA_ARG_THREADS=8 LLAMA_ARG_NUMA=distribute llama-server --cache-type-k q8_0 --cache-type-v q8_0 -c 4096 --parallel 2 -fa --metrics
warning: no usable GPU found, --gpu-layers option will be ignored
warning: one possible reason is that llama.cpp was compiled without GPU support
warning: consult docs/build.md for compilation instructions
build: 4798 (1782cdfe) with cc (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0 for x86_64-linux-gnu
system info: n_threads = 8, n_threads_batch = 8, total_threads = 32

system_info: n_threads = 8 (n_threads_batch = 8) / 32 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | LLAMAFILE = 1 | OPENMP = 1 | AARCH64_REPACK = 1 | 

main: HTTP server is listening, hostname: 127.0.0.1, port: 8080, http threads: 31
main: loading model
srv    load_model: loading model '/home/vscode/.cache/llama.cpp/bartowski_microsoft_Phi-4-mini-instruct-GGUF_microsoft_Phi-4-mini-instruct-Q4_0.gguf'
common_download_file: previous metadata file found /home/vscode/.cache/llama.cpp/bartowski_microsoft_Phi-4-mini-instruct-GGUF_microsoft_Phi-4-mini-instruct-Q4_0.gguf.json: {"etag":"\"9d79733b873cc227388646f1ab2f102c-146\"","lastModified":"Fri, 28 Feb 2025 15:50:38 GMT","url":"https://huggingface.co/bartowski/microsoft_Phi-4-mini-instruct-GGUF/resolve/main/microsoft_Phi-4-mini-instruct-Q4_0.gguf"}
curl_perform_with_retry: Trying to download from https://huggingface.co/bartowski/microsoft_Phi-4-mini-instruct-GGUF/resolve/main/microsoft_Phi-4-mini-instruct-Q4_0.gguf (attempt 1 of 3)...
llama_model_loader: loaded meta data with 40 key-value pairs and 196 tensors from /home/vscode/.cache/llama.cpp/bartowski_microsoft_Phi-4-mini-instruct-GGUF_microsoft_Phi-4-mini-instruct-Q4_0.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = phi3
llama_model_loader: - kv   1:              phi3.rope.scaling.attn_factor f32              = 1.190238
llama_model_loader: - kv   2:                               general.type str              = model
llama_model_loader: - kv   3:                               general.name str              = Phi 4 Mini Instruct
llama_model_loader: - kv   4:                           general.finetune str              = instruct
llama_model_loader: - kv   5:                           general.basename str              = Phi-4
llama_model_loader: - kv   6:                         general.size_label str              = mini
llama_model_loader: - kv   7:                            general.license str              = mit
llama_model_loader: - kv   8:                       general.license.link str              = https://huggingface.co/microsoft/Phi-...
llama_model_loader: - kv   9:                               general.tags arr[str,3]       = ["nlp", "code", "text-generation"]
llama_model_loader: - kv  10:                          general.languages arr[str,1]       = ["multilingual"]
llama_model_loader: - kv  11:                        phi3.context_length u32              = 131072
llama_model_loader: - kv  12:  phi3.rope.scaling.original_context_length u32              = 4096
llama_model_loader: - kv  13:                      phi3.embedding_length u32              = 3072
llama_model_loader: - kv  14:                   phi3.feed_forward_length u32              = 8192
llama_model_loader: - kv  15:                           phi3.block_count u32              = 32
llama_model_loader: - kv  16:                  phi3.attention.head_count u32              = 24
llama_model_loader: - kv  17:               phi3.attention.head_count_kv u32              = 8
llama_model_loader: - kv  18:      phi3.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  19:                  phi3.rope.dimension_count u32              = 96
llama_model_loader: - kv  20:                        phi3.rope.freq_base f32              = 10000.000000
llama_model_loader: - kv  21:              phi3.attention.sliding_window u32              = 262144
llama_model_loader: - kv  22:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  23:                         tokenizer.ggml.pre str              = gpt-4o
llama_model_loader: - kv  24:                      tokenizer.ggml.tokens arr[str,200064]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  25:                  tokenizer.ggml.token_type arr[i32,200064]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  26:                      tokenizer.ggml.merges arr[str,199742]  = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "e r", ...
llama_model_loader: - kv  27:                tokenizer.ggml.bos_token_id u32              = 199999
llama_model_loader: - kv  28:                tokenizer.ggml.eos_token_id u32              = 199999
llama_model_loader: - kv  29:            tokenizer.ggml.unknown_token_id u32              = 199999
llama_model_loader: - kv  30:            tokenizer.ggml.padding_token_id u32              = 199999
llama_model_loader: - kv  31:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  32:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  33:                    tokenizer.chat_template str              = {% for message in messages %}{% if me...
llama_model_loader: - kv  34:               general.quantization_version u32              = 2
llama_model_loader: - kv  35:                          general.file_type u32              = 2
llama_model_loader: - kv  36:                      quantize.imatrix.file str              = /models_out/Phi-4-mini-instruct-GGUF/...
llama_model_loader: - kv  37:                   quantize.imatrix.dataset str              = /training_dir/calibration_datav3.txt
llama_model_loader: - kv  38:             quantize.imatrix.entries_count i32              = 128
llama_model_loader: - kv  39:              quantize.imatrix.chunks_count i32              = 123
llama_model_loader: - type  f32:   67 tensors
llama_model_loader: - type q4_0:  124 tensors
llama_model_loader: - type q4_1:    4 tensors
llama_model_loader: - type q6_K:    1 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = Q4_0
print_info: file size   = 2.16 GiB (4.85 BPW) 
load: special tokens cache size = 12
load: token to piece cache size = 1.3333 MB
print_info: arch             = phi3
print_info: vocab_only       = 0
print_info: n_ctx_train      = 131072
print_info: n_embd           = 3072
print_info: n_layer          = 32
print_info: n_head           = 24
print_info: n_head_kv        = 8
print_info: n_rot            = 96
print_info: n_swa            = 262144
print_info: n_embd_head_k    = 128
print_info: n_embd_head_v    = 128
print_info: n_gqa            = 3
print_info: n_embd_k_gqa     = 1024
print_info: n_embd_v_gqa     = 1024
print_info: f_norm_eps       = 0.0e+00
print_info: f_norm_rms_eps   = 1.0e-05
print_info: f_clamp_kqv      = 0.0e+00
print_info: f_max_alibi_bias = 0.0e+00
print_info: f_logit_scale    = 0.0e+00
print_info: n_ff             = 8192
print_info: n_expert         = 0
print_info: n_expert_used    = 0
print_info: causal attn      = 1
print_info: pooling type     = 0
print_info: rope type        = 2
print_info: rope scaling     = linear
print_info: freq_base_train  = 10000.0
print_info: freq_scale_train = 1
print_info: n_ctx_orig_yarn  = 4096
print_info: rope_finetuned   = unknown
print_info: ssm_d_conv       = 0
print_info: ssm_d_inner      = 0
print_info: ssm_d_state      = 0
print_info: ssm_dt_rank      = 0
print_info: ssm_dt_b_c_rms   = 0
print_info: model type       = 3B
print_info: model params     = 3.84 B
print_info: general.name     = Phi 4 Mini Instruct
print_info: vocab type       = BPE
print_info: n_vocab          = 200064
print_info: n_merges         = 199742
print_info: BOS token        = 199999 '<|endoftext|>'
print_info: EOS token        = 199999 '<|endoftext|>'
print_info: EOT token        = 199999 '<|endoftext|>'
print_info: UNK token        = 199999 '<|endoftext|>'
print_info: PAD token        = 199999 '<|endoftext|>'
print_info: LF token         = 198 'Ċ'
print_info: EOG token        = 199999 '<|endoftext|>'
print_info: EOG token        = 200020 '<|end|>'
print_info: max token length = 256
load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors:  CPU_AARCH64 model buffer size =  1674.00 MiB
load_tensors:   CPU_Mapped model buffer size =  2188.57 MiB
.........................................................................
llama_init_from_model: n_seq_max     = 2
llama_init_from_model: n_ctx         = 4096
llama_init_from_model: n_ctx_per_seq = 2048
llama_init_from_model: n_batch       = 1984
llama_init_from_model: n_ubatch      = 512
llama_init_from_model: flash_attn    = 1
llama_init_from_model: freq_base     = 10000.0
llama_init_from_model: freq_scale    = 1
llama_init_from_model: n_ctx_per_seq (2048) < n_ctx_train (131072) -- the full capacity of the model will not be utilized
llama_kv_cache_init: kv_size = 4096, offload = 1, type_k = 'q8_0', type_v = 'q8_0', n_layer = 32, can_shift = 1
llama_kv_cache_init:        CPU KV buffer size =   272.00 MiB
llama_init_from_model: KV self size  =  272.00 MiB, K (q8_0):  136.00 MiB, V (q8_0):  136.00 MiB
llama_init_from_model:        CPU  output buffer size =     1.53 MiB
llama_init_from_model:        CPU compute buffer size =   404.76 MiB
llama_init_from_model: graph nodes  = 1159
llama_init_from_model: graph splits = 1
common_init_from_params: setting dry_penalty_last_n to ctx_size = 4096
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
srv          init: initializing slots, n_slots = 2
slot         init: id  0 | task -1 | new slot n_ctx_slot = 2048
slot         init: id  1 | task -1 | new slot n_ctx_slot = 2048
main: model loaded
main: chat template, chat_template: {% for message in messages %}{% if message['role'] == 'system' and 'tools' in message and message['tools'] is not none %}{{ '<|' + message['role'] + '|>' + message['content'] + '<|tool|>' + message['tools'] + '<|/tool|>' + '<|end|>' }}{% else %}{{ '<|' + message['role'] + '|>' + message['content'] + '<|end|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>' }}{% else %}{{ eos_token }}{% endif %}, example_format: '<|system|>
You are a helpful assistant<|end|>
<|user|>
Hello<|end|>
<|assistant|>
Hi there<|end|>
<|user|>
How are you?<|end|>
<|assistant|>
'
main: server is listening on http://127.0.0.1:8080 - starting the main loop
srv  update_slots: all slots are idle
srv  params_from_: Chat format: Content-only
slot launch_slot_: id  0 | task 0 | processing task
slot update_slots: id  0 | task 0 | new prompt, n_ctx_slot = 2048, n_keep = 0, n_prompt_tokens = 109
slot update_slots: id  0 | task 0 | kv cache rm [0, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 109, n_tokens = 109, progress = 1.000000
slot update_slots: id  0 | task 0 | prompt done, n_past = 109, n_tokens = 109
srv  params_from_: Chat format: Content-only
slot launch_slot_: id  1 | task 41 | processing task
slot update_slots: id  1 | task 41 | new prompt, n_ctx_slot = 2048, n_keep = 0, n_prompt_tokens = 29
slot update_slots: id  1 | task 41 | kv cache rm [0, end)
slot update_slots: id  1 | task 41 | prompt processing progress, n_past = 29, n_tokens = 30, progress = 1.000000
slot update_slots: id  1 | task 41 | prompt done, n_past = 29, n_tokens = 30
slot      release: id  0 | task 0 | stop processing: n_past = 684, truncated = 0
slot print_timing: id  0 | task 0 | 
prompt eval time =     981.50 ms /   109 tokens (    9.00 ms per token,   111.05 tokens per second)
       eval time =   47713.23 ms /   576 tokens (   82.84 ms per token,    12.07 tokens per second)
      total time =   48694.73 ms /   685 tokens
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Content-only
slot launch_slot_: id  0 | task 671 | processing task
slot update_slots: id  0 | task 671 | new prompt, n_ctx_slot = 2048, n_keep = 0, n_prompt_tokens = 109
slot update_slots: id  0 | task 671 | need to evaluate at least 1 token to generate logits, n_past = 109, n_prompt_tokens = 109
slot update_slots: id  0 | task 671 | kv cache rm [108, end)
slot update_slots: id  0 | task 671 | prompt processing progress, n_past = 109, n_tokens = 2, progress = 0.009174
slot update_slots: id  0 | task 671 | prompt done, n_past = 109, n_tokens = 2
slot      release: id  1 | task 41 | stop processing: n_past = 691, truncated = 0
slot print_timing: id  1 | task 41 | 
prompt eval time =     306.28 ms /    29 tokens (   10.56 ms per token,    94.69 tokens per second)
       eval time =   54127.38 ms /   663 tokens (   81.64 ms per token,    12.25 tokens per second)
      total time =   54433.65 ms /   692 tokens
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Content-only
slot launch_slot_: id  1 | task 786 | processing task
slot update_slots: id  1 | task 786 | new prompt, n_ctx_slot = 2048, n_keep = 0, n_prompt_tokens = 905
slot update_slots: id  1 | task 786 | kv cache rm [18, end)
slot update_slots: id  1 | task 786 | prompt processing progress, n_past = 905, n_tokens = 888, progress = 0.980111
slot update_slots: id  1 | task 786 | prompt done, n_past = 905, n_tokens = 888
slot      release: id  0 | task 671 | stop processing: n_past = 673, truncated = 0
slot print_timing: id  0 | task 671 | 
prompt eval time =     103.39 ms /     1 tokens (  103.39 ms per token,     9.67 tokens per second)
       eval time =   58151.75 ms /   565 tokens (  102.92 ms per token,     9.72 tokens per second)
      total time =   58255.14 ms /   566 tokens
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Content-only
slot launch_slot_: id  0 | task 1269 | processing task
slot update_slots: id  0 | task 1269 | new prompt, n_ctx_slot = 2048, n_keep = 0, n_prompt_tokens = 109
slot update_slots: id  0 | task 1269 | need to evaluate at least 1 token to generate logits, n_past = 109, n_prompt_tokens = 109
slot update_slots: id  0 | task 1269 | kv cache rm [108, end)
slot update_slots: id  0 | task 1269 | prompt processing progress, n_past = 109, n_tokens = 2, progress = 0.009174
slot update_slots: id  0 | task 1269 | prompt done, n_past = 109, n_tokens = 2
@ggerganov
Copy link
Member

Does the issue occur without quantized KV cache?
Also does it occur if you simply disable the defrag with -dt 0?

@aviallon
Copy link
Contributor Author

aviallon commented Mar 7, 2025

@ggerganov it does not occur if I disable the defrag, as mentioned.
I did not try with unquantized KV cache. I'll try now.

@aviallon
Copy link
Contributor Author

aviallon commented Mar 7, 2025

It looks like the bug does not occur with KV-cache quantization disabled.
To be more precise, before that, in "CPU defrag" mode, I was triggering an out-of-bounds write.
If no quantization is enabled, it does not occur anymore.

@ggerganov
Copy link
Member

I am not very sure if the "CPU defrag" mode works correctly - I used it just for the debugging in the initial implementation, but after that it might have broke. It's better to verify that no garbage is produced with non-quantized cache without using the "CPU defrag".

Also check if with quantized cache and using 1 thread (-t 1) the issue disappears.

@aviallon
Copy link
Contributor Author

aviallon commented Mar 7, 2025

Yes it does occur.
I have garbled output with the default defrag (graph-based) algorithm, even with -t 1.

@ggerganov
Copy link
Member

Does this patch fix it:

diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index f2ab4c5d..6988187a 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -4085,7 +4085,7 @@ static void ggml_compute_forward_dup_bytes(
         ne00 == ne0 &&
         nb00 == type_size && nb0 == type_size) {
         // copy by rows
-        const size_t rs = ne00 * type_size;
+        const size_t rs = ggml_row_size(src0->type, ne00);
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
                 for (int64_t i01 = ir0; i01 < ir1; i01++) {

@aviallon
Copy link
Contributor Author

aviallon commented Mar 7, 2025

Not completely. It is less worse though. Some slots continued to correctly output content.

@aviallon
Copy link
Contributor Author

aviallon commented Mar 7, 2025

Another thing to note, I actually made a mistake before: with graph defrag, I have garbled output even with f16 KV-cache.
I did not have this phenomenon when using "CPU defrag"

@aviallon
Copy link
Contributor Author

aviallon commented Mar 7, 2025

Note: increasing --parallel (and making parallel requests) helps make the problem occur faster.

@ggerganov
Copy link
Member

Try to also add this to the patch:

diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index f2ab4c5d..141b9aec 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -3111,7 +3111,7 @@ static void ggml_compute_forward_dup_same_cont(
     const int nth = params->nth; // number of threads
 
     // parallelize by elements
-    const int ne = ggml_nelements(dst);
+    const int ne = ggml_nelements(dst)/ggml_blck_size(dst->type);
     const int dr = (ne + nth - 1) / nth;
     const int ie0 = dr * ith;
     const int ie1 = MIN(ie0 + dr, ne);

I think the bug is clear - we are miscalculating the buffer size for CPY(Q -> Q). I'm blindly going through the code and guessing what might have an effect on the defrag. But the proper solution is to fix all places where this mistake was made.

@aviallon
Copy link
Contributor Author

aviallon commented Mar 7, 2025

With your patch, it helps all of the slots make some correct forward progress before they end up outputting garbage. So it IS better.
To be more precise, the defragmentation issue no longer contaminates the other slots. Only the slot(s) that caused the defragmentation to occur output garbage.

Do you know how I can help fixing the buffer size miscalculation?

@aviallon
Copy link
Contributor Author

aviallon commented Mar 7, 2025

Another question: what kind of assert can I add in order to abort instead of producing invalid output?
In my case, this is a much desirable outcome, since I can detect crashes quite easily.

@ggerganov
Copy link
Member

ggerganov commented Mar 7, 2025

Assert is not necessary - we just have to fix the implementation. In these 2 patches, it was assumed that the block size of the tensor is 1, which is not true for quantized data types. From what I saw today, there are a few places in the code where this mistake was made.

I think you can try to add test_cpy cases in test-backend-ops.cpp for the Q8_0 -> Q8_0. These will likely fail when you run test-backend-ops. And you can start from there to find the fix.

@aviallon
Copy link
Contributor Author

aviallon commented Mar 7, 2025

@ggerganov I tried to add a test case, but it does not seem to fail.

diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 1dc2cdda..2d7f67d4 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -3943,6 +3943,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
             test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {1, 0, 2, 3})); // cpy not-contiguous
         }
     }
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0, {256, 2, 3, 4}, {1, 0, 2, 3}));
 
     test_cases.emplace_back(new test_cont());
     test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1}));

@aviallon

This comment has been minimized.

@ggerganov
Copy link
Member

Try to run:

./bin/test-backend-ops -o CPY -b CPU

I just tried and it fails.

@aviallon
Copy link
Contributor Author

aviallon commented Mar 7, 2025

Indeed. Thanks.

@aviallon
Copy link
Contributor Author

aviallon commented Mar 8, 2025

@ggerganov the backend is quite a complex beast.
I am trying to understand, how it is structured.
To be more precise, why does it work for all other combinations (Copy QX_0 <> F32), but not for any QX_0 <> QX_0?
How are tensors represented in llama.cpp's memory?

@aviallon

This comment has been minimized.

@ggerganov ggerganov added bug Something isn't working and removed bug-unconfirmed labels Mar 8, 2025
@ggerganov
Copy link
Member

@aviallon Could you test if #12310 works?

@aviallon
Copy link
Contributor Author

aviallon commented Mar 10, 2025

@ggerganov It is much better. When a cache-shift occurs (often accompanied by a defrag), only the current completion becomes broken. But it is an infinite repetition issue. It may be another bug entirely.

@ggerganov
Copy link
Member

When the model uses a chat template (i.e. requires a specific structure of the context) shifting the KV cache is not recommended because it will destroy the structure of the context (i.e. the system message will be shifted away, the roles information of the messages can become incomplete, etc.). This will likely lead to repetition and corrupted generation, so it is expected to behave like this.

You can solve this on the client side by making a queue of the chat messages and discard the old messages from the context when you get close to filling the context. Using the --cache-reuse parameter will still avoid recomputing the new context after you evict the old messages, so the performance will still be good.

@aviallon
Copy link
Contributor Author

@ggerganov does that mean I should disable cache shifting in that case?

@ggerganov
Copy link
Member

Yes, microsoft_Phi-4-mini-instruct is a chat-tuned model so it is incompatible with context shift.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants