数据
x=(1,2,3,4,5)
y = (1,1.5,3,4.5,5)
算法结果
R语言运行结果
算法原理
x的均值:
xp=sum(x1+x2+x3+…+xn)/n
y的均值 :
yp=sum(y1+y2+y3+…+yn)/n
x的平方差之和:
lxx=sum( (xi-xp) ^ 2 )
协方差之和
lxy=sum( (xi-xp)*(yi-yp) )
拟合直线 y’=kx+b
k=lxy / lxx
b=yp-k*xp
代码实现
(数据容器选用集合,这样可以把数据当向量运算)
集合求和方法
public static double sum(List<Number> c) {try {Objects.requireNonNull(c);}catch (Exception e) {return Double.NaN;}double ret=0;Iterator<Number> itr=c.iterator();while(itr.hasNext()) {ret+=itr.next().doubleValue();}return ret;}
由于计算协方差需要集合内所有数作乘法,与求和一样,都是对连续计算每个元素,不如定义一个连续计算方法。
集合连续计算方法
第一个参数为一个集合
第二个参数为函数接口,入参1为记录值,入参2为待计算值,出参为计算后的值,是下次迭代的入参1。
//Continuous computationpublic static double conc(List<Number> c,BiFunction<Number,Number,Number> fun) {try {Objects.requireNonNull(c);}catch (Exception e) {return Double.NaN;}Number ret=null;Iterator<Number> itr=c.iterator();while(itr.hasNext()) {if(ret==null) {ret=itr.next();}else{ret=fun.apply(ret,itr.next().doubleValue()); // System.out.println("");}}return ret.doubleValue();}
使用方法,(计算方法由函数接口决定):
可以求和:
conc(x,(r,n)->r.doubleValue()+n.doubleValue());
可以求积:
conc(x,(r,n)->r.doubleValue()*n.doubleValue());
对集合内每一个元素进行计算更新
计算 xi-xp
public static BiFunction< List<Number>,Function<Number,Number>,List<Number> > calc=new BiFunction< List<Number>,Function<Number,Number>,List<Number> >() {public List<Number> apply(List<Number> a, Function<Number,Number> fun) {try {Objects.requireNonNull(a);Objects.requireNonNull(fun);}catch (Exception e) {return null;}List<Number> b=new ArrayList();a.forEach(itme->{b.add(fun.apply(itme));});return b;}};
多个集合运算生成新集合
主要计算协方差----两个集合相乘 (xi-xp)*(yi-yp)
public static BiFunction< List<List<Number>>,Function<List<Number>,Number>,List<Number> > call=new BiFunction< List<List<Number>> ,Function<List<Number>,Number>,List<Number> >() {public List<Number> apply( List<List<Number>> c, Function<List<Number>,Number> fun) {try {Objects.requireNonNull(c);Objects.requireNonNull(fun);}catch (Exception e) {return null;}int width=c.size();List<Number> b=new ArrayList();int height=c.get(0).size();for(int h=0;h<height;h++) {List<Number> tmp=new ArrayList();for(int w=0;w<width;w++) {tmp.add(c.get(w).get(h));} b.add(fun.apply(tmp));}return b;}};
拟合算法
public static double[] lineFit(List<Number> x,List<Number> y){double xp=mean(x);List<Number> xi_xp = calc.apply(x,e->e.doubleValue()-xp);double lxx = sum( calc.apply(xi_xp,e->e.doubleValue()*e.doubleValue()) ); System.out.println(lxx);double yp=mean(y);List<Number> yi_yp = calc.apply(y,e->e.doubleValue()-yp);double lxy = sum( call.apply(Arrays.asList(xi_xp,yi_yp),e->conc(e,(r,n)->r.doubleValue()*n.doubleValue()) ));double k=lxy/lxx;double b=yp-k*xp;return new double[] {k,b};}
完整代码
package utility;import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Function;import javafx.application.Application;
import javafx.scene.Parent;
import javafx.scene.Scene;
import javafx.scene.chart.LineChart;
import javafx.scene.chart.NumberAxis;
import javafx.scene.chart.XYChart.Data;
import javafx.scene.chart.XYChart.Series;
import javafx.stage.Stage;
public class LineFit extends Application{public static double sum(List<Number> c) {try {Objects.requireNonNull(c);}catch (Exception e) {return Double.NaN;}double ret=0;Iterator<Number> itr=c.iterator();while(itr.hasNext()) {ret+=itr.next().doubleValue();}return ret;}//Continuous computationpublic static double conc(List<Number> c,BiFunction<Number,Number,Number> fun) {try {Objects.requireNonNull(c);}catch (Exception e) {return Double.NaN;}Number ret=null;Iterator<Number> itr=c.iterator();while(itr.hasNext()) {if(ret==null) {ret=itr.next();}else{ret=fun.apply(ret,itr.next().doubleValue()); }}return ret.doubleValue();}public static double mean(List<Number> c) {try {Objects.requireNonNull(c);}catch (Exception e) {return Double.NaN;}double ret = c.size()>0? sum(c)/c.size():0;return ret;}public static double[] lineFit(List<Number> x,List<Number> y){double xp=mean(x);List<Number> xi_xp = calc.apply(x,e->e.doubleValue()-xp); double lxx = sum( calc.apply(xi_xp,e->e.doubleValue()*e.doubleValue()) ); double yp=mean(y);List<Number> yi_yp = calc.apply(y,e->e.doubleValue()-yp);double lxy = sum( call.apply(Arrays.asList(xi_xp,yi_yp),e->conc(e,(r,n)->r.doubleValue()*n.doubleValue()) ));double k=lxy/lxx;double b=yp-k*xp;return new double[] {k,b};}public static void main(String[] args) {launch();}@Overridepublic void start(Stage primaryStage) throws Exception {List<Number> x = Arrays.asList(1,2,3,4,5);List<Number> y = Arrays.asList(1,1.5,3,4.5,5);primaryStage.setScene(new Scene(plot(x,y)));primaryStage.show();}public Series<Number, Number> test(List<Number> x,List<Number> y){double[] kb = lineFit(x,y);double k = kb[0];double b = kb[1];Iterator<Number> xi = x.iterator();Series<Number, Number> series = new LineChart.Series<Number,Number>();while(xi.hasNext()) {Number tmp = xi.next();series.getData().add(new Data(tmp,k*tmp.doubleValue()+b));}series.setName("拟合直线 "+"Y="+k+"x+("+String.format("%.2f",b)+")");return series;}public Series<Number, Number> data(List<Number> x,List<Number> y){Series<Number, Number> series = new LineChart.Series<Number,Number>();Iterator<Number> xi = x.iterator();Iterator<Number> yi = y.iterator();while(xi.hasNext()&&yi.hasNext()) {series.getData().add(new Data(xi.next(),yi.next()));}series.setName("data");return series;}public Parent plot(List<Number> x,List<Number> y) { NumberAxis xAxis=new NumberAxis();NumberAxis yAxis=new NumberAxis();LineChart chart = new LineChart(xAxis, yAxis); chart.getData().add(data(x,y));chart.getData().add(test(x,y));return chart;}public static BiFunction< List<Number>,Function<Number,Number>,List<Number> > calc=new BiFunction< List<Number>,Function<Number,Number>,List<Number> >() {public List<Number> apply(List<Number> a, Function<Number,Number> fun) {try {Objects.requireNonNull(a);Objects.requireNonNull(fun);}catch (Exception e) {return null;}List<Number> b=new ArrayList();a.forEach(itme->{b.add(fun.apply(itme));});return b;}};public static BiFunction< List<List<Number>>,Function<List<Number>,Number>,List<Number> > call=new BiFunction< List<List<Number>> ,Function<List<Number>,Number>,List<Number> >() {public List<Number> apply( List<List<Number>> c, Function<List<Number>,Number> fun) {try {Objects.requireNonNull(c);Objects.requireNonNull(fun);}catch (Exception e) {return null;}int width=c.size();List<Number> b=new ArrayList();int height=c.get(0).size();for(int h=0;h<height;h++) {List<Number> tmp=new ArrayList();for(int w=0;w<width;w++) {tmp.add(c.get(w).get(h));} b.add(fun.apply(tmp));}return b;}};
}
R语言代码
x=c(1,2,3,4,5)
y = c(1,1.5,3,4.5,5)
data1=data.frame(x=x,y=y)
lm.data1<-lm(y~x,data=data1)
b<-round(lm.data1$coefficients[1],3)
k<-round(lm.data1$coefficients[2],3)
plot(data1$x,data1$y,xlab="x",ylab = "y",col="red",pch="*")
abline(lm.data1,col="blue")
text(mean(data1$x),max(data1$y),paste("y = ",k,"x+(",k,")",sep = ""))
















