The softmax function
September 13, 2024 — September 16, 2024
A function which maps an arbitrary \(\mathbb{R}^d\)-vector to the weights of a categorical distribution (i.e. the \((d-1)\)-simplex).
The \(d\)-simplex is defined as the set of \(K\)-dimensional vectors whose elements are non-negative and sum to one. Specifically,
\[ \Delta^{K-1} = \left\{ \mathbf{p} \in \mathbb{R}^K : p_i \geq 0 \text{ for all } i, \text{ and } \sum_{i=1}^K p_i = 1 \right\} \]
This set describes all possible probability distributions over \(K\) outcomes, which aligns with the purpose of the softmax function in generating probabilities from “logits” (un-normalised log-probabilities) in classification problems.
Ubiquitous in modern classification tasks, particularly in neural networks.
Why? Well for one, it turns the slightly fiddly problem of estimating a constrained quantity into an unconstrained one, in a computationally expedient way. It’s not the only such option, but it is simple and has lots of nice mathematical symmetries. It is kinda-sorta convex in its arguments. It falls out in variational inference via KL, etc.
1 Basic
The softmax function transforms a vector of real numbers into a probability distribution over predicted output classes for classification tasks. Given a vector \(\mathbf{z} = (z_1, z_2, \dots, z_K)\), the softmax function \(\sigma(\mathbf{z})_i\) for the \(i\)-th component is
\[ \sigma(\mathbf{z})_i = \frac{e^{z_i}}{\sum_{k=1}^K e^{z_k}}.\]
2 Derivatives
The first derivative with respect to \(z_j\) is \[ \begin{aligned} \frac{\partial \sigma_{\phi,i}}{\partial z_j} = \sigma_{\phi,i} \left( \delta_{ij} - \sigma_{\phi,j} \right) \end{aligned} \]
where \(\delta_{ij}\) is the Kronecker delta.
The second derivative is then \[ \begin{aligned} \frac{\partial^2 \sigma_{\phi,i}}{\partial z_j \partial z_k} = \sigma_{\phi,i} (\delta_{ik} - \sigma_{\phi,k})(\delta_{ij} - \sigma_{\phi,j}) - \sigma_{\phi,i} \sigma_{\phi,j} (\delta_{jk} - \sigma_{\phi,k}) \end{aligned} \] i.e.
- \(i = j = k\): \(\sigma_{\phi,i} (1 - \sigma_{\phi,i})(1 - 2 \sigma_{\phi,i})\)
- \(i = j \neq k\), or \(i \neq j = k\): \(\sigma_{\phi,i} \sigma_{\phi,k} (2 \sigma_{\phi,i} - 1)\)
- \(i \neq j \neq k\): \(2 \sigma_{\phi,i} \sigma_{\phi,j} \sigma_{\phi,k}\)
3 Non-exponential
Suppose we do not use the \(\exp\) map, but generalize the softmax to use some other invertible, differentiable, increasing function \(\phi:\mathbb{R}\to\mathbb{R}^+\). Given a vector \(\mathbf{z} = (z_1, z_2, \dots, z_K)\), the generalized softmax function \(\Phi_{\phi}(\mathbf{z})\) for the \(i\)-th component is defined as
\[ \Phi_{\phi}(\mathbf{z})_i = \frac{\phi(z_i)}{\sum_{k=1}^K \phi(z_k)}.\]
4 log-Taylor softmax
TBD
5 Via Gumbel
The softmax function can be approximated using the Gumbel-softmax trick, which is useful for training neural networks with discrete outputs.
6 Entropy
6.1 Softmax
We consider the entropy \(H(\mathbf{p})\) of a categorical distribution with probabilities \(\mathbf{p} = [p_1, p_2, \dots, p_K]^T\), where the probabilities are given by the softmax function, \[ \begin{aligned} p_k = \sigma_k(\mathbf{z}) = \frac{e^{z_k}}{\sum_{j=1}^K e^{z_j}} = \frac{e^{z_k}}{Z}, \end{aligned} \] with \(Z = \sum_{j=1}^K e^{z_j}.\)
The entropy \(H(\mathbf{p})\) is by definition \[ \begin{aligned} H(\mathbf{p}) = -\sum_{k=1}^K p_k \log p_k. \end{aligned} \] Substituting \(p_k\) into the entropy expression, we obtain: \[ \begin{aligned} H(\mathbf{p}) &= -\sum_{k=1}^K p_k \log p_k \\ % &= -\sum_{k=1}^K p_k \left( z_k - \log Z \right) \\ &= -\sum_{k=1}^K p_k z_k + \sum_{k=1}^K p_k \log Z \\ % &= -\sum_{k=1}^K p_k z_k + \log Z \sum_{k=1}^K p_k \\ &= -\sum_{k=1}^K p_k z_k + \log Z. \end{aligned} \]
Thus, the entropy of the softmax distribution simplifies to \[ \begin{aligned} H(\sigma(\mathbf{z})) = \log Z - \sum_{k=1}^K p_k z_k. \end{aligned} \]
If we are using softmax we probably care about derivatives, so let us compute the gradient of the entropy with respect to \(z_i\), \[ \begin{aligned} \frac{\partial H}{\partial z_i} &= \frac{\partial}{\partial z_i} \left( \log Z - \sum_{k=1}^K p_k z_k \right) \\ &= \frac{1}{Z} \frac{\partial Z}{\partial z_i} - \sum_{k=1}^K \left( \frac{\partial p_k}{\partial z_i} z_k + p_k \delta_{ik} \right) \\ &= p_i - \sum_{k=1}^K \left( p_k (\delta_{ik} - p_i) z_k + p_k \delta_{ik} \right) \\ &= p_i - \left( p_i (1 - p_i) z_i + p_i \right) - \sum_{k \neq i} p_k (-p_i) z_k \\ &= -1 + p_i, \end{aligned} \] where we used \(\frac{\partial Z}{\partial z_i} = e^{z_i} = Z p_i\) and \(\frac{\partial p_k}{\partial z_i} = p_k (\delta_{ik} - p_i)\).
Thus, the gradient vector is \[ \begin{aligned} \nabla_\mathbf{z} H = -\mathbf{1} + \mathbf{p}, \end{aligned} \] thence the Hessian matrix \(\nabla^2 H\) \[ \begin{aligned} \frac{\partial^2 H}{\partial z_i \partial z_j} &= \frac{\partial}{\partial z_j} \left( -1 + p_i \right) \\ &= \frac{\partial p_i}{\partial z_j} = p_i (\delta_{ij} - p_j)\\ \nabla^2 H &= \operatorname{diag}(\mathbf{p}) - \mathbf{p} \mathbf{p}^T. \end{aligned} \]
For compactness, we define \(\mathbf{p} = \sigma(\mathbf{z})\). Using the Taylor expansion, we approximate the entropy after a small change \(\Delta \mathbf{z}\): \[ \begin{aligned} H(\mathbf{z} + \Delta \mathbf{z}) &\approx H(\mathbf{z}) + (\nabla_\mathbf{z} H)^T \Delta \mathbf{z} + \frac{1}{2} \Delta \mathbf{z}^T (\nabla^2 H) \Delta \mathbf{z} \\ &= H(\mathbf{p}) + (-\mathbf{1} + \mathbf{p})^T \Delta \mathbf{z} + \frac{1}{2} \Delta \mathbf{z}^T \left( \operatorname{diag}(\mathbf{p}) - \mathbf{p} \mathbf{p}^T \right) \Delta \mathbf{z} \\ &= H(\mathbf{p}) - \mathbf{1}^T \Delta \mathbf{z} + \mathbf{p}^T \Delta \mathbf{z} + \frac{1}{2} \Delta \mathbf{z}^T \operatorname{diag}(\mathbf{p}) \Delta \mathbf{z} - \frac{1}{2} (\mathbf{p}^T \Delta \mathbf{z})^2 \\ &= H(\mathbf{p}) - \mathbf{1}^T \Delta \mathbf{z} + \mathbf{p}^T \Delta \mathbf{z} + \frac{1}{2} \sum_{i=1}^K p_i (\Delta z_i)^2 - \frac{1}{2} \left( \sum_{i=1}^K p_i \Delta z_i \right)^2. \end{aligned} \]
6.2 Non-exponential
Let’s extend the reasoning to category probabilities given by the generalized softmax function. \[ \begin{aligned} p_k = \Phi_k(\mathbf{z}) = \frac{\phi(z_k)}{\sum_{j=1}^K \phi(z_j)} = \frac{\phi(z_k)}{Z}, \end{aligned} \] where \(\phi: \mathbb{R} \rightarrow \mathbb{R}^+\) is an increasing, differentiable function, and \(Z = \sum_{j=1}^K \phi(z_j)\).
The entropy becomes \[ \begin{aligned} H(\mathbf{p}) = -\sum_{k=1}^K p_k \log p_k = -\sum_{k=1}^K p_k \left( \log \phi(z_k) - \log Z \right) = -\sum_{k=1}^K p_k \log \phi(z_k) + \log Z. \end{aligned} \]
To compute the gradient \(\nabla_\mathbf{z} H\), we note that \[ \begin{aligned} \frac{\partial p_k}{\partial z_i} = p_k \left( s_k \delta_{ik} - \sum_{j=1}^K p_j s_j \delta_{ij} \right) = p_k s_k \delta_{ik} - p_k p_i s_i, \end{aligned} \] where \(s_i = \frac{\phi'(z_i)}{\phi(z_i)}\).
Then, the gradient is \[ \begin{aligned} \frac{\partial H}{\partial z_i} &= -\sum_{k=1}^K \left( \frac{\partial p_k}{\partial z_i} \log \phi(z_k) + p_k \frac{\phi'(z_k)}{\phi(z_k)} \delta_{ik} \right) + \frac{1}{Z} \phi'(z_i) \\ &= -\sum_{k=1}^K \left( (p_k s_k \delta_{ik} - p_k p_i s_i) \log \phi(z_k) + p_k s_k \delta_{ik} \right) + \frac{1}{Z} \phi'(z_i). \end{aligned} \]