Implement the ability to infer the checkpoint type#254
Draft
blester125 wants to merge 1 commit intor-three:mainfrom
Draft
Implement the ability to infer the checkpoint type#254blester125 wants to merge 1 commit intor-three:mainfrom
blester125 wants to merge 1 commit intor-three:mainfrom
Conversation
Currently this is just based on the file path. In order to make the sniffing extensible, it is implemented as a plugin. In order to do that, and to avoid having to import a deep learning framework to do sniffing, the sniffers are separated out into their own files. This lead to the question of why are checkpoint plugs stored in the main repo? If we move them to the actual plugins area we can continue to support them while not needing to install them if a framework isn't going to be used. Now when you want to git theta with a specific framework you can either run `pip install git-theta[framework]` or run `pip install git-theta-checkpoints-framework` for your framework.
craffel
reviewed
Jun 12, 2024
|
|
||
|
|
||
| def sniff_checkpoint(checkpoint_path) -> str: | ||
| """En""" |
| ) | ||
| for ckpt_type, ckpt_sniffer in loaded_plugins.items(): | ||
| logger.debug(f"Checking if {checkpoint_path} is a {ckpt_type} checkpoint.") | ||
| if ckpt_sniffer(checkpoint_path): |
Contributor
There was a problem hiding this comment.
Should we have any kind of error handling in the case where two checkpoint sniffers denote that a given checkpoint is their type? Currently this will return the checkpoint type corresponding to the first checkpoint sniffer the loop encounters that matches (which I assume is in arbitrary order).
| @@ -0,0 +1,9 @@ | |||
| """Infer if a checkpoint is flax based. | |||
|
|
|||
| We put this in a different file to avoid importing dl frameworks for file sniffing. | |||
Contributor
There was a problem hiding this comment.
Not a relevant comment for this particular sniffer?
| from setuptools import setup | ||
|
|
||
|
|
||
| def get_version(file_name: str, version_variable: str = "__version__") -> str: |
Contributor
There was a problem hiding this comment.
This does not seem to be specific to the flax checkpoint sniffer and is duplicated across the setup.py files - should it be pulled out into a more general utility function filea
| return cls(model_dict) | ||
|
|
||
|
|
||
| def safetensors_sniffer(checkpoint_path: str) -> bool: |
Contributor
There was a problem hiding this comment.
Why is this here and in the sniffer file itself?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Currently the included sniffers are just based on the file path, but some formats, for example pytorch's pickle and newer zip format, have magic numbers in their headers we can check for in the future.
In order to make the sniffing extensible, it is implemented as a plugin. In order to do that, and to avoid having to import a deep learning framework to do sniffing, the sniffers are separated out into their own files. This lead to the question of why are checkpoint plugs stored in the main repo? If we move them to the actual plugins area we can continue to support them while not needing to install them if a framework isn't going to be used.
An alternative implementation could include the sniffer plugin as part of the main repo, but the usage would be a bit clunkier as we would need to handle import errors (for example a pytorch sniffer might want to use torch but it isn't installed). The current solution would only sniff for checkpoints from frameworks that you have installed.
From a user perspective, this PR results in the following:
Currently, to use git-theta with a framework, say tensorflow, you need to have tensorflow installed in your python environment and you need to tell git theta that a checkpoint is TF via an environment variable.
pip install git-theta[tensorflow]is provided as an easy way to ensure both git theta and tensorflow are installed, but it doesn't need to be used.With this version, to use some framework, say tensorflow, you either need to install
git-theta-checkpoints-tensorflowalong withgit-thetaor you can usepip install git-theta[tensorflow]as a short cut. If later you want to use git-theta with pytorch, you need to installgit-theta-checkpoint-pytorch.I'd like some feedback on if people think this change to how things are installed would be too much for users.