Frequency Based Constrained Decoding for Language Model Watermarks
Dec 24, 24pip install constraint-watermark
Check out the github.
Inspiration
A few months ago I read the OpenAI blog on how they use constrained decoding to provide 100% accuracy on structured outputs like JSON. The basic idea, in the context of JSON for simplicity, is as follows:
Consider we want to generate the first token of a JSON structured output given a prompt \(x\). For generation of the first token, a transformer will model the distribution \(p(y_1|x)\) where \(y_i\) is the \(i\)-th token generated by the model. Let $m$ be a parameter representing the size of the vocabulary. Assume our output comes in the form of a tensor of size \((m,)\) of raw logits (unnormalized). Call this tensor \(d\).
We know the first token in any JSON output should be either a {
or a [
. More formally, say the “allowed” set of tokens that \(y_i\) can take on is \(S_i\). So in this context \(S_0\) contains one element that is just the {
token and the [
token.
So, given our tensor, for all tokens \(c \notin S_0\), we can set \(d_c=-\infty\). After this, once we run softmax the probability these tokens are selected becomes \(0\). As such, only tokens in \(S_0\) (i.e., just {
and [
) will have positive probability of being selected regardless of our sampling scheme (deterministic argmax, temperature adjusted, top-k, etc.).
Now, consider \(S_1\). We know that in JSON, after an initial {
, we have further restrictions on the type of characters we can encounter, as well as with an initial [
character. Based on the first generated token we can determine the elements of \(S_1\). Encoding these rules in some structure, allows us to mask “unallowed tokens” to only sample from the desired tokens.
Watermarking Motivation
This constrained decoding scheme has been used mainly for structured outputs, but it is also incredibly beneficial for watermarking. By steering outputs and controlling token generation, we can effectively embed watermarks into the generated text. This is crucial for AI safety, as it allows us to encode unique identifiers or signatures within the text, ensuring authenticity and traceability. With multi-agent systems becoming a more popular use case, output watermarking will become critical in order for agents to verify that their inputs came from their peer agents rather than a bad actor.
Approach
The frequency based approach is one of customizability and test-time flexibility. After loading a model and tokenizer, a user can specify a window size (call this \(r\)) and a constraint dictionary. The constraint dictionary contains key-value pairs of tokens and limits. This means that for a token \(t\) in the dictionary with limit \(z\), in any group of \(r\) consecutive tokens generated by the model, \(t\) will appear at most \(z\) times.
How is this implemented? We maintain a dictionary initialized to the inputted constraint dictionary (any tokens not given a limit by the user are defaulted to float('inf')
). Then, as tokens are generated, tokens with their limit as 0 are masked (i.e. logit set to -float('inf')
). Once we sample, the selected token’s limit is decremented in the dictionary. If we have generated at least \(r\) tokens, then we also increment the limit of the \(r\)-th most recent token generated.
More formally, if \(s_i\) is the \(i\)-th token generated. Then we decrement the limit of \(s_i\) in our dictionary. If \(i>=r\), then we increment the limit for \(s_{i-r}\) in the dictionary as well to restore the window properties.
With this technique, model outputs can be verified with a simple token-wise loop. We can initialize a dictionary where each token has a count of 0. For every new token, increment its count in the dict. If we have generated at least \(r\) tokens, decrement the count of the \(r\)-th most recent token. If at any point, the count of a token exceeds that specified by the constraint dictionary, we know that the model output could not have come from a peer agent.
Advantages
The core advantage of this technique is the customizability at test-time. For example, if we prompt a model with Tell me a story about Santa.
and set the limit for the Santa
token to 1 with a relatively large window size, then most models that aren’t influenced by this constrained decoding scheme will be easily tracked. A robust use case for this library could be to pair it with a model that, given a prompt, identifies the most influential token, with the assumption that this token will likely need to be included repeatedly in a response. By heavily restricting the frequency of this influential token in a response, there is a low probability that a bad actor model could go undetected, as it would need to not frequently utilize these tokens in its response despite not being aware of the token frequency constraints.
Feedback and Contributions
I would really appreciate any feedback on this library and would love to discuss it with anybody interested. If you have any suggestions or ideas, please feel free to contribute to the project on GitHub or email me.