折叠

Polars提供了许多表达式来执行跨列计算, 例如sum_horizontal,mean_horizontal,min_horizontal. 然而这些只是称为折叠的通用算法的 特殊情况, 当Polars提供的不够用时, 我们可以使用Polars提供的通用机制来自定义折叠

使用fold函数计算的折叠操作会针对整列进行操作, 以实现最高速度. 它们能够高效地利用数据布局, 并且通常矢量化执行

工作方式

对于每一行来说, 设置一个初始值acc, 然后遍历exprs中的每个表达式, 每一步都调用我们定义的函数, 把acc和表达式的结果(这一行的值)传进去, 再更新acc, 这样最后的acc就是这一行的计算结果

1pl.fold(
2    acc=初始值,                     # 第一个传入 add 的 acc 值
3    function=add,                   # 每一步使用的函数
4    exprs=[expr1, expr2, ...]       # 要依次传入的表达式
5)
1acc = 初始值
2for expr in exprs:
3    acc = add(acc, expr的每行结果)

Example1

我们来计算a+b+c

1import polars as pl
2
3df = pl.DataFrame({
4    "a": [1, 2, 3],
5    "b": [4, 5, 6],
6    "c": [7, 8, 9],
7})
8
9# 使用 fold 对每一行进行水平求和
10result = df.select(
11    pl.fold(
12        acc=pl.lit(0),  # 初始累加器为0
13        function=lambda acc, x: acc + x, # 定义累加函数:当前累加器 + 当前列的值
14        exprs=pl.all()  # 对所有列执行计算, 把每一列中的每一行依次传入function中
15    ).alias("sum_fold"),
16	pl.sum_horizontal(pl.col("a","b","c")).alias("sum_horz")    # 对这三列按行求和
17)
18
19print(result)
1shape: (3, 2)
2┌──────────┬──────────┐
3│ sum_fold ┆ sum_horz │
4│ ---      ┆ ---      │
5│ i64      ┆ i64      │
6╞══════════╪══════════╡
7│ 12       ┆ 12       │
8│ 15       ┆ 15       │
9│ 18       ┆ 18       │
10└──────────┴──────────┘

Example2

这个例子我们来计算a^2+b-c, 使用两种方式来定义表达式

1import polars as pl
2
3df = pl.DataFrame({
4    "a": [1, 2, 3],
5    "b": [4, 5, 6],
6    "c": [7, 8, 9],
7})
8def add(acc: pl.Series, x: pl.Series) -> pl.Series:
9	return acc+x
10df.select(
11	pl.all(),
12	pl.fold(
13		acc=pl.col("a")*pl.col("a"),
14		function=add,
15		exprs=[pl.col("b"), -pl.col("c")]
16    ).alias("expr_list_a^2+b-c"),
17	pl.fold(
18		acc=pl.col("a")*pl.col("a"),
19		function=add,
20		exprs=pl.col("b")-pl.col("c")
21    ).alias("expr_a^2+b-c")
22)
1shape: (3, 5)
2┌─────┬─────┬─────┬───────────────────┬──────────────┐
3│ a   ┆ b   ┆ c   ┆ expr_list_a^2+b-c ┆ expr_a^2+b-c │
4│ --- ┆ --- ┆ --- ┆ ---               ┆ ---          │
5│ i64 ┆ i64 ┆ i64 ┆ i64               ┆ i64          │
6╞═════╪═════╪═════╪═══════════════════╪══════════════╡
7│ 1   ┆ 4   ┆ 7   ┆ -2                ┆ -2           │
8│ 2   ┆ 5   ┆ 8   ┆ 1                 ┆ 1            │
9│ 3   ┆ 6   ┆ 9   ┆ 6                 ┆ 6            │
10└─────┴─────┴─────┴───────────────────┴──────────────┘

初始值acc

如果要进行乘法运算, 我们的累加器就不能设置为0, 我们再来看一个例子, 加深理解

1import polars as pl
2
3df = pl.DataFrame({
4    "a": [1, 2, 3],
5    "b": [4, 5, 6],
6    "c": [7, 8, 9],
7})
8res = df.select(
9	pl.fold(
10		acc=pl.lit(0),
11		function=lambda acc, x: acc * x,
12		exprs=pl.all()
13    ).alias("prod")
14)
15print(res)

可以看到每一行的结果都是0, 因为对于每一行来说, 都是0 × 第一列的值 × 第二列的值 × 第三列的值, 结果必然是0

1shape: (3, 1)
2┌──────┐
3│ prod │
4│ ---  │
5│ i64  │
6╞══════╡
7│ 0    │
8│ 0    │
9│ 0    │
10└──────┘

想解决问题非常简单, acc设置为1即可, 这里我们使用了operator.mul, 可以简化我们的代码

1import polars as pl
2import operator
3df = pl.DataFrame({
4    "a": [1, 2, 3],
5    "b": [4, 5, 6],
6    "c": [7, 8, 9],
7})
8res = df.select(
9	pl.fold(
10		acc=pl.lit(1),
11		function=operator.mul,
12		exprs=pl.all()
13    ).alias("prod")
14)
15print(res)
1shape: (3, 1)
2┌──────┐
3│ prod │
4│ ---  │
5│ i64  │
6╞══════╡
7│ 28   │
8│ 80   │
9│ 162  │
10└──────┘

使用条件/过滤

我们可以很方便的使用条件语句来选择我们想要的行

1df = pl.DataFrame(
2    {
3        "a": [1, 2, 3],
4        "b": [0, 1, 2],
5    }
6)
7
8result = df.filter(
9    pl.fold(
10        acc=pl.lit(True),
11        function=lambda acc, x: acc & x,
12        exprs=pl.all() > 1,
13    )
14)
15print(result)
1shape: (1, 2)
2┌─────┬─────┐
3│ a   ┆ b   │
4│ --- ┆ --- │
5│ i64 ┆ i64 │
6╞═════╪═════╡
7│ 3   ┆ 2   │
8└─────┴─────┘

字符串

折叠操作可以用于连接字符串数据, 然而由于fold在每一步都会生成中间结果, 中间结果(字符串是不可变数据), 每次拼接都会生成新的字符串, 占用内存会很多, 所以这里的场景可以使用pl.concat_str来代替, 优化由Polars来完成

1df = pl.DataFrame(
2    {
3        "a": ["a", "b", "c"],
4        "b": [1, 2, 3],
5    }
6)
7
8result = df.select(
9    pl.concat_str(["a", "b"]).alias("ab"),
10	pl.concat_str(pl.col("a"),pl.lit("_"),pl.col("b")).alias("a_b")
11)
12print(result)
1shape: (3, 2)
2┌─────┬─────┐
3│ ab  ┆ a_b │
4│ --- ┆ --- │
5│ str ┆ str │
6╞═════╪═════╡
7│ a1  ┆ a_1 │
8│ b2  ┆ b_2 │
9│ c3  ┆ c_3 │
10└─────┴─────┘