Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 10 additions & 83 deletions cr-examples/onnx/README.md
Original file line number Diff line number Diff line change
@@ -1,93 +1,20 @@
### MavenStyleProject using code reflection with a Java-based ONNX programming model.
## MavenStyleProject using code reflection with a Java-based ONNX programming model.

Running the demo:
```
JAVA_HOME=<path to the Babylon JDK home>;mvn process-test-classes exec:java -Dexec.classpathScope=test -Dexec.mainClass=oracle.code.onnx.MNISTDemo
```

### Onnx Generation API to create and run LLM Onnx models.
### ONNX Runtime running convolution neural network from Java source

Example of direct execution of existing Onnx LLM model:
Running the MNIST demo:
```
// model-specific prompt format
static final String PROMPT_TEMPLATE = "<|...|>%s<|...|><|...|>";

public static void main(String... args) {

// compatible `libonnxruntime` library must be present in the same folder as `libonnxruntime-genai` library
// native library extension (.dylib, .so or .dll) is platform specific
System.load("path/To/libonnxruntime-genai.dylib");

// model folder must contain the Onnx model file and all configuration and external data files
try (OnnxGenRuntimeSession session = new OnnxGenRuntimeSession(Path.of("path/To/Onnx/Model/Folder/")) {
// each LLM model has specific prompt format
session.prompt(PROMPT_TEMPLATE.formatted("Tell me a joke"), System.out::print);
}
}
```

Example of a custom LLM Onnx model generation from Java sources and execution:
mvn process-test-classes exec:exec -Dexec.executable=<path to the Babylon JDK home>/bin/java -Dexec.mainClass=oracle.code.onnx.mnist.MNISTDemo
```
// model-specific prompt format
static final String PROMPT_TEMPLATE = "<|...|>%s<|...|><|...|>";

public static void main(String... args) {

// compatible `libonnxruntime` library must be present in the same folder as `libonnxruntime-genai` library
// native library extension (.dylib or .so or .dll) is platform specific
System.load("path/To/libonnxruntime-genai.dylib");
### ONNX GenAI running large language model from Java source.

// instance of a custom Onnx LLM model
MyCustomLLMModel myCustomModelInstance = ...;
Setup:
- Download [onnxruntime-genai](https://github.com/microsoft/onnxruntime-genai/releases) native library coresponding to your system/architecture, unzip and put it into `cr-examples/onnx/lib` folder.
- Download `model.onnx.data`, `tokenizer.json` and `tokenizer_config.json` data files from [Llama-3.2-1B-Instruct-ONNX](https://huggingface.co/onnx-community/Llama-3.2-1B-Instruct-ONNX/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4) and put them into `cr-examples/onnx/src/test/resources/oracle/code/onnx/llm` folder.

// target model folder must contain all configuration files
// `genai_config.json` must be configured following way:
// - model filename to match generated model file name (below)
// - model inputs to match main model method argument names
// - model outputs to match main model result record component names
Path targetModelFolder = ...;

// Onnx model file and external data file are generated to the target model folder
// and the session is created from the generated model
try (OnnxGenRuntimeSession session = OnnxGenRuntimeSession.buildFromCodeReflection(myCustomModelInstance, "myMainModelMethod", targetModelFolder, "MyModelFileName.onnx", "MyDataFileName")) {
// each LLM model has specific prompt format
session.prompt(PROMPT_TEMPLATE.formatted("Tell me a joke"), System.out::print);
}
}
Running the Llama demo:
```

