--- title: "Computing the Wasserstein distance" bibliography: ../inst/REFERENCES.bib output: rmarkdown::html_vignette: toc: true toc_depth: 2 number_sections: false fig_caption: true mathjax: default vignette: > %\VignetteIndexEntry{Computing the Wasserstein distance} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r init, eval=TRUE, echo=FALSE} if (!requireNamespace("mlbench", quietly = TRUE)) { knitr::opts_chunk$set(eval = FALSE) message("Install 'mlbench' to run the code in this vignette.") } ``` ```{r setup, include=FALSE, echo=FALSE} knitr::opts_chunk$set(cache = TRUE) library(mlbench) library(T4transport) set.seed(10) ``` # Introduction The Wasserstein distance [@villani_2003_TopicsOptimalTransportation] provides a natural and geometrically meaningful way to compare probability distributions. Unlike traditional dissimilarities such as Kullback–Leibler divergence or total variation, it reflects the minimal "cost" of transporting one distribution to another, with respect to an underlying geometry of the sample space. Formally, the \( p \)-Wasserstein distance between two probability measures \( \mu \) and \( \nu \) on a metric space \( (\mathcal{X}, d) \) is defined as: \[ W_p(\mu, \nu) = \left( \inf_{\gamma \in \Gamma(\mu, \nu)} \int_{\mathcal{X} \times \mathcal{X}} d(x, y)^p \, d\gamma(x, y) \right)^{1/p}, \] where \( \Gamma(\mu, \nu) \) denotes the set of all couplings (i.e., joint distributions) with marginals \( \mu \) and \( \nu \). This is known as the Kantrovich formulation [@kantorovitch_1958_TranslocationMasses], which differs from the original formulation by @monge_1781_MemoireTheorieDeblais and yet arises the equivalent measure of distance under mild conditions. # Toy example We use two artificial datasets - `cassini` and `smiley` - from the [mlbench](https://CRAN.R-project.org/package=mlbench) package to demonstrate the usage of our package. Let's generate small samples of cardinality $n=100$. For visual clarity, two datasets are first normalized to have mean zero and unit variance across both dimensions, and the second one is translated by +5 in the $x$-direction. ```{r toy_code, echo=TRUE, eval=TRUE, cache=FALSE, fig.alt="Two dataets: Cassini and Smiley"} # load the library library(mlbench) # generate two datasets data1 = mlbench::mlbench.cassini(n=100)$x data2 = mlbench::mlbench.smiley(n=100)$x # normalize the datasets data1 = as.matrix(scale(data1)) data2 = as.matrix(scale(data2)) # translate the second dataset data2[,1] = data2[,1] + 5 # plot the datasets plot(data1, col="blue", pch=19, cex=0.5, main="Two datasets", xlim=c(-2, 7), xlab="x", ylab="y") points(data2[,1], data2[,2], col="red", cex=0.5, pch=19) ``` As shown in the figure, two datasets are quite different in terms of their shapes. Moreover, the horizontal translation applied to the second dataset (`smiley`) makes it more distinct. # How to compute? Per the formulation of Kantrovich, computing the Wasserstein distance requires solving a linear programming problem. In **T4transport**, the `wasserstein()` function achieves that. Let's see how we can use it. Assume that we consider the simplest case of order 2. ```{r compute1, echo=TRUE, eval=TRUE, cache=TRUE, fig.alt="Wasserstein distance computation"} # call the function output = wasserstein(data1, data2, p=2) # print the output print(paste0("2-wasserstein distance: ",round(output$distance, 4))) ``` The computed Wasserstein distance of order 2 between the two empirical measures is `r round(output$distance,4)`. Another benefit of the Kantrovich formulation is that it returns an optimal coupling matrix that matches elements from two sets. The attained coupling matrix $\hat{\Gamma}$ is given as follows. ```{r compute2a, echo=FALSE, eval=TRUE, fig.alt="Optimal coupling matrix", fig.align="center"} par(pty="s") ## --- Plot 1: Optimal coupling matrix P = output$plan image( x = 1:nrow(P), y = 1:ncol(P), z = t(P)[, nrow(P):1], # transpose + flip rows col = gray.colors(100, start = 1, end = 0), xlab = "Source Index", ylab = "Target Index", main = "Optimal Coupling", axes = FALSE ) axis(1, at = 1:nrow(P), labels = 1:nrow(P), las = 2, cex.axis = 0.6, tick=FALSE) axis(2, at = 1:ncol(P), labels = rev(1:ncol(P)), las = 2, cex.axis = 0.6, tick=FALSE) ``` Furthermore, this coupling can be shown within the original scatterplot of the two datasets by considering the bipartite graph representation. ```{r compute2b, echo=FALSE, eval=TRUE, fig.alt="Optimal coupling scatterplot", fig.align="center"} ## --- Plot 2: Bipartite graph plot(data1, col="blue", pch=19, cex=0.5, main="Bipartite Graph", xlim=c(-2, 7), xlab="x", ylab="y") points(data2[,1], data2[,2], col="red", cex=0.5, pch=19) maxP = max(P) multiplier = 0.5 for (i in 1:nrow(data1)){ for (j in 1:nrow(data2)){ if (P[i,j] > 0){ lines(x = c(data1[i, 1], data2[j, 1]), y = c(data1[i, 2], data2[j, 2]), col = "gray40", lwd = multiplier * P[i, j] / maxP) } } } ``` In the figure, edges connecting the dots represent the optimal coupling between the two empirical measures. The thickness of the edges is proportional to the amount of mass transported between the points, which can be interpreted as the "flow" in the transportation problem. # Alternative input It is often a case where we have cross distances between two empirical measures, rather than the measures themselves. For instance, the current `wasserstein()` function only assumes the Euclidean-valued atoms. In practical scenarios, it is plausible to have the distances or dissimilarities according to a user-defined metric. The `wassersteinD()` function is the choice in such scenarios. ```{r computeD, echo=TRUE, eval=TRUE, cache=TRUE, fig.alt="Wasserstein distance computation with distances"} # compute the cross distance with a helper function cross_dist <- function(X, Y) { X2 <- rowSums(X^2) Y2 <- rowSums(Y^2) sqrt(outer(X2, Y2, "+") - 2 * tcrossprod(X, Y)) } cdist = cross_dist(data1, data2) # call the function crossed = wassersteinD(cdist, p=2) # print the output print(paste0("2-wasserstein distance: ",round(crossed$distance, 4))) ``` Supplying a cross distance matrix also returns the 2-Wasserstein distance value of `r round(crossed$distance,4)`, which is identical to the value we obtained before (`r round(output$distance,4)`). # Disclaimer Computing the Wasserstein distance using a linear programming (LP) solver is the classical approach based on the Kantrovich formulation. It is typically included as a core method in any OT libraries. However, this approach does not scale well to large-scale data, since the LP problem involves $n\times m$ variables and requires $\mathcal{O}(n^2)$ memory and at least $\mathcal{O}(n^3)$ time in the worst case. For this reason, we bring alternative algorithms in the package and hope you enjoy exploring them. # References