[Scala] 一次打十二個--透過 Future 和 For-Comprehensive 解決 Callback 地獄

話說這幾年用 Non-Blocking 的典範來寫程式很夯(例如 NodeJS 上幾乎都是 Non-Blocking),不過用 Non-Blocking 寫程式的時候,有時候程式碼會因為 Callback 的關係變得很雜亂,需要花一些時間來整理和維護,俗稱 Callback 地獄,後來大家也發展了不少的技巧來避免和減清這個問題。

不過這個問題在 Scala 上其實不太嚴重,主要因為 Scala 從函式編程引入的 Monad 的觀念和他的 for 迴圈的語法實在是太強大了……透過 Scala 標準函式庫內建 Future 類別和 for 迴圈,我們可以很輕鬆地把本來是 Blocking 的操作變成 Non-Blocking 操作,並且解決掉這類的 Callback 地獄問題。

口說為憑,我們直接來看一個例子好了。

話說台北市的開放資料網站上提供了行人垃圾筒的物件,但是卻是十二個行政區都是獨立的網址。如果今天我們想要取得台北市內所有的行人垃圾筒的資料並整合成一個單獨的 List 物件,最簡單的方法就是用 Blocking 的方式,從一個網址取回資料後,把回傳的資料塞到 List 之後,再接著抓第二個網址。

當然,這樣的作法顯然有點笨也很花時間。為什麼不一次抓十二個網址呢?然後等所有的網址都抓完之後再把資料整合到一個 List 物件中,然後等到全部都抓完之後再通知我們印出 List 的內容?

雖然這個東西看似很簡單,但如果是在其他的程式語言中,可能會需要有一些注意的事項(例如你怎麼知道十二個網址都抓完了,該執行列印的步驟了?),但 Scala 的 Future 和 for 語法可以讓我們很簡單的達成這件事。

在這邊要注意的是,因為 Scala 本身就是個多典範的程式語言,所以我們不需要侷限在一開始就要用 Non-Blocking 的思考模式來寫程式,可以直接從我們習慣的會 Blocking 程式開始,然後再透過 Future 來讓 Scala 幫我們把苦工做掉。

在開始之前,因為台北市政府開放資料回傳的資料是 JSON 格式,我們需要透過 json4s 這個函式庫來幫我們解所 JSON 格式,如果是使用 SBT 當做建置系統的話,只要在 build.sbt 檔裡加入下面這行就可以了:

libraryDependencies += "org.json4s" %% "json4s-native" % "3.3.0"

接下來我們先實作最基礎的部份,也就是專注在「把資料從單一網址上抓下來並轉換成 List」這件事,當我們在寫這段程式碼的時候,也不需要考慮什麼 Non-blocking 的問題。所以我們可以寫出下面的程式碼:

import scala.io.Source
import org.json4s._
import org.json4s.native.JsonMethods._
import scala.util.Try

case class TrashCan(road: String, section: String, latitude: Double, longitude: Double)

object DataGetter {

  def getDataList(jsonDataURL: String): Try[List[TrashCan]] = Try {

    println("Get data from " + jsonDataURL + "....")

    val getJsonData = parse(Source.fromURL(jsonDataURL).mkString)
    val jsonResult = getJsonData \\ "results"
    val JArray(jsonObjects) = jsonResult

    jsonObjects.map { jsonObject =>
      TrashCan(
        (jsonObject \ "路名").values.toString,
        (jsonObject \ "段、號及其他註明").values.toString,
        (jsonObject \ "緯度").values.toString.toDouble,
        (jsonObject \ "經度").values.toString.toDouble
      )
    }
  }

  def main(args: Array[String]) {
  }
}

如果熟悉 Scala,那麼對這個程式碼應該不會太陌生。我們透過 Source.fromURL 函式抓取網頁裡的內容,然後再把回傳的資料轉換成一個 List[TrashCan] 物件。

唯一比較特別的,是我們把整個函式用 Try {} 區塊包起來了,因為我們很有可能會遇到錯誤(網路不通,網頁回傳的資料不是正確的 JSON 格式等),所以需要用一個 Try 區塊包起來。這樣子的話,如果這個函式確實成功執行完畢,那麼會拿到一個 Success[List[TrashCan]] 的容器物件,其內容就是我們的 List 物件。另一方面,如果在過程中失敗了,那我們會得到一個 Failure 物件,告知我們沒有拿到正確的資料。

接下來,我們加入「查詢十二個網頁並整合資料成一個單獨的 List」的功能,這方面一樣是透過 Scala 的 for 語法來達成的:

