Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nanorwkv tflite inference #2

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 134 additions & 40 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,44 @@
import torch.nn as nn
from torch.nn import functional as F


@dataclass
class GPTConfig:
block_size: int = 1024
vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: int = 12
n_head: int = 12
n_embd: int = 768
dropout: float = 0.0
bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster


class LayerState:
# the recurrent neural network (RNN) state for a layer of RWKV5.2
def __init__(self, cfg: GPTConfig, batch_size, dtype=torch.float32, device='cpu'):
B, C, H, K = batch_size, cfg.n_embd, cfg.n_head, cfg.n_embd // cfg.n_head
V = K

# a (B,C) size tensor representing latest time mixer token embedding processed
self.time_mixer_state = torch.zeros(B,C, dtype=dtype, device=device)
# an (B,H,K,V) size tensor representing a decaying token embedding memory for each head, where H=number_of_heads, K=key_dim_per_head, V=value_dim_per_head
self.kv_state = torch.zeros(B,H,K,V, dtype=dtype, device=device)
# a (B,C) size tensor representing latest channel mixer token embedding processed
self.channel_mixer_state = torch.zeros(B,C, dtype=dtype, device=device)


class LayerNorm(nn.Module):
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

def __init__(self, ndim, bias):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
self.eps = 1e-5

def forward(self, input):
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, self.eps)


class RWKV_TimeMix_x051a(nn.Module):

Expand Down Expand Up @@ -74,7 +102,7 @@ def __init__(self, config, layer_id):

self.dropout = nn.Dropout(config.dropout)

def forward(self, x):
def forward(self, x, warn=True):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
H, N = self.n_head, self.head_size
#
Expand All @@ -86,7 +114,8 @@ def forward(self, x):
elif T % 128 == 0: Q = 128
else:
Q = T
warnings.warn(f'\n{"#"*80}\n\n{" "*38}Note\nThe GPT-mode forward() should only be called when we are training models.\nNow we are using it for inference for simplicity, which works, but will be very inefficient.\n\n{"#"*80}\n')
if warn:
warnings.warn(f'\n{"#"*80}\n\n{" "*38}Note\nThe GPT-mode forward() should only be called when we are training models.\nNow we are using it for inference for simplicity, which works, but will be very inefficient.\n\n{"#"*80}\n')
assert T % Q == 0

xx = self.time_shift(x) - x
Expand Down Expand Up @@ -138,6 +167,43 @@ def forward(self, x):
y = self.dropout(self.output(y))
return y

def forward_step(self, x, state, kv_state):
B, H, N = x.size(0), self.n_head, self.head_size

xx = state - x
xk = x + xx * self.time_maa_k.squeeze(-2)
xv = x + xx * self.time_maa_v.squeeze(-2)
xr = x + xx * self.time_maa_r.squeeze(-2)
xg = x + xx * self.time_maa_g.squeeze(-2)
r = self.receptance(xr).view(B, H, 1, N)
k = self.key(xk).view(B, H, N, 1)
v = self.value(xv).view(B, H, 1, N)
g = F.silu(self.gate(xg)) # extra gate

w = torch.exp(-torch.exp(self.time_decay.float())).unsqueeze(-1) # time_decay
u = self.time_faaaa.float().unsqueeze(-1) # time_first

y, kv_state = self.single_time_step(r, k, v, u, w, kv_state)

y = y.contiguous().view(B, H * N)
y = self.ln_x(y) * g

# output projection
y = self.dropout(self.output(y))
return y, x, kv_state

@staticmethod
def single_time_step(r, k, v, u, w, kv_state):
y = kv_state # BHKV
y = y + (k @ v) * u # BHKV * HK1 + BHKV = BHKV
out = r @ y # BH1K @ BHKV = BH1V

kv_state = kv_state * w # BHKV
kv_state = kv_state + (k @ v) # BHKV * HK1 + BHKV = BHKV

return out.squeeze(-2), kv_state # BHV, BHKV


class RWKV_ChannelMix_x051a(nn.Module):

def __init__(self, config, layer_id):
Expand Down Expand Up @@ -169,6 +235,19 @@ def forward(self, x):
x = self.dropout(x)
return x

def forward_step(self, x, state):
xx = state - x
xk = x + xx * self.time_maa_k.squeeze(-2)
xr = x + xx * self.time_maa_r.squeeze(-2)

out = self.key(xk)
out = torch.relu(out) ** 2
out = self.value(out)
out = torch.sigmoid(self.receptance(xr)) * out
out = self.dropout(out)
return out, x


class Block(nn.Module):

def __init__(self, config, layer_id):
Expand All @@ -178,20 +257,20 @@ def __init__(self, config, layer_id):
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
self.cmix = RWKV_ChannelMix_x051a(config, layer_id)

def forward(self, x):
x = x + self.tmix(self.ln_1(x))
def forward(self, x, warn=True):
x = x + self.tmix(self.ln_1(x), warn=warn)
x = x + self.cmix(self.ln_2(x))
return x

@dataclass
class GPTConfig:
block_size: int = 1024
vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: int = 12
n_head: int = 12
n_embd: int = 768
dropout: float = 0.0
bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
def forward_step(self, x, s: LayerState):
out, s.time_mixer_state, s.kv_state = \
self.tmix.forward_step(self.ln_1(x), s.time_mixer_state, s.kv_state)
x = x + out
out, s.channel_mixer_state = \
self.cmix.forward_step(self.ln_2(x), s.channel_mixer_state)
x = x + out
return x, s


class GPT(nn.Module):

Expand Down Expand Up @@ -245,31 +324,36 @@ def _init_weights(self, module):
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