Example of a custom LLM Onnx model Java source:
mvn process-test-classes exec:exec -Dexec.executable=<path to the Babylon JDK home>/bin/java -Dexec.mainClass=oracle.code.onnx.llm.LlamaDemo
```
import oracle.code.onnx.Tensor;
import jdk.incubator.code.CodeReflection;
import static oracle.code.onnx.OnnxOperators.*;

public final class MyCustomLLMModel {

public final Tensor<Float> myModelWeights...
public final Tensor<Byte> otherMyModelWeights...

public MyCustomLLMModel(...) {
// initilize all weight tensors
// large tensors data can be memory-mapped
this.myModelWeights = ...
this.otherMyModelWeights = ...
...
}

// custom record with main model method response
public record MyModelResponse(Tensor<Float> logits, Tensor<Float> presentKey0, Tensor<Float> presentValue0, ...) {
}

@CodeReflection
public MyModelResponse myMainModelMethod(Tensor<Long> inputIds, Tensor<Long> attentionMask, Tensor<Float> pastKey0, Tensor<Float> pastValue0, ...) {

// computation of the model using oracle.code.onnx.OnnxOperators.* method calls
...
Tensor<Float> logits = MatMul(...

// composition of the return record
return new MyModelResponse(logits, key0, value0, ...);
}
}
```
2 changes: 2 additions & 0 deletions cr-examples/onnx/lib/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
/libonnxruntime.*
/libonnxruntime-genai.*
3 changes: 2 additions & 1 deletion cr-examples/onnx/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ questions.
<artifactId>exec-maven-plugin</artifactId>
<version>3.5.1</version>
<configuration>
<commandlineArgs>--add-modules jdk.incubator.code ${exec.args}</commandlineArgs>
<classpathScope>test</classpathScope>
<commandlineArgs>--add-modules jdk.incubator.code -classpath %classpath ${exec.mainClass}</commandlineArgs>
</configuration>
</plugin>
</plugins>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,8 @@
import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
import java.lang.reflect.AccessFlag;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.RecordComponent;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
Expand All @@ -55,7 +52,6 @@
import oracle.code.onnx.compiler.OnnxTransformer;
import oracle.code.onnx.foreign.OrtApi;
import oracle.code.onnx.foreign.OrtApiBase;
import oracle.code.onnx.proto.OnnxModel;

