typistapp/
correlation.rs

1use log;
2
3use crate::F64_ALMOST_ZERO;
4
5/// Computes the Pearson correlation coefficient between two vectors of f64 values.
6/// Returns None if the input lengths do not match or are empty.
7pub fn correlation(x_values: &[f64], y_values: &[f64]) -> Option<f64> {
8    if x_values.len() != y_values.len() || x_values.is_empty() || y_values.is_empty() {
9        return None;
10    }
11
12    let n = x_values.len();
13    let mean_x = x_values.iter().sum::<f64>() / n as f64;
14    let mean_y = y_values.iter().sum::<f64>() / n as f64;
15
16    let mut numerator = 0.0;
17    let mut den_x = 0.0;
18    let mut den_y = 0.0;
19
20    for (x, y) in x_values.iter().zip(y_values.iter()) {
21        let diff_x = x - mean_x;
22        let diff_y = y - mean_y;
23        numerator += diff_x * diff_y;
24        den_x += diff_x * diff_x;
25        den_y += diff_y * diff_y;
26    }
27
28    let denominator = den_x.sqrt() * den_y.sqrt();
29    if denominator.abs() < F64_ALMOST_ZERO {
30        let is_den_x_zero = den_x.abs() < F64_ALMOST_ZERO;
31        let is_den_y_zero = den_y.abs() < F64_ALMOST_ZERO;
32        let are_means_equal = (mean_x - mean_y).abs() < F64_ALMOST_ZERO;
33
34        return match (is_den_x_zero, is_den_y_zero, are_means_equal) {
35            (true, true, true) => Some(1.0),
36            _ => Some(0.0),
37        };
38    }
39
40    let result = numerator / denominator;
41    log::trace!("Correlation result: {result}");
42    Some(result)
43}
44
45#[cfg(test)]
46mod tests {
47    use super::*;
48
49    #[test]
50    fn correlation_different_lengths_returns_none() {
51        assert_eq!(correlation(&[1.0], &[1.0, 2.0]), None);
52    }
53
54    #[test]
55    fn correlation_empty_slices_returns_none() {
56        assert_eq!(correlation(&[], &[]), None);
57    }
58
59    #[test]
60    fn correlation_valid_data_returns_some() {
61        let x_values = [1.0, 2.0, 3.0];
62        let y_values = [4.0, 5.0, 6.0];
63        let result = correlation(&x_values, &y_values);
64        assert!(result.is_some());
65        assert!((result.unwrap() - 1.0).abs() < 1e-9);
66    }
67}