Skip to content

Commit 091b12a

Browse files
authored
update tests #2234 (#2297)
1 parent 93564bc commit 091b12a

File tree

3 files changed

+483
-27
lines changed

3 files changed

+483
-27
lines changed

tests/torch4ms/test_simple_op.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import torch4ms
2+
import torch
3+
import mindspore
4+
import numpy as np
5+
6+
env = torch4ms.default_env()
7+
env.__enter__()
8+
9+
def test_matrix_operations():
10+
"""测试矩阵乘法和加法组合运算"""
11+
# 创建测试数据
12+
np_x = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
13+
np_y = np.array([[5.0, 6.0], [7.0, 8.0]], dtype=np.float32)
14+
np_z = np.array([[9.0, 10.0], [11.0, 12.0]], dtype=np.float32)
15+
16+
x = torch.tensor(np_x)
17+
y = torch.tensor(np_y)
18+
z = torch.tensor(np_z)
19+
result = torch.matmul(x, y) + z
20+
21+
print(f"x = {x}")
22+
print(f"y = {y}")
23+
print(f"z = {z}")
24+
print(f"x * y + z = {result}")
25+
26+
expected = np.matmul(np_x, np_y) + np_z
27+
print(f"\n预期结果:")
28+
print(f"{expected}")
29+
30+
np_result = result.detach().numpy()
31+
print(f"\n数值验证:")
32+
print(f"结果是否接近预期: {np.allclose(np_result, expected, atol=1e-5)}")
33+
34+
def test_activation_functions():
35+
"""测试激活函数"""
36+
np_data = np.array([[-1.0, 0.0, 1.0], [-0.5, 0.5, 1.5]], dtype=np.float32)
37+
x = torch.tensor(np_data)
38+
39+
print("\n" + "="*40)
40+
print("测试激活函数:")
41+
print(f"输入: {x}")
42+
43+
relu_result = torch.relu(x)
44+
print(f"\nReLU结果: {relu_result}")
45+
46+
sigmoid_result = torch.sigmoid(x)
47+
print(f"Sigmoid结果: {sigmoid_result}")
48+
49+
tanh_result = torch.tanh(x)
50+
print(f"Tanh结果: {tanh_result}")
51+
52+
if __name__ == "__main__":
53+
print("PyTorch版本: {}".format(torch.__version__))
54+
print("MindSpore版本: {}".format(mindspore.__version__))
55+
print("="*40)
56+
57+
test_matrix_operations()
58+
test_activation_functions()
59+
env.__exit__(None, None, None)

0 commit comments

Comments
 (0)