def getAggregateData: Try[List[TrashCan]] = {

  val data = Array(
    getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=97cc923a-e9ee-4adc-8c3d-335567dc15d3"),
    getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=5fa14e06-018b-4851-8316-1ff324384f79"),
    getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=f40cd66c-afba-4409-9289-e677b6b8d00e"),
    getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=33b2c4c5-9870-4ee9-b280-a3a297c56a22"),
    getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=0b544701-fb47-4fa9-90f1-15b1987da0f5"),
    getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=37eac6d1-6569-43c9-9fcf-fc676417c2cd"),
    getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=46647394-d47f-4a4d-b0f0-14a60ac2aade"),
    getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=05d67de9-a034-4177-9f53-10d6f79e02cf"),
    getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=179d0fe1-ef31-4775-b9f0-c17b3adf0fbc"),
    getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=8cbb344b-83d2-4176-9abd-d84508e7dc73"),
    getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=7b955414-f460-4472-b1a8-44819f74dc86"),
    getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=5697d81f-7c9d-43fc-a202-ae8804bbd34b")
  )

  for {
    dataOfArea01 <- data(0)
    dataOfArea02 <- data(1)
    dataOfArea03 <- data(2)
    dataOfArea04 <- data(3)
    dataOfArea05 <- data(4)
    dataOfArea06 <- data(5)
    dataOfArea07 <- data(6)
    dataOfArea08 <- data(7)
    dataOfArea09 <- data(8)
    dataOfArea10 <- data(9)
    dataOfArea11 <- data(10)
    dataOfArea12 <- data(11)
  } yield {
    dataOfArea01 ++ dataOfArea02 ++
    dataOfArea03 ++ dataOfArea04 ++
    dataOfArea05 ++ dataOfArea06 ++
    dataOfArea07 ++ dataOfArea08 ++
    dataOfArea09 ++ dataOfArea10 ++
    dataOfArea11 ++ dataOfArea12
  }

}

這段程式碼是這樣的,他依照順序先執行每一個 getDataList 函式,並且把結果放到 Array 當中。

而這個 for 迴圈會先檢查 data(0),如果這個是 Success 物件,那麼就把 Success 物件裡的 List 指定給 dataOfArea1 這個變數,然後再繼續檢查 data(1),如果一樣是 Success 物件,再把其中的 List 物件指定給 dataOfArea02 變數。依此類推,如果陣列裡 十二個元全都是 Success 物件,那麼就執行 yield 區塊裡的程式碼,並且把這個區塊的回傳值轉成一個 Success 物件,並且做為整個 for 區塊的回傳值(是的,Scala 裡的 for 是有回傳值的)。

如果陣列中其中任何一個元素是 Failure 物件,那整個 for 區塊的回傳值就會是一個 Failure 物件。

接下來再加上主程式碼:

def main(args: Array[String]) {

  val dataListHolder = getAggregateData

  for {
    dataList <- dataListHolder
    trashCan <- dataList
  } {
    println(trashCan)
  }
}

這個 for 迴圈一樣很簡單,如果 dataListHolder 是 Success 的(也就是全部十二個網址都抓到了),那麼就把資料放到 dataList 中,然後再把 dataList 中的每一個元素拉出來放到 trashCan 變數中,然後再印出來。

如果執行這隻程式的話,就會看見先印出一行 Get data from XXX... 後,等了好一陣子才印出第二行的 Get data from XXX... 字樣,等到十二個網址都處理完後,再一口氣印出所有的 TrashCan 物件。

很標準的 Blocking 程式碼,那我們要怎樣讓他變成一次打十二個呢?很簡單,我們只要把程式嗎裡所有的 Try 改成 Future 就好了!

啥?這樣就好了?!沒錯,就是這樣就好了!唯一要注意的只是要 import 相對應的東西而已。

import org.json4s._
import org.json4s.native.JsonMethods._
import scala.concurrent.duration._
import scala.concurrent.Future
import scala.concurrent.ExecutionContext
import scala.concurrent.Await
import scala.io.Source

case class TrashCan(road: String, section: String, latitude: Double, longitude: Double)

object DataGetter {

  import scala.concurrent.ExecutionContext.Implicits.global

  def getDataList(jsonDataURL: String): Future[List[TrashCan]] = Future {

    println("Get data from " + jsonDataURL + "....")

    val getJsonData = parse(Source.fromURL(jsonDataURL).mkString)
    val jsonResult = getJsonData \\ "results"
    val JArray(jsonObjects) = jsonResult

    jsonObjects.map { jsonObject =>
      TrashCan(
        (jsonObject \ "路名").values.toString,
        (jsonObject \ "段、號及其他註明").values.toString,
        (jsonObject \ "緯度").values.toString.toDouble,
        (jsonObject \ "經度").values.toString.toDouble
      )
    }
  }

