ソースを参照

Add and use the Transform stage

theenglishway (time) 2 年 前
コミット
6012561273
2 ファイル変更77 行追加1 行削除
  1. 7 1
      src/main/scala/Main.scala
  2. 70 0
      src/main/scala/Transform.scala

+ 7 - 1
src/main/scala/Main.scala

@@ -2,6 +2,7 @@ import org.apache.spark.sql.SparkSession
 
 import game._
 import extract.Extract
+import transform.GamesAnalysis
 
 object Main extends App {
   val teams_output = os.pwd / "teams.json"
@@ -15,12 +16,17 @@ object Main extends App {
     "Milwaukee Bucks"
   )
 
-  println("Hello, World!")
   val teamFilter = (team: Team) => selectedTeams.contains(team.full_name)
 
   val teams = Extract.getTeams(teams_output, teamFilter)
   println(teams)
+
   val games = Extract.getGames(games_output, 2021, teams.map(_.id))
   println(games.size)
+
   val stats = Extract.getStats(stats_output, games.map(_.id))
   println(stats.size)
+
+  val analysis = GamesAnalysis(teams_output, games_output, stats_output)
+  analysis.writeToCsv("output")
+}

+ 70 - 0
src/main/scala/Transform.scala

@@ -0,0 +1,70 @@
+package transform
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.functions.{when, sum}
+
+case class GamesAnalysis(teams: os.Path, games: os.Path, stats: os.Path) {
+  val spark = SparkSession
+    .builder()
+    .appName("balldontlie")
+    .config("spark.master", "local")
+    .getOrCreate()
+  import spark.implicits._
+
+  val teams_df = readInput(teams)
+  val games_df = readInput(games)
+  val stats_df = readInput(stats)
+
+  val teams_games = teams_df
+    .as("t")
+    .join(games_df.as("g"))
+    .where(
+      $"g.home_team.id" === $"t.id"
+        || $"g.visitor_team.id" === $"t.id"
+    )
+    .withColumn(
+      "team_score",
+      when(
+        $"t.id" === $"g.home_team.id",
+        $"g.home_team_score"
+      )
+        .otherwise($"g.visitor_team_score")
+    )
+    .select(
+      $"g.id".alias("game_id"),
+      $"t.id".alias("team_id"),
+      $"t.full_name".alias("team_full_name"),
+      $"team_score"
+    )
+
+  val stats_games =
+    stats_df
+      .join(games_df, stats_df("game_id") === games_df("id"))
+      .join(teams_df, stats_df("team.id") === teams_df("id"))
+
+  val stats_games_pts = stats_games
+    .groupBy($"game.id".alias("game_id"), $"team.id".alias("team_id"))
+    .agg(
+      sum($"pts").alias("pts"),
+      sum($"ast").alias("ast"),
+      sum($"blk").alias("blk"),
+      sum($"reb").alias("reb")
+    )
+
+  val merged = teams_games
+    .as("tg")
+    .join(stats_games_pts.as("sgp"))
+    .where(
+      $"tg.game_id" === $"sgp.game_id"
+        && $"tg.team_id" === $"sgp.team_id"
+    )
+    .drop($"sgp.game_id", $"sgp.team_id")
+
+  def writeToCsv(output: String) = {
+    merged.write.option("header", true).mode("overwrite").csv(output)
+  }
+
+  private def readInput(filePath: os.Path) = {
+    spark.read.option("multiline", "true").json(filePath.toString())
+  }
+}