开发者社区> 问答> 正文

带有fit_transform错误的列转换器

我得到错误,而使用make_column_transformer与LabelEncoder

def train_or_load_model(data,learn=True):
to_categorical = None
to_OH = None
to_drop = None

with open('to_categorical.pickle','rb') as f:
    to_categorical=pickle.load(f)
with open('to_OH.pickle','rb') as f:
    to_OH=pickle.load(f)
with open('to_drop.pickle','rb') as f:
    to_drop=pickle.load(f)

# print(to_drop)

ID = data.drop(['id'],axis=1,inplace=True)


if learn:
    target = np.array(data[['target']])
    target.reshape((300000,-1))
    print(type(target))
    to_drop.append('target')
    data.drop(to_drop,axis=1,inplace=True)
else:
    data.drop(to_drop,axis=1,inplace=True)
if learn:
    transformer = make_column_transformer(
            (LabelEncoder(),to_categorical),
            (OneHotEncoder(),to_OH)
    )


if learn:
    model = Pipeline(
        steps=[('preprocess_data', transformer),
         ('model',KNeighborsClassifier(2,n_jobs=-1))
        ]
    )
    X_train,X_test, y_train, y_test = train_test_split(data,target,test_size=0.2)
    model.fit(X_train,y_train)

我使用的数据来自https://www.kaggle.com/c/cat-thedat/data 我得到错误

    Traceback (most recent call last):
  File "c:\Users\barte\.vscode\extensions\ms-python.python-2019.11.50794\pythonFiles\ptvsd_launcher.py", line 43, in <module>
    main(ptvsdArgs)
  File "c:\Users\barte\.vscode\extensions\ms-python.python-2019.11.50794\pythonFiles\lib\python\old_ptvsd\ptvsd\__main__.py", line 432, in main
    run()
  File "c:\Users\barte\.vscode\extensions\ms-python.python-2019.11.50794\pythonFiles\lib\python\old_ptvsd\ptvsd\__main__.py", line 316, in run_file
    runpy.run_path(target, run_name='__main__')
  File "C:\Users\barte\AppData\Local\Programs\Python\Python36\Lib\runpy.py", line 263, in run_path
    pkg_name=pkg_name, script_name=fname)
  File "C:\Users\barte\AppData\Local\Programs\Python\Python36\Lib\runpy.py", line 96, in _run_module_code
    mod_name, mod_spec, pkg_name, script_name)
  File "C:\Users\barte\AppData\Local\Programs\Python\Python36\Lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "c:\Users\barte\Desktop\Projects\tf\kaggle categorical feature\main.py", line 102, in <module>
    print(train_or_load_model(raw_data))
  File "c:\Users\barte\Desktop\Projects\tf\kaggle categorical feature\main.py", line 97, in train_or_load_model
    model.fit(X_train,y_train)
  File "C:\Users\barte\Desktop\Projects\tf\env\lib\site-packages\sklearn\pipeline.py", line 352, in fit
    Xt, fit_params = self._fit(X, y, **fit_params)
  File "C:\Users\barte\Desktop\Projects\tf\env\lib\site-packages\sklearn\pipeline.py", line 317, in _fit
    **fit_params_steps[name])
  File "C:\Users\barte\Desktop\Projects\tf\env\lib\site-packages\joblib\memory.py", line 355, in __call__
    return self.func(*args, **kwargs)
  File "C:\Users\barte\Desktop\Projects\tf\env\lib\site-packages\sklearn\pipeline.py", line 716, in _fit_transform_one
    res = transformer.fit_transform(X, y, **fit_params)
  File "C:\Users\barte\Desktop\Projects\tf\env\lib\site-packages\sklearn\compose\_column_transformer.py", line 476, in fit_transform
    result = self._fit_transform(X, y, _fit_transform_one)
  File "C:\Users\barte\Desktop\Projects\tf\env\lib\site-packages\sklearn\compose\_column_transformer.py", line 420, in _fit_transform
    self._iter(fitted=fitted, replace_strings=True), 1))
  File "C:\Users\barte\Desktop\Projects\tf\env\lib\site-packages\joblib\parallel.py", line 921, in __call__
    if self.dispatch_one_batch(iterator):
  File "C:\Users\barte\Desktop\Projects\tf\env\lib\site-packages\joblib\parallel.py", line 759, in dispatch_one_batch
    self._dispatch(tasks)
  File "C:\Users\barte\Desktop\Projects\tf\env\lib\site-packages\joblib\parallel.py", line 716, in _dispatch
    job = self._backend.apply_async(batch, callback=cb)
  File "C:\Users\barte\Desktop\Projects\tf\env\lib\site-packages\joblib\_parallel_backends.py", line 182, in apply_async        
    result = ImmediateResult(func)
  File "C:\Users\barte\Desktop\Projects\tf\env\lib\site-packages\joblib\_parallel_backends.py", line 549, in __init__
    self.results = batch()
  File "C:\Users\barte\Desktop\Projects\tf\env\lib\site-packages\joblib\parallel.py", line 225, in __call__
    for func, args, kwargs in self.items]
  File "C:\Users\barte\Desktop\Projects\tf\env\lib\site-packages\joblib\parallel.py", line 225, in <listcomp>
    for func, args, kwargs in self.items]
  File "C:\Users\barte\Desktop\Projects\tf\env\lib\site-packages\sklearn\pipeline.py", line 716, in _fit_transform_one
    res = transformer.fit_transform(X, y, **fit_params)
TypeError: fit_transform() takes 2 positional arguments but 3 were given
enter code here

我不知道为什么这个管道给了3个参数LabelEncoder而只有X_train应该去那里 我还尝试创建自己的类,比如MyLabelEncoder(BaseEstimator,TransformerMixin) 但是在拟合的时候出现了严重的形状误差。 谢谢你的帮助,圣诞节快乐:) 问题来源StackOverflow 地址:/questions/59466597/column-transformer-with-fit-transform-error

展开
收起
kun坤 2019-12-25 09:57:45 801 0
1 条回答
写回答
取消 提交回答
  • 删除LabelEncoder()并使用OneHotEncoder()。使用Scikit-learn执行OneHot之前,您不再需要LabelEncode。

    而且,LabelEncoder()与管道一起使用时确实存在一些问题。为此,如果您需要编码特定的列,则可以设计一种方法来独立运行此转换,而无需通过管道转换。

    2019-12-26 14:25:42
    赞同 展开评论 打赏
问答分类:
问答地址:
问答排行榜
最热
最新

相关电子书

更多
低代码开发师(初级)实战教程 立即下载
冬季实战营第三期:MySQL数据库进阶实战 立即下载
阿里巴巴DevOps 最佳实践手册 立即下载