From 32412794665a451178dfca5c82974a51dbc7cc7e Mon Sep 17 00:00:00 2001 From: Adam Sotona Date: Mon, 21 Jul 2025 17:38:48 +0200 Subject: [PATCH] LlamaDemo implementation --- cr-examples/onnx/README.md | 93 ++-------- cr-examples/onnx/lib/.gitignore | 2 + cr-examples/onnx/pom.xml | 3 +- .../java/oracle/code/onnx/OnnxRuntime.java | 4 - .../onnx/genai/OnnxGenRuntimeSession.java | 32 ++++ .../java/oracle/code/onnx/llm/LlamaDemo.java | 46 +++++ .../java/oracle/code/onnx/llm/LlamaModel.java | 159 ++++++++++++++++++ .../resources/oracle/code/onnx/llm/.gitignore | 4 + .../oracle/code/onnx/llm/genai_config.json | 53 ++++++ 9 files changed, 308 insertions(+), 88 deletions(-) create mode 100644 cr-examples/onnx/lib/.gitignore create mode 100644 cr-examples/onnx/src/test/java/oracle/code/onnx/llm/LlamaDemo.java create mode 100644 cr-examples/onnx/src/test/java/oracle/code/onnx/llm/LlamaModel.java create mode 100644 cr-examples/onnx/src/test/resources/oracle/code/onnx/llm/.gitignore create mode 100644 cr-examples/onnx/src/test/resources/oracle/code/onnx/llm/genai_config.json diff --git a/cr-examples/onnx/README.md b/cr-examples/onnx/README.md index 838a974265d..c1f841806f0 100644 --- a/cr-examples/onnx/README.md +++ b/cr-examples/onnx/README.md @@ -1,93 +1,20 @@ -### MavenStyleProject using code reflection with a Java-based ONNX programming model. +## MavenStyleProject using code reflection with a Java-based ONNX programming model. -Running the demo: -``` -JAVA_HOME=;mvn process-test-classes exec:java -Dexec.classpathScope=test -Dexec.mainClass=oracle.code.onnx.MNISTDemo -``` - -### Onnx Generation API to create and run LLM Onnx models. +### ONNX Runtime running convolution neural network from Java source -Example of direct execution of existing Onnx LLM model: +Running the MNIST demo: ``` -// model-specific prompt format -static final String PROMPT_TEMPLATE = "<|...|>%s<|...|><|...|>"; - -public static void main(String... args) { - - // compatible `libonnxruntime` library must be present in the same folder as `libonnxruntime-genai` library - // native library extension (.dylib, .so or .dll) is platform specific - System.load("path/To/libonnxruntime-genai.dylib"); - - // model folder must contain the Onnx model file and all configuration and external data files - try (OnnxGenRuntimeSession session = new OnnxGenRuntimeSession(Path.of("path/To/Onnx/Model/Folder/")) { - // each LLM model has specific prompt format - session.prompt(PROMPT_TEMPLATE.formatted("Tell me a joke"), System.out::print); - } -} -``` - -Example of a custom LLM Onnx model generation from Java sources and execution: +mvn process-test-classes exec:exec -Dexec.executable=/bin/java -Dexec.mainClass=oracle.code.onnx.mnist.MNISTDemo ``` -// model-specific prompt format -static final String PROMPT_TEMPLATE = "<|...|>%s<|...|><|...|>"; - -public static void main(String... args) { - // compatible `libonnxruntime` library must be present in the same folder as `libonnxruntime-genai` library - // native library extension (.dylib or .so or .dll) is platform specific - System.load("path/To/libonnxruntime-genai.dylib"); +### ONNX GenAI running large language model from Java source. - // instance of a custom Onnx LLM model - MyCustomLLMModel myCustomModelInstance = ...; +Setup: + - Download [onnxruntime-genai](https://github.com/microsoft/onnxruntime-genai/releases) native library coresponding to your system/architecture, unzip and put it into `cr-examples/onnx/lib` folder. + - Download `model.onnx.data`, `tokenizer.json` and `tokenizer_config.json` data files from [Llama-3.2-1B-Instruct-ONNX](https://huggingface.co/onnx-community/Llama-3.2-1B-Instruct-ONNX/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4) and put them into `cr-examples/onnx/src/test/resources/oracle/code/onnx/llm` folder. - // target model folder must contain all configuration files - // `genai_config.json` must be configured following way: - // - model filename to match generated model file name (below) - // - model inputs to match main model method argument names - // - model outputs to match main model result record component names - Path targetModelFolder = ...; - - // Onnx model file and external data file are generated to the target model folder - // and the session is created from the generated model - try (OnnxGenRuntimeSession session = OnnxGenRuntimeSession.buildFromCodeReflection(myCustomModelInstance, "myMainModelMethod", targetModelFolder, "MyModelFileName.onnx", "MyDataFileName")) { - // each LLM model has specific prompt format - session.prompt(PROMPT_TEMPLATE.formatted("Tell me a joke"), System.out::print); - } -} +Running the Llama demo: ``` - -Example of a custom LLM Onnx model Java source: +mvn process-test-classes exec:exec -Dexec.executable=/bin/java -Dexec.mainClass=oracle.code.onnx.llm.LlamaDemo ``` -import oracle.code.onnx.Tensor; -import jdk.incubator.code.CodeReflection; -import static oracle.code.onnx.OnnxOperators.*; - -public final class MyCustomLLMModel { - - public final Tensor myModelWeights... - public final Tensor otherMyModelWeights... - - public MyCustomLLMModel(...) { - // initilize all weight tensors - // large tensors data can be memory-mapped - this.myModelWeights = ... - this.otherMyModelWeights = ... - ... - } - // custom record with main model method response - public record MyModelResponse(Tensor logits, Tensor presentKey0, Tensor presentValue0, ...) { - } - - @CodeReflection - public MyModelResponse myMainModelMethod(Tensor inputIds, Tensor attentionMask, Tensor pastKey0, Tensor pastValue0, ...) { - - // computation of the model using oracle.code.onnx.OnnxOperators.* method calls - ... - Tensor logits = MatMul(... - - // composition of the return record - return new MyModelResponse(logits, key0, value0, ...); - } -} -``` diff --git a/cr-examples/onnx/lib/.gitignore b/cr-examples/onnx/lib/.gitignore new file mode 100644 index 00000000000..f9b9d5ad59b --- /dev/null +++ b/cr-examples/onnx/lib/.gitignore @@ -0,0 +1,2 @@ +/libonnxruntime.* +/libonnxruntime-genai.* diff --git a/cr-examples/onnx/pom.xml b/cr-examples/onnx/pom.xml index 8e4e2de688f..f48eb945a41 100644 --- a/cr-examples/onnx/pom.xml +++ b/cr-examples/onnx/pom.xml @@ -96,7 +96,8 @@ questions. exec-maven-plugin 3.5.1 - --add-modules jdk.incubator.code ${exec.args} + test + --add-modules jdk.incubator.code -classpath %classpath ${exec.mainClass} diff --git a/cr-examples/onnx/src/main/java/oracle/code/onnx/OnnxRuntime.java b/cr-examples/onnx/src/main/java/oracle/code/onnx/OnnxRuntime.java index 9507323b715..b49338a06ad 100644 --- a/cr-examples/onnx/src/main/java/oracle/code/onnx/OnnxRuntime.java +++ b/cr-examples/onnx/src/main/java/oracle/code/onnx/OnnxRuntime.java @@ -31,11 +31,8 @@ import java.lang.invoke.MethodHandles; import java.lang.invoke.VarHandle; import java.lang.reflect.AccessFlag; -import java.lang.reflect.Constructor; import java.lang.reflect.Field; -import java.lang.reflect.InvocationTargetException; import java.lang.reflect.ParameterizedType; -import java.lang.reflect.RecordComponent; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardCopyOption; @@ -55,7 +52,6 @@ import oracle.code.onnx.compiler.OnnxTransformer; import oracle.code.onnx.foreign.OrtApi; import oracle.code.onnx.foreign.OrtApiBase; -import oracle.code.onnx.proto.OnnxModel; import static oracle.code.onnx.foreign.onnxruntime_c_api_h.*; diff --git a/cr-examples/onnx/src/main/java/oracle/code/onnx/genai/OnnxGenRuntimeSession.java b/cr-examples/onnx/src/main/java/oracle/code/onnx/genai/OnnxGenRuntimeSession.java index 485c36364b3..2f7f3737637 100644 --- a/cr-examples/onnx/src/main/java/oracle/code/onnx/genai/OnnxGenRuntimeSession.java +++ b/cr-examples/onnx/src/main/java/oracle/code/onnx/genai/OnnxGenRuntimeSession.java @@ -32,6 +32,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.List; +import java.util.Locale; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import java.util.stream.Stream; @@ -134,6 +135,37 @@ */ public class OnnxGenRuntimeSession implements AutoCloseable { + /** + * Loads {@code onnxruntime-genai} native library from the given folder. + * This method unpacks required {@code onnxruntime} library from dependencies, if missing. + * + * @param libRoot folder with the library + */ + public static void loadGenAILib(Path libRoot) { + Path runtime = libRoot.resolve(System.mapLibraryName("onnxruntime")); + if (!Files.isRegularFile(runtime)) { + // onnxruntime-genai requires onnxruntime in the same directory + String arch = System.getProperty("os.arch", "generic").toLowerCase(Locale.ENGLISH).startsWith("aarch64") ? "aarch64" : "x64"; + String os = System.getProperty("os.name", "generic").toLowerCase(Locale.ENGLISH); + String libResource; + if (os.contains("mac") || os.contains("darwin")) { + libResource = "/ai/onnxruntime/native/osx-" + arch + "/libonnxruntime.dylib"; + } else if (os.contains("win")) { + libResource = "/ai/onnxruntime/native/win-" + arch + "/libonnxruntime.dll"; + } else if (os.contains("nux")) { + libResource = "/ai/onnxruntime/native/linux-" + arch + "/libonnxruntime.so"; + } else { + throw new IllegalStateException("Unsupported os:" + os); + } + try (var libStream = OnnxRuntime.class.getResourceAsStream(libResource)) { + Files.copy(libStream, runtime); + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + } + System.load(libRoot.resolve(System.mapLibraryName("onnxruntime-genai")).toAbsolutePath().toString()); + } + /** * Builds Onnx model from the provided Java model instance and loads it into a constructs the Onnx Generate API session. * @param codeReflectionModelInstance Instance of a class representing Onnx LLM model. diff --git a/cr-examples/onnx/src/test/java/oracle/code/onnx/llm/LlamaDemo.java b/cr-examples/onnx/src/test/java/oracle/code/onnx/llm/LlamaDemo.java new file mode 100644 index 00000000000..56df5f22d74 --- /dev/null +++ b/cr-examples/onnx/src/test/java/oracle/code/onnx/llm/LlamaDemo.java @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package oracle.code.onnx.llm; + +import java.lang.foreign.Arena; +import java.nio.file.Path; +import oracle.code.onnx.genai.OnnxGenRuntimeSession; + +public class LlamaDemo { + + public static void main(String... args) throws Exception { + + OnnxGenRuntimeSession.loadGenAILib(Path.of("lib")); + + Path modelRoot = Path.of(LlamaDemo.class.getResource("LlamaDemo.class").toURI()).getParent(); + try (Arena arena = Arena.ofConfined()) { + var modelInstance = new LlamaModel(arena); + try (OnnxGenRuntimeSession session = OnnxGenRuntimeSession.buildFromCodeReflection(modelInstance, "forward", modelRoot, "model.onnx", "model.data")) { + session.prompt(""" + <|start_header_id|>user<|end_header_id|>Hello, tell me a joke.<|eot_id|> + <|start_header_id|>assistant<|end_header_id|> + """, System.out::print); + } + } + } +} diff --git a/cr-examples/onnx/src/test/java/oracle/code/onnx/llm/LlamaModel.java b/cr-examples/onnx/src/test/java/oracle/code/onnx/llm/LlamaModel.java new file mode 100644 index 00000000000..05c1e6ce27b --- /dev/null +++ b/cr-examples/onnx/src/test/java/oracle/code/onnx/llm/LlamaModel.java @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package oracle.code.onnx.llm; + +import java.io.IOException; +import java.lang.foreign.Arena; +import jdk.incubator.code.CodeReflection; +import oracle.code.onnx.Tensor; +import oracle.code.onnx.genai.TensorDataStream; + +import static java.util.Optional.*; +import static oracle.code.onnx.OnnxOperators.*; +import static oracle.code.onnx.Tensor.ElementType.*; +import oracle.code.onnx.ir.OnnxType; + +public final class LlamaModel { + + public static final int LAYERS = 16; + public static final long BITS = 4, + BLOCK_SIZE = 32, + NUM_KEY_VALUE_HEADS = 8, + ACCURACY_LEVEL = 4, + VOCAB_SIZE = 128256, + HEAD_SIZE = 64, + HIDEN_SIZE = 2048, + CONTEXT_SIZE = 131072, + INTERMEDIATE_SIZE = 8192, + ATTN_WEIGHTS_SIZE = 3072; + public static final float EPSILON = 1.0E-5f, + SCALE = 0.125f; + + public final Tensor flat1, scalar1; + public final Tensor tokensWeights, initWeight, cosCache, sinCache, headScales; + public final Tensor[] postAttentionWeights = new Tensor[LAYERS], + inputWeights = new Tensor[LAYERS], + attnQkvScales = new Tensor[LAYERS], + attnOScales = new Tensor[LAYERS], + mlpGateScales = new Tensor[LAYERS], + mlpUpScales = new Tensor[LAYERS], + mlpDownScales = new Tensor[LAYERS]; + public final Tensor[] attnQkvWeight = new Tensor[LAYERS], + attnOWeight = new Tensor[LAYERS], + mlpGateWeight = new Tensor[LAYERS], + mlpUpWeight = new Tensor[LAYERS], + mlpDownWeight = new Tensor[LAYERS]; + public final Tensor headWeight; + + public LlamaModel(Arena arena) throws IOException { + flat1 = Tensor.ofFlat(arena, 1l); + scalar1 = Tensor.ofScalar(arena, 1l); + var modelData = new TensorDataStream(arena, LlamaModel.class.getResource("model.onnx.data").getPath()); + tokensWeights = modelData.nextTensor(FLOAT, VOCAB_SIZE, HIDEN_SIZE); + initWeight = modelData.nextTensor(FLOAT, HIDEN_SIZE); + cosCache = modelData.nextTensor(FLOAT, CONTEXT_SIZE, HEAD_SIZE / 2); + sinCache = modelData.nextTensor(FLOAT, CONTEXT_SIZE, HEAD_SIZE / 2); + for (int i = 0; i < LAYERS; i++) { + postAttentionWeights[i] = modelData.nextTensor(FLOAT, HIDEN_SIZE); + inputWeights[i] = modelData.nextTensor(FLOAT, HIDEN_SIZE); + } + for (int i = 0; i < LAYERS; i++) { + attnQkvWeight[i] = modelData.nextTensor(UINT8, ATTN_WEIGHTS_SIZE, HEAD_SIZE, 16); + attnQkvScales[i] = modelData.nextTensor(FLOAT, ATTN_WEIGHTS_SIZE * HEAD_SIZE); + attnOWeight[i] = modelData.nextTensor(UINT8, HIDEN_SIZE, HEAD_SIZE, 16); + attnOScales[i] = modelData.nextTensor(FLOAT, HIDEN_SIZE * HEAD_SIZE); + mlpGateWeight[i] = modelData.nextTensor(UINT8, INTERMEDIATE_SIZE, HEAD_SIZE, 16); + mlpGateScales[i] = modelData.nextTensor(FLOAT, INTERMEDIATE_SIZE * HEAD_SIZE); + mlpUpWeight[i] = modelData.nextTensor(UINT8, INTERMEDIATE_SIZE, HEAD_SIZE, 16); + mlpUpScales[i] = modelData.nextTensor(FLOAT, INTERMEDIATE_SIZE * HEAD_SIZE); + mlpDownWeight[i] = modelData.nextTensor(UINT8, HIDEN_SIZE, 256, 16); + mlpDownScales[i] = modelData.nextTensor(FLOAT, INTERMEDIATE_SIZE * HEAD_SIZE); + } + headWeight = modelData.nextTensor(UINT8, VOCAB_SIZE, HEAD_SIZE, 16); + headScales = modelData.nextTensor(FLOAT, VOCAB_SIZE * HEAD_SIZE); + } + + public record ForwardResponse(Tensor logits, + Tensor[] presentKey, + Tensor[] presentValue) { + } + + @CodeReflection + public ForwardResponse forward(Tensor inputIds, Tensor attentionMask, Tensor[] pastKey, Tensor[] pastValue) { + + Tensor amSL = Cast(Sub(ReduceSum(attentionMask, of(flat1), empty(), empty()), flat1), empty(), OnnxType.INT32.id()); + Tensor amTSL = Cast(Gather(Shape(attentionMask, empty(), empty()), scalar1, of(0l)), empty(), OnnxType.INT32.id()); + Tensor skipBias = Gather(tokensWeights, inputIds, empty()); + Tensor input = LayerNormalization(skipBias, initWeight, empty(), of(EPSILON), of(1l), of(-1l)).Y(); + + Tensor[] presentKeys = new Tensor[LAYERS]; + Tensor[] presentValues = new Tensor[LAYERS]; + + for (int i = 0; i < LAYERS; i++) { + GroupQueryAttention attn = GroupQueryAttention( + MatMulNBits(input, + attnQkvWeight[i], + attnQkvScales[i], empty(), empty(), empty(), HIDEN_SIZE, ATTN_WEIGHTS_SIZE, of(ACCURACY_LEVEL), BITS, BLOCK_SIZE), + empty(), + empty(), + of(pastKey[i]), + of(pastValue[i]), + amSL, + amTSL, + of(cosCache), + of(sinCache), of(1l), NUM_KEY_VALUE_HEADS, empty(), BLOCK_SIZE, of(0l), of(SCALE)); + + SkipSimplifiedLayerNormalization postAttnLayernorm = SkipSimplifiedLayerNormalization( + skipBias, + MatMulNBits(attn.output(), + attnOWeight[i], + attnOScales[i], empty(), empty(), empty(), HIDEN_SIZE, HIDEN_SIZE, of(ACCURACY_LEVEL), BITS, BLOCK_SIZE), + postAttentionWeights[i], empty(), of(EPSILON)); + + Tensor mlpGateProj = MatMulNBits(postAttnLayernorm.output(), + mlpGateWeight[i], + mlpGateScales[i], empty(), empty(), empty(), HIDEN_SIZE, INTERMEDIATE_SIZE, of(ACCURACY_LEVEL), BITS, BLOCK_SIZE); + + SkipSimplifiedLayerNormalization norm = SkipSimplifiedLayerNormalization(postAttnLayernorm.input_skip_bias_sum(), + MatMulNBits(Mul(Mul(mlpGateProj, + Sigmoid(mlpGateProj)), + MatMulNBits(postAttnLayernorm.output(), + mlpUpWeight[i], + mlpUpScales[i], empty(), empty(), empty(), HIDEN_SIZE, INTERMEDIATE_SIZE, of(ACCURACY_LEVEL), BITS, BLOCK_SIZE)), + mlpDownWeight[i], + mlpDownScales[i], empty(), empty(), empty(), INTERMEDIATE_SIZE, HIDEN_SIZE, of(ACCURACY_LEVEL), BITS, BLOCK_SIZE), + inputWeights[i], empty(), of(EPSILON)); + + input = norm.output(); + skipBias = norm.input_skip_bias_sum(); + presentKeys[i] = attn.present_key(); + presentValues[i] = attn.present_value(); + } + + Tensor logits = MatMulNBits(input, + headWeight, + headScales, empty(), empty(), empty(), HIDEN_SIZE, VOCAB_SIZE, of(ACCURACY_LEVEL), BITS, BLOCK_SIZE); + + return new ForwardResponse(logits, presentKeys, presentValues); + } +} diff --git a/cr-examples/onnx/src/test/resources/oracle/code/onnx/llm/.gitignore b/cr-examples/onnx/src/test/resources/oracle/code/onnx/llm/.gitignore new file mode 100644 index 00000000000..ccfc090bed3 --- /dev/null +++ b/cr-examples/onnx/src/test/resources/oracle/code/onnx/llm/.gitignore @@ -0,0 +1,4 @@ +/model.onnx.data +/tokenizer_config.json +/tokenizer.json + diff --git a/cr-examples/onnx/src/test/resources/oracle/code/onnx/llm/genai_config.json b/cr-examples/onnx/src/test/resources/oracle/code/onnx/llm/genai_config.json new file mode 100644 index 00000000000..bab2c20c8cc --- /dev/null +++ b/cr-examples/onnx/src/test/resources/oracle/code/onnx/llm/genai_config.json @@ -0,0 +1,53 @@ +{ + "model": { + "bos_token_id": 128000, + "context_length": 131072, + "decoder": { + "session_options": { + "log_id": "onnxruntime-genai", + "provider_options": [] + }, + "filename": "model.onnx", + "head_size": 64, + "hidden_size": 2048, + "inputs": { + "input_ids": "inputIds", + "attention_mask": "attentionMask", + "past_key_names": "pastKey.%d", + "past_value_names": "pastValue.%d" + }, + "outputs": { + "logits": "logits", + "present_key_names": "presentKey.%d", + "present_value_names": "presentValue.%d" + }, + "num_attention_heads": 32, + "num_hidden_layers": 16, + "num_key_value_heads": 8 + }, + "eos_token_id": [ + 128001, + 128008, + 128009 + ], + "pad_token_id": 128001, + "type": "llama", + "vocab_size": 128256 + }, + "search": { + "diversity_penalty": 0.0, + "do_sample": true, + "early_stopping": true, + "length_penalty": 1.0, + "max_length": 131072, + "min_length": 0, + "no_repeat_ngram_size": 0, + "num_beams": 1, + "num_return_sequences": 1, + "past_present_share_buffer": true, + "repetition_penalty": 1.0, + "temperature": 0.6, + "top_k": 1, + "top_p": 0.9 + } +} \ No newline at end of file