  def getAggregateData: Future[List[TrashCan]] = {

    val data = Array(
      getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=97cc923a-e9ee-4adc-8c3d-335567dc15d3"),
      getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=5fa14e06-018b-4851-8316-1ff324384f79"),
      getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=f40cd66c-afba-4409-9289-e677b6b8d00e"),
      getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=33b2c4c5-9870-4ee9-b280-a3a297c56a22"),
      getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=0b544701-fb47-4fa9-90f1-15b1987da0f5"),
      getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=37eac6d1-6569-43c9-9fcf-fc676417c2cd"),
      getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=46647394-d47f-4a4d-b0f0-14a60ac2aade"),
      getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=05d67de9-a034-4177-9f53-10d6f79e02cf"),
      getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=179d0fe1-ef31-4775-b9f0-c17b3adf0fbc"),
      getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=8cbb344b-83d2-4176-9abd-d84508e7dc73"),
      getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=7b955414-f460-4472-b1a8-44819f74dc86"),
      getDataList("http://data.taipei/opendata/datalist/apiAccess?scope=resourceAquire&rid=5697d81f-7c9d-43fc-a202-ae8804bbd34b")
    )

    for {
      dataOfArea01 <- data(0)
      dataOfArea02 <- data(1)
      dataOfArea03 <- data(2)
      dataOfArea04 <- data(3)
      dataOfArea05 <- data(4)
      dataOfArea06 <- data(5)
      dataOfArea07 <- data(6)
      dataOfArea08 <- data(7)
      dataOfArea09 <- data(8)
      dataOfArea10 <- data(9)
      dataOfArea11 <- data(10)
      dataOfArea12 <- data(11)
    } yield {
      dataOfArea01 ++ dataOfArea02 ++
      dataOfArea03 ++ dataOfArea04 ++
      dataOfArea05 ++ dataOfArea06 ++
      dataOfArea07 ++ dataOfArea08 ++
      dataOfArea09 ++ dataOfArea10 ++
      dataOfArea11 ++ dataOfArea12
    }
  }

  def main(args: Array[String]) {

    val dataListHolder = getAggregateData

    for {
      dataList <- dataListHolder
      trashCan <- dataList
    } {
      println(trashCan)
    }
  }

}

這段程式其實是這樣的,當我們在初始化 Array 的時候,getDataList 就會被非同步的執行(透過 Thread Pool 的方式來處理),所以這宣告 Array 這一段就會變成 Non-Blocking。

而因為 data(0)data(11) 都是 Future 物件,所以實際上整個 for 區塊也都會變成 Non-blocking 的,會立刻回傳而不會被 Blocking 住。就這樣,我們的 getAggregateData 函式也自動從 Blocking 變成 Non-Blocking 了。

如果我們再執行一次程式,會發現資料還沒印出來程式就結束了。這是因為 getAggregateData 已經是 non-blocking 了,所以程式不會等他回傳的資料,讓我們改一下 main 函式,改成用 Await.result 這個函式,讓我們可以回到 blocking 模式,他會一直 blocking 到 Future 回傳資料,並且把 Future 回傳的資料做為回傳值。

def main(args: Array[String]) {

  val dataListHolder = getAggregateData
  val dataList = Await.result(dataListHolder, Duration.Inf)

  dataList.foreach(println _)
}

重新執行後,這次應該會發現程式碼會等到資料都印出來後才結束。不過似乎好使還是哪裡怪怪的?!喔……似乎好像沒有一口氣印出 12 個 Get data from XXX 的訊息啊?

沒錯,因為實際上 Future 是用 Thread Pool 下去實作的,而 Scala 預設的 Thread Pool 沒開到 12 個,自然不會一口氣跑 12 個非同步的工作。不過要修改這個設定也很簡單,我們只要拿掉原本的 import scala.concurrent.ExecutionContext.Implicits.global 並改成 implicit val ec = ExecutionContext.fromExecutor(new java.util.concurrent.ForkJoinPool(12)) 就可以了。

另外,既然是 Non-Blocking,沒有看到 Callback 似乎就混身不對勁呢,所以也來把 main 函式改成用 Callback 方來處理資料吧,順便加上如果失敗時要印出的錯誤錯息:

def main(args: Array[String]) {

  val nonBlockingDataList = getAggregateData

  // 功功時的 Callback
  nonBlockingDataList.onSuccess { case dataList =>
    dataList.foreach(println _)
  }

  // 失敗時的 Callback
  nonBlockingDataList.onFailure { case e: Exception =>
    println("出錯了……")
    e.printStackTrace()
  }

  // 一樣等到拿到資料再結束唄
  Await.result(nonBlockingDataList, Duration.Inf)
}

就這樣,我們完成了一次打十二個的任務。如何,是否覺得原來 Non-blocking 的程式也可以很清楚簡單呢?真的很建議大家可以學學 Scala 喔,真的很有趣也很方便喲。

最後,如果你好奇為什麼 Scala 的 for 迴圈可以做到這樣的功能的話,這是因為 Scala 的 for 語法其實不是迴圈,而只是一套轉譯成函式呼叫的規則,對於這個議題有興趣的朋友,可以看這篇文章 和這篇 FAQ

回響