-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for different data types (float16, float32) #93
Closed
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -17,6 +17,10 @@ Then run with: | |||||||||||||||
#include <fcntl.h> | ||||||||||||||||
#include <sys/mman.h> | ||||||||||||||||
|
||||||||||||||||
#ifndef DTYPE | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's better to typedef dtype rather than spreading DTYPE macro all over.
Suggested change
|
||||||||||||||||
#define DTYPE float | ||||||||||||||||
#endif | ||||||||||||||||
|
||||||||||||||||
// ---------------------------------------------------------------------------- | ||||||||||||||||
// Transformer and RunState structs, and related memory management | ||||||||||||||||
|
||||||||||||||||
|
@@ -28,30 +32,31 @@ typedef struct { | |||||||||||||||
int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery) | ||||||||||||||||
int vocab_size; // vocabulary size, usually 256 (byte-level) | ||||||||||||||||
int seq_len; // max sequence length | ||||||||||||||||
int dtype; // 0: float16, 1:float32 | ||||||||||||||||
} Config; | ||||||||||||||||
|
||||||||||||||||
typedef struct { | ||||||||||||||||
// token embedding table | ||||||||||||||||
float* token_embedding_table; // (vocab_size, dim) | ||||||||||||||||
DTYPE* token_embedding_table; // (vocab_size, dim) | ||||||||||||||||
// weights for rmsnorms | ||||||||||||||||
float* rms_att_weight; // (layer, dim) rmsnorm weights | ||||||||||||||||
float* rms_ffn_weight; // (layer, dim) | ||||||||||||||||
DTYPE* rms_att_weight; // (layer, dim) rmsnorm weights | ||||||||||||||||
DTYPE* rms_ffn_weight; // (layer, dim) | ||||||||||||||||
// weights for matmuls | ||||||||||||||||
float* wq; // (layer, dim, dim) | ||||||||||||||||
float* wk; // (layer, dim, dim) | ||||||||||||||||
float* wv; // (layer, dim, dim) | ||||||||||||||||
float* wo; // (layer, dim, dim) | ||||||||||||||||
DTYPE* wq; // (layer, dim, dim) | ||||||||||||||||
DTYPE* wk; // (layer, dim, dim) | ||||||||||||||||
DTYPE* wv; // (layer, dim, dim) | ||||||||||||||||
DTYPE* wo; // (layer, dim, dim) | ||||||||||||||||
// weights for ffn | ||||||||||||||||
float* w1; // (layer, hidden_dim, dim) | ||||||||||||||||
float* w2; // (layer, dim, hidden_dim) | ||||||||||||||||
float* w3; // (layer, hidden_dim, dim) | ||||||||||||||||
DTYPE* w1; // (layer, hidden_dim, dim) | ||||||||||||||||
DTYPE* w2; // (layer, dim, hidden_dim) | ||||||||||||||||
DTYPE* w3; // (layer, hidden_dim, dim) | ||||||||||||||||
// final rmsnorm | ||||||||||||||||
float* rms_final_weight; // (dim,) | ||||||||||||||||
DTYPE* rms_final_weight; // (dim,) | ||||||||||||||||
// freq_cis for RoPE relatively positional embeddings | ||||||||||||||||
float* freq_cis_real; // (seq_len, dim/2) | ||||||||||||||||
float* freq_cis_imag; // (seq_len, dim/2) | ||||||||||||||||
DTYPE* freq_cis_real; // (seq_len, dim/2) | ||||||||||||||||
DTYPE* freq_cis_imag; // (seq_len, dim/2) | ||||||||||||||||
// (optional) classifier weights for the logits, on the last layer | ||||||||||||||||
float* wcls; | ||||||||||||||||
DTYPE* wcls; | ||||||||||||||||
} TransformerWeights; | ||||||||||||||||
|
||||||||||||||||
typedef struct { | ||||||||||||||||
|
@@ -86,8 +91,8 @@ void malloc_run_state(RunState* s, Config* p) { | |||||||||||||||
s->key_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float)); | ||||||||||||||||
s->value_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float)); | ||||||||||||||||
// ensure all mallocs went fine | ||||||||||||||||
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q | ||||||||||||||||
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache | ||||||||||||||||
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q | ||||||||||||||||
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache | ||||||||||||||||
|| !s->value_cache) { | ||||||||||||||||
printf("malloc failed!\n"); | ||||||||||||||||
exit(1); | ||||||||||||||||
|
@@ -112,8 +117,8 @@ void free_run_state(RunState* s) { | |||||||||||||||
// ---------------------------------------------------------------------------- | ||||||||||||||||
// initialization: read from checkpoint | ||||||||||||||||
|
||||||||||||||||
void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f, int shared_weights) { | ||||||||||||||||
float* ptr = f; | ||||||||||||||||
void checkpoint_init_weights(TransformerWeights *w, Config* p, DTYPE* f, int shared_weights) { | ||||||||||||||||
DTYPE* ptr = f; | ||||||||||||||||
w->token_embedding_table = ptr; | ||||||||||||||||
ptr += p->vocab_size * p->dim; | ||||||||||||||||
w->rms_att_weight = ptr; | ||||||||||||||||
|
@@ -153,7 +158,7 @@ void accum(float *a, float *b, int size) { | |||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
void rmsnorm(float* o, float* x, float* weight, int size) { | ||||||||||||||||
void rmsnorm(float* o, float* x, DTYPE* weight, int size) { | ||||||||||||||||
// calculate sum of squares | ||||||||||||||||
float ss = 0.0f; | ||||||||||||||||
for (int j = 0; j < size; j++) { | ||||||||||||||||
|
@@ -188,7 +193,7 @@ void softmax(float* x, int size) { | |||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
void matmul(float* xout, float* x, float* w, int n, int d) { | ||||||||||||||||
void matmul(float* xout, float* x, DTYPE* w, int n, int d) { | ||||||||||||||||
// W (d,n) @ x (n,) -> xout (d,) | ||||||||||||||||
#pragma omp parallel for | ||||||||||||||||
for (int i = 0; i < d; i++) { | ||||||||||||||||
|
@@ -201,24 +206,26 @@ void matmul(float* xout, float* x, float* w, int n, int d) { | |||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) { | ||||||||||||||||
|
||||||||||||||||
// a few convenience variables | ||||||||||||||||
float *x = s->x; | ||||||||||||||||
int dim = p->dim; | ||||||||||||||||
int hidden_dim = p->hidden_dim; | ||||||||||||||||
int head_size = dim / p->n_heads; | ||||||||||||||||
|
||||||||||||||||
// copy the token embedding into x | ||||||||||||||||
float* content_row = &(w->token_embedding_table[token * dim]); | ||||||||||||||||
memcpy(x, content_row, dim*sizeof(*x)); | ||||||||||||||||
DTYPE* content_row = &(w->token_embedding_table[token * dim]); | ||||||||||||||||
for (int i = 0; i < dim; i++) { | ||||||||||||||||
x[i] = content_row[i]; | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
// pluck out the "pos" row of freq_cis_real and freq_cis_imag | ||||||||||||||||
float* freq_cis_real_row = w->freq_cis_real + pos * head_size / 2; | ||||||||||||||||
float* freq_cis_imag_row = w->freq_cis_imag + pos * head_size / 2; | ||||||||||||||||
DTYPE* freq_cis_real_row = w->freq_cis_real + pos * head_size / 2; | ||||||||||||||||
DTYPE* freq_cis_imag_row = w->freq_cis_imag + pos * head_size / 2; | ||||||||||||||||
|
||||||||||||||||
// forward all the layers | ||||||||||||||||
for(int l = 0; l < p->n_layers; l++) { | ||||||||||||||||
|
||||||||||||||||
// attention rmsnorm | ||||||||||||||||
rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim); | ||||||||||||||||
|
||||||||||||||||
|
@@ -253,7 +260,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* | |||||||||||||||
float* value_cache_row = s->value_cache + loff + pos * dim; | ||||||||||||||||
memcpy(key_cache_row, s->k, dim*sizeof(*key_cache_row)); | ||||||||||||||||
memcpy(value_cache_row, s->v, dim*sizeof(*value_cache_row)); | ||||||||||||||||
|
||||||||||||||||
// multihead attention. iterate over all heads | ||||||||||||||||
#pragma omp parallel for | ||||||||||||||||
for (int h = 0; h < p->n_heads; h++) { | ||||||||||||||||
|
@@ -277,7 +284,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* | |||||||||||||||
|
||||||||||||||||
// softmax the scores to get attention weights, from 0..pos inclusively | ||||||||||||||||
softmax(att, pos + 1); | ||||||||||||||||
|
||||||||||||||||
// weighted sum of the values, store back into xb | ||||||||||||||||
for (int i = 0; i < head_size; i++) { | ||||||||||||||||
float val = 0.0f; | ||||||||||||||||
|
@@ -301,12 +308,12 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* | |||||||||||||||
// first calculate self.w1(x) and self.w3(x) | ||||||||||||||||
matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim); | ||||||||||||||||
matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim); | ||||||||||||||||
|
||||||||||||||||
// F.silu; silu(x)=x*σ(x),where σ(x) is the logistic sigmoid | ||||||||||||||||
for (int i = 0; i < hidden_dim; i++) { | ||||||||||||||||
s->hb[i] = s->hb[i] * (1.0f / (1.0f + expf(-s->hb[i]))); | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
// elementwise multiply with w3(x) | ||||||||||||||||
for (int i = 0; i < hidden_dim; i++) { | ||||||||||||||||
s->hb[i] = s->hb[i] * s->hb2[i]; | ||||||||||||||||
|
@@ -318,7 +325,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* | |||||||||||||||
// residual connection | ||||||||||||||||
accum(x, s->xb, dim); | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
// final rmsnorm | ||||||||||||||||
rmsnorm(x, x, w->rms_final_weight, dim); | ||||||||||||||||
|
||||||||||||||||
|
@@ -388,13 +395,13 @@ int main(int argc, char *argv[]) { | |||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
// seed rng with time. if you want deterministic behavior use temperature 0.0 | ||||||||||||||||
srand((unsigned int)time(NULL)); | ||||||||||||||||
srand((unsigned int)time(NULL)); | ||||||||||||||||
|
||||||||||||||||
// read in the model.bin file | ||||||||||||||||
Config config; | ||||||||||||||||
TransformerWeights weights; | ||||||||||||||||
int fd = 0; | ||||||||||||||||
float* data = NULL; | ||||||||||||||||
DTYPE* data = NULL; | ||||||||||||||||
long file_size; | ||||||||||||||||
{ | ||||||||||||||||
FILE *file = fopen(checkpoint, "rb"); | ||||||||||||||||
|
@@ -416,8 +423,27 @@ int main(int argc, char *argv[]) { | |||||||||||||||
if (fd == -1) { printf("open failed!\n"); return 1; } | ||||||||||||||||
data = mmap(NULL, file_size, PROT_READ, MAP_PRIVATE, fd, 0); | ||||||||||||||||
if (data == MAP_FAILED) { printf("mmap failed!\n"); return 1; } | ||||||||||||||||
float* weights_ptr = data + sizeof(Config)/sizeof(float); | ||||||||||||||||
checkpoint_init_weights(&weights, &config, weights_ptr, shared_weights); | ||||||||||||||||
enum dtype { | ||||||||||||||||
float16 = 0, | ||||||||||||||||
float32 = 1 | ||||||||||||||||
}; | ||||||||||||||||
switch (config.dtype) { | ||||||||||||||||
default: | ||||||||||||||||
printf("dtype not supported!\n"); | ||||||||||||||||
return 1; | ||||||||||||||||
|
||||||||||||||||
case float16: | ||||||||||||||||
if (sizeof(DTYPE) != sizeof(_Float16)) { printf("dtype doesn't match!\n"); return 1; } | ||||||||||||||||
DTYPE* weights_ptr_float16 = data + sizeof(Config)/sizeof(DTYPE); | ||||||||||||||||
checkpoint_init_weights(&weights, &config, weights_ptr_float16, shared_weights); | ||||||||||||||||
break; | ||||||||||||||||
|
||||||||||||||||
case float32: | ||||||||||||||||
if (sizeof(DTYPE) != sizeof(float)) { printf("dtype doesn't match!\n"); return 1; } | ||||||||||||||||
DTYPE* weights_ptr_float32 = data + sizeof(Config)/sizeof(DTYPE); | ||||||||||||||||
checkpoint_init_weights(&weights, &config, weights_ptr_float32, shared_weights); | ||||||||||||||||
break; | ||||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
// right now we cannot run for more than config.seq_len steps | ||||||||||||||||
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; } | ||||||||||||||||
|
@@ -444,7 +470,7 @@ int main(int argc, char *argv[]) { | |||||||||||||||
// create and init the application RunState | ||||||||||||||||
RunState state; | ||||||||||||||||
malloc_run_state(&state, &config); | ||||||||||||||||
|
||||||||||||||||
// the current position we are in | ||||||||||||||||
long start = time_in_ms(); | ||||||||||||||||
int next; | ||||||||||||||||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reaons -> reasons