当前位置: 代码迷 >> 数据仓库 >> 数据挖掘札记-分类-决策树-4
  详细解决方案

数据挖掘札记-分类-决策树-4

热度:286   发布时间:2016-05-05 15:43:24.0
数据挖掘笔记-分类-决策树-4

之前写的代码都是单机上跑的,发现现在很流行hadoop,所以又试着用hadoop mapreduce来处理下决策树的创建。因为hadoop接触的也不多,所以写的不好,勿怪。

?

看了一些mahout在处理决策树和随机森林的过程,大体过程是Job只有一个Mapper处理,在map方法里面做数据的转换收集工作,然后在cleanup方法里面去做决策树的创建过程。然后将决策树序列化到HDFS上面,分类样本数据集的时候,在从HDFS上面取回决策树结构。大体来说,mahout决策树的构建过程好像并没有结合分布式计算,因为我也并没有仔仔细细的去研读mahout里面的源码,所以可能是我没发现。下面是我实现的一个简单hadoop版本决策树,用的C4.5算法,通过MapReduce去计算增益率。最后生成的决策树并未保存在HDFS上面,后面有时间在考虑下吧。下面是具体代码实现:

?

public class DecisionTreeC45Job extends AbstractJob {		/** 对数据集做准备工作,主要就是将填充好默认值的数据集再次传到HDFS上*/	public String prepare(Data trainData) {		String path = FileUtils.obtainRandomTxtPath();		DataHandler.writeData(path, trainData);		System.out.println(path);		String name = path.substring(path.lastIndexOf(File.separator) + 1);		String hdfsPath = HDFSUtils.HDFS_TEMP_INPUT_URL + name;		HDFSUtils.copyFromLocalFile(conf, path, hdfsPath);		return hdfsPath;	}		/** 选择最佳属性,读取MapReduce计算后产生的文件,取增益率最大*/	public AttributeGainWritable chooseBestAttribute(String output) {		AttributeGainWritable maxAttribute = null;		Path path = new Path(output);		try {			FileSystem fs = path.getFileSystem(conf);			Path[] paths = HDFSUtils.getPathFiles(fs, path);			ShowUtils.print(paths);			double maxGainRatio = 0.0;			SequenceFile.Reader reader = null;			for (Path p : paths) {				reader = new SequenceFile.Reader(fs, p, conf);				Text key = (Text) ReflectionUtils.newInstance(						reader.getKeyClass(), conf);				AttributeGainWritable value = new AttributeGainWritable();				while (reader.next(key, value)) {					double gainRatio = value.getGainRatio();					if (gainRatio >= maxGainRatio) {						maxGainRatio = gainRatio;						maxAttribute = value;					}					value = new AttributeGainWritable();				}				IOUtils.closeQuietly(reader);			}			System.out.println("output: " + path.toString());			HDFSUtils.delete(conf, path);			System.out.println("hdfs delete file : " + path.toString());		} catch (IOException e) {			e.printStackTrace();		}		return maxAttribute;	}		/** 构造决策树 */	public Object build(String input, Data data) {		Object preHandleResult = preHandle(data);		if (null != preHandleResult) return preHandleResult;		String output = HDFSUtils.HDFS_TEMP_OUTPUT_URL;		HDFSUtils.delete(conf, new Path(output));		System.out.println("delete output path : " + output);		String[] paths = new String[]{input, output};		//通过MapReduce计算增益率		CalculateC45GainRatioMR.main(paths);				AttributeGainWritable bestAttr = chooseBestAttribute(output);		String attribute = bestAttr.getAttribute();		System.out.println("best attribute: " + attribute);		System.out.println("isCategory: " + bestAttr.isCategory());		if (bestAttr.isCategory()) {			return attribute;		}		String[] splitPoints = bestAttr.obtainSplitPoints();		System.out.print("splitPoints: ");		ShowUtils.print(splitPoints);		TreeNode treeNode = new TreeNode(attribute);		String[] attributes = data.getAttributesExcept(attribute);				//分割数据集,并将分割后的数据集传到HDFS上		DataSplit dataSplit = DataHandler.split(new Data(				data.getInstances(), attribute, splitPoints));		for (DataSplitItem item : dataSplit.getItems()) {			String path = item.getPath();			String name = path.substring(path.lastIndexOf(File.separator) + 1);			String hdfsPath = HDFSUtils.HDFS_TEMP_INPUT_URL + name;			HDFSUtils.copyFromLocalFile(conf, path, hdfsPath);			treeNode.setChild(item.getSplitPoint(), build(hdfsPath, 					new Data(attributes, item.getInstances())));		}		return treeNode;	}		/** 分类,根据决策树节点判断测试样本集的类型,并将结果上传到HDFS上*/	private void classify(TreeNode treeNode, String trainSet, String testSet, String output) {		OutputStream out = null;		BufferedWriter writer = null;		try {			Path trainSetPath = new Path(trainSet);			FileSystem trainFS = trainSetPath.getFileSystem(conf);			Path[] trainHdfsPaths = HDFSUtils.getPathFiles(trainFS, trainSetPath);			FSDataInputStream trainFSInputStream = trainFS.open(trainHdfsPaths[0]);			Data trainData = DataLoader.load(trainFSInputStream, true);						Path testSetPath = new Path(testSet);			FileSystem testFS = testSetPath.getFileSystem(conf);			Path[] testHdfsPaths = HDFSUtils.getPathFiles(testFS, testSetPath);			FSDataInputStream fsInputStream = testFS.open(testHdfsPaths[0]);			Data testData = DataLoader.load(fsInputStream, true);						DataHandler.fill(testData.getInstances(), trainData.getAttributes(), 0);			Object[] results = (Object[]) treeNode.classify(testData);			ShowUtils.print(results);			DataError dataError = new DataError(testData.getCategories(), results);			dataError.report();			String path = FileUtils.obtainRandomTxtPath();			out = new FileOutputStream(new File(path));			writer = new BufferedWriter(new OutputStreamWriter(out));			StringBuilder sb = null;			for (int i = 0, len = results.length; i < len; i++) {				sb = new StringBuilder();				sb.append(i+1).append("\t").append(results[i]);				writer.write(sb.toString());				writer.newLine();			}			writer.flush();			Path outputPath = new Path(output);			FileSystem fs = outputPath.getFileSystem(conf);			if (!fs.exists(outputPath)) {				fs.mkdirs(outputPath);			}			String name = path.substring(path.lastIndexOf(File.separator) + 1);			HDFSUtils.copyFromLocalFile(conf, path, output + 					File.separator + name);		} catch (IOException e) {			e.printStackTrace();		} finally {			IOUtils.closeQuietly(out);			IOUtils.closeQuietly(writer);		}	}		public void run(String[] args) {		try {			if (null == conf) conf = new Configuration();			String[] inputArgs = new GenericOptionsParser(					conf, args).getRemainingArgs();			if (inputArgs.length != 3) {				System.out.println("error, please input three path.");				System.out.println("1. trainset path.");				System.out.println("2. testset path.");				System.out.println("3. result output path.");				System.exit(2);			}			Path input = new Path(inputArgs[0]);			FileSystem fs = input.getFileSystem(conf);			Path[] hdfsPaths = HDFSUtils.getPathFiles(fs, input);			FSDataInputStream fsInputStream = fs.open(hdfsPaths[0]);			Data trainData = DataLoader.load(fsInputStream, true);			/** 填充缺失属性的默认值*/			DataHandler.fill(trainData, 0);			String hdfsInput = prepare(trainData);			TreeNode treeNode = (TreeNode) build(hdfsInput, trainData);			TreeNodeHelper.print(treeNode, 0, null);			classify(treeNode, inputArgs[0], inputArgs[1], inputArgs[2]);		} catch (Exception e) {			e.printStackTrace();		}	}		public static void main(String[] args) {		DecisionTreeC45Job job = new DecisionTreeC45Job();		long startTime = System.currentTimeMillis();		job.run(args);		long endTime = System.currentTimeMillis();		System.out.println("spend time: " + (endTime - startTime));	}}

CalculateC45GainRatioMR具体实现:

public class CalculateC45GainRatioMR {		private static void configureJob(Job job) {		job.setJarByClass(CalculateC45GainRatioMR.class);				job.setMapperClass(CalculateC45GainRatioMapper.class);		job.setMapOutputKeyClass(Text.class);		job.setMapOutputValueClass(AttributeWritable.class);		job.setReducerClass(CalculateC45GainRatioReducer.class);		job.setOutputKeyClass(Text.class);		job.setOutputValueClass(AttributeGainWritable.class);				job.setInputFormatClass(TextInputFormat.class);		job.setOutputFormatClass(SequenceFileOutputFormat.class);	}	public static void main(String[] args) {		Configuration configuration = new Configuration();		try {			String[] inputArgs = new GenericOptionsParser(						configuration, args).getRemainingArgs();			if (inputArgs.length != 2) {				System.out.println("error, please input two path. input and output");				System.exit(2);			}			Job job = new Job(configuration, "Decision Tree");						FileInputFormat.setInputPaths(job, new Path(inputArgs[0]));			FileOutputFormat.setOutputPath(job, new Path(inputArgs[1]));						configureJob(job);						System.out.println(job.waitForCompletion(true) ? 0 : 1);		} catch (Exception e) {			e.printStackTrace();		}	}}class CalculateC45GainRatioMapper extends Mapper<LongWritable, Text, 	Text, AttributeWritable> {		@Override	protected void setup(Context context) throws IOException,			InterruptedException {		super.setup(context);	}	@Override	protected void map(LongWritable key, Text value, Context context)			throws IOException, InterruptedException {		String line = value.toString();		StringTokenizer tokenizer = new StringTokenizer(line);		Long id = Long.parseLong(tokenizer.nextToken());		String category = tokenizer.nextToken();		boolean isCategory = true;		while (tokenizer.hasMoreTokens()) {			isCategory = false;			String attribute = tokenizer.nextToken();			String[] entry = attribute.split(":");			context.write(new Text(entry[0]), new AttributeWritable(id, category, entry[1]));		}		if (isCategory) {			context.write(new Text(category), new AttributeWritable(id, category, category));		}	}		@Override	protected void cleanup(Context context) throws IOException, InterruptedException {		super.cleanup(context);	}}class CalculateC45GainRatioReducer extends Reducer<Text, AttributeWritable, Text, AttributeGainWritable> {		@Override	protected void setup(Context context) throws IOException, InterruptedException {		super.setup(context);	}		@Override	protected void reduce(Text key, Iterable<AttributeWritable> values,			Context context) throws IOException, InterruptedException {		String attributeName = key.toString();		double totalNum = 0.0;		Map<String, Map<String, Integer>> attrValueSplits = 				new HashMap<String, Map<String, Integer>>();		Iterator<AttributeWritable> iterator = values.iterator();		boolean isCategory = false;		while (iterator.hasNext()) {			AttributeWritable attribute = iterator.next();			String attributeValue = attribute.getAttributeValue();			if (attributeName.equals(attributeValue)) {				isCategory = true;				break;			}			Map<String, Integer> attrValueSplit = attrValueSplits.get(attributeValue);			if (null == attrValueSplit) {				attrValueSplit = new HashMap<String, Integer>();				attrValueSplits.put(attributeValue, attrValueSplit);			}			String category = attribute.getCategory();			Integer categoryNum = attrValueSplit.get(category);			attrValueSplit.put(category, null == categoryNum ? 1 : categoryNum + 1);			totalNum++;		}		if (isCategory) {			System.out.println("is Category");			int sum = 0;			iterator = values.iterator();			while (iterator.hasNext()) {				iterator.next();				sum += 1;			}			System.out.println("sum: " + sum);			context.write(key, new AttributeGainWritable(attributeName,					sum, true, null));		} else {			double gainInfo = 0.0;			double splitInfo = 0.0;			for (Map<String, Integer> attrValueSplit : attrValueSplits.values()) {				double totalCategoryNum = 0;				for (Integer categoryNum : attrValueSplit.values()) {					totalCategoryNum += categoryNum;				}				double entropy = 0.0;				for (Integer categoryNum : attrValueSplit.values()) {					double p = categoryNum / totalCategoryNum;					entropy -= p * (Math.log(p) / Math.log(2));				}				double dj = totalCategoryNum / totalNum;				gainInfo += dj * entropy;				splitInfo -= dj * (Math.log(dj) / Math.log(2));			}			double gainRatio = splitInfo == 0.0 ? 0.0 : gainInfo / splitInfo;			StringBuilder splitPoints = new StringBuilder();			for (String attrValue : attrValueSplits.keySet()) {				splitPoints.append(attrValue).append(",");			}			splitPoints.deleteCharAt(splitPoints.length() - 1);			System.out.println("attribute: " + attributeName);			System.out.println("gainRatio: " + gainRatio);			System.out.println("splitPoints: " + splitPoints.toString());			context.write(key, new AttributeGainWritable(attributeName,					gainRatio, false, splitPoints.toString()));		}	}		@Override	protected void cleanup(Context context) throws IOException, InterruptedException {		super.cleanup(context);	}	}

?

?

  相关解决方案