当前位置: 代码迷 >> 综合 >> paddlepaddle加载预训练词向量
  详细解决方案

paddlepaddle加载预训练词向量

热度:42   发布时间:2024-02-26 16:33:28.0

直接见代码:

import paddle.fluid as fluid
import numpy as npdata = fluid.layers.data(name='data', shape=[1], dtype='int64')
np.random.seed(28)
weight_data = np.random.random(size=(20, 8))
print(weight_data)#加载用户自定义或预训练的词向量
w_param_attrs = fluid.ParamAttr(name="w_param_attrs",initializer=fluid.initializer.NumpyArrayInitializer(weight_data),trainable=False)emb_2 = fluid.embedding(input=data, size=(20, 8), param_attr=w_param_attrs, dtype='float32')cpu = fluid.CPUPlace() # 定义运算场所
exe = fluid.Executor(cpu) # 创建执行器
exe.run(fluid.default_startup_program()) # 网络参数初始化x = np.array([[1], [2]])
outs = exe.run(feed={'data': x}, fetch_list=[emb_2.name])
print(outs)

结果如下:

[[0.72901374 0.5612396  0.12496709 0.39759237 0.78130821 0.510992980.18269336 0.85351288][0.95537189 0.98421347 0.19270097 0.9707951  0.23480835 0.026353850.94606034 0.92172485][0.29397577 0.1662737  0.39542284 0.51066973 0.30803723 0.429568830.83006941 0.56239357][0.83088831 0.99692929 0.33257881 0.09100813 0.77383156 0.149383730.72535506 0.95514643][0.07309577 0.44716275 0.84111807 0.14553967 0.76527154 0.781784920.67507855 0.13170219][0.03930318 0.65602308 0.25118261 0.98841838 0.53338304 0.059175240.69875531 0.62717477][0.89577854 0.16192467 0.61038158 0.3169851  0.76326567 0.156282080.92988758 0.49781052][0.83323397 0.22996943 0.10681001 0.67370038 0.57898325 0.875849370.99712764 0.27530634][0.74263626 0.28473195 0.72624867 0.49107034 0.86801609 0.16226170.9713251  0.04888569][0.70054591 0.65194491 0.04645909 0.19730088 0.33060701 0.752644950.36501458 0.53077101][0.35418132 0.51467406 0.26169937 0.85173949 0.62324126 0.304469750.77547856 0.89555198][0.7374077  0.85555241 0.82012533 0.86522095 0.38212962 0.611407060.41550595 0.2421348 ][0.06125105 0.81751611 0.38363211 0.97884048 0.38187252 0.630149680.44335181 0.02552223][0.23321525 0.77924846 0.16996923 0.41457111 0.59480006 0.910870080.50639157 0.4386332 ][0.03229215 0.22840922 0.18160441 0.24255622 0.8094556  0.519288470.36861752 0.46235367][0.60488351 0.55737864 0.03305479 0.39902018 0.08332113 0.483166350.85653765 0.84775654][0.37035053 0.71812028 0.00461064 0.76418841 0.74670009 0.858918820.45676896 0.94777212][0.63737347 0.49762039 0.18912248 0.75981605 0.37119162 0.209273750.32256109 0.20617277][0.40986867 0.13548799 0.81640462 0.63828349 0.67581164 0.008539340.73750379 0.76717025][0.16223589 0.9606869  0.79786617 0.58411784 0.04252264 0.342688690.36767624 0.88560098]]
[array([[[0.9553719 , 0.9842135 , 0.19270097, 0.9707951 , 0.23480836,0.02635385, 0.94606036, 0.92172486]],[[0.29397577, 0.1662737 , 0.39542285, 0.5106697 , 0.30803722,0.42956883, 0.8300694 , 0.56239355]]], dtype=float32)]