遇到的问题:
如题。类似如下代码:
X = np.zeros(2000, 24000), dtype=np.float16)
Y = np.zeros(2000, 24000), dtype=np.float16)
Z = X.dot(Y.T)
即,定义两个 2000 x 24000 的 numpy 二维数组,使用 dot 对其做点乘运算,此时会发现python开始了漫长的计算,漫长到每次都是以通过任务管理器强行停止python进程结束……
解决方法:
不要使用 float16 类型定义 numpy 数组。本来想降低内存使用量,但发现代价大得过分了,我调试的时候还郁闷了好久,想不通一个 dot 运算能出什么问题……
即,改为:
X = np.zeros(2000, 24000), dtype=np.float32)
Y = np.zeros(2000, 24000), dtype=np.float32)
Z = X.dot(Y.T)
为什么会这样:
和该运算基于 BLAS(Basic Linear Algebra Subprograms,基础线性代数程序集 ) 有关,未深究,可以进一步阅读:https://www.edureka.co/community/18855/numpy-multiplying-large-arrays-with-dtype-int8-is-slow