import static oracle.code.onnx.foreign.onnxruntime_c_api_h.*;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.stream.Stream;
Expand Down Expand Up @@ -134,6 +135,37 @@
*/
public class OnnxGenRuntimeSession implements AutoCloseable {

/**
* Loads {@code onnxruntime-genai} native library from the given folder.
* This method unpacks required {@code onnxruntime} library from dependencies, if missing.
*
* @param libRoot folder with the library
*/
public static void loadGenAILib(Path libRoot) {
Path runtime = libRoot.resolve(System.mapLibraryName("onnxruntime"));
if (!Files.isRegularFile(runtime)) {
// onnxruntime-genai requires onnxruntime in the same directory
String arch = System.getProperty("os.arch", "generic").toLowerCase(Locale.ENGLISH).startsWith("aarch64") ? "aarch64" : "x64";
String os = System.getProperty("os.name", "generic").toLowerCase(Locale.ENGLISH);
String libResource;
if (os.contains("mac") || os.contains("darwin")) {
libResource = "/ai/onnxruntime/native/osx-" + arch + "/libonnxruntime.dylib";
} else if (os.contains("win")) {
libResource = "/ai/onnxruntime/native/win-" + arch + "/libonnxruntime.dll";
} else if (os.contains("nux")) {
libResource = "/ai/onnxruntime/native/linux-" + arch + "/libonnxruntime.so";
} else {
throw new IllegalStateException("Unsupported os:" + os);
}
try (var libStream = OnnxRuntime.class.getResourceAsStream(libResource)) {
Files.copy(libStream, runtime);
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
}
System.load(libRoot.resolve(System.mapLibraryName("onnxruntime-genai")).toAbsolutePath().toString());
}

/**
* Builds Onnx model from the provided Java model instance and loads it into a constructs the Onnx Generate API session.
* @param codeReflectionModelInstance Instance of a class representing Onnx LLM model.
Expand Down
46 changes: 46 additions & 0 deletions cr-examples/onnx/src/test/java/oracle/code/onnx/llm/LlamaDemo.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package oracle.code.onnx.llm;

import java.lang.foreign.Arena;
import java.nio.file.Path;
import oracle.code.onnx.genai.OnnxGenRuntimeSession;

public class LlamaDemo {

public static void main(String... args) throws Exception {

OnnxGenRuntimeSession.loadGenAILib(Path.of("lib"));

Path modelRoot = Path.of(LlamaDemo.class.getResource("LlamaDemo.class").toURI()).getParent();
try (Arena arena = Arena.ofConfined()) {
var modelInstance = new LlamaModel(arena);
try (OnnxGenRuntimeSession session = OnnxGenRuntimeSession.buildFromCodeReflection(modelInstance, "forward", modelRoot, "model.onnx", "model.data")) {
session.prompt("""
<|start_header_id|>user<|end_header_id|>Hello, tell me a joke.<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
""", System.out::print);
}
}
}
}
159 changes: 159 additions & 0 deletions cr-examples/onnx/src/test/java/oracle/code/onnx/llm/LlamaModel.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package oracle.code.onnx.llm;

import java.io.IOException;
import java.lang.foreign.Arena;
import jdk.incubator.code.CodeReflection;
import oracle.code.onnx.Tensor;
import oracle.code.onnx.genai.TensorDataStream;

import static java.util.Optional.*;
import static oracle.code.onnx.OnnxOperators.*;
import static oracle.code.onnx.Tensor.ElementType.*;
import oracle.code.onnx.ir.OnnxType;

public final class LlamaModel {

public static final int LAYERS = 16;
public static final long BITS = 4,
BLOCK_SIZE = 32,
NUM_KEY_VALUE_HEADS = 8,
ACCURACY_LEVEL = 4,
VOCAB_SIZE = 128256,
HEAD_SIZE = 64,
HIDEN_SIZE = 2048,
CONTEXT_SIZE = 131072,
INTERMEDIATE_SIZE = 8192,
ATTN_WEIGHTS_SIZE = 3072;
public static final float EPSILON = 1.0E-5f,
SCALE = 0.125f;

public final Tensor<Long> flat1, scalar1;
public final Tensor<Float> tokensWeights, initWeight, cosCache, sinCache, headScales;
public final Tensor<Float>[] postAttentionWeights = new Tensor[LAYERS],
inputWeights = new Tensor[LAYERS],
attnQkvScales = new Tensor[LAYERS],
attnOScales = new Tensor[LAYERS],
mlpGateScales = new Tensor[LAYERS],
mlpUpScales = new Tensor[LAYERS],
mlpDownScales = new Tensor[LAYERS];
public final Tensor<Byte>[] attnQkvWeight = new Tensor[LAYERS],
attnOWeight = new Tensor[LAYERS],
mlpGateWeight = new Tensor[LAYERS],
mlpUpWeight = new Tensor[LAYERS],
mlpDownWeight = new Tensor[LAYERS];
public final Tensor<Byte> headWeight;

public LlamaModel(Arena arena) throws IOException {
flat1 = Tensor.ofFlat(arena, 1l);
scalar1 = Tensor.ofScalar(arena, 1l);
var modelData = new TensorDataStream(arena, LlamaModel.class.getResource("model.onnx.data").getPath());
tokensWeights = modelData.nextTensor(FLOAT, VOCAB_SIZE, HIDEN_SIZE);
initWeight = modelData.nextTensor(FLOAT, HIDEN_SIZE);
cosCache = modelData.nextTensor(FLOAT, CONTEXT_SIZE, HEAD_SIZE / 2);
sinCache = modelData.nextTensor(FLOAT, CONTEXT_SIZE, HEAD_SIZE / 2);
for (int i = 0; i < LAYERS; i++) {
postAttentionWeights[i] = modelData.nextTensor(FLOAT, HIDEN_SIZE);
inputWeights[i] = modelData.nextTensor(FLOAT, HIDEN_SIZE);
}
for (int i = 0; i < LAYERS; i++) {
attnQkvWeight[i] = modelData.nextTensor(UINT8, ATTN_WEIGHTS_SIZE, HEAD_SIZE, 16);
attnQkvScales[i] = modelData.nextTensor(FLOAT, ATTN_WEIGHTS_SIZE * HEAD_SIZE);
attnOWeight[i] = modelData.nextTensor(UINT8, HIDEN_SIZE, HEAD_SIZE, 16);
attnOScales[i] = modelData.nextTensor(FLOAT, HIDEN_SIZE * HEAD_SIZE);
mlpGateWeight[i] = modelData.nextTensor(UINT8, INTERMEDIATE_SIZE, HEAD_SIZE, 16);
mlpGateScales[i] = modelData.nextTensor(FLOAT, INTERMEDIATE_SIZE * HEAD_SIZE);
mlpUpWeight[i] = modelData.nextTensor(UINT8, INTERMEDIATE_SIZE, HEAD_SIZE, 16);
mlpUpScales[i] = modelData.nextTensor(FLOAT, INTERMEDIATE_SIZE * HEAD_SIZE);
mlpDownWeight[i] = modelData.nextTensor(UINT8, HIDEN_SIZE, 256, 16);
mlpDownScales[i] = modelData.nextTensor(FLOAT, INTERMEDIATE_SIZE * HEAD_SIZE);
}
headWeight = modelData.nextTensor(UINT8, VOCAB_SIZE, HEAD_SIZE, 16);
headScales = modelData.nextTensor(FLOAT, VOCAB_SIZE * HEAD_SIZE);
}

public record ForwardResponse(Tensor<Float> logits,
Tensor<Float>[] presentKey,
Tensor<Float>[] presentValue) {
}

@CodeReflection
public ForwardResponse forward(Tensor<Long> inputIds, Tensor<Long> attentionMask, Tensor<Float>[] pastKey, Tensor<Float>[] pastValue) {

Tensor<Integer> amSL = Cast(Sub(ReduceSum(attentionMask, of(flat1), empty(), empty()), flat1), empty(), OnnxType.INT32.id());
Tensor<Integer> amTSL = Cast(Gather(Shape(attentionMask, empty(), empty()), scalar1, of(0l)), empty(), OnnxType.INT32.id());
Tensor<Float> skipBias = Gather(tokensWeights, inputIds, empty());
Tensor<Float> input = LayerNormalization(skipBias, initWeight, empty(), of(EPSILON), of(1l), of(-1l)).Y();

Tensor<Float>[] presentKeys = new Tensor[LAYERS];
Tensor<Float>[] presentValues = new Tensor[LAYERS];

for (int i = 0; i < LAYERS; i++) {
GroupQueryAttention<Float> attn = GroupQueryAttention(
MatMulNBits(input,
attnQkvWeight[i],
attnQkvScales[i], empty(), empty(), empty(), HIDEN_SIZE, ATTN_WEIGHTS_SIZE, of(ACCURACY_LEVEL), BITS, BLOCK_SIZE),
empty(),
empty(),
of(pastKey[i]),
of(pastValue[i]),
amSL,
amTSL,
of(cosCache),
of(sinCache), of(1l), NUM_KEY_VALUE_HEADS, empty(), BLOCK_SIZE, of(0l), of(SCALE));

SkipSimplifiedLayerNormalization<Float> postAttnLayernorm = SkipSimplifiedLayerNormalization(
skipBias,
MatMulNBits(attn.output(),
attnOWeight[i],
attnOScales[i], empty(), empty(), empty(), HIDEN_SIZE, HIDEN_SIZE, of(ACCURACY_LEVEL), BITS, BLOCK_SIZE),
postAttentionWeights[i], empty(), of(EPSILON));

Tensor<Float> mlpGateProj = MatMulNBits(postAttnLayernorm.output(),
mlpGateWeight[i],
mlpGateScales[i], empty(), empty(), empty(), HIDEN_SIZE, INTERMEDIATE_SIZE, of(ACCURACY_LEVEL), BITS, BLOCK_SIZE);

SkipSimplifiedLayerNormalization<Float> norm = SkipSimplifiedLayerNormalization(postAttnLayernorm.input_skip_bias_sum(),
MatMulNBits(Mul(Mul(mlpGateProj,
Sigmoid(mlpGateProj)),
MatMulNBits(postAttnLayernorm.output(),
mlpUpWeight[i],
mlpUpScales[i], empty(), empty(), empty(), HIDEN_SIZE, INTERMEDIATE_SIZE, of(ACCURACY_LEVEL), BITS, BLOCK_SIZE)),
mlpDownWeight[i],
mlpDownScales[i], empty(), empty(), empty(), INTERMEDIATE_SIZE, HIDEN_SIZE, of(ACCURACY_LEVEL), BITS, BLOCK_SIZE),
inputWeights[i], empty(), of(EPSILON));

input = norm.output();
skipBias = norm.input_skip_bias_sum();
presentKeys[i] = attn.present_key();
presentValues[i] = attn.present_value();
}

Tensor<Float> logits = MatMulNBits(input,
headWeight,
headScales, empty(), empty(), empty(), HIDEN_SIZE, VOCAB_SIZE, of(ACCURACY_LEVEL), BITS, BLOCK_SIZE);

return new ForwardResponse(logits, presentKeys, presentValues);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
/model.onnx.data
/tokenizer_config.json
/tokenizer.json

Loading