折叠
Polars提供了许多表达式来执行跨列计算, 例如sum_horizontal
,mean_horizontal
,min_horizontal
. 然而这些只是称为折叠的通用算法的
特殊情况, 当Polars提供的不够用时, 我们可以使用Polars提供的通用机制来自定义折叠
使用fold
函数计算的折叠操作会针对整列进行操作, 以实现最高速度. 它们能够高效地利用数据布局, 并且通常矢量化执行
工作方式
对于每一行来说, 设置一个初始值acc
, 然后遍历exprs
中的每个表达式, 每一步都调用我们定义的函数, 把acc
和表达式的结果(这一行的值)传进去,
再更新acc
, 这样最后的acc
就是这一行的计算结果
1pl.fold(
2 acc=初始值, # 第一个传入 add 的 acc 值
3 function=add, # 每一步使用的函数
4 exprs=[expr1, expr2, ...] # 要依次传入的表达式
5)
1acc = 初始值
2for expr in exprs:
3 acc = add(acc, expr的每行结果)
Example1
我们来计算a+b+c
1import polars as pl
2
3df = pl.DataFrame({
4 "a": [1, 2, 3],
5 "b": [4, 5, 6],
6 "c": [7, 8, 9],
7})
8
9# 使用 fold 对每一行进行水平求和
10result = df.select(
11 pl.fold(
12 acc=pl.lit(0), # 初始累加器为0
13 function=lambda acc, x: acc + x, # 定义累加函数:当前累加器 + 当前列的值
14 exprs=pl.all() # 对所有列执行计算, 把每一列中的每一行依次传入function中
15 ).alias("sum_fold"),
16 pl.sum_horizontal(pl.col("a","b","c")).alias("sum_horz") # 对这三列按行求和
17)
18
19print(result)
1shape: (3, 2)
2┌──────────┬──────────┐
3│ sum_fold ┆ sum_horz │
4│ --- ┆ --- │
5│ i64 ┆ i64 │
6╞══════════╪══════════╡
7│ 12 ┆ 12 │
8│ 15 ┆ 15 │
9│ 18 ┆ 18 │
10└──────────┴──────────┘
Example2
这个例子我们来计算a^2+b-c
, 使用两种方式来定义表达式
1import polars as pl
2
3df = pl.DataFrame({
4 "a": [1, 2, 3],
5 "b": [4, 5, 6],
6 "c": [7, 8, 9],
7})
8def add(acc: pl.Series, x: pl.Series) -> pl.Series:
9 return acc+x
10df.select(
11 pl.all(),
12 pl.fold(
13 acc=pl.col("a")*pl.col("a"),
14 function=add,
15 exprs=[pl.col("b"), -pl.col("c")]
16 ).alias("expr_list_a^2+b-c"),
17 pl.fold(
18 acc=pl.col("a")*pl.col("a"),
19 function=add,
20 exprs=pl.col("b")-pl.col("c")
21 ).alias("expr_a^2+b-c")
22)
1shape: (3, 5)
2┌─────┬─────┬─────┬───────────────────┬──────────────┐
3│ a ┆ b ┆ c ┆ expr_list_a^2+b-c ┆ expr_a^2+b-c │
4│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
5│ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │
6╞═════╪═════╪═════╪═══════════════════╪══════════════╡
7│ 1 ┆ 4 ┆ 7 ┆ -2 ┆ -2 │
8│ 2 ┆ 5 ┆ 8 ┆ 1 ┆ 1 │
9│ 3 ┆ 6 ┆ 9 ┆ 6 ┆ 6 │
10└─────┴─────┴─────┴───────────────────┴──────────────┘
初始值acc
如果要进行乘法运算, 我们的累加器就不能设置为0, 我们再来看一个例子, 加深理解
1import polars as pl
2
3df = pl.DataFrame({
4 "a": [1, 2, 3],
5 "b": [4, 5, 6],
6 "c": [7, 8, 9],
7})
8res = df.select(
9 pl.fold(
10 acc=pl.lit(0),
11 function=lambda acc, x: acc * x,
12 exprs=pl.all()
13 ).alias("prod")
14)
15print(res)
可以看到每一行的结果都是0, 因为对于每一行来说, 都是0
× 第一列的值
× 第二列的值
× 第三列的值
, 结果必然是0
1shape: (3, 1)
2┌──────┐
3│ prod │
4│ --- │
5│ i64 │
6╞══════╡
7│ 0 │
8│ 0 │
9│ 0 │
10└──────┘
想解决问题非常简单, acc设置为1即可, 这里我们使用了operator.mul
, 可以简化我们的代码
1import polars as pl
2import operator
3df = pl.DataFrame({
4 "a": [1, 2, 3],
5 "b": [4, 5, 6],
6 "c": [7, 8, 9],
7})
8res = df.select(
9 pl.fold(
10 acc=pl.lit(1),
11 function=operator.mul,
12 exprs=pl.all()
13 ).alias("prod")
14)
15print(res)
1shape: (3, 1)
2┌──────┐
3│ prod │
4│ --- │
5│ i64 │
6╞══════╡
7│ 28 │
8│ 80 │
9│ 162 │
10└──────┘
使用条件/过滤
我们可以很方便的使用条件语句来选择我们想要的行
1df = pl.DataFrame(
2 {
3 "a": [1, 2, 3],
4 "b": [0, 1, 2],
5 }
6)
7
8result = df.filter(
9 pl.fold(
10 acc=pl.lit(True),
11 function=lambda acc, x: acc & x,
12 exprs=pl.all() > 1,
13 )
14)
15print(result)
1shape: (1, 2)
2┌─────┬─────┐
3│ a ┆ b │
4│ --- ┆ --- │
5│ i64 ┆ i64 │
6╞═════╪═════╡
7│ 3 ┆ 2 │
8└─────┴─────┘
字符串
折叠操作可以用于连接字符串数据, 然而由于fold
在每一步都会生成中间结果, 中间结果(字符串是不可变数据), 每次拼接都会生成新的字符串,
占用内存会很多, 所以这里的场景可以使用pl.concat_str
来代替, 优化由Polars来完成
1df = pl.DataFrame(
2 {
3 "a": ["a", "b", "c"],
4 "b": [1, 2, 3],
5 }
6)
7
8result = df.select(
9 pl.concat_str(["a", "b"]).alias("ab"),
10 pl.concat_str(pl.col("a"),pl.lit("_"),pl.col("b")).alias("a_b")
11)
12print(result)
1shape: (3, 2)
2┌─────┬─────┐
3│ ab ┆ a_b │
4│ --- ┆ --- │
5│ str ┆ str │
6╞═════╪═════╡
7│ a1 ┆ a_1 │
8│ b2 ┆ b_2 │
9│ c3 ┆ c_3 │
10└─────┴─────┘