-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Soft prompts #231
Soft prompts #231
Conversation
return model | ||
|
||
|
||
Model.register("transformers::with_prefix")(make_prefix_transformer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As this will probably be difficult to change later, it's worth thinking about the terminology. Prefix tuning? Prompt tuning? Something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we'll call it with_soft_prompt
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Idk. At least some people use the prompt vs prefix tuning distinction to refer to the shallow (input layer only) vs deep distinction. I have no strong preference, but worth thinking carefully about and maybe asking for wider opinions.
# Because PyTorch hooks don't support kwargs, we monkey patch the forward method 🙈 | ||
old_forward = model.forward | ||
|
||
def new_forward(*args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What made me turn away from monkeypatching in my code is when I noticed that it doesn't need/have a self
, so there might be some fundamental differences between the old vs. new forward. If I were two years younger I probably would have voted for monkeypatching but the older me is less adventurous and worry more about safety. Go ahead if you're confident that this is safe, but at least I would suggest some sort of assertion to check the forward has not been monkeypatched (because if it had, the logic would be incorrect).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a little uneasy about it, but I think it beats the alternatives. At the very least I want to see where it goes and where it falls down, if it does. Also, apparently there is movement on the PyTorch side to allow kwargs
in hooks. When that comes true, we can do this properly.
As for your specific concern, this will work fine even if forward()
has already been monkey patched before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will it? Wouldn't be the patching happen multiple times at each recursion level?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the thing you pass into this function was monkey patched before, then old_forward
ends up being the first level of monkey patching, and it will get called when we go one level down.
old_forward
becomes part of the closure of new_forward
. That's how the chain of forward methods is maintained.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. So it would be like this, no?
forward: # monkey patch lvl 1
patch_tensor
forward: # monkey patch lvl 2
patch_tensor
forward # original
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the tensors will be patched twice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The inputs and outputs should be patched twice. That is correct.
What won't work is that you can't call set_input_embeddings()
twice with the way I have it here, because _WithPromptEmbedding
reaches into the original embedding's internals.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, what I was worried about is the method being unintentionally called twice. I can't think of a case where it is intentionally called twice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
c8d1b86 should make it possible to stack two prompt-enabled transformers on top of each other.
I think it's important that we ensure this pattern works for other forms as well. What if we implement adapters the same way, and we want to run both at the same time? The whole point of trying for this "looks like a normal huggingface transformer" approach is that it should be easy to combine with other components that do the same thing.
|
||
result = old_forward(*args, **kwargs) | ||
|
||
if isinstance(result, CausalLMOutputWithCrossAttentions): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comment for what this is doing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I have to go through and write docs and whatnot.
Oh no, I found a big problem with this. It doesn't work with |
This does not work for T5 at all 😭. I'm no longer sure this approach of patching the model will work. The huggingface generation code makes calls into the middle of their model, instead of always going through the |
Is this problematic for generation only?
This is what I was worrying about above. |
Copying from Slack: I can patch just the encoder for T5. Then the soft prompt has the opportunity to change how the rest of the prompt is encoded. But the encoded soft tokens are not part of the encoder output, and cannot be attended to by the decoder. @ZhaofengWu, is that important? |
Just to resolve this chain of comments: I made it work with T5. |
CHANGELOG.md
Outdated
@@ -262,6 +262,7 @@ instead of `ModuleNotFound`. | |||
- Added the "-n/--name" option to `tango run`. This option allows the user to give the run an arbitrary name. | |||
- Added a convenience property `.workspace` to `Step` class that can be called from a step's `.run()` method to get the current `Workspace` being used. | |||
- Gave `FromParams` objects (which includes all `Registrable` objects) the ability to version themselves. | |||
- Added the `transformers::with_soft_prompt` integration, to make soft-prompted prefix transformers easy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should move this up in the changelog.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
) | ||
) | ||
r = random.Random(random_seed) | ||
indices = torch.tensor(r.sample(range(5000), prompt_length)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where does 5000
come from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a number that Zhaofeng used in his code. He got it from some paper.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sure is a little weird. Maybe it should sample from the entire original embedding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was originally used in https://arxiv.org/abs/2104.08691 and subsequently other papers such as https://arxiv.org/abs/2108.04106 and of course ours. The idea is to only use the representation of the top-5000 tokens.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd keep 5000 or at least have some flag to control this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the idea that the top 5000 most frequent tokens have received more training data and are therefore better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that's my understanding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made it configurable, with a default of 5000.
This is ready for another review. |
patch_tensor(kwargs, "labels") | ||
patch_tensor(kwargs, "attention_mask", 1) | ||
patch_tensor(kwargs, "token_type_ids") | ||
patch_tensor_with_indices(kwargs, "position_ids", prompt_length) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, if the position ids are originally [0, 1, 2, 3, 4]
, they will now be [0, 1, 2, .. prompt_len-1, 0, 1, 2, 3, 4]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's right. I could see it going the other way, but I think it's important that the output does not change in the case where the soft prompt is configured to do nothing. Also, if we offset the position ids, we would decrease the max length that the model can handle, which is uncomfortable.
In code:
In config files:
Missing: