发布日期:2022-08-03
VIP内容
图分析示例:简单航班数据分析
本节通过一个简单的示例,分析3个航线信息,每条航线中所有列信息见下表。
用网络图表示的话,如下图所示。
在构建的图模型中,将机场表示为顶点,航线表示为边。图中有三个顶点,每个顶点代表一个机场。每个顶点都有机场代码作为ID,机场所在城市名称作为属性。表示机场的顶点见下表。
边具有源ID、目标ID和作为属性的距离及延误时间。表示航线的边见下表。
接下来使用GraphFrames进行分析。请按以下步骤操作。
(1) 首先,导入相关的依赖包,代码如下:
import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions._ import org.graphframes.GraphFrame
(2) 定义两个case class分别用来表示顶点和边的Schema,代码如下:
// case class定义顶点和边的schema case class Airport(id: String, city: String) extends Serializable case class Flight(id: String, src: String,dst: String, dist: Double, delay: Double) extends Serializable
(3) 定义顶点。
将机场定义为顶点。顶点DataFrame必须有一个ID列,并且可能有多个属性列。在本例中,每个机场顶点由(顶点id->id,顶点属性->city)组成,代码如下:
// 创建顶点DataFrame
val vertices = spark.createDataFrame(
Array(
Airport("SFO","San Francisco"),
Airport("ORD","Chicago"),
Airport("DFW","Dallas Fort Worth")
)
)
// 查看顶点
vertices.show()
执行以上代码,输出内容如下:
+---+-----------------+ | id| city| +---+-----------------+ |SFO| San Francisco| |ORD| Chicago| |DFW|Dallas Fort Worth| +---+-----------------+
(4) 定义边
边是机场之间的航班。一个边DataFrame必须有src和dst列,并且可能有多个关系列。在本例中,边包括如下内容:
(1) 源ID → src。 (2) 目标ID → dst。 (3) 边属性:距离 → dist。 (4) 边属性:延误 → delay。
定义边的代码如下:
val edges = spark.createDataFrame(
Array(
Flight("SFO_ORD_2017-01-01_AA","SFO","ORD",1800, 40),
Flight("ORD_DFW_2017-01-01_UA","ORD","DFW",800, 0),
Flight("DFW_SFO_2017-01-01_DL","DFW","SFO",1400, 10)
)
)
edges.show(false)
执行代码,输出结果如下:
+---------------------+---+---+------+-----+ |id |src|dst|dist |delay| +---------------------+---+---+------+-----+ |SFO_ORD_2017-01-01_AA|SFO|ORD|1800.0|40.0 | |ORD_DFW_2017-01-01_UA|ORD|DFW|800.0 |0.0 | |DFW_SFO_2017-01-01_DL|DFW|SFO|1400.0|10.0 | +---------------------+---+---+------+-----+
(5) 构建GraphFrame。
通过提供顶点DataFrame和边DataFrame来创建一个GraphFrame。也可以只使用一个边DataFrame创建一个GraphFrame,然后从边DataFrame的src和dst列获得顶点,代码如下:
// 定义图 val graph = GraphFrame(vertices, edges) // 显示图的顶点 graph.vertices.show() // 显示图的边 graph.edges.show(false)
6)现在可以查询GraphFrame来回答以下问题:
// 有多少个机场?
println("\n有多少个机场?" + graph.vertices.count())
// 机场之间有多少航线?
println("\n机场之间有多少航线?" + graph.edges.count())
// 哪条航线的距离大于1000英里?
println("\n哪些航线的距离大于1000英里?")
graph.edges.filter("dist > 1000").show(false)
// 距离最长的航线是?
println("\n距离最长的航线是?")
graph.edges
.groupBy("src", "dst")
.agg(max("dist").as("longest"))
.sort(desc("longest"))
.show(false)
执行上面的代码,输出内容如下:
有多少个机场?3 机场之间有多少航线?3 哪些航线的距离大于1000英里? +---------------------+---+---+------+-----+ |id | src|dst|dist |delay| +---------------------+---+---+------+-----+ |SFO_ORD_2017-01-01_AA|SFO|ORD|1800.0|40.0 | |DFW_SFO_2017-01-01_DL|DFW|SFO|1400.0|10.0 | +---------------------+---+---+------+-----+ 距离最长的航线是? +---+---+-------+ |src|dst|longest| +---+---+-------+ |SFO|ORD|1800.0 | |DFW|SFO|1400.0 | |ORD|DFW|800.0 | +---+---+-------+