5
5
* This source code is licensed under the BSD-style license found in the
6
6
* LICENSE file in the root directory of this source tree.
7
7
*/
8
- package org.pytorch.executorch
8
+
9
+ package org.pytorch.executorch.training
9
10
10
11
import android.Manifest
11
12
import android.util.Log
12
13
import androidx.test.ext.junit.runners.AndroidJUnit4
13
14
import androidx.test.rule.GrantPermissionRule
14
- import java.io.File
15
- import java.io.IOException
16
- import java.net.URISyntaxException
17
15
import org.apache.commons.io.FileUtils
18
16
import org.junit.Assert
19
17
import org.junit.Rule
20
18
import org.junit.Test
21
19
import org.junit.runner.RunWith
22
- import org.pytorch.executorch.TestFileUtils.getTestFilePath
20
+ import org.pytorch.executorch.EValue
21
+ import org.pytorch.executorch.Tensor
22
+ import org.pytorch.executorch.TestFileUtils
23
+ import java.io.File
24
+ import java.io.IOException
25
+ import java.net.URISyntaxException
23
26
import kotlin.random.Random
24
27
import kotlin.test.assertContains
25
28
@@ -36,17 +39,20 @@ class TrainingModuleE2ETest {
36
39
val pteFilePath = " /xor.pte"
37
40
val ptdFilePath = " /xor.ptd"
38
41
39
- val pteFile = File (getTestFilePath(pteFilePath))
42
+ val pteFile = File (TestFileUtils . getTestFilePath(pteFilePath))
40
43
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
41
44
FileUtils .copyInputStreamToFile(pteInputStream, pteFile)
42
45
pteInputStream.close()
43
46
44
- val ptdFile = File (getTestFilePath(ptdFilePath))
47
+ val ptdFile = File (TestFileUtils . getTestFilePath(ptdFilePath))
45
48
val ptdInputStream = javaClass.getResourceAsStream(ptdFilePath)
46
49
FileUtils .copyInputStreamToFile(ptdInputStream, ptdFile)
47
50
ptdInputStream.close()
48
51
49
- val module = TrainingModule .load(getTestFilePath(pteFilePath), getTestFilePath(ptdFilePath))
52
+ val module = TrainingModule .load(
53
+ TestFileUtils .getTestFilePath(pteFilePath),
54
+ TestFileUtils .getTestFilePath(ptdFilePath)
55
+ )
50
56
val params = module.namedParameters(" forward" )
51
57
52
58
Assert .assertEquals(4 , params.size)
@@ -75,7 +81,10 @@ class TrainingModuleE2ETest {
75
81
val targetDex = inputDex + 1
76
82
val input = dataset.get(inputDex)
77
83
val target = dataset.get(targetDex)
78
- val out = module.executeForwardBackward(" forward" , EValue .from(input), EValue .from(target))
84
+ val out = module.executeForwardBackward(" forward" ,
85
+ EValue .from(input),
86
+ EValue .from(target)
87
+ )
79
88
val gradients = module.namedGradients(" forward" )
80
89
81
90
if (i == 0 ) {
@@ -96,7 +105,9 @@ class TrainingModuleE2ETest {
96
105
input.getDataAsFloatArray()[0 ],
97
106
input.getDataAsFloatArray()[1 ],
98
107
out [1 ].toTensor().getDataAsLongArray()[0 ],
99
- target.getDataAsLongArray()[0 ]));
108
+ target.getDataAsLongArray()[0 ]
109
+ )
110
+ );
100
111
}
101
112
102
113
sgd.step(gradients)
@@ -113,12 +124,12 @@ class TrainingModuleE2ETest {
113
124
fun testTrainXOR_PTEOnly () {
114
125
val pteFilePath = " /xor_full.pte"
115
126
116
- val pteFile = File (getTestFilePath(pteFilePath))
127
+ val pteFile = File (TestFileUtils . getTestFilePath(pteFilePath))
117
128
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
118
129
FileUtils .copyInputStreamToFile(pteInputStream, pteFile)
119
130
pteInputStream.close()
120
131
121
- val module = TrainingModule .load(getTestFilePath(pteFilePath));
132
+ val module = TrainingModule .load(TestFileUtils . getTestFilePath(pteFilePath));
122
133
val params = module.namedParameters(" forward" )
123
134
124
135
Assert .assertEquals(4 , params.size)
@@ -147,7 +158,10 @@ class TrainingModuleE2ETest {
147
158
val targetDex = inputDex + 1
148
159
val input = dataset.get(inputDex)
149
160
val target = dataset.get(targetDex)
150
- val out = module.executeForwardBackward(" forward" , EValue .from(input), EValue .from(target))
161
+ val out = module.executeForwardBackward(" forward" ,
162
+ EValue .from(input),
163
+ EValue .from(target)
164
+ )
151
165
val gradients = module.namedGradients(" forward" )
152
166
153
167
if (i == 0 ) {
@@ -168,7 +182,9 @@ class TrainingModuleE2ETest {
168
182
input.getDataAsFloatArray()[0 ],
169
183
input.getDataAsFloatArray()[1 ],
170
184
out [1 ].toTensor().getDataAsLongArray()[0 ],
171
- target.getDataAsLongArray()[0 ]));
185
+ target.getDataAsLongArray()[0 ]
186
+ )
187
+ );
172
188
}
173
189
174
190
sgd.step(gradients)
@@ -184,24 +200,33 @@ class TrainingModuleE2ETest {
184
200
@Throws(IOException ::class )
185
201
fun testMissingPteFile () {
186
202
val exception = Assert .assertThrows(RuntimeException ::class .java) {
187
- TrainingModule .load(getTestFilePath(MISSING_PTE_NAME ))
203
+ TrainingModule .load(TestFileUtils . getTestFilePath(MISSING_PTE_NAME ))
188
204
}
189
- Assert .assertEquals(exception.message, " Cannot load model path!! " + getTestFilePath(MISSING_PTE_NAME ))
205
+ Assert .assertEquals(
206
+ exception.message,
207
+ " Cannot load model path!! " + TestFileUtils .getTestFilePath(MISSING_PTE_NAME )
208
+ )
190
209
}
191
210
192
211
@Test
193
212
@Throws(IOException ::class )
194
213
fun testMissingPtdFile () {
195
214
val exception = Assert .assertThrows(RuntimeException ::class .java) {
196
215
val pteFilePath = " /xor.pte"
197
- val pteFile = File (getTestFilePath(pteFilePath))
216
+ val pteFile = File (TestFileUtils . getTestFilePath(pteFilePath))
198
217
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
199
218
FileUtils .copyInputStreamToFile(pteInputStream, pteFile)
200
219
pteInputStream.close()
201
220
202
- TrainingModule .load(getTestFilePath(pteFilePath), getTestFilePath(MISSING_PTD_NAME ))
221
+ TrainingModule .load(
222
+ TestFileUtils .getTestFilePath(pteFilePath),
223
+ TestFileUtils .getTestFilePath(MISSING_PTD_NAME )
224
+ )
203
225
}
204
- Assert .assertEquals(exception.message, " Cannot load data path!! " + getTestFilePath(MISSING_PTD_NAME ))
226
+ Assert .assertEquals(
227
+ exception.message,
228
+ " Cannot load data path!! " + TestFileUtils .getTestFilePath(MISSING_PTD_NAME )
229
+ )
205
230
}
206
231
207
232
companion object {
@@ -212,4 +237,4 @@ class TrainingModuleE2ETest {
212
237
private const val MISSING_PTE_NAME = " /missing.pte"
213
238
private const val MISSING_PTD_NAME = " /missing.ptd"
214
239
}
215
- }
240
+ }
0 commit comments