窗口函数over

窗口函数会对一组数据(以某种方式相关联)执行计算. 窗口函数不会像非窗口聚合调用那样导致行被分组为单个输出(group_by的结果的行数是分组的数量).

我们来看下代码代码演示加深下理解

Example1

我们有一组学生成绩和分组的数据,现在我们想得到每个人在自己组内的排名, 注意不是在所有人中的排名

1import polars as pl
2
3df = pl.DataFrame({
4    "group": ["A", "A", "A", "B", "B"],
5    "name": ["Alice", "Bob", "Charlie", "David", "Eve"],
6    "score": [85, 95, 90, 75, 80],
7})
8res = df.with_columns(
9    pl.col("score")
10    .rank("dense", descending=True)     # 在每组内按照分数降序进行连续排名
11    .over("group")                      # 表示这个排名是在每个group内进行的
12    .alias("rank")
13)
14print(res)

我们可以观察到下面的结果, 总行数不变, 新增了一列, 每行数据都有自己在组内的排名, 这就是窗口函数的作用

1shape: (5, 4)
2┌───────┬─────────┬───────┬──────┐
3│ group ┆ name    ┆ score ┆ rank │
4│ ---   ┆ ---     ┆ ---   ┆ ---  │
5│ str   ┆ str     ┆ i64   ┆ u32  │
6╞═══════╪═════════╪═══════╪══════╡
7│ A     ┆ Alice   ┆ 85    ┆ 3    │
8│ A     ┆ Bob     ┆ 95    ┆ 1    │
9│ A     ┆ Charlie ┆ 90    ┆ 2    │
10│ B     ┆ David   ┆ 75    ┆ 2    │
11│ B     ┆ Eve     ┆ 80    ┆ 1    │
12└───────┴─────────┴───────┴──────┘

Example2

计算在每个a分组内, c列的最大值, 并且将这个最大值赋给每一行

1import polars as pl
2df = pl.DataFrame(
3    {
4        "a": ["a", "a", "b", "b", "b"],
5        "b": [1, 2, 3, 5, 3],
6        "c": [5, 4, 3, 2, 1],
7    }
8)
9res = df.with_columns(
10	pl.col("c").max().over("a").alias("c_max")
11)
12print(res)
1shape: (5, 4)
2┌─────┬─────┬─────┬───────┐
3│ a   ┆ b   ┆ c   ┆ c_max │
4│ --- ┆ --- ┆ --- ┆ ---   │
5│ str ┆ i64 ┆ i64 ┆ i64   │
6╞═════╪═════╪═════╪═══════╡
7│ a   ┆ 1   ┆ 5   ┆ 5     │
8│ a   ┆ 2   ┆ 4   ┆ 5     │
9│ b   ┆ 3   ┆ 3   ┆ 3     │
10│ b   ┆ 5   ┆ 2   ┆ 3     │
11│ b   ┆ 3   ┆ 1   ┆ 3     │
12└─────┴─────┴─────┴───────┘

Example3

下面代码高亮的第10行的意思是对每个group组中的value做求和, 然后把每组的结果广播到这个组内的所有行, 下面紧接着就会说怎么广播回去

1import polars as pl
2
3df = pl.DataFrame({
4    "group": ["a", "a", "b", "b"],
5    "value": [1, 2, 3, 4]
6})
7
8res = df.select(
9	pl.all(),
10	pl.col("value").sum().over("group").alias("sum")
11)
12print(res)
1shape: (4, 3)
2┌───────┬───────┬─────┐
3│ group ┆ value ┆ sum │
4│ ---   ┆ ---   ┆ --- │
5│ str   ┆ i64   ┆ i64 │
6╞═══════╪═══════╪═════╡
7│ a     ┆ 1     ┆ 3   │
8│ a     ┆ 2     ┆ 3   │
9│ b     ┆ 3     ┆ 7   │
10│ b     ┆ 4     ┆ 7   │
11└───────┴───────┴─────┘

