diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/StringIndexerModelTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/StringIndexerModelTests.cs new file mode 100644 index 000000000..b4ced6878 --- /dev/null +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/StringIndexerModelTests.cs @@ -0,0 +1,87 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Microsoft.Spark.ML.Feature; +using Microsoft.Spark.Sql; +using Microsoft.Spark.Sql.Types; +using Microsoft.Spark.UnitTest.TestUtils; +using Xunit; + +namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature +{ + [Collection("Spark E2E Tests")] + public class StringIndexerModelTests : FeatureBaseTests + { + private readonly SparkSession _spark; + + public StringIndexerModelTests(SparkFixture fixture) : base(fixture) + { + _spark = fixture.Spark; + } + + /// + /// Create a , create a and test the + /// available methods. + /// + [Fact] + public void TestStringIndexerModel() + { + DataFrame input = _spark.CreateDataFrame( + new List + { + new GenericRow(new object[] {0, "a"}), + new GenericRow(new object[] {1, "b"}), + new GenericRow(new object[] {2, "c"}), + new GenericRow(new object[] {3, "a"}), + new GenericRow(new object[] {4, "a"}), + new GenericRow(new object[] {5, "c"}) + }, + new StructType(new List + { + new StructField("id", new IntegerType()), + new StructField("category", new StringType()) + })); + + string expectedUid = "theUid"; + StringIndexer stringIndexer = new StringIndexer(expectedUid) + .SetInputCol("category") + .SetOutputCol("categoryIndex"); + + StringIndexerModel stringIndexerModel = stringIndexer.Fit(input); + DataFrame transformedDF = stringIndexerModel.Transform(input); + List observed = transformedDF.Select("category", new string[] { "categoryIndex" }) + .Collect().ToList(); + List expected = new List + { + new Row(new GenericRow(new object[] {"a", "0"})), + new Row(new GenericRow(new object[] {"b", "2"})), + new Row(new GenericRow(new object[] {"c", "1"})), + new Row(new GenericRow(new object[] {"a", "0"})), + new Row(new GenericRow(new object[] {"a", "0"})), + new Row(new GenericRow(new object[] {"c", "1"})) + }; + + observed.ForEach(a => + { + Assert.Equal(a, expected.Where(b => b == a).FirstOrDefault()); + } + ); + Assert.Equal("category", stringIndexer.GetInputCol()); + Assert.Equal("categoryIndex", stringIndexer.GetOutputCol()); + Assert.Equal(expectedUid, stringIndexer.Uid()); + + using (var tempDirectory = new TemporaryDirectory()) + { + string savePath = Path.Join(tempDirectory.Path, "stringIndexerModel"); + stringIndexerModel.Save(savePath); + + StringIndexerModel loadedModel = StringIndexerModel.Load(savePath); + Assert.Equal(stringIndexerModel.Uid(), loadedModel.Uid()); + } + } + } +} diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/StringIndexerTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/StringIndexerTests.cs new file mode 100644 index 000000000..8fc2b70f7 --- /dev/null +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/StringIndexerTests.cs @@ -0,0 +1,52 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Microsoft.Spark.ML.Feature; +using Microsoft.Spark.Sql; +using Microsoft.Spark.Sql.Types; +using Microsoft.Spark.UnitTest.TestUtils; +using Xunit; + +namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature +{ + [Collection("Spark E2E Tests")] + public class StringIndexerTests : FeatureBaseTests + { + private readonly SparkSession _spark; + + public StringIndexerTests(SparkFixture fixture) : base(fixture) + { + _spark = fixture.Spark; + } + + /// + /// Create a , create a and test the + /// available methods. + /// + [Fact] + public void TestStringIndexer() + { + string expectedUid = "theUid"; + StringIndexer stringIndexer = new StringIndexer(expectedUid) + .SetInputCol("category") + .SetOutputCol("categoryIndex"); + + Assert.Equal("category", stringIndexer.GetInputCol()); + Assert.Equal("categoryIndex", stringIndexer.GetOutputCol()); + Assert.Equal(expectedUid, stringIndexer.Uid()); + + using (var tempDirectory = new TemporaryDirectory()) + { + string savePath = Path.Join(tempDirectory.Path, "stringIndexer"); + stringIndexer.Save(savePath); + + StringIndexer loadedstringIndexer = StringIndexer.Load(savePath); + Assert.Equal(stringIndexer.Uid(), loadedstringIndexer.Uid()); + } + } + } +} diff --git a/src/csharp/Microsoft.Spark/ML/Feature/StringIndexer.cs b/src/csharp/Microsoft.Spark/ML/Feature/StringIndexer.cs new file mode 100644 index 000000000..d5bbab383 --- /dev/null +++ b/src/csharp/Microsoft.Spark/ML/Feature/StringIndexer.cs @@ -0,0 +1,174 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Spark.Interop; +using Microsoft.Spark.Interop.Ipc; +using Microsoft.Spark.Sql; +using Microsoft.Spark.Sql.Types; + +namespace Microsoft.Spark.ML.Feature +{ + /// + /// encodes a string column of labels to a column of label indices. + /// + public class StringIndexer : FeatureBase, IJvmObjectReferenceProvider + { + private static readonly string s_StringIndexerClassName = + "org.apache.spark.ml.feature.StringIndexer"; + + /// + /// Create a without any parameters. + /// + public StringIndexer() : base(s_StringIndexerClassName) + { + } + + /// + /// Create a with a UID that is used to give the + /// a unique ID. + /// + /// An immutable unique ID for the object and its derivatives. + public StringIndexer(string uid) : base(s_StringIndexerClassName, uid) + { + } + + internal StringIndexer(JvmObjectReference jvmObject) : base(jvmObject) + { + } + + JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject; + + /// + /// Executes the and transforms the schema. + /// + /// The Schema to be transformed + /// + /// New object with the schema transformed. + /// + public StructType TransformSchema(StructType value) => + new StructType( + (JvmObjectReference)_jvmObject.Invoke( + "transformSchema", + DataType.FromJson(_jvmObject.Jvm, value.Json))); + + /// + /// Executes the and fits a model to the input data. + /// + /// The to fit the model to. + /// + public StringIndexerModel Fit(DataFrame source) => + new StringIndexerModel((JvmObjectReference)_jvmObject.Invoke("fit", source)); + + /// + /// Gets the HandleInvalid. + /// + /// Handle Invalid option + public string GetHandleInvalid() => (string)_jvmObject.Invoke("getHandleInvalid"); + + /// + /// Sets the Handle Invalid option to . + /// + /// Handle Invalid option + /// + /// with the Handle Invalid set. + /// + public StringIndexer SetHandleInvalid(string handleInvalid) => + WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setHandleInvalid", handleInvalid)); + + /// + /// Gets the InputCol. + /// + /// Input Col option + public string GetInputCol() => (string)_jvmObject.Invoke("getInputCol"); + + /// + /// Sets the Input Col option to . + /// + /// Input Col option + /// + /// with the Input Col set. + /// + public StringIndexer SetInputCol(string inputCol) => + WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setInputCol", inputCol)); + + /// + /// Gets the InputCols array. + /// + /// Input Cols array option + public string[] GetInputCols() => (string[])_jvmObject.Invoke("getInputCols"); + + /// + /// Sets the Input Cols array option to . + /// + /// Input Cols array option + /// + /// with the Input Cols array set. + /// + public StringIndexer SetInputCols(string[] inputCols) => + WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setInputCols", inputCols)); + + /// + /// Gets the OutputCol. + /// + /// Output Col option + public string GetOutputCol() => (string)_jvmObject.Invoke("getOutputCol"); + + /// + /// Sets the Output Col option to . + /// + /// Output Col option + /// + /// with the Output Col set. + /// + public StringIndexer SetOutputCol(string outputCol) => + WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setOutputCol", outputCol)); + + /// + /// Gets the OutputCols array. + /// + /// Output Cols array option + public string[] GetOutputCols() => (string[])_jvmObject.Invoke("getOutputCols"); + + /// + /// Sets the Output Cols array option to . + /// + /// Output Cols array option + /// + /// with the Output Cols array set. + /// + public StringIndexer SetOutputCols(string[] outputCols) => + WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setOutputCol", outputCols)); + + /// + /// Gets the String Order Type. + /// + /// String Order Type + public string GetStringOrderType() => (string)_jvmObject.Invoke("getStringOrderType"); + + /// + /// Sets the String Order Type to . + /// + /// String Order Type + /// + /// with the String Order Type set. + /// + public StringIndexer SetStringOrderType(string stringOrderType) => + WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setStringOrderType", stringOrderType)); + + /// + /// Loads the that was previously saved using Save. + /// + /// The path the previous was saved to + /// New object, loaded from path + public static StringIndexer Load(string path) => + WrapAsStringIndexer( + SparkEnvironment.JvmBridge.CallStaticJavaMethod( + s_StringIndexerClassName, + "load", + path)); + + private static StringIndexer WrapAsStringIndexer(object obj) => + new StringIndexer((JvmObjectReference)obj); + } +} diff --git a/src/csharp/Microsoft.Spark/ML/Feature/StringIndexerModel.cs b/src/csharp/Microsoft.Spark/ML/Feature/StringIndexerModel.cs new file mode 100644 index 000000000..e16757a93 --- /dev/null +++ b/src/csharp/Microsoft.Spark/ML/Feature/StringIndexerModel.cs @@ -0,0 +1,92 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using Microsoft.Spark.Interop; +using Microsoft.Spark.Interop.Ipc; +using Microsoft.Spark.Sql; +using Microsoft.Spark.Sql.Types; + +namespace Microsoft.Spark.ML.Feature +{ + public class StringIndexerModel + : FeatureBase, IJvmObjectReferenceProvider + { + private static readonly string s_stringIndexerModelClassName = + "org.apache.spark.ml.feature.StringIndexerModel"; + + /// + /// Creates a without any parameters + /// + /// The vocabulary to use + public StringIndexerModel(List vocabulary) + : this(SparkEnvironment.JvmBridge.CallConstructor( + s_stringIndexerModelClassName, vocabulary)) + { + } + + /// + /// Creates a with a UID that is used to give the + /// a unique ID + /// + /// An immutable unique ID for the object and its derivatives. + /// The vocabulary to use + public StringIndexerModel(string uid, List vocabulary) + : this(SparkEnvironment.JvmBridge.CallConstructor( + s_stringIndexerModelClassName, uid, vocabulary)) + { + } + + internal StringIndexerModel(JvmObjectReference jvmObject) : base(jvmObject) + { + } + + JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject; + + /// + /// Loads the that was previously saved using Save + /// + /// + /// The path the previous was saved to + /// + /// New object + public static StringIndexerModel Load(string path) => + WrapAsStringIndexerModel( + SparkEnvironment.JvmBridge.CallStaticJavaMethod( + s_stringIndexerModelClassName, "load", path)); + + /// + /// Check transform validity and derive the output schema from the input schema. + /// + /// This checks for validity of interactions between parameters during Transform and + /// raises an exception if any parameter value is invalid. + /// + /// Typical implementation should first conduct verification on schema change and parameter + /// validity, including complex parameter interaction checks. + /// + /// + /// The of the which will be transformed. + /// + /// + /// The of the output schema that would have been derived from the + /// input schema, if Transform had been called. + /// + public StructType TransformSchema(StructType value) => + new StructType( + (JvmObjectReference)_jvmObject.Invoke( + "transformSchema", + DataType.FromJson(_jvmObject.Jvm, value.Json))); + + /// + /// Converts a DataFrame with a text document to a sparse vector of token counts. + /// + /// to transform + /// containing the original data and the counts + public DataFrame Transform(DataFrame document) => + new DataFrame((JvmObjectReference)_jvmObject.Invoke("transform", document)); + + private static StringIndexerModel WrapAsStringIndexerModel(object obj) => + new StringIndexerModel((JvmObjectReference)obj); + } +}