Skip to content

Commit 31a3a87

Browse files
authored
[Android] Move training part to its own package
Differential Revision: D79377344 Pull Request resolved: pytorch#13047
1 parent 5d3550f commit 31a3a87

File tree

5 files changed

+59
-30
lines changed

5 files changed

+59
-30
lines changed

extension/android/BUCK

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ non_fbcode_target(_kind = fb_android_library,
1313
"executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java",
1414
"executorch_android/src/main/java/org/pytorch/executorch/Module.java",
1515
"executorch_android/src/main/java/org/pytorch/executorch/Tensor.java",
16-
"executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java",
17-
"executorch_android/src/main/java/org/pytorch/executorch/SGD.java",
1816
"executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java",
17+
"executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java",
18+
"executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java",
1919
],
2020
autoglob = False,
2121
language = "JAVA",

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TrainingModuleE2ETest.kt renamed to extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/training/TrainingModuleE2ETest.kt

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,24 @@
55
* This source code is licensed under the BSD-style license found in the
66
* LICENSE file in the root directory of this source tree.
77
*/
8-
package org.pytorch.executorch
8+
9+
package org.pytorch.executorch.training
910

1011
import android.Manifest
1112
import android.util.Log
1213
import androidx.test.ext.junit.runners.AndroidJUnit4
1314
import androidx.test.rule.GrantPermissionRule
14-
import java.io.File
15-
import java.io.IOException
16-
import java.net.URISyntaxException
1715
import org.apache.commons.io.FileUtils
1816
import org.junit.Assert
1917
import org.junit.Rule
2018
import org.junit.Test
2119
import org.junit.runner.RunWith
22-
import org.pytorch.executorch.TestFileUtils.getTestFilePath
20+
import org.pytorch.executorch.EValue
21+
import org.pytorch.executorch.Tensor
22+
import org.pytorch.executorch.TestFileUtils
23+
import java.io.File
24+
import java.io.IOException
25+
import java.net.URISyntaxException
2326
import kotlin.random.Random
2427
import kotlin.test.assertContains
2528

@@ -36,17 +39,20 @@ class TrainingModuleE2ETest {
3639
val pteFilePath = "/xor.pte"
3740
val ptdFilePath = "/xor.ptd"
3841

39-
val pteFile = File(getTestFilePath(pteFilePath))
42+
val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath))
4043
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
4144
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
4245
pteInputStream.close()
4346

44-
val ptdFile = File(getTestFilePath(ptdFilePath))
47+
val ptdFile = File(TestFileUtils.getTestFilePath(ptdFilePath))
4548
val ptdInputStream = javaClass.getResourceAsStream(ptdFilePath)
4649
FileUtils.copyInputStreamToFile(ptdInputStream, ptdFile)
4750
ptdInputStream.close()
4851

49-
val module = TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(ptdFilePath))
52+
val module = TrainingModule.load(
53+
TestFileUtils.getTestFilePath(pteFilePath),
54+
TestFileUtils.getTestFilePath(ptdFilePath)
55+
)
5056
val params = module.namedParameters("forward")
5157

5258
Assert.assertEquals(4, params.size)
@@ -75,7 +81,10 @@ class TrainingModuleE2ETest {
7581
val targetDex = inputDex + 1
7682
val input = dataset.get(inputDex)
7783
val target = dataset.get(targetDex)
78-
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
84+
val out = module.executeForwardBackward("forward",
85+
EValue.from(input),
86+
EValue.from(target)
87+
)
7988
val gradients = module.namedGradients("forward")
8089

8190
if (i == 0) {
@@ -96,7 +105,9 @@ class TrainingModuleE2ETest {
96105
input.getDataAsFloatArray()[0],
97106
input.getDataAsFloatArray()[1],
98107
out[1].toTensor().getDataAsLongArray()[0],
99-
target.getDataAsLongArray()[0]));
108+
target.getDataAsLongArray()[0]
109+
)
110+
);
100111
}
101112

102113
sgd.step(gradients)
@@ -113,12 +124,12 @@ class TrainingModuleE2ETest {
113124
fun testTrainXOR_PTEOnly() {
114125
val pteFilePath = "/xor_full.pte"
115126

116-
val pteFile = File(getTestFilePath(pteFilePath))
127+
val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath))
117128
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
118129
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
119130
pteInputStream.close()
120131

121-
val module = TrainingModule.load(getTestFilePath(pteFilePath));
132+
val module = TrainingModule.load(TestFileUtils.getTestFilePath(pteFilePath));
122133
val params = module.namedParameters("forward")
123134

124135
Assert.assertEquals(4, params.size)
@@ -147,7 +158,10 @@ class TrainingModuleE2ETest {
147158
val targetDex = inputDex + 1
148159
val input = dataset.get(inputDex)
149160
val target = dataset.get(targetDex)
150-
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
161+
val out = module.executeForwardBackward("forward",
162+
EValue.from(input),
163+
EValue.from(target)
164+
)
151165
val gradients = module.namedGradients("forward")
152166

153167
if (i == 0) {
@@ -168,7 +182,9 @@ class TrainingModuleE2ETest {
168182
input.getDataAsFloatArray()[0],
169183
input.getDataAsFloatArray()[1],
170184
out[1].toTensor().getDataAsLongArray()[0],
171-
target.getDataAsLongArray()[0]));
185+
target.getDataAsLongArray()[0]
186+
)
187+
);
172188
}
173189

174190
sgd.step(gradients)
@@ -184,24 +200,33 @@ class TrainingModuleE2ETest {
184200
@Throws(IOException::class)
185201
fun testMissingPteFile() {
186202
val exception = Assert.assertThrows(RuntimeException::class.java) {
187-
TrainingModule.load(getTestFilePath(MISSING_PTE_NAME))
203+
TrainingModule.load(TestFileUtils.getTestFilePath(MISSING_PTE_NAME))
188204
}
189-
Assert.assertEquals(exception.message, "Cannot load model path!! " + getTestFilePath(MISSING_PTE_NAME))
205+
Assert.assertEquals(
206+
exception.message,
207+
"Cannot load model path!! " + TestFileUtils.getTestFilePath(MISSING_PTE_NAME)
208+
)
190209
}
191210

192211
@Test
193212
@Throws(IOException::class)
194213
fun testMissingPtdFile() {
195214
val exception = Assert.assertThrows(RuntimeException::class.java) {
196215
val pteFilePath = "/xor.pte"
197-
val pteFile = File(getTestFilePath(pteFilePath))
216+
val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath))
198217
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
199218
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
200219
pteInputStream.close()
201220

202-
TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(MISSING_PTD_NAME))
221+
TrainingModule.load(
222+
TestFileUtils.getTestFilePath(pteFilePath),
223+
TestFileUtils.getTestFilePath(MISSING_PTD_NAME)
224+
)
203225
}
204-
Assert.assertEquals(exception.message, "Cannot load data path!! " + getTestFilePath(MISSING_PTD_NAME))
226+
Assert.assertEquals(
227+
exception.message,
228+
"Cannot load data path!! " + TestFileUtils.getTestFilePath(MISSING_PTD_NAME)
229+
)
205230
}
206231

207232
companion object {
@@ -212,4 +237,4 @@ class TrainingModuleE2ETest {
212237
private const val MISSING_PTE_NAME = "/missing.pte"
213238
private const val MISSING_PTD_NAME = "/missing.ptd"
214239
}
215-
}
240+
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/SGD.java renamed to extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
package org.pytorch.executorch;
9+
package org.pytorch.executorch.training;
1010

1111
import com.facebook.jni.HybridData;
1212
import com.facebook.jni.annotations.DoNotStrip;
1313
import com.facebook.soloader.nativeloader.NativeLoader;
1414
import com.facebook.soloader.nativeloader.SystemDelegate;
1515
import java.util.Map;
16+
import org.pytorch.executorch.Tensor;
1617
import org.pytorch.executorch.annotations.Experimental;
1718

1819
/**
@@ -62,7 +63,7 @@ private SGD(
6263
* @param dampening The dampening value
6364
* @param weightDecay The weight decay value
6465
* @param nesterov Whether to use Nesterov momentum
65-
* @return new {@link org.pytorch.executorch.SGD} object
66+
* @return new {@link SGD} object
6667
*/
6768
public static SGD create(
6869
Map<String, Tensor> namedParameters,
@@ -79,7 +80,7 @@ public static SGD create(
7980
*
8081
* @param namedParameters Map of parameter names to tensors to be optimized
8182
* @param learningRate The learning rate for the optimizer
82-
* @return new {@link org.pytorch.executorch.SGD} object
83+
* @return new {@link SGD} object
8384
*/
8485
public static SGD create(Map<String, Tensor> namedParameters, double learningRate) {
8586
return create(namedParameters, learningRate, 0.0, 0.0, 0.0, false);

extension/android/executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java renamed to extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
package org.pytorch.executorch;
9+
package org.pytorch.executorch.training;
1010

1111
import android.util.Log;
1212
import com.facebook.jni.HybridData;
@@ -16,6 +16,8 @@
1616
import java.io.File;
1717
import java.util.HashMap;
1818
import java.util.Map;
19+
import org.pytorch.executorch.EValue;
20+
import org.pytorch.executorch.Tensor;
1921
import org.pytorch.executorch.annotations.Experimental;
2022

2123
/**
@@ -48,7 +50,7 @@ private TrainingModule(String moduleAbsolutePath, String dataAbsolutePath) {
4850
*
4951
* @param modelPath path to file that contains the serialized ExecuTorch module.
5052
* @param dataPath path to file that contains the ExecuTorch module external weights.
51-
* @return new {@link org.pytorch.executorch.TrainingModule} object which owns the model module.
53+
* @return new {@link TrainingModule} object which owns the model module.
5254
*/
5355
public static TrainingModule load(final String modelPath, final String dataPath) {
5456
File modelFile = new File(modelPath);
@@ -67,7 +69,7 @@ public static TrainingModule load(final String modelPath, final String dataPath)
6769
*
6870
* @param modelPath path to file that contains the serialized ExecuTorch module. This PTE does not
6971
* rely on external weights.
70-
* @return new {@link org.pytorch.executorch.TrainingModule} object which owns the model module.
72+
* @return new {@link TrainingModule} object which owns the model module.
7173
*/
7274
public static TrainingModule load(final String modelPath) {
7375
File modelFile = new File(modelPath);

extension/android/jni/jni_layer_training.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class ExecuTorchTrainingJni
6767

6868
public:
6969
constexpr static auto kJavaDescriptor =
70-
"Lorg/pytorch/executorch/TrainingModule;";
70+
"Lorg/pytorch/executorch/training/TrainingModule;";
7171

7272
ExecuTorchTrainingJni(
7373
facebook::jni::alias_ref<jstring> modelPath,
@@ -226,7 +226,8 @@ class ExecuTorchTrainingJni
226226

227227
class SGDHybrid : public facebook::jni::HybridClass<SGDHybrid> {
228228
public:
229-
constexpr static const char* kJavaDescriptor = "Lorg/pytorch/executorch/SGD;";
229+
constexpr static const char* kJavaDescriptor =
230+
"Lorg/pytorch/executorch/training/SGD;";
230231

231232
static facebook::jni::local_ref<jhybriddata> initHybrid(
232233
facebook::jni::alias_ref<jclass>,

0 commit comments

Comments
 (0)