当前位置: 代码迷 >> 综合 >> keras报错: ValueError: Shapes (None, 1) and (None, 2) are incompatible
  详细解决方案

keras报错: ValueError: Shapes (None, 1) and (None, 2) are incompatible

热度:78   发布时间:2023-12-08 07:25:34.0

keras报错:ValueError:Shapes (None, 1)and (None,2)are incompatible

  • 任务背景
  • 错误提示
  • 问题解决
  • 具体程序

任务背景

使用 MLP 做时间序列的二分类问题,通过历史股价判断 未来天数 是涨还是跌。

错误提示

ValueError: Shapes (None, 1) and (None, 2) are incompatible

问题解决

将标签的数值 0,1 转化成 类别的 0,1

from tensorflow.keras.utils import to_categorical
y = to_categorical(dataset['binary_target'].values)

具体程序

import matplotlib.pylab as plt
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
# read_data 
data = pd.read_csv('./AAPL.csv')
close_price = data.loc[:, 'Adj Close'].tolist()
close_price_diffs = data.loc[:, 'Adj Close'].pct_change()
# generate the dataset 
target = close_price_diffs.apply(lambda x: 1 if x > 0 else 0)
dataset = pd.DataFrame({
    'close_price': close_price, 'binary_target':target })
# scale the dataset 
scaler = MinMaxScaler()
X_sc = scaler.fit_transform(dataset['close_price'].values.reshape(-1,1))
# y = dataset['binary_target'].values ## 会报错,因为不是类别
y = to_categorical(dataset['binary_target'].values)  ## 正确写法!!
X_train, X_test, y_train, y_test = train_test_split(X_sc, y, test_size=0.15, random_state=0)# define the model 
model = Sequential()
model.add(Dense(64, input_dim=X_train.shape[1]))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Dense(2))
model.add(Activation('softmax'))reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.9, patience=5, min_lr=0.000001, verbose=1)
model.compile(optimizer='rmsprop', loss='categorical_crossentropy',metrics=['accuracy'])
# fit the model 
history = model.fit(X_train, y_train, epochs = 50, batch_size = 128, verbose=1, validation_data=(X_test, y_test),shuffle=True,callbacks=[reduce_lr])
  相关解决方案