当前位置: 代码迷 >> GIS >> mahout源码分析之logistic regression(一)-实战
  详细解决方案

mahout源码分析之logistic regression(一)-实战

热度:673   发布时间:2016-05-05 06:12:12.0
mahout源码分析之logistic regression(1)--实战

版本:mahout0.9

Mahout里面使用逻辑回归(logistic regression)的主要两个类是org.apache.mahout.classifier.sgd.TrainLogistic、org.apache.mahout.classifier.sgd.RunLogistic,一个是建立模型,一个是进行模型评估。

首先是原始数据,格式如下:(可以在https://github.com/dirkweissenborn/mahout-rbmClassifier/blob/master/examples/src/main/resources/donut.csv#L1下载)

"x","y","shape","color","k","k0","xx","xy","yy","a","b","c","bias"0.923307513352484,0.0135197141207755,21,2,4,8,0.852496764213146,0.0124828536260896,0.000182782669907495,0.923406490600458,0.0778750292332978,0.644866125183976,10.711011884035543,0.909141522599384,22,2,3,9,0.505537899239772,0.64641042683833,0.826538308114327,1.15415605849213,0.953966686673604,0.46035073663368,10.75118898646906,0.836567111080512,23,2,3,9,0.564284893392414,0.62842000028592,0.699844531341594,1.12433510339845,0.872783737128441,0.419968245447719,1

进入mahout的bin目录,运行:

./mahout trainlogistic --input /data/mahout-data/donut.csv --output /data/mahout-output/model2 --target color --categories 2 --predictors x y a b c --types numeric --features 20 --passes 100 --rate 50

这里各个参数说明如下:

input:输入数据;output:输出模型文件;--target 预测的变量(输入数据要求第一行为变量名称);categories 预测变量的取值个数;predictors参与建模的变量;types 预测变量的类型(number、word、text其中一个,如果全部是一样的话,使用一个就可以);pass训练的时候对输入数据测试的次数(这里也不是很清楚);feature内部随机向量维度(用于建模,好像是这样理解,越大越好,但是时间会长 );rate学习速率(如果输入数据比较大,此值可以设置大点)。

得到下面的输出:

Running on hadoop, using /opt/hadoop2/bin/hadoop and HADOOP_CONF_DIR=MAHOUT-JOB: /opt/mahout-distribution-0.9/examples/target/mahout-examples-0.9-job.jarSLF4J: Class path contains multiple SLF4J bindings.SLF4J: Found binding in [jar:file:/opt/hadoop2/share/hadoop/common/lib/slf4j-log4j12-1.7.5.jar!/org/slf4j/impl/StaticLoggerBinder.class]SLF4J: Found binding in [jar:file:/opt/hadoop2/share/hadoop/mapreduce/lib/mahout-core-0.9-job.jar!/org/slf4j/impl/StaticLoggerBinder.class]SLF4J: See http://www.slf4j.org/codes.html#multiple_bindings for an explanation.SLF4J: Actual binding is of type [org.slf4j.impl.Log4jLoggerFactory]20color ~ 7.068*Intercept Term + 0.581*a + -1.369*b + -25.059*c + 0.581*x + 2.319*y      Intercept Term 7.06759                   a 0.58123                   b -1.36893                   c -25.05945                   x 0.58123                   y 2.31879    0.000000000     0.000000000     0.000000000     0.000000000     0.000000000    -1.368933989     0.000000000     0.000000000     0.000000000     0.000000000     0.581234210     0.000000000     0.000000000     7.067587159     0.000000000     0.000000000     0.000000000     2.318786209     0.000000000   -25.059452292 14/04/11 10:33:18 INFO driver.MahoutDriver: Program took 1758 ms (Minutes: 0.0293)

我这里有slf jar包的冲突,暂时不理这个。看后面的公式即可(公式变量前的值,每次训练不一定相同),应该是由这个公式算得最后的预测结果的,但是暂时不清楚Intercept是什么。

然后使用模型评估命令(测试数据:https://svn.apache.org/repos/asf/mahout/trunk/examples/src/main/resources/donut-test.csv):

 ./mahout runlogistic --input /data/mahout-data/donut-test.csv --model /data/mahout-output/model2 --scores --auc --confusion

input就是测试数据;model是模型文件;scores打印预测值和原始值对比;auc打印auc值(评判主要标准,越大越好,最好接近1);confusion打印模糊矩阵;

得到下面的结果:

Running on hadoop, using /opt/hadoop2/bin/hadoop and HADOOP_CONF_DIR=MAHOUT-JOB: /opt/mahout-distribution-0.9/examples/target/mahout-examples-0.9-job.jarSLF4J: Class path contains multiple SLF4J bindings.SLF4J: Found binding in [jar:file:/opt/hadoop2/share/hadoop/common/lib/slf4j-log4j12-1.7.5.jar!/org/slf4j/impl/StaticLoggerBinder.class]SLF4J: Found binding in [jar:file:/opt/hadoop2/share/hadoop/mapreduce/lib/mahout-core-0.9-job.jar!/org/slf4j/impl/StaticLoggerBinder.class]SLF4J: See http://www.slf4j.org/codes.html#multiple_bindings for an explanation.SLF4J: Actual binding is of type [org.slf4j.impl.Log4jLoggerFactory]"target","model-output","log-likelihood"0,0.009,-0.0092410,0.000,-0.0004811,0.985,-0.0150381,0.991,-0.0094070,0.001,-0.0008831,0.974,-0.0260001,0.823,-0.1948750,0.041,-0.0420150,0.051,-0.0525650,0.613,-0.9500080,0.147,-0.1585381,0.910,-0.0941771,0.252,-1.3772201,0.924,-0.0785211,0.998,-0.0017770,0.023,-0.0237561,0.990,-0.0099280,0.003,-0.0031181,0.961,-0.0392840,0.000,-0.0000460,0.167,-0.1831600,0.049,-0.0498220,0.006,-0.0057920,0.706,-1.2224870,0.000,-0.0004211,0.999,-0.0010451,0.969,-0.0314520,0.034,-0.0340880,0.370,-0.4616320,0.011,-0.0114890,0.465,-0.6249710,0.053,-0.0546460,0.340,-0.4149590,0.053,-0.0541230,0.007,-0.0068000,0.248,-0.2856501,0.482,-0.7288350,0.781,-1.5169600,0.024,-0.0239750,0.022,-0.022281AUC = 0.97confusion: [[24.0, 2.0], [3.0, 11.0]]entropy: [[-0.2, -2.8], [-4.1, -0.1]]14/04/11 10:43:39 INFO driver.MahoutDriver: Program took 414 ms (Minutes: 0.0069)
可以看到auc=0.97 说明模型还是比较好的;模糊矩阵中说明 有2个应该被分为1的被分为了0,有3个应该是0的结果被分为了1。

本来打算使用上面得到的公式带入测试数据,看能否得到第一行的输出,比如0.009,但是不知道哪个Interceptor值是什么,所以也是没有得到0.009的。大概浏览了下源码,好像要归一化的。具体下次在分析。

总结:

     目前遇到的问题有:1)如何使用上面的公式(Interceptor是什么?);2)如何把这个在hadoop上面运行起来(从上面的结果来看,似乎mahout并没有运行在hadoop上面)。


分享,成长,快乐

转载请注明blog地址:http://blog.csdn.net/fansy1990



  相关解决方案