machine-learning Getting started with Machine Learning using Apache spark MLib Write your first classification problem using Logistic Regression model


I am using eclipse here, and you need to add below given dependency to your pom.xml


    <project xmlns="" xmlns:xsi=""




     <!-- Spark -->

2.) APP.JAVA(your application class)

We are doing classification based on country, hours and our label is clicked.

    package com.predection.classification.logisitcRegression;

import org.apache.spark.SparkConf;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import java.util.Arrays;
import java.util.List;
import org.apache.spark.sql.RowFactory;
import static org.apache.spark.sql.types.DataTypes.*;

 * Classification problem using Logistic Regression Model

public class App 
    public static void main( String[] args )
        SparkConf sparkConf = new SparkConf().setAppName("JavaLogisticRegressionExample");
        // Creating spark session
        SparkSession sparkSession = SparkSession.builder().config(sparkConf).getOrCreate();
        StructType schema = createStructType(new StructField[]{
                  createStructField("id", IntegerType, false),
                  createStructField("country", StringType, false),
                  createStructField("hour", IntegerType, false),
                  createStructField("clicked", DoubleType, false)

                List<Row> data = Arrays.asList(
                  RowFactory.create(7, "US", 18, 1.0),
                  RowFactory.create(8, "CA", 12, 0.0),
                  RowFactory.create(9, "NZ", 15, 1.0),

 RowFactory.create(10,"FR", 8, 0.0),
                  RowFactory.create(11, "IT", 16, 1.0),
              RowFactory.create(12, "CH", 5, 0.0),
              RowFactory.create(13, "AU", 20, 1.0)

Dataset<Row> dataset = sparkSession.createDataFrame(data, schema);        

// Using stringindexer transformer to transform string into index
 dataset = new StringIndexer().setInputCol("country").setOutputCol("countryIndex").fit(dataset).transform(dataset);
// creating feature vector using dependent variables countryIndex, hours are features and clicked is label
VectorAssembler assembler = new VectorAssembler()
        .setInputCols(new String[] {"countryIndex", "hour"})

    Dataset<Row> finalDS = assembler.transform(dataset);
    // Split the data into training and test sets (30% held out for
    // testing).
        Dataset<Row>[] splits = finalDS.randomSplit(new double[] { 0.7, 0.3 });
        Dataset<Row> trainingData = splits[0];
        Dataset<Row> testData = splits[1];;;
        // Building LogisticRegression Model
        LogisticRegression lr = new LogisticRegression().setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8).setLabelCol("clicked");

        // Fit the model
        LogisticRegressionModel lrModel =;
        // Transform the model, and predict class for test dataset
        Dataset<Row> output = lrModel.transform(testData);;

3.) To run this application, first perform mvn-clean-package on application project, it would create jar. 4.) Open spark root directory, and submit this job

bin/spark-submit --class com.predection.regression.App --master local[2] ./regression-0.0.1-SNAPSHOT.jar(path to the jar file)

5.) After submitting see it builds training data

enter image description here

6.) same way test data

enter image description here

7.) And here is the prediction result under the prediction column

enter image description here