Skip to content

Commit 5bdad0e

Browse files
authored
Add files via upload
1 parent b0aba45 commit 5bdad0e

File tree

1 file changed

+80
-71
lines changed

1 file changed

+80
-71
lines changed

examples/basic_extract_features.py

Lines changed: 80 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,25 @@
77

88
#! -*- coding: utf-8 -*-
99
# 测试代码可用性: 提取特征
10+
1011
import os
1112
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
12-
os.environ["KERAS_BACKEND"] = "jax"
13-
os.environ["CUDA_VISIBLE_DEVICES"]="-1"
13+
os.environ["KERAS_BACKEND"] = "torch"
14+
os.environ["FLASH_ATTN"] = "0"#启动flash-attn
15+
1416

1517
import numpy as np
1618
from bert4keras3.backend import keras,ops
19+
keras.backend.set_floatx('bfloat16')
1720
from bert4keras3.models import build_transformer_model
1821
from bert4keras3.tokenizers import Tokenizer
1922
from bert4keras3.snippets import to_array
23+
2024
#bert from
2125
#model download form https://open.zhuiyi.ai/releases/nlp/models/zhuiyi/chinese_roberta_L-4_H-312_A-12.zip
22-
config_path = 'models/chinese_roberta_L-4_H-312_A-12/bert_config.json'
23-
checkpoint_path = 'models/chinese_roberta_L-4_H-312_A-12/bert_model.ckpt'
24-
dict_path = 'models/chinese_roberta_L-4_H-312_A-12/vocab.txt'
26+
config_path = 'chinese_wobert_L-12_H-768_A-12/bert_config.json'
27+
checkpoint_path = 'chinese_wobert_L-12_H-768_A-12/bert_model.ckpt'
28+
dict_path = 'chinese_wobert_L-12_H-768_A-12/vocab.txt'
2529

2630
tokenizer = Tokenizer(dict_path, do_lower_case=True) # 建立分词器
2731
model = build_transformer_model(config_path, checkpoint_path) # 建立模型,加载权重
@@ -30,77 +34,82 @@
3034
token_ids, segment_ids = tokenizer.encode(u'语言模型')
3135
token_ids, segment_ids = to_array([token_ids], [segment_ids])
3236

37+
38+
3339
print('\n ===== predicting =====\n')
3440
print(model.predict([token_ids, segment_ids]))
35-
model.predict([np.zeros([1024,32]),np.zeros([1024,32])],batch_size=8)
36-
model.predict([np.zeros([1024,32]),np.zeros([1024,32])],batch_size=8,verbose=1)
41+
model.predict([np.zeros([1024,32]),np.zeros([1024,32])],batch_size=16)
42+
model.predict([np.zeros([1024,32]),np.zeros([1024,32])],batch_size=16,verbose=1)
3743

3844

