DEMix Layers: Disentangling Domains for Modular Language Modeling by dokuDoku mixture of experts 🔍 Hide media ▲ 📄 Document Download PDF ⇔ MoE model where each expert corresponds to a different domain i.e. source text dataset (Reddit, Medical Papers, etc.). Can extrapolate to new domains by copying and training nearest expert. The main contribution is the **DEMix layer**, which is a collection of expert feedforward networks each specialized to a domain that can be used as a replacement for a feedforward network. This makes the LM modular since experts can be mixed, added or removed after initial training. This also enables new experts to be added without forgetting and without retraining the entire network. DEMix layers explicitly conditions the LM on the domain of the input text or estimates it at inference. *Note* that the domain is simply the dataset the text is from, it is quite non-rigorous & there's high overlap between domains. For historical reference, Shazeer et al 2017 *Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer* proposed a one or more MLP layers as an ensemble of n experts $FFN_1, ..., FFN_n$ assigned weights by functions $g_1,...,g_n$: $$ FFN(h_{t,l-1}) = \sum_{j=1}^n g_j (h_{t,l-1}) \cdot FFN_j(h_{t,l-1})$$ The g function routes tokens to different experts, which are usually instances of the original. Furthermore, this usually routes at the token level, necessitating load balancing for experts. For **DEMix layers** we simply use what dataset the tokens come from to route the data. i.e. each token in the same sequence goes to the same expert. Indexing experts by $D$ and $d \in D$ is the domain label for a sequence, then $$ g_j(h_{t,l}) = \begin{cases} 1 & \text{if } j = d \\ 0 & \text{if } otherwise \end{cases}$$ Note that at test time, a sequence can be from multiple domains improving model performance. I also believe they did this to make training more efficient since what GPUs will be utilized is now predictable. Also interesting to note, their experiments showed that interleaving fully connected layers with DEMix layers led to worse in-domain performance. So having effectively n separate networks performed better when training. Their hypothesis was that shared layers are a bottleneck for finding shared features between domains & impede performance when training on highly different domains since the necessary abstractions are different. It was shown that DEMix models outperformed Dense (standard) models based on in domain test set perplexity meaning roughly how well it was able to see which dataset a sample was from and do. Additionally, DEMix models perform better wrt dense models when there's lower parameter count, but as you scale dense models catch up. There were two DEMix models including **DEMix (native)** and **DEMix cached**. DEMix cached basically just says that prior to testing, we look at the distribution of the test data for the routers. From here we calculate the posterior over domain labels from that data and fix the prior to that estimate. To give context, please look at Equations 5 & 6, where we calculate $p(D_t=j|x_t)$ i.e. the probability we're in domain j given the sequence. Equations 5 & 6 are Bayes rule for this, respectively: $$p(D_t=j|x_t) = \frac{p(x_{\lt t}|D_t=j) \cdot p(D_t=j)}{p(x_{\lt t})}$$ $$p(D_t=j|x_t) = \frac{p(x_{\lt t}|D_t=j) \cdot p(D_t=j)} {\sum_{j'-1}^n p(x_{\lt t}|D_t =j') \cdot p(D_t=j')} $$ We can find the likelihood, prior, and marginal from the data. In particular, the DEMix Cached calculates the posterior from the test data's likelihood, prior, and marginal and then sets that to the prior, which of course greatly increases performance (almost cheating). The other techniques use a uniform prior and exponentially weighted moving average. This $p(D_t=j|x_t)$ is directly used in the router and for recombining outputs. To expand to new domains what we do is we simply use Equation 6, see which domain has the highest posterior $p(D_t=j|x_t)$, then copy that network to initialize training for the new domain. This approach is quite simple and works since the domains are all quite close in my opinion. This approach works well, however there's a large assumption that your existing domains are close enough to your new one you can do this. It would be **interesting to study this in a robotics context** where experts are different actions or diffusion policies possibly. Of note, this approach doesn't cause forgetting on previous domains. Similarly, to restrict bad behavior from unwanted domains, just take that expert out of the model. Note that this is not full proof or rigorous. MoE model where each expert corresponds to a different domain i.e. source text dataset (Reddit, Medical Papers, etc.). Can extrapolate to new domains by copying and training nearest expert. The main contribution is the **DEMix layer**, which is a collection of expert feedforward networks each specialized to a domain that can be used as a replacement for a feedforward network. This makes the LM modular since experts can be mixed, added or removed after initial training. This also enables new experts to be added without forgetting and without retraining the entire network. DEMix layers explicitly conditions the LM on the domain of the input text or estimates it at inference. *Note* that the domain is simply the dataset the text is from, it is quite non-rigorous & there's high overlap between domains. For historical reference, Shazeer et al 2017 *Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer* proposed a one or more MLP layers as an ensemble of n experts $FFN_1, ..., FFN_n$ assigned weights by functions $g_1,...,g_n$: $$ FFN(h_{t,l-1}) = \sum_{j=1}^n g_j (h_{t,l-1}) \cdot FFN_j(h_{t,l-1})$$ The g function routes tokens to different experts, which are usually instances of the original. Furthermore, this usually routes at the token level, necessitating load balancing for experts. For **DEMix layers** we simply use what dataset the tokens come from to route the data. i.e. each token in the same sequence goes to the same expert. Indexing experts by $D$ and $d \in D$ is the domain label for a sequence, then $$ g_j(h_{t,l}) = \begin{cases} 1 & \text{if } j = d \\ 0 & \text{if } otherwise \end{cases}$$ Note that at test time, a sequence can be from multiple domains improving model performance. I also believe they did this to make training more efficient since what GPUs will be utilized is now predictable. Also interesting to note, their experiments showed that interleaving fully connected layers with DEMix layers led to worse in-domain performance. So having effectively n separate networks performed better when training. Their hypothesis was that shared layers are a bottleneck for finding shared features between domains & impede performance when training on highly different domains since the necessary abstractions are different. It was shown that DEMix models outperformed Dense (standard) models based on in domain test set perplexity meaning roughly how well it was able to see which dataset a sample was from and do. Additionally, DEMix models perform better wrt dense models when there's lower parameter count, but as you scale dense models catch up. There were two DEMix models including **DEMix (native)** and **DEMix cached**. DEMix cached basically just says that prior to testing, we look at the distribution of the test data for the routers. From here we calculate the posterior over domain labels from that data and fix the prior to that estimate. To give context, please look at Equations 5 & 6, where we calculate $p(D_t=j|x_t)$ i.e. the probability we're in domain j given the sequence. Equations 5 & 6 are Bayes rule for this, respectively: $$p(D_t=j|x_t) = \frac{p(x_{\lt t}|D_t=j) \cdot p(D_t=j)}{p(x_{\lt t})}$$ $$p(D_t=j|x_t) = \frac{p(x_{\lt t}|D_t=j) \cdot p(D_t=j)} {\sum_{j'-1}^n p(x_{\lt t}|D_t =j') \cdot p(D_t=j')} $$ We can find the likelihood, prior, and marginal from the data. In particular, the DEMix Cached calculates the posterior from the test data's likelihood, prior, and marginal and then sets that to the prior, which of course greatly increases performance (almost cheating). The other techniques use a uniform prior and exponentially weighted moving average. This $p(D_t=j|x_t)$ is directly used in the router and for recombining outputs. To expand to new domains what we do is we simply use Equation 6, see which domain has the highest posterior $p(D_t=j|x_t)$, then copy that network to initialize training for the new domain. This approach is quite simple and works since the domains are all quite close in my opinion. This approach works well, however there's a large assumption that your existing domains are close enough to your new one you can do this. It would be **interesting to study this in a robotics context** where experts are different actions or diffusion policies possibly. Of note, this approach doesn't cause forgetting on previous domains. Similarly, to restrict bad behavior from unwanted domains, just take that expert out of the model. Note that this is not full proof or rigorous. Comments (0) Please log in to comment. No comments yet. Be the first to comment! ← Back to Blog
Comments (0)
Please log in to comment.
No comments yet. Be the first to comment!