Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BGV/CKKS: support scale management #1459

Open
wants to merge 40 commits into
base: main
Choose a base branch
from

Conversation

ZenithalHourlyRate
Copy link
Collaborator

@ZenithalHourlyRate ZenithalHourlyRate commented Feb 24, 2025

See #1169

I am afraid we can only do this for Lattigo backend, as Openfhe does not have explicit port for setting scale. Although the policy implemented is in line with Openfhe's implementation, and Openfhe does that automatically.

The detailed rational/impl of the scale management should be put in design doc within this PR.

There are a few changes to support scale

  • mgmt.level_reduce, mgmt.adjust_scale op to support corresponding operation
  • Modify secret-insert-mgmt-bgv to use these ops to handle cross level op, where adjust_scale is a place holder
  • --validate-noise will generate parameters aware of these management op
  • --populate-scale (better name wanted) to concretely fill the scale based on the parameter

TODO

Cc @AlexanderViand-Intel: A comment on #1295 (comment) is that, the two backends we have can safely Add(ct0, ct1) with ciphertexts of different scale as internally when they find scale mismatching they would just adjust scale themselves. So the mixed-degree option for optimize-relinearization can be on without affecting correctness, though the noise is different. The merging of this PR does not fix the scale mismatching problem possibly induced by optimize-relinearization for our current two backends, but it indeed pave the way for our own poly backend which must be scale aware.

Example

the input mlir

func.func @cross_level_add(%base: tensor<4xi16> {secret.secret}, %add: tensor<4xi16> {secret.secret}) -> tensor<4xi16> {
  // same level
  %base0 = arith.addi %base, %add : tensor<4xi16>
  // increase one level
  %mul1 = arith.muli %base0, %base0 : tensor<4xi16>
  // cross level add
  %base1 = arith.addi %mul1, %add : tensor<4xi16>
  // increase one level
  %mul2 = arith.muli %base1, %base1 : tensor<4xi16>
  // cross level add
  %base2 = arith.addi %mul2, %add : tensor<4xi16>
  // increase one level
  %mul3 = arith.muli %base2, %base2 : tensor<4xi16>
  // cross level add
  %base3 = arith.addi %mul3, %add : tensor<4xi16>
  return %base3 : tensor<4xi16>
}

