|
7 | 7 |
|
8 | 8 | #! -*- coding: utf-8 -*-
|
9 | 9 | # 测试代码可用性: 提取特征
|
| 10 | + |
10 | 11 | import os
|
11 | 12 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
12 |
| -os.environ["KERAS_BACKEND"] = "jax" |
13 |
| -os.environ["CUDA_VISIBLE_DEVICES"]="-1" |
| 13 | +os.environ["KERAS_BACKEND"] = "torch" |
| 14 | +os.environ["FLASH_ATTN"] = "0"#启动flash-attn |
| 15 | + |
14 | 16 |
|
15 | 17 | import numpy as np
|
16 | 18 | from bert4keras3.backend import keras,ops
|
| 19 | +keras.backend.set_floatx('bfloat16') |
17 | 20 | from bert4keras3.models import build_transformer_model
|
18 | 21 | from bert4keras3.tokenizers import Tokenizer
|
19 | 22 | from bert4keras3.snippets import to_array
|
| 23 | + |
20 | 24 | #bert from
|
21 | 25 | #model download form https://open.zhuiyi.ai/releases/nlp/models/zhuiyi/chinese_roberta_L-4_H-312_A-12.zip
|
22 |
| -config_path = 'models/chinese_roberta_L-4_H-312_A-12/bert_config.json' |
23 |
| -checkpoint_path = 'models/chinese_roberta_L-4_H-312_A-12/bert_model.ckpt' |
24 |
| -dict_path = 'models/chinese_roberta_L-4_H-312_A-12/vocab.txt' |
| 26 | +config_path = 'chinese_wobert_L-12_H-768_A-12/bert_config.json' |
| 27 | +checkpoint_path = 'chinese_wobert_L-12_H-768_A-12/bert_model.ckpt' |
| 28 | +dict_path = 'chinese_wobert_L-12_H-768_A-12/vocab.txt' |
25 | 29 |
|
26 | 30 | tokenizer = Tokenizer(dict_path, do_lower_case=True) # 建立分词器
|
27 | 31 | model = build_transformer_model(config_path, checkpoint_path) # 建立模型,加载权重
|
|
30 | 34 | token_ids, segment_ids = tokenizer.encode(u'语言模型')
|
31 | 35 | token_ids, segment_ids = to_array([token_ids], [segment_ids])
|
32 | 36 |
|
| 37 | + |
| 38 | + |
33 | 39 | print('\n ===== predicting =====\n')
|
34 | 40 | print(model.predict([token_ids, segment_ids]))
|
35 |
| -model.predict([np.zeros([1024,32]),np.zeros([1024,32])],batch_size=8) |
36 |
| -model.predict([np.zeros([1024,32]),np.zeros([1024,32])],batch_size=8,verbose=1) |
| 41 | +model.predict([np.zeros([1024,32]),np.zeros([1024,32])],batch_size=16) |
| 42 | +model.predict([np.zeros([1024,32]),np.zeros([1024,32])],batch_size=16,verbose=1) |
37 | 43 |
|
38 | 44 |
|
39 | 45 | """
|
40 |
| -keras2.3.1 with tf2.2 : |
41 |
| -gpu: |
42 |
| -1024/1024 [==============================] - 0s 466us/step |
43 |
| -cpu: |
44 |
| -1024/1024 [==============================] - 3s 3ms/step |
45 |
| -array([[[-0.529619 , -0.08855614, 0.37196752, ..., -0.33411935, |
46 |
| - 0.13711435, -0.956624 ], |
47 |
| - [-0.83184814, 0.3324544 , -0.49911997, ..., -0.12592706, |
48 |
| - -0.87739134, -1.120965 ], |
49 |
| - [-0.9416604 , 0.11662968, 0.92229784, ..., 0.6774571 , |
50 |
| - 1.5154107 , -0.16043526], |
51 |
| - [-0.891538 , -0.8726713 , -1.5886593 , ..., 0.2074936 , |
52 |
| - -0.44794142, -1.0378699 ], |
53 |
| - [-0.87546647, 0.75775445, -0.2165907 , ..., 0.63286835, |
54 |
| - 2.0644133 , -0.0790057 ], |
55 |
| - [-0.26717812, -0.5348375 , 0.16076468, ..., -0.9300951 , |
56 |
| - 1.2696625 , -1.60602 ]]], dtype=float32) |
57 |
| -
|
58 |
| -keras3 torch backend |
59 |
| -gpu:128/128 ━━━━━━━━━━━━━━━━━━━━ 3s 26ms/step |
60 |
| -cpu:128/128 ━━━━━━━━━━━━━━━━━━━━ 6s 45ms/step |
61 |
| -[[[-0.52961934 -0.08855701 0.37196732 ... -0.33411875 0.13711427 |
62 |
| - -0.9566241 ] |
63 |
| - [-0.8318479 0.33245584 -0.49911973 ... -0.12592846 -0.8773916 |
64 |
| - -1.1209643 ] |
65 |
| - [-0.9416604 0.11662959 0.9222967 ... 0.67745614 1.515408 |
66 |
| - -0.16043454] |
67 |
| - [-0.8915373 -0.87267166 -1.5886595 ... 0.20749295 -0.44793993 |
68 |
| - -1.037869 ] |
69 |
| - [-0.8754667 0.7577548 -0.21659094 ... 0.6328681 2.064412 |
70 |
| - -0.07900612] |
71 |
| - [-0.26717797 -0.5348387 0.16076465 ... -0.93009293 1.2696612 |
72 |
| - -1.6060183 ]]] |
73 |
| -
|
74 |
| -keras3 jax backend |
75 |
| -cpu 128/128 ━━━━━━━━━━━━━━━━━━━━ 4s 28ms/step |
76 |
| -[[[-0.5296185 -0.08855577 0.37196666 ... -0.3341192 0.13711472 |
77 |
| - -0.9566229 ] |
78 |
| - [-0.8318465 0.33245653 -0.49912187 ... -0.12592922 -0.87738854 |
79 |
| - -1.1209626 ] |
80 |
| - [-0.9416592 0.11662999 0.92229736 ... 0.67745733 1.5154091 |
81 |
| - -0.16043147] |
82 |
| - [-0.89153856 -0.8726713 -1.5886576 ... 0.20749114 -0.4479399 |
83 |
| - -1.0378699 ] |
84 |
| - [-0.8754649 0.7577545 -0.21659109 ... 0.6328678 2.0644124 |
85 |
| - -0.07900498] |
86 |
| - [-0.26717943 -0.53483707 0.1607644 ... -0.93009156 1.2696632 |
87 |
| - -1.6060191 ]]] |
88 |
| -
|
89 |
| -
|
90 |
| -keras3 tf backend |
91 |
| -cpu 128/128 ━━━━━━━━━━━━━━━━━━━━ 3s 20ms/step |
92 |
| -[[[-0.52961934 -0.08855658 0.37196696 ... -0.3341183 0.13711505 |
93 |
| - -0.956624 ] |
94 |
| - [-0.8318459 0.3324545 -0.49912187 ... -0.1259276 -0.8773911 |
95 |
| - -1.1209633 ] |
96 |
| - [-0.941661 0.11663001 0.9222968 ... 0.67745817 1.5154085 |
97 |
| - -0.1604334 ] |
98 |
| - [-0.89153755 -0.8726738 -1.5886576 ... 0.20749217 -0.4479404 |
99 |
| - -1.0378699 ] |
100 |
| - [-0.87546563 0.7577539 -0.21659082 ... 0.63286823 2.0644112 |
101 |
| - -0.07900462] |
102 |
| - [-0.26717886 -0.53484 0.16076434 ... -0.93009007 1.2696624 |
103 |
| - -1.6060195 ]]] |
| 46 | +jax + keras3 |
| 47 | +1/1 ━━━━━━━━━━━━━━━━━━━━ 4s 4s/step |
| 48 | +[[[ 0.38988823 -0.01438433 1.5175353 ... -0.24279015 0.0602929 |
| 49 | + -0.5276529 ] |
| 50 | + [-0.7205516 -0.18944462 1.2657804 ... -1.0808806 -0.07783457 |
| 51 | + 0.02543534] |
| 52 | + [ 0.2350443 -0.59266067 1.0074751 ... -0.870237 -0.41888168 |
| 53 | + -0.20036128] |
| 54 | + [-0.46260795 0.2652816 1.2453502 ... -0.5627243 -0.51681757 |
| 55 | + -0.75340915] |
| 56 | + [-1.1026425 0.4773692 1.5190558 ... -0.9399947 -1.3550612 |
| 57 | + -0.25996414] |
| 58 | + [ 0.3899519 -0.01446744 1.5174159 ... -0.24278383 0.06029204 |
| 59 | + -0.5276821 ]]] |
| 60 | +128/128 ━━━━━━━━━━━━━━━━━━━━ 6s 2ms/step |
| 61 | +128/128 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step |
| 62 | +
|
| 63 | + keras3+torch |
| 64 | +
|
| 65 | + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 46ms/step |
| 66 | +[[[ 0.39091924 -0.01291097 1.518835 ... -0.24233319 0.0598794 |
| 67 | + -0.52755356] |
| 68 | + [-0.7224981 -0.18857734 1.264718 ... -1.0812999 -0.07896316 |
| 69 | + 0.02745117] |
| 70 | + [ 0.23513678 -0.5940911 1.0072188 ... -0.87042177 -0.41914654 |
| 71 | + -0.200799 ] |
| 72 | + [-0.4600449 0.26725328 1.2442416 ... -0.5605526 -0.52189654 |
| 73 | + -0.75340474] |
| 74 | + [-1.1030028 0.47758132 1.5185008 ... -0.9410242 -1.352944 |
| 75 | + -0.26015034] |
| 76 | + [ 0.39091927 -0.01291088 1.5188353 ... -0.24233319 0.05987941 |
| 77 | + -0.5275535 ]]] |
| 78 | +128/128 ━━━━━━━━━━━━━━━━━━━━ 5s 42ms/step |
| 79 | +128/128 ━━━━━━━━━━━━━━━━━━━━ 5s 42ms/step |
| 80 | +
|
| 81 | +jax+bf16+flash-att |
| 82 | +1/1 ━━━━━━━━━━━━━━━━━━━━ 4s 4s/step |
| 83 | +[[[0.330078 -0.570312 0.535156 ... -0.178711 0.0795898 -0.257812] |
| 84 | + [1.24219 -0.546875 0.392578 ... 0.0874023 -0.117188 -0.769531] |
| 85 | + [0.316406 -0.326172 0.328125 ... 0.216797 0.185547 -0.667969] |
| 86 | + [0.457031 -0.0537109 0.503906 ... -0.0708008 -0.135742 -0.291016] |
| 87 | + [0.240234 -0.200195 0.123047 ... -0.460938 -0.097168 -0.316406] |
| 88 | + [0.330078 -0.570312 0.535156 ... -0.174805 0.0776367 -0.259766]]] |
| 89 | +64/64 ━━━━━━━━━━━━━━━━━━━━ 4s 5ms/step |
| 90 | +64/64 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step |
| 91 | +jax+bf16 |
| 92 | +
|
| 93 | +1/1 ━━━━━━━━━━━━━━━━━━━━ 4s 4s/step |
| 94 | +[[[0.332031 -0.574219 0.535156 ... -0.175781 0.0771484 -0.265625] |
| 95 | + [1.25781 -0.539062 0.384766 ... 0.0844727 -0.111816 -0.777344] |
| 96 | + [0.324219 -0.324219 0.330078 ... 0.21582 0.183594 -0.660156] |
| 97 | + [0.460938 -0.0551758 0.498047 ... -0.0766602 -0.128906 -0.296875] |
| 98 | + [0.245117 -0.201172 0.11084 ... -0.474609 -0.103516 -0.314453] |
| 99 | + [0.332031 -0.574219 0.535156 ... -0.175781 0.0751953 -0.263672]]] |
| 100 | +128/128 ━━━━━━━━━━━━━━━━━━━━ 4s 4ms/step |
| 101 | +128/128 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step |
| 102 | +
|
| 103 | +torch+bf16 |
| 104 | +1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 46ms/step |
| 105 | +[[[0.330078 -0.566406 0.546875 ... -0.169922 0.0742188 -0.267578] |
| 106 | + [1.25 -0.542969 0.402344 ... 0.074707 -0.111328 -0.761719] |
| 107 | + [0.320312 -0.302734 0.345703 ... 0.238281 0.169922 -0.660156] |
| 108 | + [0.455078 -0.0373535 0.511719 ... -0.0810547 -0.145508 -0.291016] |
| 109 | + [0.227539 -0.193359 0.111328 ... -0.466797 -0.0996094 -0.320312] |
| 110 | + [0.330078 -0.566406 0.546875 ... -0.173828 0.0742188 -0.265625]]] |
| 111 | +64/64 ━━━━━━━━━━━━━━━━━━━━ 3s 44ms/step |
| 112 | +64/64 ━━━━━━━━━━━━━━━━━━━━ 3s 43ms/step |
104 | 113 |
|
105 | 114 | """
|
106 | 115 |
|
0 commit comments