001/* 002 * Copyright (C) 2012 The Guava Authors 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 005 * in compliance with the License. You may obtain a copy of the License at 006 * 007 * http://www.apache.org/licenses/LICENSE-2.0 008 * 009 * Unless required by applicable law or agreed to in writing, software distributed under the License 010 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 011 * or implied. See the License for the specific language governing permissions and limitations under 012 * the License. 013 */ 014 015package com.google.common.math; 016 017import static com.google.common.base.Preconditions.checkState; 018import static com.google.common.primitives.Doubles.isFinite; 019import static java.lang.Double.NaN; 020import static java.lang.Double.isNaN; 021 022import com.google.common.annotations.Beta; 023import com.google.common.annotations.GwtIncompatible; 024 025/** 026 * A mutable object which accumulates paired double values (e.g. points on a plane) and tracks some 027 * basic statistics over all the values added so far. This class is not thread safe. 028 * 029 * @author Pete Gillin 030 * @since 20.0 031 */ 032@Beta 033@GwtIncompatible 034public final class PairedStatsAccumulator { 035 036 // These fields must satisfy the requirements of PairedStats' constructor as well as those of the 037 // stat methods of this class. 038 private final StatsAccumulator xStats = new StatsAccumulator(); 039 private final StatsAccumulator yStats = new StatsAccumulator(); 040 private double sumOfProductsOfDeltas = 0.0; 041 042 /** 043 * Adds the given pair of values to the dataset. 044 */ 045 public void add(double x, double y) { 046 // We extend the recursive expression for the one-variable case at Art of Computer Programming 047 // vol. 2, Knuth, 4.2.2, (16) to the two-variable case. We have two value series x_i and y_i. 048 // We define the arithmetic means X_n = 1/n \sum_{i=1}^n x_i, and Y_n = 1/n \sum_{i=1}^n y_i. 049 // We also define the sum of the products of the differences from the means 050 // C_n = \sum_{i=1}^n x_i y_i - n X_n Y_n 051 // for all n >= 1. Then for all n > 1: 052 // C_{n-1} = \sum_{i=1}^{n-1} x_i y_i - (n-1) X_{n-1} Y_{n-1} 053 // C_n - C_{n-1} = x_n y_n - n X_n Y_n + (n-1) X_{n-1} Y_{n-1} 054 // = x_n y_n - X_n [ y_n + (n-1) Y_{n-1} ] + [ n X_n - x_n ] Y_{n-1} 055 // = x_n y_n - X_n y_n - x_n Y_{n-1} + X_n Y_{n-1} 056 // = (x_n - X_n) (y_n - Y_{n-1}) 057 xStats.add(x); 058 if (isFinite(x) && isFinite(y)) { 059 if (xStats.count() > 1) { 060 sumOfProductsOfDeltas += (x - xStats.mean()) * (y - yStats.mean()); 061 } 062 } else { 063 sumOfProductsOfDeltas = NaN; 064 } 065 yStats.add(y); 066 } 067 068 /** 069 * Adds the given statistics to the dataset, as if the individual values used to compute the 070 * statistics had been added directly. 071 */ 072 public void addAll(PairedStats values) { 073 if (values.count() == 0) { 074 return; 075 } 076 077 xStats.addAll(values.xStats()); 078 if (yStats.count() == 0) { 079 sumOfProductsOfDeltas = values.sumOfProductsOfDeltas(); 080 } else { 081 // This is a generalized version of the calculation in add(double, double) above. Note that 082 // non-finite inputs will have sumOfProductsOfDeltas = NaN, so non-finite values will result 083 // in NaN naturally. 084 sumOfProductsOfDeltas += 085 values.sumOfProductsOfDeltas() 086 + (values.xStats().mean() - xStats.mean()) 087 * (values.yStats().mean() - yStats.mean()) 088 * values.count(); 089 } 090 yStats.addAll(values.yStats()); 091 } 092 093 /** 094 * Returns an immutable snapshot of the current statistics. 095 */ 096 public PairedStats snapshot() { 097 return new PairedStats(xStats.snapshot(), yStats.snapshot(), sumOfProductsOfDeltas); 098 } 099 100 /** 101 * Returns the number of pairs in the dataset. 102 */ 103 public long count() { 104 return xStats.count(); 105 } 106 107 /** 108 * Returns an immutable snapshot of the statistics on the {@code x} values alone. 109 */ 110 public Stats xStats() { 111 return xStats.snapshot(); 112 } 113 114 /** 115 * Returns an immutable snapshot of the statistics on the {@code y} values alone. 116 */ 117 public Stats yStats() { 118 return yStats.snapshot(); 119 } 120 121 /** 122 * Returns the population covariance of the values. The count must be non-zero. 123 * 124 * <p>This is guaranteed to return zero if the dataset contains a single pair of finite values. It 125 * is not guaranteed to return zero when the dataset consists of the same pair of values multiple 126 * times, due to numerical errors. 127 * 128 * <h3>Non-finite values</h3> 129 * 130 * <p>If the dataset contains any non-finite values ({@link Double#POSITIVE_INFINITY}, 131 * {@link Double#NEGATIVE_INFINITY}, or {@link Double#NaN}) then the result is {@link Double#NaN}. 132 * 133 * @throws IllegalStateException if the dataset is empty 134 */ 135 public double populationCovariance() { 136 checkState(count() != 0); 137 return sumOfProductsOfDeltas / count(); 138 } 139 140 /** 141 * Returns the sample covariance of the values. The count must be greater than one. 142 * 143 * <p>This is not guaranteed to return zero when the dataset consists of the same pair of values 144 * multiple times, due to numerical errors. 145 * 146 * <h3>Non-finite values</h3> 147 * 148 * <p>If the dataset contains any non-finite values ({@link Double#POSITIVE_INFINITY}, 149 * {@link Double#NEGATIVE_INFINITY}, or {@link Double#NaN}) then the result is {@link Double#NaN}. 150 * 151 * @throws IllegalStateException if the dataset is empty or contains a single pair of values 152 */ 153 public final double sampleCovariance() { 154 checkState(count() > 1); 155 return sumOfProductsOfDeltas / (count() - 1); 156 } 157 158 /** 159 * Returns the <a href="http://mathworld.wolfram.com/CorrelationCoefficient.html">Pearson's or 160 * product-moment correlation coefficient</a> of the values. The count must greater than one, and 161 * the {@code x} and {@code y} values must both have non-zero population variance (i.e. 162 * {@code xStats().populationVariance() > 0.0 && yStats().populationVariance() > 0.0}). The result 163 * is not guaranteed to be exactly +/-1 even when the data are perfectly (anti-)correlated, due to 164 * numerical errors. However, it is guaranteed to be in the inclusive range [-1, +1]. 165 * 166 * <h3>Non-finite values</h3> 167 * 168 * <p>If the dataset contains any non-finite values ({@link Double#POSITIVE_INFINITY}, 169 * {@link Double#NEGATIVE_INFINITY}, or {@link Double#NaN}) then the result is {@link Double#NaN}. 170 * 171 * @throws IllegalStateException if the dataset is empty or contains a single pair of values, or 172 * either the {@code x} and {@code y} dataset has zero population variance 173 */ 174 public final double pearsonsCorrelationCoefficient() { 175 checkState(count() > 1); 176 if (isNaN(sumOfProductsOfDeltas)) { 177 return NaN; 178 } 179 double xSumOfSquaresOfDeltas = xStats.sumOfSquaresOfDeltas(); 180 double ySumOfSquaresOfDeltas = yStats.sumOfSquaresOfDeltas(); 181 checkState(xSumOfSquaresOfDeltas > 0.0); 182 checkState(ySumOfSquaresOfDeltas > 0.0); 183 // The product of two positive numbers can be zero if the multiplication underflowed. We 184 // force a positive value by effectively rounding up to MIN_VALUE. 185 double productOfSumsOfSquaresOfDeltas = 186 ensurePositive(xSumOfSquaresOfDeltas * ySumOfSquaresOfDeltas); 187 return ensureInUnitRange(sumOfProductsOfDeltas / Math.sqrt(productOfSumsOfSquaresOfDeltas)); 188 } 189 190 /** 191 * Returns a linear transformation giving the best fit to the data according to 192 * <a href="http://mathworld.wolfram.com/LeastSquaresFitting.html">Ordinary Least Squares linear 193 * regression</a> of {@code y} as a function of {@code x}. The count must be greater than one, and 194 * either the {@code x} or {@code y} data must have a non-zero population variance (i.e. 195 * {@code xStats().populationVariance() > 0.0 || yStats().populationVariance() > 0.0}). The result 196 * is guaranteed to be horizontal if there is variance in the {@code x} data but not the {@code y} 197 * data, and vertical if there is variance in the {@code y} data but not the {@code x} data. 198 * 199 * <p>This fit minimizes the root-mean-square error in {@code y} as a function of {@code x}. This 200 * error is defined as the square root of the mean of the squares of the differences between the 201 * actual {@code y} values of the data and the values predicted by the fit for the {@code x} 202 * values (i.e. it is the square root of the mean of the squares of the vertical distances between 203 * the data points and the best fit line). For this fit, this error is a fraction 204 * {@code sqrt(1 - R*R)} of the population standard deviation of {@code y}, where {@code R} is the 205 * Pearson's correlation coefficient (as given by {@link #pearsonsCorrelationCoefficient()}). 206 * 207 * <p>The corresponding root-mean-square error in {@code x} as a function of {@code y} is a 208 * fraction {@code sqrt(1/(R*R) - 1)} of the population standard deviation of {@code x}. This fit 209 * does not normally minimize that error: to do that, you should swap the roles of {@code x} and 210 * {@code y}. 211 * 212 * <h3>Non-finite values</h3> 213 * 214 * <p>If the dataset contains any non-finite values ({@link Double#POSITIVE_INFINITY}, 215 * {@link Double#NEGATIVE_INFINITY}, or {@link Double#NaN}) then the result is 216 * {@link LinearTransformation#forNaN()}. 217 * 218 * @throws IllegalStateException if the dataset is empty or contains a single pair of values, or 219 * both the {@code x} and {@code y} dataset have zero population variance 220 */ 221 public final LinearTransformation leastSquaresFit() { 222 checkState(count() > 1); 223 if (isNaN(sumOfProductsOfDeltas)) { 224 return LinearTransformation.forNaN(); 225 } 226 double xSumOfSquaresOfDeltas = xStats.sumOfSquaresOfDeltas(); 227 if (xSumOfSquaresOfDeltas > 0.0) { 228 if (yStats.sumOfSquaresOfDeltas() > 0.0) { 229 return LinearTransformation.mapping(xStats.mean(), yStats.mean()) 230 .withSlope(sumOfProductsOfDeltas / xSumOfSquaresOfDeltas); 231 } else { 232 return LinearTransformation.horizontal(yStats.mean()); 233 } 234 } else { 235 checkState(yStats.sumOfSquaresOfDeltas() > 0.0); 236 return LinearTransformation.vertical(xStats.mean()); 237 } 238 } 239 240 private double ensurePositive(double value) { 241 if (value > 0.0) { 242 return value; 243 } else { 244 return Double.MIN_VALUE; 245 } 246 } 247 248 private static double ensureInUnitRange(double value) { 249 if (value >= 1.0) { 250 return 1.0; 251 } 252 if (value <= -1.0) { 253 return -1.0; 254 } 255 return value; 256 } 257}