结果值映射策略mapping_strategy

over用于: 分组后广播结果, 那么怎么广播映射回去呢

join

我们先看最简单的, join将结果值组织成一个列表, 然后映射到每一行中

下面高亮的第11行代码对c这一列按分组排名, 然后把每个分组内的所有rank值聚合为一个列表, 然后把这个列表结果复制到组内的所有行上, 每个组的rank列是一样的

WARNING

这可能很占内存!

1import polars as pl
2df = pl.DataFrame(
3    {
4        "a": ["a", "a", "b", "b", "b"],
5        "c": [5, 4, 3, 2, 1],
6    }
7)
8res = df.select(
9	pl.col("a"),
10	pl.col("c"),
11	pl.col("c").rank("dense").over("a", mapping_strategy="join").alias("rank")
12)
13print(res)
1shape: (5, 3)
2┌─────┬─────┬───────────┐
3│ a   ┆ c   ┆ rank      │
4│ --- ┆ --- ┆ ---       │
5│ str ┆ i64 ┆ list[u32] │
6╞═════╪═════╪═══════════╡
7│ a   ┆ 5   ┆ [2, 1]    │
8│ a   ┆ 4   ┆ [2, 1]    │
9│ b   ┆ 3   ┆ [3, 2, 1] │
10│ b   ┆ 2   ┆ [3, 2, 1] │
11│ b   ┆ 1   ┆ [3, 2, 1] │
12└─────┴─────┴───────────┘

需要注意的是, 如果聚合函数只有一个值,那么最后不会被聚合为一个列表, 即使我们指定了mapping_strategy="join", 比如pl.col("c").max(), 我们看下面代码

1import polars as pl
2df = pl.DataFrame(
3    {
4        "a": ["a", "a", "b", "b", "b"],
5        "c": [5, 4, 3, 2, 1],
6    }
7)
8res = df.select(
9	pl.col("a"),
10	pl.col("c"),
11	pl.col("c").max().over("a", mapping_strategy="join").alias("max")
12)
13print(res)
1shape: (5, 3)
2┌─────┬─────┬─────┐
3│ a   ┆ c   ┆ max │
4│ --- ┆ --- ┆ --- │
5│ str ┆ i64 ┆ i64 │
6╞═════╪═════╪═════╡
7│ a   ┆ 5   ┆ 5   │
8│ a   ┆ 4   ┆ 5   │
9│ b   ┆ 3   ┆ 3   │
10│ b   ┆ 2   ┆ 3   │
11│ b   ┆ 1   ┆ 3   │
12└─────┴─────┴─────┘

group_to_rows

会把聚合结果映射回去, 并且组内元素不变, 顺序不乱, 要求聚合操作是可以一一映射回去的

1import polars as pl
2df = pl.DataFrame(
3    {
4        "a": ["a", "a", "b", "b", "b"],
5        "c": [5, 4, 3, 2, 1],
6    }
7)
8res = df.select(
9	pl.col("a"),
10	pl.col("c"),
11	pl.col("c").max().over("a", mapping_strategy="group_to_rows").alias("max")
12)
13print(res)

我们看到原来的顺序是没有改变的

1shape: (5, 3)
2┌─────┬─────┬─────┐
3│ a   ┆ c   ┆ max │
4│ --- ┆ --- ┆ --- │
5│ str ┆ i64 ┆ i64 │
6╞═════╪═════╪═════╡
7│ a   ┆ 5   ┆ 5   │
8│ a   ┆ 4   ┆ 5   │
9│ b   ┆ 3   ┆ 3   │
10│ b   ┆ 2   ┆ 3   │
11│ b   ┆ 1   ┆ 3   │
12└─────┴─────┴─────┘

explode

就像group_by+agg+explode一样, 可能会改变原来行的顺序.

将聚合结果展开到一个新的列表中, 要求展开后的结果长度与原始行数相同