After secret-insert-mgmt-bgv, we get

      %1 = arith.addi %input0, %input1 {mgmt.mgmt = #mgmt.mgmt<level = 3>} : tensor<4xi16>
      %2 = arith.muli %1, %1 {mgmt.mgmt = #mgmt.mgmt<level = 3, dimension = 3>} : tensor<4xi16>
      %3 = mgmt.relinearize %2 {mgmt.mgmt = #mgmt.mgmt<level = 3>} : tensor<4xi16>
      %4 = arith.addi %3, %input1 {mgmt.mgmt = #mgmt.mgmt<level = 3>} : tensor<4xi16>
      %5 = mgmt.modreduce %4 {mgmt.mgmt = #mgmt.mgmt<level = 2>} : tensor<4xi16>
      %6 = arith.muli %5, %5 {mgmt.mgmt = #mgmt.mgmt<level = 2, dimension = 3>} : tensor<4xi16>
      %7 = mgmt.relinearize %6 {mgmt.mgmt = #mgmt.mgmt<level = 2>} : tensor<4xi16>
      %8 = mgmt.adjust_scale %input1 {mgmt.mgmt = #mgmt.mgmt<level = 3>} : tensor<4xi16>
      %9 = mgmt.modreduce %8 {mgmt.mgmt = #mgmt.mgmt<level = 2>} : tensor<4xi16>
      %10 = arith.addi %7, %9 {mgmt.mgmt = #mgmt.mgmt<level = 2>} : tensor<4xi16>
      %11 = mgmt.modreduce %10 {mgmt.mgmt = #mgmt.mgmt<level = 1>} : tensor<4xi16>
      %12 = arith.muli %11, %11 {mgmt.mgmt = #mgmt.mgmt<level = 1, dimension = 3>} : tensor<4xi16>
      %13 = mgmt.relinearize %12 {mgmt.mgmt = #mgmt.mgmt<level = 1>} : tensor<4xi16>
      %14 = mgmt.level_reduce %input1 {mgmt.mgmt = #mgmt.mgmt<level = 2>} : tensor<4xi16>
      %15 = mgmt.adjust_scale %14 {mgmt.mgmt = #mgmt.mgmt<level = 2>} : tensor<4xi16>
      %16 = mgmt.modreduce %15 {mgmt.mgmt = #mgmt.mgmt<level = 1>} : tensor<4xi16>
      %17 = arith.addi %13, %16 {mgmt.mgmt = #mgmt.mgmt<level = 1>} : tensor<4xi16>
      %18 = mgmt.modreduce %17 {mgmt.mgmt = #mgmt.mgmt<level = 0>} : tensor<4xi16>

where adjust_scale has no concrete scale parameter

After --validate-noise and --populate-scale, we will get the per-level scale, and the value to fill for each adjust_scale

PopulateScale: scale = [57802, 46604, 21845, 1, ]
PopulateScale: scaleBig = [60481, 36636, 29128, 1, ]
PopulateScale: adjustScale = [1, 2528, 13431, 21845, ]
Propagate ScaleState(1) to <block argument> of type 'tensor<4xi16>' at index: 0
Propagate ScaleState(1) to <block argument> of type 'tensor<4xi16>' at index: 1
Propagate ScaleState(1) to %1 = arith.addi %input0, %input1 {mgmt.mgmt = #mgmt.mgmt<level = 3>} : tensor<4xi16>
Propagate ScaleState(1) to %2 = arith.muli %1, %1 {mgmt.mgmt = #mgmt.mgmt<level = 3, dimension = 3>} : tensor<4xi16>
Propagate ScaleState(1) to %3 = mgmt.relinearize %2 {mgmt.mgmt = #mgmt.mgmt<level = 3>} : tensor<4xi16>
Propagate ScaleState(1) to %4 = arith.addi %3, %input1 {mgmt.mgmt = #mgmt.mgmt<level = 3>} : tensor<4xi16>
Propagate ScaleState(21845) to %5 = mgmt.modreduce %4 {mgmt.mgmt = #mgmt.mgmt<level = 2>} : tensor<4xi16>
Propagate ScaleState(29128) to %6 = arith.muli %5, %5 {mgmt.mgmt = #mgmt.mgmt<level = 2, dimension = 3>} : tensor<4xi16>
Propagate ScaleState(29128) to %7 = mgmt.relinearize %6 {mgmt.mgmt = #mgmt.mgmt<level = 2>} : tensor<4xi16>
Propagate ScaleState(21845) to %8 = mgmt.adjust_scale %input1 {mgmt.mgmt = #mgmt.mgmt<level = 3>, scale = 21845 : i64} : tensor<4xi16>
Propagate ScaleState(29128) to %9 = mgmt.modreduce %8 {mgmt.mgmt = #mgmt.mgmt<level = 2>} : tensor<4xi16>
Propagate ScaleState(29128) to %10 = arith.addi %7, %9 {mgmt.mgmt = #mgmt.mgmt<level = 2>} : tensor<4xi16>
Propagate ScaleState(46604) to %11 = mgmt.modreduce %10 {mgmt.mgmt = #mgmt.mgmt<level = 1>} : tensor<4xi16>
Propagate ScaleState(36636) to %12 = arith.muli %11, %11 {mgmt.mgmt = #mgmt.mgmt<level = 1, dimension = 3>} : tensor<4xi16>
Propagate ScaleState(36636) to %13 = mgmt.relinearize %12 {mgmt.mgmt = #mgmt.mgmt<level = 1>} : tensor<4xi16>
Propagate ScaleState(1) to %14 = mgmt.level_reduce %input1 {mgmt.mgmt = #mgmt.mgmt<level = 2>} : tensor<4xi16>
Propagate ScaleState(13431) to %15 = mgmt.adjust_scale %14 {mgmt.mgmt = #mgmt.mgmt<level = 2>, scale = 13431 : i64} : tensor<4xi16>
Propagate ScaleState(36636) to %16 = mgmt.modreduce %15 {mgmt.mgmt = #mgmt.mgmt<level = 1>} : tensor<4xi16>
Propagate ScaleState(36636) to %17 = arith.addi %13, %16 {mgmt.mgmt = #mgmt.mgmt<level = 1>} : tensor<4xi16>
Propagate ScaleState(57802) to %18 = mgmt.modreduce %17 {mgmt.mgmt = #mgmt.mgmt<level = 0>} : tensor<4xi16>

Where the first three lines are purely calculated from bgv scheme parameter and the later is the analysis to validate whether the scale matches.

The initial scaling factor is chosen to be 1 for both include-first-mul={true,false}, as for include-first-mul=false, the scaling factor of the last level must be the same, so we have 1 * 1 = 1.

@ZenithalHourlyRate
Copy link
Collaborator Author

It has been quite messy supporting scale, as we have to change these things below

  • mgmt insertion policy for both BGV and CKKS
    • insert rescale after mult, before mult, and before mult with first mult
    • note that the original cross-level policy for CKKS is wrong
  • populate scale with regard to all three rescale insertion policy
  • LWE type, which if we add a new field quite a bunch of test file need to change
  • Two backends

My idea is to skip the LWE type support and use attribute to pass information temporarily, and skip openfhe as it does support that anyway.

The pipeline works now for Lattigo, see example

func.func @cross_level_add(%base: tensor<4xi16> {secret.secret}, %add: tensor<4xi16> {secret.secret}) -> tensor<4xi16> {
  // increase one level
  %mul1 = arith.muli %base, %add : tensor<4xi16>
  // cross level add
  %base1 = arith.addi %mul1, %add : tensor<4xi16>
  return %base1 : tensor<4xi16>
}

After properly managed and calculation of scale we get

      %1 = mgmt.modreduce %input0 {mgmt.mgmt = #mgmt.mgmt<level = 1, scale = 4>} : tensor<4xi16>
      %2 = mgmt.modreduce %input1 {mgmt.mgmt = #mgmt.mgmt<level = 1, scale = 4>} : tensor<4xi16>
      %3 = arith.muli %1, %2 {mgmt.mgmt = #mgmt.mgmt<level = 1, dimension = 3, scale = 16>} : tensor<4xi16>
      %4 = mgmt.relinearize %3 {mgmt.mgmt = #mgmt.mgmt<level = 1, scale = 16>} : tensor<4xi16>
      // need to adjust the scale by mul_const delta_scale
      %5 = mgmt.adjust_scale %input1 {delta_scale = 4 : i64, mgmt.mgmt = #mgmt.mgmt<level = 2, scale = 4>, scale = 4 : i64} : tensor<4xi16>
      %6 = mgmt.modreduce %5 {mgmt.mgmt = #mgmt.mgmt<level = 1, scale = 16>} : tensor<4xi16>
      %7 = arith.addi %4, %6 {mgmt.mgmt = #mgmt.mgmt<level = 1, scale = 16>} : tensor<4xi16>
      %8 = mgmt.modreduce %7 {mgmt.mgmt = #mgmt.mgmt<level = 0, scale = 65505>} : tensor<4xi16>

adjust_scale is materialized as the following

    %cst = arith.constant dense<1> : tensor<4xi16>
    %pt = lwe.rlwe_encode %cst {encoding = #full_crt_packing_encoding, lwe.scale = 4 : i64, ring = #ring_Z65537_i64_1_x4_} : tensor<4xi16> -> !pt
    %ct_5 = bgv.mul_plain %ct_0, %pt : (!ct_L2_, !pt) -> !ct_L2_

When emitted to lattigo with debug handler, we can observe the scale change exactly the same

Input
  Scale:  1
Input
  Scale:  1
lattigo.bgv.rescale_new
  Scale:  4
lattigo.bgv.rescale_new
  Scale:  4
lattigo.bgv.mul_new
  Scale:  16
lattigo.bgv.relinearize_new
  Scale:  16
// this is adjust_scale
lattigo.bgv.mul_new
  Scale:  4
lattigo.bgv.rescale_new
  Scale:  16
lattigo.bgv.add_new
  Scale:  16
lattigo.bgv.rescale_new
  Scale:  65505
Result [4 9 16 25]

@j2kun
Copy link
Collaborator

j2kun commented Mar 3, 2025

Talking about this in office hours. Some ideas:

  • Have one integer scaling factor attribute that specifies its own bitwidth, to support large bitwidth scaling factors (e.g., for CKKS). The the lowerings to a particular backend (e.g., C++/OpenFHE) would need to pick the appropriate type (perhaps long double) to represent that scaling factor in the target language.
  • Make the scaling factor attribute optional on the LWE type to avoid having to update the entire codebase, and raise errors or handle the default cause when the scaling factor is not present. We could also have some backlog work to go update the rest of the codebase so that there are scaling factors and we can later remove optionality.

@ZenithalHourlyRate ZenithalHourlyRate marked this pull request as ready for review March 5, 2025 17:21
@ZenithalHourlyRate ZenithalHourlyRate changed the title BGV: support scale management BGV/CKKS: support scale management Mar 5, 2025
@ZenithalHourlyRate
Copy link
Collaborator Author

Until now 99 files changed...would be insane if more changes are introduced. Ask for review now because many technical changes need discussion/decision. Doc/cleanup are not done yet.

Loop problem

The hard part of supporting scale management is that, making scale match everywhere. The current state of the PR will break the loop support.

The intrinsic problem with loop support is that, we need to make it FHE-aware enough. This is the same problem as LevelAnalysis in #1181, where we want to know some invariant kept by the loop. We used to think about keeping level/dimension the same, now we need to consider more to make the scale the same.

The following example shows the current matmul code can not live through the scale analysis

affine.for %result_tensor (assume 2^45 scale initially)
  %a, %const scale 2^45, 
  %0 = mul_plain %a, %const // result scale 2^90
  %1 = add %0, result // scale 2^90
  tensor.insert %1 into %result_tensor // scale mismatch!

This centainly need some insight into the loop.

We can not even deal with unrolled version because we need some kind of back-propagation:

%result = tensor.empty // how do we know it when we encounter it
tensor.insert %sth into %result // until now do we know

Current status

  • Changed secret-insert-mgmt with three different way of inserting mgmt ops, note that these different ways has different indication of initial scaling factor requirement
    • for before-mul-include-first-mul, we need to encode at double degree like 2^90.
  • In secret-insert-mgmt, cross-level op now will be adjusted with level_reduce + adjust_level + mod_reduce to the same level.
    • adjust_level at this time will be adjust_level { scale = -1 } as we do now know the scale now
  • Now annotate-mgmt also annotates the MgmtAttr for plaintext because we need to know the scale of the plaintext. As multiple arith.constant will be canonicalied away, mgmt.no_op is introduced as a placeholder for mgmt attr
  • generate-param will generate param as usual
  • populate-scale now knows the param and it can determine the scale of each ciphertext, then it will fill the adjust_scale with concrete scale by back-propagation style heuristic, then adjust_scale is lowered to mul_plain %ct, 1 where arith.constant 1 has the mgmt attr of scale = N. Note that this is a metadata change instead of message change.
  • secret-to-<scheme> takes the scale into LWE type
  • refactored LWE type related dialect on code structure, with verifier on scale
  • Make backend aware of scale (lattigo can set pt.Scale = NewScale(Pow(2, 90)), for openfhe we can not

Problem with backend

  • Openfhe does the automation itself, so the mul_plain %ct, 1 there does not have any metadata effect. In constrast it will introduce noise, and openfhe itself then does the adjustment itself, introducing more nosie. We might want to just turn off adjust_scale for openfhe backend, but then the scale matching problem emerges in LWE type system, where we might need to lwe.opaque_scale_cast indicating the backend is doing its own job.

  • For Lattigo BGV, our adjust to the scale is exact and Lattigo likes it. For CKKS it is not the case as there will be tiny scale mismatch (2^44.9999999 != 2^45) and it will automatically rescale somewhere, make more level assumption and will fail execution when it find no more level to consume.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants