diff --git a/starcode_kv_cache_injection/__init__.py b/starcode_kv_cache_injection/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/starcode_kv_cache_injection/kv_cache_injection.py b/starcode_kv_cache_injection/kv_cache_injection.py new file mode 100644 index 00000000000..907ed0a6e7f --- /dev/null +++ b/starcode_kv_cache_injection/kv_cache_injection.py @@ -0,0 +1,274 @@ +from transformers import AutoTokenizer, AutoConfig + +import onnx +import logging +import os +from typing import List, Optional +from onnx import TensorProto, ModelProto, helper, NodeProto +from sparseml.onnx.utils import ONNXGraph +from sparseml.exporters.transforms.kv_cache.cache_keys_and_values import reshape_kv_cache_inputs_outputs +from sparseml.exporters.transforms.kv_cache.transforms_codegen import AdditionalTransformsCodeGen +from sparseml.onnx.utils.helpers import get_nodes_by_output_id + +_LOGGER = logging.getLogger(__name__) + +class AdditionalTransformsBigCode(AdditionalTransformsCodeGen): + """ + Since the entries of the causal mask are similar in their values + and layout to the CodeGen causal mask, I inherit from the + AdditionalTransformsCodeGen class + """ + + # position ids are created by a Sub node (the one that is folllowed by a Where node + # in the onnx graph) + POSITION_IDS_MATCHING_PATTERN = dict(op_type="Sub", children_ops=[["Where"]]) + # causal mask is created by a Unsqueeze node (the one that is folllowed by a Where node + # in the onnx graph) + CAUSAL_MASK_MATCHING_PATTERN = dict(op_type="Unsqueeze", children_ops=[["Where", "Softmax"]]) + + def swap_nodes_for_input( + self, + model: ModelProto, + nodes: List[NodeProto], + input_name: str, + nodes_parent_op_type: Optional[str] = None, + ) -> ModelProto: + + """ + Injects the specified input to the graph, replacing the specified nodes. + + :param model: the ONNX model to inject the input into + :param nodes: the nodes to replace with the input + :param input_name: the name of the input to replace the nodes with + :param nodes_parent_op_type: the parent op type of the nodes to replace + + :return: the updated model + """ + + graph = ONNXGraph(model) + for node in nodes: + # edits so that we can have multiple children nodes + children_nodes = graph.get_node_children(node) + for child_node in children_nodes: + if nodes_parent_op_type: + assert child_node.op_type == nodes_parent_op_type, ( + f"Expected to find {nodes_parent_op_type} node, " + f"found {child_node.op_type}" + ) + output_to_replace = node.output[0] + self.log_match(node) + for idx, input_name_child_node in enumerate(child_node.input): + if input_name_child_node == output_to_replace: + graph.update_node_input(child_node, input_name, idx) + + graph.delete_orphaned_node_branches() + + _LOGGER.info( + f"Successfully swapped {len(nodes)} nodes for input '{input_name}'" + ) + + return model + + def add_constant_reshape_node(self, model: ModelProto) -> ModelProto: + """ + Adds positions as an input to the model. + + Positions is a tensor of shape and dtype + equal to input_ids. + + :param model: model to update + :return: updated model + """ + graph = ONNXGraph(model) + # create a constant node that will feed value (1, 256, 768) to the reshape node + constant_node = onnx.helper.make_node( + "Constant", + inputs=[], + name="abc", + outputs=["reshape_input"], + value=onnx.helper.make_tensor( + name="const_tensor", + data_type=TensorProto.INT64, + dims=[3], + vals=[1, 256, 768], + ), + ) + graph.add_node(constant_node) + reshape_node = get_nodes_by_output_id(model, "/transformer/Reshape_2_output_0")[0] + reshape_node.input[1] = "reshape_input" + _LOGGER.info(f"Inserted constant reshape node to the ONNX model") + return model + + def add_causal_mask_reshape_node(self, model: ModelProto) -> ModelProto: + """ + Adds positions as an input to the model. + + Positions is a tensor of shape and dtype + equal to input_ids. + + :param model: model to update + :return: updated model + """ + graph = ONNXGraph(model) + + transpose_node = onnx.helper.make_node( + op_type="Transpose", + inputs=["causal_mask"], + outputs=["causal_mask_transpose"], + name=f"causal_mask_transpose", + perm=(0,3,2,1), + ) + graph.add_node(transpose_node) + reshape_node = get_nodes_by_output_id(model, "causal_mask_adjusted")[0] + reshape_node.input[0] = "causal_mask_transpose" + _LOGGER.info(f"Inserted transpose to the causal mask in the ONNX model") + return model + + def transform(self, model: ModelProto) -> ModelProto: + """ + 1. Adds `positions` as an input to the model + 2. Adds `causal_mask` as an input to the model + 2. Finds the node that initially creates the `position_ids` tensor + 3. Updates the node to use the positions input instead of + computing it from the Range op + 4. Finds the nodes that initially create the `causal_mask` tensors + 5. Updates the nodes to use the causal_mask input instead of + computing it from the Slice op + + :param model: model to update + :return: updated model + """ + model = self.add_positions_input(model) + model = self.add_causal_mask_input(model) + model = self.add_constant_reshape_node(model) + + + position_ids_nodes = self.find_nodes_by_pattern( + model, pattern=self.POSITION_IDS_MATCHING_PATTERN + ) + if len(position_ids_nodes) != 1: + raise ValueError( + "Expected to find exactly one node matching " + f"the pattern {self.POSITION_IDS_MATCHING_PATTERN}, " + f"found {len(position_ids_nodes)}" + ) + + model = self.inject_positions(model, position_ids_nodes, "Where") + + causal_mask_nodes = self.find_nodes_by_pattern( + model, pattern=self.CAUSAL_MASK_MATCHING_PATTERN + ) + model = self.inject_causal_mask(model, causal_mask_nodes, "Where") + model = self.adjust_causal_mask(model) + model = self.add_causal_mask_reshape_node(model) + return model + +def inject_kv_cache_inputs_outputs(model: ModelProto, names_nodes: List[str], hidden_size_kv_cache, batch_size = 1, key: bool = True, output_num:int=0): + graph = ONNXGraph(model) + + inputs_to_add = [] + outputs_to_add = [] + num_attention_heads = 1 + attention_layer_idx = 0 + + for node in model.graph.node: + if node.name in names_nodes: + + # inject kv cache input/output + cache_name = "key" if key else "value" + cache_input_name_concat = f"past_key_values.{attention_layer_idx}.{cache_name}" + cache_output_name_concat = f"present.{attention_layer_idx}.{cache_name}" + + cache_input_info = onnx.helper.make_tensor_value_info( + cache_input_name_concat, + TensorProto.FLOAT, + [ + batch_size, + num_attention_heads, + "past_sequence_len", + hidden_size_kv_cache, + ] + ) + + cache_output_info = onnx.helper.make_tensor_value_info( + cache_output_name_concat, + TensorProto.FLOAT, + [ + batch_size, + num_attention_heads, + "past_sequence_len + 1", + hidden_size_kv_cache, + ] + ) + + model, cache_input_dims_concat, cache_input_name_concat, cache_output_name_concat = reshape_kv_cache_inputs_outputs( + model=model, + cache_input_name=cache_input_name_concat, + cache_output_name=cache_output_name_concat, + cache_input_dims= [ + batch_size, + num_attention_heads, + "past_sequence_len", + hidden_size_kv_cache, + ], + batch_size=batch_size, + num_attention_heads=1, + ) + cache_parent = node + concat_axis = 1 # concat over length axis + concat_node = onnx.helper.make_node( + op_type="Concat", + inputs=[cache_input_name_concat, cache_parent.output[output_num]], + outputs=[cache_output_name_concat], + axis=concat_axis, + name=f"concat.{cache_input_name_concat}", + ) + + for _node in model.graph.node: + for input_idx, input_id in enumerate(_node.input): + if input_id == cache_parent.output[output_num] and _node.name != concat_node.name: + _node.input[input_idx] = cache_output_name_concat + + graph.add_node(concat_node) + inputs_to_add.extend([cache_input_info]) + outputs_to_add.extend([cache_output_info]) + + attention_layer_idx += 1 + print(f"Injected kv cache input/output for {attention_layer_idx}:{cache_name}") + + model.graph.input.extend(inputs_to_add) + model.graph.output.extend(outputs_to_add) + return model + + +def main(deployment_folder_path, save_name_injected_model): + onnx_model = onnx.load(os.path.join(deployment_folder_path, "model.onnx"), load_external_data=False) + config = AutoConfig.from_pretrained(os.path.join(deployment_folder_path, "config.json")) + # KV Cache injection + onnx_model = inject_kv_cache_inputs_outputs(model = onnx_model, + names_nodes=[f"/transformer/h.{i}/attn/Split_1" for i in range(config.n_layer)], + hidden_size_kv_cache= config.n_embd // config.n_head, + key=True, + output_num=0) + onnx_model = inject_kv_cache_inputs_outputs(model = onnx_model, + names_nodes=[f"/transformer/h.{i}/attn/Split_1" for i in range(config.n_layer)], + hidden_size_kv_cache= config.n_embd // config.n_head, + key=False, + output_num=1) + # Adjustment of causal masks and positions + transformation = AdditionalTransformsBigCode() + onnx_model = transformation.transform(model = onnx_model) + # Save the model + _LOGGER.info(f"Saved injected model to {os.path.join(deployment_folder_path, save_name_injected_model)}") + onnx.save_model(onnx_model, os.path.join(deployment_folder_path, save_name_injected_model)) + + + +if __name__ == "__main__": + PATH_TO_DEPLOYMENT_FOLDER = "/Users/damian/Code/nm/sparseml/tiny_starcoder_py/deployment/" + # model created by running: + # sparseml.export /Users/damian/Code/nm/sparseml/tiny_starcoder_py/ --task text-generation --integration transformers --sequence_length 256 --trust_remote_code True + NAME_INJECTED_MODEL = "test.onnx" + main(PATH_TO_DEPLOYMENT_FOLDER, NAME_INJECTED_MODEL) + + diff --git a/starcode_kv_cache_injection/run_model.py b/starcode_kv_cache_injection/run_model.py new file mode 100644 index 00000000000..170fef0632a --- /dev/null +++ b/starcode_kv_cache_injection/run_model.py @@ -0,0 +1,10 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer + +checkpoint = "bigcode/tiny_starcoder_py" +device="cpu" +tokenizer = AutoTokenizer.from_pretrained(checkpoint) +model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device) + +inputs = tokenizer.encode("def print_hello_world():", return_tensors="pt").to(device) +outputs = model.generate(inputs, max_new_tokens=10) +print(tokenizer.decode(outputs[0])) \ No newline at end of file diff --git a/starcode_kv_cache_injection/validation.py b/starcode_kv_cache_injection/validation.py new file mode 100644 index 00000000000..89e734cabd8 --- /dev/null +++ b/starcode_kv_cache_injection/validation.py @@ -0,0 +1,222 @@ +import onnxruntime as ort +import numpy as np +import onnx +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig +from onnx.tools import update_model_dims +from sparseml.onnx.utils import ONNXGraph +import logging +import numpy +from typing import List, Union + +_LOGGER = logging.getLogger(__name__) + + +def create_causal_mask( + input_ids: Union[numpy.ndarray, List[int]], + attention_mask: Union[numpy.ndarray, List[int]], + dtype: numpy.dtype = numpy.int64, +) -> numpy.ndarray: + """ + Compute a causal mask from a set of module inputs. + In transformers, a causal mask is a boolean mask that is used to + prevent information from future positions in a sequence from + being used to predict the current position. Each element of the mask + is set to 1 if the corresponding position in the input sequence + is allowed to attend to positions up to and including that position, + and 0 otherwise. + + in case of single-token input, the causal mask is an array + of shape [1, 1, 1, sequence_length], + (essentially the reshaped attention_mask) + + in case of a multi-token input, the causal mask is an array + of shape [batch_size, 1, input_ids_length, sequence_length] + it is a concatenation of a: + - past (cache) causal mask + - and a causal mask (a lower triangular matrix of 1's and 0's) + e.g + ``` + input_ids = [[1,2,3,4]] + attention_mask = [[1,1,1,1,1,1]] + + causal_mask = [[[[ 1 1 | 1 0 0 0 ], + [ 1 1 | 1 1 0 0 ], + [ 1 1 | 1 1 1 0 ], + [ 1 1 | 1 1 1 1 ]]]] + ``` + or + ``` + input_ids = [[1,2,3,4]] + attention_mask = [[0,0,1,1,1,1,1]] + + causal_mask = [[[[ 0 0 1 1 | 1 0 0 0 ], + [ 0 0 1 1 | 1 1 0 0 ], + [ 0 0 1 1 | 1 1 1 0 ], + [ 0 0 1 1 | 1 1 1 1 ]]]] + ``` + + :param input_ids: input ids of the model input + :param attention_mask: attention mask of the model input + :param dtype: data type of the mask + :return: causal mask + """ + if isinstance(input_ids, numpy.ndarray): + batch_size, input_ids_length = input_ids.shape + + else: + batch_size, input_ids_length = 1, len(input_ids) + + if isinstance(attention_mask, numpy.ndarray): + sequence_length = attention_mask.shape[1] + else: + sequence_length = len(attention_mask) + attention_mask = numpy.array(attention_mask)[None, ...] + + if input_ids_length == 1: + causal_mask = numpy.reshape(attention_mask, (batch_size, 1, 1, sequence_length)) + return causal_mask.astype(dtype) + + causal_mask = numpy.tril( + numpy.ones((batch_size, 1, input_ids_length, input_ids_length), dtype=dtype), 0 + ) + past_causal_mask = numpy.ones( + (batch_size, 1, input_ids_length, sequence_length - input_ids_length), + dtype=dtype, + ) + causal_mask = numpy.concatenate((past_causal_mask, causal_mask), axis=-1) + + num_zeros = numpy.count_nonzero(attention_mask == 0) + + # changes to the original function + causal_mask[:, :, num_zeros:, :] = 0 + + return causal_mask + +def apply_input_shapes(model, onnx_model_path, sequence_length, config): + kv_cache_hidden_dim = config.n_embd // config.n_head + cache_changes_in = {n.name: [1, 1,"dynamic_len_1", kv_cache_hidden_dim] for n in model.graph.input if n.name.startswith("past_key_values")} + cache_changes_out = {n.name: [1, 1,"dynamic_len_2", kv_cache_hidden_dim] for n in model.graph.output if n.name.startswith("present")} + graph = ONNXGraph(model) + + graph.delete_unused_initializers() + graph.delete_orphaned_node_branches() + graph.sort_nodes_topologically() + + model = update_model_dims.update_inputs_outputs_dims(model, + {"input_ids": [1, "dynamic_len_3"], + "positions": [1, "dynamic_len_4"], + "attention_mask": [1, sequence_length], + "causal_mask": [1, 1, "dynamic_len_5", "dynamic_len_6"], + **cache_changes_in}, + {"logits": [1, "dynamic_len_6", config.vocab_size], **cache_changes_out}) + + onnx.save(model, onnx_model_path) + return model + + +def multitoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequence_length, logits_gt, kv_cache_gt): + # feed the whole sequence to the model so that we can initially validate + # the correctness of the kv cache injected model + kv_cache_hidden_dim = config.n_embd // config.n_head + inputs = tokenizer(prompt, return_tensors="np", padding='max_length', max_length=sequence_length) + input_ids = inputs.input_ids # (1, sequence_length) + attention_mask = inputs.attention_mask # (1, sequence_length) + kv_cache_value = {f"past_key_values.{i}.value": np.zeros((1, 1, 0, kv_cache_hidden_dim), dtype=np.float32) for i in + range(config.n_layer)} # (1, 0, embedding) + kv_cache_keys = {f"past_key_values.{i}.key": np.zeros((1, 1, 0, kv_cache_hidden_dim), dtype=np.float32) for i in + range(config.n_layer)} # (1, 0, embedding) + kv_cache = {**kv_cache_keys, **kv_cache_value} + causal_mask = create_causal_mask(input_ids, attention_mask) # (1, sequence_length, 1, sequence_length) + positions = attention_mask.cumsum(-1) - 1 # (1, sequence_length) + + session = ort.InferenceSession(onnx_model_path) + + out = session.run( + None, + { + "input_ids": input_ids, + "attention_mask": attention_mask, + **kv_cache, + "causal_mask": causal_mask, + "positions": positions, + }, + ) + logits, *kv_cache = out + + num_tokens_processed = logits_gt.shape[1] # only test the relevant, non-padded tokens + assert np.allclose(logits[:, :num_tokens_processed, :], logits_gt, atol=1e-3) + assert all(np.allclose(x[:, :num_tokens_processed, :], y, atol=1e-3) for x, y in zip(kv_cache, kv_cache_gt)) + +def singletoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequence_length, logits_gt, kv_cache_gt): + # feed the model one token at a time to validate the correctness of the kv cache injected model + model = onnx.load(onnx_model_path, load_external_data=True) + apply_input_shapes(model, onnx_model_path, sequence_length, config) + + kv_cache_hidden_dim = config.n_embd // config.n_head + inputs = tokenizer(prompt, return_tensors="np") + attention_mask = np.zeros((1, sequence_length), dtype=np.int64) + kv_cache_keys = {f"past_key_values.{i}.key": np.zeros((1,1,sequence_length-1, kv_cache_hidden_dim), dtype=np.float32) for i in range(config.n_layer)} + kv_cache_values = {f"past_key_values.{i}.value": np.zeros((1,1,sequence_length-1, kv_cache_hidden_dim), dtype=np.float32) for i in range(config.n_layer)} + kv_cache = {**kv_cache_keys, **kv_cache_values} + session = ort.InferenceSession(onnx_model_path) + + for idx, token in enumerate(inputs.input_ids[0]): + if token == tokenizer.pad_token_id: + break + attention_mask[:, -(idx + 1):] = 1 + positions = np.array([[idx]]) + input_ids = np.array([[token]]) + causal_mask = create_causal_mask(input_ids, attention_mask) + print(causal_mask.shape) + print(input_ids.shape) + print(attention_mask.shape) + print(positions) + print(kv_cache["past_key_values.0.key"].shape) + outputs = session.run(None, { + "input_ids": input_ids, + "attention_mask": attention_mask, + "positions": positions, + "causal_mask": causal_mask, + **kv_cache + }) + #logits, *kv_cache = outputs + #for _idx, (cache_gt, cache) in enumerate(zip(kv_cache_gt, kv_cache)): + # if np.allclose(cache_gt[:,idx,:], cache[:,-(idx + 1)],atol=1e-3): + # print(f"Cache {_idx} matches for iteration {idx}") + # will not run without throwing an error, there are some missing pieces that need to be addressed + +def get_baseline(prompt, hf_model_name, tokenizer): + model = AutoModelForCausalLM.from_pretrained(hf_model_name) + tokens = tokenizer.encode(prompt, return_tensors="pt") + model.generate(tokens[:,:1], max_length=256) + out = model(tokens, return_dict=True) + logits_gt = out.logits.detach().numpy() + kv_cache_gt = [t.detach().numpy() for t in out.past_key_values] + return logits_gt, kv_cache_gt + +def main(prompt, hf_model_name, onnx_model_path, sequence_length): + config = AutoConfig.from_pretrained(hf_model_name) + tokenizer = AutoTokenizer.from_pretrained(hf_model_name) + tokenizer.pad_token = tokenizer.eos_token + + logits_gt, kv_cache_gt = get_baseline(prompt, hf_model_name, tokenizer) + + #multitoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequence_length, logits_gt, kv_cache_gt) + # _LOGGER.info("Successfully ran multi-token inference on the kv cache injected model") + singletoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequence_length, logits_gt, kv_cache_gt) + _LOGGER.info("Successfully ran single-token inference on the kv cache injected model") + + + +if __name__ == "__main__": + PROMPT = "def eight_queens():\n if True:\n return 1\n " + HF_MODEL_NAME = "bigcode/tiny_starcoder_py" + ONNX_MODEL_PATH = "/Users/damian/Code/nm/sparseml/tiny_starcoder_py/deployment/test.onnx" + SEQUENCE_LENGTH = 256 + main(PROMPT, HF_MODEL_NAME, ONNX_MODEL_PATH, SEQUENCE_LENGTH) + + + + + +