VowpalWabbitClassifier does not work with --oaa (One Against All) argument
See original GitHub issueDescribe the bug Vowpal Wabbit’s One Against All classifier does not work via the MMLSpark interface.
To Reproduce
val vwClassifier = new VowpalWabbitClassifier()
.setFeaturesCol("features")
.setLabelCol("label")
.setProbabilityCol("predictedProb")
.setPredictionCol("predictedLabel")
.setRawPredictionCol("rawPrediction")
.setArgs("--oaa=2 --quiet --holdout_off")
features is a column of sparse vectors (constructed via VowpalWabbitFeaturizer in my case), label is a column of integers with values {1, 2}.
Expected behavior
val predictions = vwClassifier.fit(trainDF).transform(testDF)
predictions.show
would show my testDF with predictedLabel column containing predictions.
Info (please complete the following information):
- MMLSpark Version: 1.0.0-rc1
- Spark Version: 2.4.3
- Spark Platform: AWS EMR 5.26.0 (Zeppelin 0.8.1)
** Stacktrace**
org.apache.spark.SparkException: Job aborted due to stage failure: Task 95 in stage 46.0 failed 4 times, most recent failure: Lost task 95.3 in stage 46.0 (TID 2609, ip-10-5-29-73.ec2.internal, executor 7): org.apache.spark.SparkException: Failed to execute user defined function($anonfun$2: (struct<features:struct<type:tinyint,size:int,indices:array<int>,values:array<double>>>) => double)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:636)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:291)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:283)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:121)
at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.ClassCastException: java.lang.Integer cannot be cast to org.vowpalwabbit.spark.prediction.ScalarPrediction
at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$class.predictInternal(VowpalWabbitBaseModel.scala:84)
at com.microsoft.ml.spark.vw.VowpalWabbitClassificationModel.predictInternal(VowpalWabbitClassifier.scala:61)
at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$$anonfun$2.apply(VowpalWabbitBaseModel.scala:49)
at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$$anonfun$2.apply(VowpalWabbitBaseModel.scala:49)
... 21 more
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:2041)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:2029)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:2028)
at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2028)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:966)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:966)
at scala.Option.foreach(Option.scala:257)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:966)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2262)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2211)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2200)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:777)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:401)
at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3383)
at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544)
at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544)
at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3364)
at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78)
at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73)
at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3363)
at org.apache.spark.sql.Dataset.head(Dataset.scala:2544)
at org.apache.spark.sql.Dataset.take(Dataset.scala:2758)
at org.apache.spark.sql.Dataset.getRows(Dataset.scala:254)
at org.apache.spark.sql.Dataset.showString(Dataset.scala:291)
at org.apache.spark.sql.Dataset.show(Dataset.scala:745)
at org.apache.spark.sql.Dataset.show(Dataset.scala:704)
at org.apache.spark.sql.Dataset.show(Dataset.scala:713)
... 51 elided
Caused by: org.apache.spark.SparkException: Failed to execute user defined function($anonfun$2: (struct<features:struct<type:tinyint,size:int,indices:array<int>,values:array<double>>>) => double)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:636)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:291)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:283)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:121)
at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
... 3 more
Caused by: java.lang.ClassCastException: java.lang.Integer cannot be cast to org.vowpalwabbit.spark.prediction.ScalarPrediction
at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$class.predictInternal(VowpalWabbitBaseModel.scala:84)
at com.microsoft.ml.spark.vw.VowpalWabbitClassificationModel.predictInternal(VowpalWabbitClassifier.scala:61)
at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$$anonfun$2.apply(VowpalWabbitBaseModel.scala:49)
at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$$anonfun$2.apply(VowpalWabbitBaseModel.scala:49)
... 21 more
To me, it looks like the Caused by: java.lang.ClassCastException: java.lang.Integer cannot be cast to org.vowpalwabbit.spark.prediction.ScalarPrediction at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$class.predictInternal(VowpalWabbitBaseModel.scala:84)
is the root cause. Could it be that --oaa
outputs integers instead of doubles expected by MMLSpark?
Additional context For context, this works fine in my setup on the same dataset with the same VowpalWabbitFeaturizer (although I have to convert labels to {1, 0}):
val vwClassifier = new VowpalWabbitClassifier()
.setFeaturesCol("features")
.setLabelCol("label")
.setProbabilityCol("predictedProb")
.setPredictionCol("predictedLabel")
.setRawPredictionCol("rawPrediction")
.setArgs("--loss_function=logistic --link=logistic --quiet --holdout_off")
Issue Analytics
- State:
- Created 4 years ago
- Comments:13 (6 by maintainers)
Top GitHub Comments
waiting for VW PR to be approved and then we can look into merging this for SynapseML
It’s not implemented yet: https://github.com/microsoft/SynapseML/blob/9d16166314ef42767e1c50b9c831c163050c15d3/vw/src/main/scala/com/microsoft/azure/synapse/ml/vw/VowpalWabbitClassifier.scala#L64
let me follow-up w/ Jack. is there a minimal dataset we can use to repro?