-
Notifications
You must be signed in to change notification settings - Fork 1
Process positional arguments #134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 23 commits
df69fd5
3b881a6
2899de3
731b664
a4cf2cd
d6b16ed
aeca489
61f1b1b
e8b508c
109e0e4
b7d9d5c
6dbbadd
e5c5b80
1046cff
8c489aa
130f807
55cadea
c83eccf
4361e2e
53c44f1
23ae949
94a005f
528a3a7
a3cfae6
92cf918
b145a53
72d53c6
d05963b
fffd0d0
49065c9
f53fe07
e52d371
ca0c4d2
914c417
ebd3880
aa4f6cd
19a8a72
f92872c
1eea97a
f7cd804
a107678
a835e8b
594ea9d
dbf19ad
c182a37
d3239b4
9a60cbf
1cb7502
f710611
dacd17a
5673690
474acdf
ccce776
4e16587
5d72476
b4358fd
54ba275
afd01ff
1ca2997
afdbbb0
4161ad7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,17 +3,21 @@ | |
| import static org.eclipse.core.runtime.Platform.getLog; | ||
|
|
||
| import java.io.File; | ||
| import java.util.ArrayList; | ||
| import java.util.Objects; | ||
| import java.util.Set; | ||
|
|
||
| import org.eclipse.core.runtime.ILog; | ||
| import org.eclipse.core.runtime.IProgressMonitor; | ||
| import org.eclipse.jface.text.BadLocationException; | ||
| import org.eclipse.jface.text.IDocument; | ||
| import org.python.pydev.ast.codecompletion.revisited.visitors.Definition; | ||
| import org.python.pydev.core.IPythonNature; | ||
| import org.python.pydev.core.docutils.PySelection; | ||
| import org.python.pydev.parser.jython.ast.Attribute; | ||
| import org.python.pydev.parser.jython.ast.Call; | ||
| import org.python.pydev.parser.jython.ast.FunctionDef; | ||
| import org.python.pydev.parser.jython.ast.Name; | ||
| import org.python.pydev.parser.jython.ast.NameTok; | ||
| import org.python.pydev.parser.jython.ast.argumentsType; | ||
| import org.python.pydev.parser.jython.ast.decoratorsType; | ||
|
|
@@ -105,6 +109,9 @@ public HybridizationParameters(IProgressMonitor monitor) throws BadLocationExcep | |
| // Will contain the last tf.function decorator | ||
| decoratorsType tfFunctionDecorator = null; | ||
|
|
||
| // Declaring definitions of the decorator | ||
| Set<Definition> declaringDefinitions = null; | ||
|
|
||
| // Iterate through the decorators of the function | ||
| for (decoratorsType decorator : decoratorArray) { | ||
| IDocument document = Function.this.getContainingDocument(); | ||
|
|
@@ -113,54 +120,213 @@ public HybridizationParameters(IProgressMonitor monitor) throws BadLocationExcep | |
| // Save the hybrid decorator | ||
| try { | ||
| if (Function.isHybrid(decorator, Function.this.containingModuleName, Function.this.containingFile, selection, | ||
| Function.this.nature, monitor)) // TODO: Cache this from a previous call (#118). | ||
| Function.this.nature, monitor)) { // TODO: Cache this from a previous call (#118). | ||
| tfFunctionDecorator = decorator; | ||
| declaringDefinitions = Util.getDeclaringDefinition(selection, Function.this.containingModuleName, | ||
tatianacv marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Function.this.containingFile, Function.this.nature, monitor); | ||
| } | ||
| } catch (AmbiguousDeclaringModuleException e) { | ||
| throw new IllegalStateException("Can't determine whether decorator: " + decorator + " is hybrid.", e); | ||
| } | ||
| } // We expect to have the last tf.function decorator in tfFunctionDecorator | ||
|
|
||
| // Declaring definition of the decorator | ||
| Definition declaringDefinition = null; | ||
tatianacv marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| // Getting the definition, there should only be one in the set. | ||
tatianacv marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if (declaringDefinitions != null) { | ||
| declaringDefinition = declaringDefinitions.iterator().next(); | ||
| } | ||
|
|
||
| // Python source arguments from the declaring definition | ||
| exprType[] declaringArguments = null; | ||
|
|
||
| // Getting the arguments from TensorFlow source | ||
| if (declaringDefinition != null) { | ||
| if (declaringDefinition.ast instanceof FunctionDef) { | ||
| FunctionDef declaringFunctionDefinition = (FunctionDef) declaringDefinition.ast; | ||
| argumentsType declaringArgumentTypes = declaringFunctionDefinition.args; | ||
| declaringArguments = declaringArgumentTypes.args; | ||
| } | ||
| } | ||
|
|
||
| // Python source arguments from the declaring definition | ||
| ArrayList<String> argumentIdDeclaringDefintion = new ArrayList<>(); | ||
|
|
||
| // Getting the arguments from the definition | ||
| if (declaringArguments != null) { | ||
| for (exprType declaredArgument : declaringArguments) { | ||
| if (declaredArgument instanceof Name) { | ||
| Name argumentName = (Name) declaredArgument; | ||
| argumentIdDeclaringDefintion.add(argumentName.id); | ||
| } | ||
| } | ||
| } | ||
tatianacv marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| if (tfFunctionDecorator != null) | ||
| // tfFunctionDecorator must be an instance of Call, because that's the only way we have parameters. | ||
| if (tfFunctionDecorator.func instanceof Call) { | ||
| Call callFunction = (Call) tfFunctionDecorator.func; | ||
| // We only care about the actual keywords for now. | ||
| // TODO: Parse positional arguments (#108). | ||
|
|
||
| // Processing positional arguments for tf.function a | ||
| exprType[] arguments = callFunction.args; | ||
| for (int i = 0; i < arguments.length; i++) { | ||
| String argumentDeclaringDefinition = argumentIdDeclaringDefintion.get(i); | ||
|
||
|
|
||
| // Matching the arguments from the definition and the arguments from the code being analyzed. | ||
| if (argumentDeclaringDefinition.equals(FUNC)) { | ||
| // Not considering the default values | ||
| if (arguments[i] instanceof Name) { | ||
| Name nameArgument = (Name) arguments[i]; | ||
| if (nameArgument.id != "None") | ||
| // Found parameter func | ||
| this.funcParamExists = true; | ||
| } else { | ||
| // Found parameter func | ||
| this.funcParamExists = true; | ||
| } | ||
| } else if (argumentDeclaringDefinition.equals(INPUT_SIGNATURE)) { | ||
| // Not considering the default values | ||
| if (arguments[i] instanceof Name) { | ||
| Name nameArgument = (Name) arguments[i]; | ||
| if (nameArgument.id != "None") | ||
| // Found parameter input_signature | ||
| this.inputSignatureParamExists = true; | ||
| } else { | ||
| // Found parameter input_signature | ||
| this.inputSignatureParamExists = true; | ||
| } | ||
| } else if (argumentDeclaringDefinition.equals(AUTOGRAPH)) { | ||
| // Not considering the default values | ||
| if (arguments[i] instanceof Name) { | ||
| Name nameArgument = (Name) arguments[i]; | ||
| if (nameArgument.id != "True") | ||
| // Found parameter autograph | ||
| this.autoGraphParamExists = true; | ||
| } else { | ||
| // Found parameter autograph | ||
| this.autoGraphParamExists = true; | ||
| } | ||
| // The latest version of the API we are using allows | ||
| // parameter names jit_compile and | ||
| // deprecated name experimental_compile | ||
| } else if (argumentDeclaringDefinition.equals(JIT_COMPILE) | ||
| || argumentDeclaringDefinition.equals(EXPERIMENTAL_COMPILE)) { | ||
| // Not considering the default values | ||
| if (arguments[i] instanceof Name) { | ||
| Name nameArgument = (Name) arguments[i]; | ||
| if (nameArgument.id != "None") | ||
| // Found parameter jit_compile/experimental_compile | ||
| this.jitCompileParamExists = true; | ||
| } else { | ||
| // Found parameter jit_compile/experimental_compile | ||
| this.jitCompileParamExists = true; | ||
| } | ||
| // The latest version of the API we are using allows | ||
| // parameter names reduce_retracing | ||
| // and deprecated name experimental_relax_shapes | ||
| } else if (argumentDeclaringDefinition.equals(REDUCE_RETRACING)) { | ||
| // Not considering the default values | ||
| if (arguments[i] instanceof Name) { | ||
| Name nameArgument = (Name) arguments[i]; | ||
| if (nameArgument.id != "False") | ||
| // Found parameter reduce_retracing | ||
| this.reduceRetracingParamExists = true; | ||
| } else { | ||
| // Found parameter reduce_retracing | ||
| this.reduceRetracingParamExists = true; | ||
| } | ||
| } else if (argumentDeclaringDefinition.equals(EXPERIMENTAL_RELAX_SHAPES)) { | ||
| // Not considering the default values | ||
| if (arguments[i] instanceof Name) { | ||
| Name nameArgument = (Name) arguments[i]; | ||
| if (nameArgument.id != "None") | ||
| // Found parameter experimental_relax_shapes | ||
| this.reduceRetracingParamExists = true; | ||
| } else { | ||
| // Found parameter experimental_relax_shapes | ||
| this.reduceRetracingParamExists = true; | ||
| } | ||
| } else if (argumentDeclaringDefinition.equals(EXPERIMENTAL_IMPLEMENTS)) { | ||
| // Not considering the default values | ||
| if (arguments[i] instanceof Name) { | ||
| Name nameArgument = (Name) arguments[i]; | ||
| if (nameArgument.id != "None") | ||
| // Found parameter experimental_implements | ||
| this.experimentalImplementsParamExists = true; | ||
| } else { | ||
| // Found parameter experimental_implements | ||
| this.experimentalImplementsParamExists = true; | ||
| } | ||
| } else if (argumentDeclaringDefinition.equals(EXPERIMENTAL_AUTOGRAPH_OPTIONS)) { | ||
| // Not considering the default values | ||
| if (arguments[i] instanceof Name) { | ||
| Name nameArgument = (Name) arguments[i]; | ||
| if (nameArgument.id != "None") | ||
| // Found parameter experimental_autograph_options | ||
| this.experimentalAutographOptionsParamExists = true; | ||
| } else { | ||
| // Found parameter experimental_autograph_options | ||
| this.experimentalAutographOptionsParamExists = true; | ||
| } | ||
| } else if (argumentDeclaringDefinition.equals(EXPERIMENTAL_FOLLOW_TYPE_HINTS)) { | ||
| // Not considering the default values | ||
| if (arguments[i] instanceof Name) { | ||
| Name nameArgument = (Name) arguments[i]; | ||
| if (nameArgument.id != "None") | ||
| // Found parameter experimental_follow_type_hints | ||
| this.experimentaFollowTypeHintsParamExists = true; | ||
| } else { | ||
| // Found parameter experimental_follow_type_hints | ||
| this.experimentaFollowTypeHintsParamExists = true; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Processing keywords arguments | ||
| // If we have keyword parameter, afterwards, we cannot have positional parameters because it would result in invalid | ||
| // Python code. | ||
| keywordType[] keywords = callFunction.keywords; | ||
| for (keywordType keyword : keywords) { | ||
| if (keyword.arg instanceof NameTok) { | ||
| NameTok name = (NameTok) keyword.arg; | ||
| if (name.id.equals(FUNC)) | ||
| if (name.id.equals(FUNC) && argumentIdDeclaringDefintion.contains(name.id)) | ||
| // Found parameter func | ||
| this.funcParamExists = true; | ||
| else if (name.id.equals(INPUT_SIGNATURE)) | ||
| else if (name.id.equals(INPUT_SIGNATURE) && argumentIdDeclaringDefintion.contains(name.id)) | ||
| // Found parameter input_signature | ||
| this.inputSignatureParamExists = true; | ||
| else if (name.id.equals(AUTOGRAPH)) | ||
| else if (name.id.equals(AUTOGRAPH) && argumentIdDeclaringDefintion.contains(name.id)) | ||
| // Found parameter autograph | ||
| this.autoGraphParamExists = true; | ||
| // The version of the API we are using allows | ||
| // The latest version of the API we are using allows | ||
tatianacv marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // parameter names jit_compile and | ||
| // deprecated name experimental_compile | ||
| else if (name.id.equals(JIT_COMPILE) || name.id.equals(EXPERIMENTAL_COMPILE)) | ||
| else if ((name.id.equals(JIT_COMPILE) || name.id.equals(EXPERIMENTAL_COMPILE)) | ||
| && argumentIdDeclaringDefintion.contains(name.id)) | ||
| // Found parameter jit_compile/experimental_compile | ||
| this.jitCompileParamExists = true; | ||
| // The version of the API we are using allows | ||
| // The latest version of the API we are using allows | ||
| // parameter names reduce_retracing | ||
| // and deprecated name experimental_relax_shapes | ||
| else if (name.id.equals(REDUCE_RETRACING) || name.id.equals(EXPERIMENTAL_RELAX_SHAPES)) | ||
| else if ((name.id.equals(REDUCE_RETRACING) || name.id.equals(EXPERIMENTAL_RELAX_SHAPES)) | ||
| && argumentIdDeclaringDefintion.contains(name.id)) | ||
| // Found parameter reduce_retracing | ||
| // or experimental_relax_shapes | ||
| this.reduceRetracingParamExists = true; | ||
| else if (name.id.equals(EXPERIMENTAL_IMPLEMENTS)) | ||
| else if (name.id.equals(EXPERIMENTAL_IMPLEMENTS) && argumentIdDeclaringDefintion.contains(name.id)) | ||
| // Found parameter experimental_implements | ||
| this.experimentalImplementsParamExists = true; | ||
| else if (name.id.equals(EXPERIMENTAL_AUTOGRAPH_OPTIONS)) | ||
| else if (name.id.equals(EXPERIMENTAL_AUTOGRAPH_OPTIONS) && argumentIdDeclaringDefintion.contains(name.id)) | ||
| // Found parameter experimental_autograph_options | ||
| this.experimentalAutographOptionsParamExists = true; | ||
| else if (name.id.equals(EXPERIMENTAL_FOLLOW_TYPE_HINTS)) | ||
| else if (name.id.equals(EXPERIMENTAL_FOLLOW_TYPE_HINTS) && argumentIdDeclaringDefintion.contains(name.id)) | ||
| // Found parameter experimental_follow_type_hints | ||
| this.experimentaFollowTypeHintsParamExists = true; | ||
| else { | ||
| throw new IllegalArgumentException(String.format("The tf.function argument " + name.id) | ||
| + " is not supported in this tool. This tool supports up to v2.9"); | ||
| } | ||
| } | ||
| } | ||
| } // else, tf.function is used without parameters. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| import tensorflow as tf | ||
|
|
||
| @tf.function(reduce_retracing=True) | ||
| def test(x): | ||
| return x | ||
|
|
||
| if __name__ == '__main__': | ||
| number = tf.constant([1.0, 1.0]) | ||
| test(number) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.8.0 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| import tensorflow as tf | ||
|
|
||
| @tf.function(None) | ||
| def test(): | ||
| pass | ||
|
|
||
| if __name__ == '__main__': | ||
| test() | ||
tatianacv marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.9.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| import tensorflow as tf | ||
|
|
||
| @tf.function(None, (tf.TensorSpec(shape=[None], dtype=tf.float32),), False, True, True, "google.matmul_low_rank_matrix", tf.autograph.experimental.Feature.EQUALITY_OPERATORS, True, None, False) | ||
| def test(x: tf.Tensor): | ||
| return x | ||
|
|
||
| if __name__ == '__main__': | ||
| number = tf.constant([1.0, 1.0]) | ||
| print(test(number)) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.9.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| import tensorflow as tf | ||
|
|
||
| @tf.function(None,None,True,None,False,None,None,None, None,None) | ||
tatianacv marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def test(): | ||
| pass | ||
|
|
||
| if __name__ == '__main__': | ||
| test() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.9.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| import tensorflow as tf | ||
|
|
||
| @tf.function(None, (tf.TensorSpec(shape=[None], dtype=tf.float32),), False, True, "google.matmul_low_rank_matrix") | ||
| def test(x): | ||
| return x | ||
|
|
||
| if __name__ == '__main__': | ||
| number = tf.constant([1.0, 1.0]) | ||
| test(number) |
Uh oh!
There was an error while loading. Please reload this page.