def forward(self, idx, targets=None):
def forward(self, idx, targets=None, warn=True):
device = idx.device
b, t = idx.size()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
time = idx.size(1)
assert time <= self.config.block_size, f"Cannot forward sequence of length {time}, block size is only {self.config.block_size}"
pos = torch.arange(0, time, dtype=torch.long, device=device) # shape (t)

# forward the GPT model itself
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
x = block(x)
x = block(x, warn)
x = self.transformer.ln_f(x)

logits, loss = self.lm_head(x), None
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.lm_head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
loss = None

return logits, loss

def forward_step(self, token, pos, state):
tok_emb = self.transformer.wte(token) # token embeddings of shape (b, n_embd)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (n_embd)
token = self.transformer.drop(tok_emb + pos_emb)
for layer_id, block in enumerate(self.transformer.h): # run each rwkv block
token, state[layer_id] = block.forward_step(token, state[layer_id])
token = self.transformer.ln_f(token)
logits = self.lm_head(token)
return logits, state

def crop_block_size(self, block_size):
# model surgery to decrease the block size if necessary
# e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
Expand Down Expand Up @@ -381,28 +465,38 @@ def estimate_mfu(self, fwdbwd_per_iter, dt):
return mfu

@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
"""
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at block_size
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
# forward the model to get the logits for the index in the sequence
logits, _ = self(idx_cond)
# pluck the logits at the final step and scale by desired temperature
logits = logits[:, -1, :] / temperature
dtype = self.lm_head.weight.dtype
device = self.lm_head.weight.device
states = [LayerState(self.config, 1, dtype, device) for _ in range(self.config.n_layer)]
pos = torch.zeros(1, dtype=torch.long, device=device)

for i in range(len(tokens)):
logits, states = self.forward_step(tokens[i].unsqueeze(0), pos, states)
pos += 1

top_k = min(top_k, logits.size(-1))
for _ in range(max_tokens - len(tokens)):
# pluck the logits at the final step
logits = logits[0]
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
logits, ids = torch.topk(logits, top_k)
# scale by desired temperature and apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits / temperature, dim=-1)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1)
next_token = torch.multinomial(probs, num_samples=1)
if top_k is not None:
next_token = ids[next_token]
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)
tokens = torch.cat((tokens, next_token), dim=0)
# forward the model to get the logits for the index in the sequence
logits, states = self.forward_step(next_token, pos, states)
pos += 1

return idx
return tokens
26 changes: 20 additions & 6 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from contextlib import nullcontext
import torch
import tiktoken
from model import GPTConfig, GPT
from model import GPTConfig, GPT, LayerState

# -----------------------------------------------------------------------------
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
Expand Down Expand Up @@ -73,18 +73,32 @@
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)

# test, that rnn mode is completely equals to transformer mode
def test_rnn_mode():
test_seq = torch.randint(0, gptconf.vocab_size, size=(1, gptconf.block_size))
gt, _ = model.forward(test_seq, warn=False)

states = [LayerState(gptconf, 1, ptdtype, device) for _ in range(gptconf.n_layer)]
for i, test_token in enumerate(test_seq[0]):
tokens = torch.tensor([test_token], dtype=torch.long, device=device)
pos = torch.tensor([i], dtype=torch.long, device=device)
logits, states = model.forward_step(tokens, pos, states)
assert torch.allclose(gt[:, i], logits, rtol=5e-2, atol=1e-5), i

# encode the beginning of the prompt
if start.startswith('FILE:'):
with open(start[5:], 'r', encoding='utf-8') as f:
start = f.read()
start_ids = encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
start_ids = torch.tensor(start_ids, dtype=torch.long, device=device)

# run generation
with torch.no_grad():
with ctx:
enable_test = True # temporal optional RNN mode validation
if enable_test:
test_rnn_mode()

for k in range(num_samples):
print('(note: this is using "GPT-mode" for inference (very slow), so we limit it to 100 characters. The much faster "RNN-mode" for inference is coming soon)')
y = model.generate(x, 100, temperature=temperature, top_k=top_k)
print(decode(y[0].tolist()))
print('---------------')
predicted = model.generate(start_ids, gptconf.block_size, temperature=temperature, top_k=top_k)
print(decode(predicted.tolist()))
44 changes: 44 additions & 0 deletions tflite_convertation/convertation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input

from tflite_convertation.keras_model import Keras_RWKV


class TfliteModel(tf.keras.Model):
def __init__(self, keras_model, states, logits_lambda=None):
inputs = [
Input(shape=tuple(), name='tokens', dtype=np.int32),
Input(shape=tuple(), name='position', dtype=np.int32),
*[
Input(shape=state.shape[1:], name=f'state_{i}')
for i, state in enumerate(states)
],
]
outputs = list(keras_model(inputs))
if logits_lambda is not None:
outputs[0] = logits_lambda(outputs[0])
super().__init__(inputs=inputs, outputs=outputs)


def convert_model(keras_model: Keras_RWKV, path_to_write=None,
logits_lambda=None, use_dynamic_quantization=False):
tokens = np.zeros(1, dtype=np.int32)
positions, keras_states = keras_model.get_states(batch_size=1)

tfliteModel = TfliteModel(keras_model, keras_states, logits_lambda)
tfliteModel([tokens, positions, *keras_states])

converter = tf.lite.TFLiteConverter.from_keras_model(tfliteModel)

converter.allow_custom_ops = False
converter.experimental_new_converter = True
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
if use_dynamic_quantization:
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converted_model = converter.convert()

if path_to_write is not None:
with open(path_to_write, 'wb') as file:
file.write(converted_model)
return converted_model
Loading