下面代码会报错, 因为一共2个分组, 每个分组各有一个sum值, 这个结果的总长度是2, 但是原来有4行, 长度不匹配就会报错

1import polars as pl
2from polars.exceptions import ShapeError
3
4df = pl.DataFrame({
5    "group": ["a", "a", "b", "b"],
6    "value": [1, 2, 3, 4]
7})
8try:
9    res = df.select(
10        pl.all(),
11        pl.col("value").sum().over("group", mapping_strategy="explode").alias("sum")
12    )
13    print(res)
14except ShapeError as e:
15	print(e)
1Series length 2 doesn't match the DataFrame height of 4

我们换一个聚合函数rank就可以解决了

1import polars as pl
2from polars.exceptions import ShapeError
3
4df = pl.DataFrame({
5    "group": ["a", "a", "b", "b"],
6    "value": [1, 2, 3, 4]
7})
8try:
9    res = df.select(
10        pl.all(),
11        pl.col("value").rank().over("group", mapping_strategy="explode").alias("sum")
12    )
13    print(res)
14except ShapeError as e:
15	print(e)
1shape: (4, 3)
2┌───────┬───────┬─────┐
3│ group ┆ value ┆ sum │
4│ ---   ┆ ---   ┆ --- │
5│ str   ┆ i64   ┆ f64 │
6╞═══════╪═══════╪═════╡
7│ a     ┆ 1     ┆ 1.0 │
8│ a     ┆ 2     ┆ 2.0 │
9│ b     ┆ 3     ┆ 1.0 │
10│ b     ┆ 4     ┆ 2.0 │
11└───────┴───────┴─────┘

我们接着来看下面的代码, 如果不指定排序, 可能会出现预期之外的错误, 顺序发生了错误

1import polars as pl
2
3df = pl.DataFrame({
4    "group": ["a", "b", "a", "b"],
5    "value": [1, 2, 3, 4]
6})
7
8res = df.select(
9    pl.all(),
10    pl.col("value").over("group", mapping_strategy="explode").alias("no_sort")
11)
12print(res)

注意原本顺序是abab, 但我们看到第三列的结果, 无法和第一列对应上了, 为什么呢, 我们来说下: 一共分为2组:

  • a: [1,3]
  • b: [2,4] 然后将结果拼接成一个新的列表: [1,3,2,4],然后再把新的列表映射回原来的行!
1shape: (4, 3)
2┌───────┬───────┬─────────┐
3│ group ┆ value ┆ no_sort │
4│ ---   ┆ ---   ┆ ---     │
5│ str   ┆ i64   ┆ i64     │
6╞═══════╪═══════╪═════════╡
7│ a     ┆ 1     ┆ 1       │
8│ b     ┆ 2     ┆ 3       │
9│ a     ┆ 3     ┆ 2       │
10│ b     ┆ 4     ┆ 4       │
11└───────┴───────┴─────────┘

我们来使用排序来解决一下

1import polars as pl
2
3df = pl.DataFrame({
4    "group": ["a", "b", "a", "b"],
5    "value": [1, 2, 3, 4]
6})
7
8res = df.select(
9    pl.all().sort_by(pl.col("group")),
10    pl.col("value").over("group", mapping_strategy="explode").alias("no_sort")
11)
12print(res)
1shape: (4, 3)
2┌───────┬───────┬─────────┐
3│ group ┆ value ┆ no_sort │
4│ ---   ┆ ---   ┆ ---     │
5│ str   ┆ i64   ┆ i64     │
6╞═══════╪═══════╪═════════╡
7│ a     ┆ 1     ┆ 1       │
8│ a     ┆ 3     ┆ 3       │
9│ b     ┆ 2     ┆ 2       │
10│ b     ┆ 4     ┆ 4       │
11└───────┴───────┴─────────┘
NOTE

explode模式下, Polars不跟踪每行的位置, 所以通常会比group_to_rows更快