Not working with regression problems
See original GitHub issueHello and thank you for this package!
It seems that this package is not useful when the problem is regression rather than classification. After getting errors on my own data when using this package, I realized it always gives error for regression problems. As an example I slightly modified https://github.com/maxpumperla/elephas/blob/master/examples/ml_pipeline_otto.py to make it a regression problem from line 61 onwards:
model = Sequential()
model.add(Dense(512, input_shape=(input_dim,)))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('linear'))
model.compile(optimizer='adam', loss='mean_absolute_error')
sgd = optimizers.SGD(lr=0.01)
sgd_conf = optimizers.serialize(sgd)
# Initialize Elephas Spark ML Estimator
estimator = ElephasEstimator()
estimator.set_keras_model_config(model.to_yaml())
estimator.set_optimizer_config(sgd_conf)
estimator.set_mode("synchronous")
estimator.set_loss("mean_absolute_error")
estimator.set_metrics(['mae'])
estimator.setFeaturesCol("scaled_features")
estimator.setLabelCol("index_category")
estimator.set_epochs(10)
estimator.set_batch_size(128)
estimator.set_num_workers(1)
estimator.set_verbosity(0)
estimator.set_validation_split(0.15)
estimator.set_categorical_labels(False)
# estimator.set_nb_classes(nb_classes)
# Fitting a model returns a Transformer
pipeline = Pipeline(stages=[string_indexer, scaler, estimator])
fitted_pipeline = pipeline.fit(train_df)
# Evaluate Spark model
prediction = fitted_pipeline.transform(train_df)
pnl = prediction.select("index_category", "prediction")
pnl.show(2)
Unfortunately it gives error:
19/04/10 05:32:34 WARN TaskSetManager: Lost task 0.0 in stage 459.0 (TID 6450, local[0], executor 103): org.apache.spark.api.python.PythonException: Traceback (most rece
nt call last):
File "/usr/lib/spark/python/pyspark/worker.py", line 372, in main
process()
File "/usr/lib/spark/python/pyspark/worker.py", line 367, in process
serializer.dump_stream(func(split_index, iterator), outfile)
File "/usr/lib/spark/python/pyspark/serializers.py", line 390, in dump_stream
vs = list(itertools.islice(iterator, batch))
File "/usr/lib/spark/python/pyspark/util.py", line 99, in wrapper
return f(*args, **kwargs)
File "/usr/lib/spark/python/pyspark/sql/session.py", line 730, in prepare
verify_func(obj)
File "/usr/lib/spark/python/pyspark/sql/types.py", line 1389, in verify
verify_value(obj)
File "/usr/lib/spark/python/pyspark/sql/types.py", line 1368, in verify_struct
"length of fields (%d)" % (len(obj), len(verifiers))))
ValueError: Length of object (7) does not match with length of fields (5)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:452)
at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:588)
at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:571)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:406)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$1.hasNext(WholeStageCodegenExec.scala:619)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:255)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:121)
at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:402)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:408)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:2039)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:2027)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:2026)
at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2026)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:966)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:966)
at scala.Option.foreach(Option.scala:257)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:966)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2260)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2209)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2198)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:777)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:365)
at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3384)
at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2545)
at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3365)
at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78)
at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73)
at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3364)
at org.apache.spark.sql.Dataset.head(Dataset.scala:2545)
at org.apache.spark.sql.Dataset.take(Dataset.scala:2759)
at org.apache.spark.sql.Dataset.getRows(Dataset.scala:255)
at org.apache.spark.sql.Dataset.showString(Dataset.scala:292)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:282)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
File "/usr/lib/spark/python/pyspark/worker.py", line 372, in main
process()
File "/usr/lib/spark/python/pyspark/worker.py", line 367, in process
serializer.dump_stream(func(split_index, iterator), outfile)
File "/usr/lib/spark/python/pyspark/serializers.py", line 390, in dump_stream
vs = list(itertools.islice(iterator, batch))
File "/usr/lib/spark/python/pyspark/util.py", line 99, in wrapper
return f(*args, **kwargs)
File "/usr/lib/spark/python/pyspark/sql/session.py", line 730, in prepare
verify_func(obj)
File "/usr/lib/spark/python/pyspark/sql/types.py", line 1389, in verify
verify_value(obj)
File "/usr/lib/spark/python/pyspark/sql/types.py", line 1368, in verify_struct
"length of fields (%d)" % (len(obj), len(verifiers))))
ValueError: Length of object (7) does not match with length of fields (5)
Issue Analytics
- State:
- Created 4 years ago
- Reactions:1
- Comments:13
Top Results From Across the Web
Five Regression Analysis Tips to Avoid Common Problems
Regression analysis is powerful but presents various pitfalls. Learn five tips that help you avoid common problems and make the modeling process easier....
Read more >Lesson 10: Regression Pitfalls | STAT 462
Know the main issues surrounding other regression pitfalls, including overfitting, excluding important predictor variables, extrapolation, missing data, ...
Read more >Five Obstacles faced in Linear Regression | by Aayush Ostwal
Remedy: The solution to the problem mentioned above is to plot Residual Plots. Residual plots are the plot between the residual, the difference ......
Read more >Is there ever a reason to solve a regression problem as a ...
Real-world example: predicting workplace performance based on test scores is a regression problem trainable on past hirees' data, which can be ...
Read more >Random Forest Regression: When Does It Fail and Why?
This is to say that when the Random Forest Regressor is tasked with the problem of predicting for values not previously seen, it...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Any update on this issue?
It seems to be broken with new versions. I ran regression example on Google colab and gave these results: