当前位置: 代码迷 >> python >> Numpy函数破坏了我继承的数据类型
  详细解决方案

Numpy函数破坏了我继承的数据类型

热度:99   发布时间:2023-06-13 13:55:13.0

假设我有一个继承自numpy.ndarray的类ndarray_plus并添加了一些额外的功能。 有时我将它传递给像np.sum这样的numpy函数,并按预期返回类型为ndarray_plus的对象。

其他时候,我通过增强对象返回numpy.ndarray对象的numpy函数,破坏额外的ndarray_plus属性中的信息。 当有问题的numpy函数执行np.asarray而不是np.asanyarray时,通常会发生这种情况。

有没有办法防止这种情况发生? 我无法进入numpy代码库并将np.asarray所有实例np.asarraynp.asanyarray 是否有一种Pythonic方法可以先发制人地保护我的继承对象?

asarray已定义和保证行为是将您的子类实例转换回基类

help on function asarray in numpy:

numpy.asarray = asarray(a, dtype=None, order=None)
Convert the input to an array.

Parameters
----------
a : array_like
    Input data, in any form that can be converted to an array.  This
    includes lists, lists of tuples, tuples, tuples of tuples, tuples
    of lists and ndarrays.
dtype : data-type, optional
    By default, the data-type is inferred from the input data.
order : {'C', 'F'}, optional
    Whether to use row-major (C-style) or
    column-major (Fortran-style) memory representation.
    Defaults to 'C'.

Returns
-------
out : ndarray
    Array interpretation of `a`.  No copy is performed if the input
    is already an ndarray.  If `a` is a subclass of ndarray, a base
    class ndarray is returned.

See Also
--------
asanyarray : Similar function which passes through subclasses.

< - snip - >

你可以尝试和monkeypatch:

>>> import numpy as np
>>> import mpt
>>> 
>>> s = np.matrix(3)
>>> mpt.aa(s)
array([[3]])
>>> np.asarray = np.asanyarray
>>> mpt.aa(s)
matrix([[3]])

文件mpt.py

import numpy as np

def aa(x):
   return np.asarray(x)

可悲的是,这并不总是有效。

替代mpt.py

from numpy import asarray

def aa(x):
   return asarray(x)

在这里,你运气不好。

  相关解决方案