3945
"""
40-
keras2.3.1 with tf2.2 :
41-
gpu:
42-
1024/1024 [==============================] - 0s 466us/step
43-
cpu:
44-
1024/1024 [==============================] - 3s 3ms/step
45-
array([[[-0.529619 , -0.08855614, 0.37196752, ..., -0.33411935,
46-
0.13711435, -0.956624 ],
47-
[-0.83184814, 0.3324544 , -0.49911997, ..., -0.12592706,
48-
-0.87739134, -1.120965 ],
49-
[-0.9416604 , 0.11662968, 0.92229784, ..., 0.6774571 ,
50-
1.5154107 , -0.16043526],
51-
[-0.891538 , -0.8726713 , -1.5886593 , ..., 0.2074936 ,
52-
-0.44794142, -1.0378699 ],
53-
[-0.87546647, 0.75775445, -0.2165907 , ..., 0.63286835,
54-
2.0644133 , -0.0790057 ],
55-
[-0.26717812, -0.5348375 , 0.16076468, ..., -0.9300951 ,
56-
1.2696625 , -1.60602 ]]], dtype=float32)
57-
58-
keras3 torch backend
59-
gpu:128/128 ━━━━━━━━━━━━━━━━━━━━ 3s 26ms/step
60-
cpu:128/128 ━━━━━━━━━━━━━━━━━━━━ 6s 45ms/step
61-
[[[-0.52961934 -0.08855701 0.37196732 ... -0.33411875 0.13711427
62-
-0.9566241 ]
63-
[-0.8318479 0.33245584 -0.49911973 ... -0.12592846 -0.8773916
64-
-1.1209643 ]
65-
[-0.9416604 0.11662959 0.9222967 ... 0.67745614 1.515408
66-
-0.16043454]
67-
[-0.8915373 -0.87267166 -1.5886595 ... 0.20749295 -0.44793993
68-
-1.037869 ]
69-
[-0.8754667 0.7577548 -0.21659094 ... 0.6328681 2.064412
70-
-0.07900612]
71-
[-0.26717797 -0.5348387 0.16076465 ... -0.93009293 1.2696612
72-
-1.6060183 ]]]
73-
74-
keras3 jax backend
75-
cpu 128/128 ━━━━━━━━━━━━━━━━━━━━ 4s 28ms/step
76-
[[[-0.5296185 -0.08855577 0.37196666 ... -0.3341192 0.13711472
77-
-0.9566229 ]
78-
[-0.8318465 0.33245653 -0.49912187 ... -0.12592922 -0.87738854
79-
-1.1209626 ]
80-
[-0.9416592 0.11662999 0.92229736 ... 0.67745733 1.5154091
81-
-0.16043147]
82-
[-0.89153856 -0.8726713 -1.5886576 ... 0.20749114 -0.4479399
83-
-1.0378699 ]
84-
[-0.8754649 0.7577545 -0.21659109 ... 0.6328678 2.0644124
85-
-0.07900498]
86-
[-0.26717943 -0.53483707 0.1607644 ... -0.93009156 1.2696632
87-
-1.6060191 ]]]
88-
89-
90-
keras3 tf backend
91-
cpu 128/128 ━━━━━━━━━━━━━━━━━━━━ 3s 20ms/step
92-
[[[-0.52961934 -0.08855658 0.37196696 ... -0.3341183 0.13711505
93-
-0.956624 ]
94-
[-0.8318459 0.3324545 -0.49912187 ... -0.1259276 -0.8773911
95-
-1.1209633 ]
96-
[-0.941661 0.11663001 0.9222968 ... 0.67745817 1.5154085
97-
-0.1604334 ]
98-
[-0.89153755 -0.8726738 -1.5886576 ... 0.20749217 -0.4479404
99-
-1.0378699 ]
100-
[-0.87546563 0.7577539 -0.21659082 ... 0.63286823 2.0644112
101-
-0.07900462]
102-
[-0.26717886 -0.53484 0.16076434 ... -0.93009007 1.2696624
103-
-1.6060195 ]]]
46+
jax + keras3
47+
1/1 ━━━━━━━━━━━━━━━━━━━━ 4s 4s/step
48+
[[[ 0.38988823 -0.01438433 1.5175353 ... -0.24279015 0.0602929
49+
-0.5276529 ]
50+
[-0.7205516 -0.18944462 1.2657804 ... -1.0808806 -0.07783457
51+
0.02543534]
52+
[ 0.2350443 -0.59266067 1.0074751 ... -0.870237 -0.41888168
53+
-0.20036128]
54+
[-0.46260795 0.2652816 1.2453502 ... -0.5627243 -0.51681757
55+
-0.75340915]
56+
[-1.1026425 0.4773692 1.5190558 ... -0.9399947 -1.3550612
57+
-0.25996414]
58+
[ 0.3899519 -0.01446744 1.5174159 ... -0.24278383 0.06029204
59+
-0.5276821 ]]]
60+
128/128 ━━━━━━━━━━━━━━━━━━━━ 6s 2ms/step
61+
128/128 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step
62+
63+
keras3+torch
64+
65+
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 46ms/step
66+
[[[ 0.39091924 -0.01291097 1.518835 ... -0.24233319 0.0598794
67+
-0.52755356]
68+
[-0.7224981 -0.18857734 1.264718 ... -1.0812999 -0.07896316
69+
0.02745117]
70+
[ 0.23513678 -0.5940911 1.0072188 ... -0.87042177 -0.41914654
71+
-0.200799 ]
72+
[-0.4600449 0.26725328 1.2442416 ... -0.5605526 -0.52189654
73+
-0.75340474]
74+
[-1.1030028 0.47758132 1.5185008 ... -0.9410242 -1.352944
75+
-0.26015034]
76+
[ 0.39091927 -0.01291088 1.5188353 ... -0.24233319 0.05987941
77+
-0.5275535 ]]]
78+
128/128 ━━━━━━━━━━━━━━━━━━━━ 5s 42ms/step
79+
128/128 ━━━━━━━━━━━━━━━━━━━━ 5s 42ms/step
80+
81+
jax+bf16+flash-att
82+
1/1 ━━━━━━━━━━━━━━━━━━━━ 4s 4s/step
83+
[[[0.330078 -0.570312 0.535156 ... -0.178711 0.0795898 -0.257812]
84+
[1.24219 -0.546875 0.392578 ... 0.0874023 -0.117188 -0.769531]
85+
[0.316406 -0.326172 0.328125 ... 0.216797 0.185547 -0.667969]
86+
[0.457031 -0.0537109 0.503906 ... -0.0708008 -0.135742 -0.291016]
87+
[0.240234 -0.200195 0.123047 ... -0.460938 -0.097168 -0.316406]
88+
[0.330078 -0.570312 0.535156 ... -0.174805 0.0776367 -0.259766]]]
89+
64/64 ━━━━━━━━━━━━━━━━━━━━ 4s 5ms/step
90+
64/64 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
91+
jax+bf16
92+
93+
1/1 ━━━━━━━━━━━━━━━━━━━━ 4s 4s/step
94+
[[[0.332031 -0.574219 0.535156 ... -0.175781 0.0771484 -0.265625]
95+
[1.25781 -0.539062 0.384766 ... 0.0844727 -0.111816 -0.777344]
96+
[0.324219 -0.324219 0.330078 ... 0.21582 0.183594 -0.660156]
97+
[0.460938 -0.0551758 0.498047 ... -0.0766602 -0.128906 -0.296875]
98+
[0.245117 -0.201172 0.11084 ... -0.474609 -0.103516 -0.314453]
99+
[0.332031 -0.574219 0.535156 ... -0.175781 0.0751953 -0.263672]]]
100+
128/128 ━━━━━━━━━━━━━━━━━━━━ 4s 4ms/step
101+
128/128 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
102+
103+
torch+bf16
104+
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 46ms/step
105+
[[[0.330078 -0.566406 0.546875 ... -0.169922 0.0742188 -0.267578]
106+
[1.25 -0.542969 0.402344 ... 0.074707 -0.111328 -0.761719]
107+
[0.320312 -0.302734 0.345703 ... 0.238281 0.169922 -0.660156]
108+
[0.455078 -0.0373535 0.511719 ... -0.0810547 -0.145508 -0.291016]
109+
[0.227539 -0.193359 0.111328 ... -0.466797 -0.0996094 -0.320312]
110+
[0.330078 -0.566406 0.546875 ... -0.173828 0.0742188 -0.265625]]]
111+
64/64 ━━━━━━━━━━━━━━━━━━━━ 3s 44ms/step
112+
64/64 ━━━━━━━━━━━━━━━━━━━━ 3s 43ms/step
104113
105114
"""
106115

0 commit comments

Comments
 (0)