BaseEstimator类用来处理输入数据的格式

类内的全局变量有[X,y,y_{required},fit_{required}]

[X,y]通过__setup_input()方法将[X,y]变为numpy.ndarray类型

如果输入数据没有[y], 则[y_{required} = False]

处理输入数据[X]的步骤如下:

  1. 若X不是numpy.ndarray类型,则转换类型;

  2. 若X为空数组,则提示值错误;

  3. 若X的维数为1,则X的样本数为1,X的特征数目为X.shape[0];

  4. 若X的维数不为1,则X的样本数为X.shape[0],X的特征数目为其他维数的长度之积。



如果输入了[y],则对[y]处理的步骤如下:

  1. 若需要输入y却没有输入y,则提示错误;

  2. 若X不是numpy.ndarray类型,则转换类型;

  3. 若输入了y,但是大小为0,则提示错误。



# coding: utf-8
import numpy as np


class BaseEstimator(Object):
    X = None
    y = None
    y_required = True
    fit_required = True
    def __setup_input(self,X,y=None):
        if not isinstance(X,np.ndarray):
            X = np.array(X)
        
        if X.size == 0:
            raise ValueError('Number of feautures must be > 0 ')
        
        if X.ndim == 1:
            self.n_samples, self.n_feautures = 1, X.shape
        else:
            self.n_samples, self.n_feautures = X.shape[0], np.prod(X.shape[1:])

        self.X = X

        if self.y_required:
            if y is None:
                raise ValueError('Missed required argument y ')

            if not isinstance(y,np.ndarray):
                y = np.array(y)

            if y.size == 0:
                raise ValueError('Number of target y must be > 0')

        self.y = y

    def fit(self,X,y= None):
        self.__setup_input(X,y)

    def predict(self,X=None):
        if not isinstance(X,np.ndarray):
            X = np.array(X)

        if X is not None or not fit_required:
            return self._predict(X)
        else:
            raise ValueError('You must call fit before predict')

    def _predict(X=None):
        raise NotImplementedError()
内容来源于网络如有侵权请私信删除

文章来源: 博客园

原文链接: https://www.cnblogs.com/shq-lj/p/11845048.html

你还没有登录,请先登录注册
  • 还没有人评论,欢迎说说您的想法!