对 np.float16 类型的 numpy array 做 dot 运算简直就是噩梦……

遇到的问题:

如题。类似如下代码:

X = np.zeros(2000, 24000), dtype=np.float16)
Y = np.zeros(2000, 24000), dtype=np.float16)
Z = X.dot(Y.T)

即,定义两个 2000 x 24000 的 numpy 二维数组,使用 dot 对其做点乘运算,此时会发现python开始了漫长的计算,漫长到每次都是以通过任务管理器强行停止python进程结束……

解决方法:

不要使用 float16 类型定义 numpy 数组。本来想降低内存使用量,但发现代价大得过分了,我调试的时候还郁闷了好久,想不通一个 dot 运算能出什么问题……

即,改为:

X = np.zeros(2000, 24000), dtype=np.float32)
Y = np.zeros(2000, 24000), dtype=np.float32)
Z = X.dot(Y.T)

为什么会这样:

和该运算基于 BLAS(Basic Linear Algebra Subprograms,基础线性代数程序集 ) 有关,未深究,可以进一步阅读:https://www.edureka.co/community/18855/numpy-multiplying-large-arrays-with-dtype-int8-is-slow

留下评论

您的电子邮箱地址不会被公开。 必填项已用*标注