-
Notifications
You must be signed in to change notification settings - Fork 11k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Eval bug: garbage output right after kv-cache defragmentation for CPU backend #12253
Comments
Does the issue occur without quantized KV cache? |
@ggerganov it does not occur if I disable the defrag, as mentioned. |
It looks like the bug does not occur with KV-cache quantization disabled. |
I am not very sure if the "CPU defrag" mode works correctly - I used it just for the debugging in the initial implementation, but after that it might have broke. It's better to verify that no garbage is produced with non-quantized cache without using the "CPU defrag". Also check if with quantized cache and using 1 thread ( |
Yes it does occur. |
Does this patch fix it: diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index f2ab4c5d..6988187a 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -4085,7 +4085,7 @@ static void ggml_compute_forward_dup_bytes(
ne00 == ne0 &&
nb00 == type_size && nb0 == type_size) {
// copy by rows
- const size_t rs = ne00 * type_size;
+ const size_t rs = ggml_row_size(src0->type, ne00);
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
Not completely. It is less worse though. Some slots continued to correctly output content. |
Another thing to note, I actually made a mistake before: with graph defrag, I have garbled output even with f16 KV-cache. |
Note: increasing |
Try to also add this to the patch: diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index f2ab4c5d..141b9aec 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -3111,7 +3111,7 @@ static void ggml_compute_forward_dup_same_cont(
const int nth = params->nth; // number of threads
// parallelize by elements
- const int ne = ggml_nelements(dst);
+ const int ne = ggml_nelements(dst)/ggml_blck_size(dst->type);
const int dr = (ne + nth - 1) / nth;
const int ie0 = dr * ith;
const int ie1 = MIN(ie0 + dr, ne); I think the bug is clear - we are miscalculating the buffer size for |
With your patch, it helps all of the slots make some correct forward progress before they end up outputting garbage. So it IS better. Do you know how I can help fixing the buffer size miscalculation? |
Another question: what kind of assert can I add in order to abort instead of producing invalid output? |
Assert is not necessary - we just have to fix the implementation. In these 2 patches, it was assumed that the block size of the tensor is 1, which is not true for quantized data types. From what I saw today, there are a few places in the code where this mistake was made. I think you can try to add |
@ggerganov I tried to add a test case, but it does not seem to fail. diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 1dc2cdda..2d7f67d4 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -3943,6 +3943,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {1, 0, 2, 3})); // cpy not-contiguous
}
}
+ test_cases.emplace_back(new test_cpy(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0, {256, 2, 3, 4}, {1, 0, 2, 3}));
test_cases.emplace_back(new test_cont());
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1})); |
This comment has been minimized.
This comment has been minimized.
Try to run: ./bin/test-backend-ops -o CPY -b CPU I just tried and it fails. |
Indeed. Thanks. |
@ggerganov the backend is quite a complex beast. |
This comment has been minimized.
This comment has been minimized.
@ggerganov It is much better. When a cache-shift occurs (often accompanied by a defrag), only the current completion becomes broken. But it is an infinite repetition issue. It may be another bug entirely. |
When the model uses a chat template (i.e. requires a specific structure of the context) shifting the KV cache is not recommended because it will destroy the structure of the context (i.e. the system message will be shifted away, the roles information of the messages can become incomplete, etc.). This will likely lead to repetition and corrupted generation, so it is expected to behave like this. You can solve this on the client side by making a queue of the chat messages and discard the old messages from the context when you get close to filling the context. Using the |
@ggerganov does that mean I should disable cache shifting in that case? |
Yes, |
Name and Version
Operating systems
Linux
GGML backends
CPU
Hardware
Ryzen 5950X (configured with LLC as NUMA node)
Models
bartowski/microsoft_Phi-4-mini-instruct-GGUF:Q4_0
Problem description & steps to reproduce
When running
llama-server
with this exact command:$ LLMA_ARG_N_PREDICT=-1 LLAMA_ARG_HF_REPO=bartowski/microsoft_Phi-4-mini-instruct-GGUF:Q4_0 LLAMA_ARG_ALIAS=myalias LLAMA_ARG_N_GPU_LAYERS=0 LLAMA_ARG_BATCH=1984 LLAMA_ARG_THREADS=8 LLAMA_ARG_NUMA=distribute llama-server --cache-type-k q8_0 --cache-type-v q8_0 -c 4096 --parallel 2 -fa --metrics
and using the API or the WebUI to make the model generate large outputs on two slots at once, I get garbage output (which stays until server restart) as soon as KV-cache defragmentation occurs once.
In fact, disabling it makes the issue disappear, and setting a breakpoint into
llama_kv_cache_defrag_impl(lctx);
allows pinpointing with absolute certitude that it is indeed caused by KV-cache defragmentation.Setting
#if 0
to#if 1
causes a segfault in (memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
)Here is the stack trace:
Note: here are my compile options:
$ cmake -B build -S . -DCMAKE_INSTALL_PREFIX=out/install/default -DCMAKE_BUILD_TYPE=RelWithDebInfo -DGGML_NATIVE=OFF -DGGML_LTO=ON -DGGML_AARCH64_REPACK=ON -DLLAMA_BUILD_SERVER=ON -DCMAKE_CXX_FLAGS="-flto=auto -ggdb -fno-omit-frame-pointer" -DLLAMA_CURL=ON
First Bad Commit
No response
Relevant log output
The text was updated successfully, but these errors were encountered: