-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support model dtype selection and cuda fallback cast for bf16 (#33
) * Support cast from bf16 weights on cuda * Add automatic dtype selection * Avoid double encoding of metal command buffer * Add default value for cli * Fix test and python api * Update readme
- Loading branch information
1 parent
9da4ea6
commit ac4423f
Showing
17 changed files
with
416 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
name: Analysis | ||
on: | ||
pull_request_target | ||
|
||
jobs: | ||
comment: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- name: Checkout code | ||
uses: actions/checkout@v4 | ||
|
||
- name: Install Rust and Cargo | ||
run: | | ||
curl -sSf https://sh.rustup.rs | sh -s -- -y | ||
source $HOME/.cargo/env | ||
- name: Install Tokei | ||
run: cargo install tokei | ||
|
||
- name: Run Tokei and get the lines of code | ||
run: tokei . > tokei_output.txt | ||
|
||
- name: Comment or Update PR | ||
uses: actions/github-script@v7 | ||
with: | ||
script: | | ||
const fs = require('fs'); | ||
const tokeiOutput = fs.readFileSync('tokei_output.txt', 'utf8'); | ||
const uniqueIdentifier = 'Code Metrics Report'; | ||
const codeReport = ` | ||
<details> | ||
<summary>${uniqueIdentifier}</summary> | ||
<pre> | ||
${tokeiOutput} | ||
</pre> | ||
</details> | ||
`; | ||
const issue_number = context.issue.number; | ||
const { owner, repo } = context.repo; | ||
const comments = await github.rest.issues.listComments({ | ||
issue_number, | ||
owner, | ||
repo | ||
}); | ||
const existingComment = comments.data.find(comment => comment.body.includes(uniqueIdentifier)); | ||
if (existingComment) { | ||
await github.rest.issues.updateComment({ | ||
owner, | ||
repo, | ||
comment_id: existingComment.id, | ||
body: codeReport | ||
}); | ||
} else { | ||
await github.rest.issues.createComment({ | ||
issue_number, | ||
owner, | ||
repo, | ||
body: codeReport | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
#include<stdint.h> | ||
|
||
__device__ __forceinline__ float bf16_to_f32(const uint16_t i) | ||
{ | ||
// If NaN, keep current mantissa but also set most significant mantissa bit | ||
if ((i & 0x7FFFu) > 0x7F80u) { | ||
// NaN path | ||
uint32_t tmp = ((static_cast<uint32_t>(i) | 0x0040u) << 16); | ||
union { | ||
uint32_t as_int; | ||
float as_float; | ||
} u; | ||
u.as_int = tmp; | ||
return u.as_float; | ||
// Alternatively: | ||
// return __int_as_float(((static_cast<uint32_t>(i) | 0x0040u) << 16)); | ||
} else { | ||
// Normal path | ||
uint32_t tmp = (static_cast<uint32_t>(i) << 16); | ||
union { | ||
uint32_t as_int; | ||
float as_float; | ||
} u; | ||
u.as_int = tmp; | ||
return u.as_float; | ||
// Alternatively: | ||
// return __int_as_float(static_cast<uint32_t>(i) << 16); | ||
} | ||
} | ||
|
||
// Convert FP32 (float) to BF16 (unsigned short) | ||
__device__ __forceinline__ uint16_t f32_to_bf16(const float value) | ||
{ | ||
// Reinterpret float bits as uint32_t | ||
union { | ||
float as_float; | ||
uint32_t as_int; | ||
} u; | ||
u.as_float = value; | ||
uint32_t x = u.as_int; | ||
|
||
// Check for NaN | ||
if ((x & 0x7FFF'FFFFu) > 0x7F80'0000u) { | ||
// Keep high part of current mantissa but also set most significant mantissa bit | ||
return static_cast<uint16_t>((x >> 16) | 0x0040u); | ||
} | ||
|
||
// Round and shift | ||
constexpr uint32_t round_bit = 0x0000'8000u; // bit 15 | ||
if (((x & round_bit) != 0) && ((x & (3 * round_bit - 1)) != 0)) { | ||
// Round half to even (or to odd) depends on your preference | ||
return static_cast<uint16_t>((x >> 16) + 1); | ||
} else { | ||
return static_cast<uint16_t>(x >> 16); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.