Skip to content

Conversation

@ottonemo
Copy link
Member

@ottonemo ottonemo commented Jul 19, 2023

Based on the paper

    Sanchez, Guillaume, et al.
    "Stay on topic with Classifier-Free Guidance."
    arXiv preprint arXiv:2306.17806 (2023).

a draft implementation of classifier free guidance.

This is simply for sharing internally and might very well be completely wrong. It is debatable if we should expose such a feature as a flag to the network or make it a separate classifier instance (or a mixin). In the past we were very much against special (potentially short-lived) feature flags and it was much nicer to have this implemented as an addon/callback. We might need to do something similar here as well.

Open tasks:

  • evaluate existing examples
  • write explicit test cases

Based on the paper

        Sanchez, Guillaume, et al.
        "Stay on topic with Classifier-Free Guidance."
        arXiv preprint arXiv:2306.17806 (2023).

a draft implementation of classifier free guidance.

This is simply for sharing internally and might very well be
completely wrong. It is debatable if we should expose such
a feature as a flag to the network or make it a separate
classifier instance (or a mixin). In the past we were
very much against special (potentially short-lived) feature
flags and it was much nicer to have this implemented as
an addon/callback. We might need to do something similar
here as well.
@BenjaminBossan
Copy link
Collaborator

The paper in question is this one:

https://arxiv.org/abs/2306.17806

Note that this method should have a greater effect the longer the labels are.

Some random comments:

  • At the moment, two forward passes are needed. Shouldn't we be able to pre-compute (or cache) P_wi_wji, since labels are always the same and known from the start?
  • How about, instead of exposing use_cfg, we expose cfg_gamma. If it is 1 (or None), don't use CFG, else apply that gamma instead of basically hard-coding it to 1.5?

It is debatable if we should expose such a feature as a flag to the network or make it a separate classifier instance (or a mixin). In the past we were very much against special (potentially short-lived) feature flags and it was much nicer to have this implemented as an addon/callback.

If this method works really well, I can see it being added explicitly. Alternatively, we could have a callbacks equivalent for logits processors, with _LogitsRecorder being the default.

- Makes it possible to set gamma parameter
- Setting it to `None` disabled functionality completely
@ottonemo ottonemo force-pushed the feature/llm-classifier-free-guidance branch from 96de091 to 1c34aca Compare July 19, 2023 18:45
- `label_id` was misleading since it is actually a list of token ids
  related to a label and not a scalar value. Also the general process
  of generating logits it not related to labels at all but rather just
  to tokens

- `kwargs` was named to be similar to transformers `generate`
  convention but is meant to be passed to `generate` and is therefore,
  in the context of `generate_logits` a model input. This should help
  the reader distinguish between expected input (`token_ids`) and
  model input (`model_input`)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants