间隙树排序算法的可视化

做文档翻译的OCR程序时,会遇到这样一个场景,因为通常OCR模型的输出都是按文本块逐行返回,当结果进入翻译模型时会丢失行与行之间的信息。为了解决这个问题,需要对OCR结果进行进一步的版面分析,将文本块合并成段落,再输入到翻译模型中去。

算法

间隙·树·排序算法

参考链接:https://github.com/hiroi-sora/GapTree_Sort_Algorithm

算法主要对文本块按间隙进行划分,再经过树形排序,将底层OCR的输出结果从子文本块转化成段落文本块。

算法重构

为了使文本块的定位更加准确,使用原算法输出段落块(content_blocks)的上下左右四个边界点构成新段落块(new_blocks)

源码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def structure_ocr(self):
for line, tb in enumerate(self.blocks_data):
tb["bbox"] = self.bboxes[line]

gtree = GapTree(lambda tb: tb["bbox"])
sorted_text_blocks = gtree.sort(self.blocks_data) # 文本块排序
# print(sorted_text_blocks)
pp = ParagraphParse(self.get_info, self.set_end)
# 获取所有区块的文本块
nodes_text_blocks = gtree.get_nodes_text_blocks()
content_blocks = []
for tbs in nodes_text_blocks:
content = ""
tbs = pp.run(tbs) # 预测结尾分隔符
for tb in tbs: # 输出文本和结尾分隔符
content += tb["text"] + tb["end"]
content_blocks.append(content)

node_tbs = []
for node in gtree.current_nodes:
if not node["units"]:
continue # 跳过没有块的根节点
x0 = node["x_left"]
x1 = node["x_right"]
y0 = gtree.current_rows[node["r_top"]][0][0][1]
y1 = gtree.current_rows[node["r_bottom"]][0][0][3]
node_tbs.append([[x0, y0], [x1, y0], [x1, y1], [x0, y1]])

return node_tbs, content_blocks

修改后:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def structure_ocr(self):
for line, tb in enumerate(self.blocks_data):
tb["bbox"] = self.bboxes[line]

gtree = GapTree(lambda tb: tb["bbox"])
sorted_text_blocks = gtree.sort(self.blocks_data) # 文本块排序
# print(sorted_text_blocks)
pp = ParagraphParse(self.get_info, self.set_end)
# 获取所有区块的文本块
nodes_text_blocks = gtree.get_nodes_text_blocks()
content_blocks = []
for tbs in nodes_text_blocks:
content = ""
tbs = pp.run(tbs) # 预测结尾分隔符
for tb in tbs: # 输出文本和结尾分隔符
content += tb["text"] + tb["end"]
content_blocks.append(content)

new_blocks = []
for m in nodes_text_blocks:
x1_list = []
y1_list = []
x2_list = []
y2_list = []
for n in m:
box_ = n.get('box')
x1_list.append(box_[0][0])
y1_list.append(box_[0][1])
x2_list.append(box_[2][0])
y2_list.append(box_[2][1])
x1_ = min(x1_list)
y1_ = min(y1_list)
x2_ = max(x2_list)
y2_ = max(y2_list)
new_blocks.append([[x1_, y1_], [x2_, y1_], [x2_, y2_], [x1_, y2_]])

return new_blocks, content_blocks

封装

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class StructureOCR:
def __init__(self, blocks_data):
self.blocks_data = blocks_data
self.bboxes = linePreprocessing(self.blocks_data)

@staticmethod
def get_info(tb): # 返回信息
b = tb["box"]
return (b[0][0], b[0][1], b[2][0], b[2][1]), tb["text"]

@staticmethod
def set_end(tb, end): # 获取预测的块尾分隔符
tb["end"] = end
# also: tb["text"] += end

def structure_ocr(self):
for line, tb in enumerate(self.blocks_data):
tb["bbox"] = self.bboxes[line]

gtree = GapTree(lambda tb: tb["bbox"])
sorted_text_blocks = gtree.sort(self.blocks_data) # 文本块排序
# print(sorted_text_blocks)
pp = ParagraphParse(self.get_info, self.set_end)
# 获取所有区块的文本块
nodes_text_blocks = gtree.get_nodes_text_blocks()
content_blocks = []
for tbs in nodes_text_blocks:
content = ""
tbs = pp.run(tbs) # 预测结尾分隔符
for tb in tbs: # 输出文本和结尾分隔符
content += tb["text"] + tb["end"]
content_blocks.append(content)

node_tbs = []
for node in gtree.current_nodes:
if not node["units"]:
continue # 跳过没有块的根节点
x0 = node["x_left"]
x1 = node["x_right"]
y0 = gtree.current_rows[node["r_top"]][0][0][1]
y1 = gtree.current_rows[node["r_bottom"]][0][0][3]
node_tbs.append([[x0, y0], [x1, y0], [x1, y1], [x0, y1]])

new_blocks = []
for m in nodes_text_blocks:
x1_list = []
y1_list = []
x2_list = []
y2_list = []
for n in m:
box_ = n.get('box')
x1_list.append(box_[0][0])
y1_list.append(box_[0][1])
x2_list.append(box_[2][0])
y2_list.append(box_[2][1])
x1_ = min(x1_list)
y1_ = min(y1_list)
x2_ = max(x2_list)
y2_ = max(y2_list)
new_blocks.append([[x1_, y1_], [x2_, y1_], [x2_, y2_], [x1_, y2_]])

return new_blocks, content_blocks

调用

使用paddleOCR,将OCR结果转成算法输入的json_data

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
if __name__ == '__main__':
import base64
import requests

test_image = './test/t2.png'
origin_image = cv2.imread(test_image)
encoded = cv2_base64(origin_image)

json_data = {
"img_b64": encoded,
"lang": "cn"
}
response = requests.post('localhost:port/ai/ppocr/ai/ppocr', json=json_data).json()
ocr_result = response.get('data')
ocr_boxes = [line[0] for line in ocr_result[0]]
ocr_txts = [line[1][0] for line in ocr_result[0]]
ocr_scores = [line[1][1] for line in ocr_result[0]]

json_data = []

for i in range(len(ocr_result[0])):
json_data.append({
"box": [[int(i[0]), int(i[1])] for i in ocr_boxes[i]],
"score": ocr_scores[i],
"text": ocr_txts[i]
})

so = StructureOCR(json_data)
blocks, paragraphs = so.structure_ocr()

可视化

当文本成段后,不能直接通过draw.text等方法将文本写作一行,而是要设置每行文本不得超过段落块的长度,并且总行数不能超出段落块总长度。需要设计算法,当每行文本超出长度限制时,自动添加换行符。

段落换行算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class VisualizeOCR:
def __init__(self, im, boxes, texts):
self.boxes = boxes
self.texts = texts
if isinstance(im, str):
self.im = Image.open(im)
self.im = np.ascontiguousarray(np.copy(im))
self.im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
else:
self.im = np.ascontiguousarray(np.copy(im))
self.im = Image.fromarray(self.im)
self.im = self.im.convert('RGBA')
self.size = (int(self.im.size[0]), int(self.im.size[1]))

def split_text(self, width, sentence, font, text_scale):
# 按规定宽度分组
max_line_height, total_lines = 0, 0
allText = []
for sen in sentence.split('\n'):
paragraph, line_height, line_count = self.get_paragraph(sen, width, font, text_scale)
max_line_height = max(line_height, max_line_height)
total_lines += line_count
allText.append((paragraph, line_count))
line_height = max_line_height
total_height = total_lines * line_height
return allText, total_height, line_height

@staticmethod
def get_paragraph(text, width, font, text_scale):
# 字体像素较小时,换行效果不佳,每5个像素换行下移0.05个像素
text_size = 0
if text_scale <= 15:
text_size = 0.1
elif text_scale <= 10:
text_size = 0.15
elif text_scale <= 5:
text_size = 0.2

txt = Image.new('RGBA', (1033, 737), (255, 255, 255, 0))
draw = ImageDraw.Draw(txt)
# 所有文字的段落
paragraph = ""
# 宽度总和
sum_width = 0
# 行数
line_count = 1
# 行高
line_height = 0
for char in text:
_, _, w, h = draw.textbbox((0, 0), char, font=font)

sum_width += w
if sum_width > width: # 超过预设宽度就修改段落 以及当前行数
line_count += 1
line_count += text_size + 0.1
sum_width = 0
paragraph += '\n'
paragraph += char
line_height = max(h, line_height)
if not paragraph.endswith('\n'):
paragraph += '\n'
return paragraph, line_height, line_count

实现换行后,还需要考虑字体的大小,将文字锁定在文本框内。在pillow库中draw.text方法中中文字符的大小近似于其正方矩形像素块的边长。接下来需要考虑的参数变量有三个:

1.每行最多的字符x(x_num);

2.总行数y(y_num);

3.字体像素s(text_scale)

同时考虑到并不是每行都能够写满段落块的长度,可能存在提前换行、空行、结束等情况。可以计算段落的总换行符数量(count(‘\n’)),将换行符出现的地方视作空行,并×2算作换行+留白部分产生的像素面积。那么已知常量包括:

1.换行次数b(blank_scale);

2.总字数l(len(text));

3.段落框长度w(width_p,值取width*0.97,减少出界概率)

3.段落框宽度h(height)

设立方程组:

{

① (y-b)*x=l

② x*s = w

③ y*s = h

{

转换到代码中:

1
2
3
(y_num - blank_scale) * x_num = len(text)
x_num * text_scale = width_p
y_num * text_scale = height

简化方程组:

1
2
3
4
5
6
7
8
9
(y_num - blank_scale) * x_num = len(text)
x_num * text_scale = width_p
y_num * text_scale = height
===>
x_num = len(text) / (y_num - blank_scale)
x_num = width_p / text_scale
width_p / text_scale = len(text) / ((height / text_scale) - blank_scale)
===>
len(text) * text_scale * text_scale + blank_scale * len(text) * text_scale - width_p * height = 0

解方程

1
2
3
4
5
6
7
8
9
def quadratic(a, b, c):
n = b * b - 4 * a * c
import math
if n >= 0:
x1 = (-b + math.sqrt(n)) / (2 * a)
x2 = (-b - math.sqrt(n)) / (2 * a)
return x1 if x1 > 0 else x2
else:
raise

获取最佳像素值

经过实际观察,对于中文字符,将计算得到的像素大小-2后视觉效果更佳。同时可以根据图像大小、单文本行不换行等因素,选择最佳字体大小

1
2
3
4
5
text_scale = min(
int(quadratic(len(text), blank_scale * len(text), -width_p * height)) - 2,
int(self.size[0] / 66),
int(height/2)
)

封装

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
class VisualizeOCR:
def __init__(self, im, boxes, texts):
self.boxes = boxes
self.texts = texts
if isinstance(im, str):
self.im = Image.open(im)
self.im = np.ascontiguousarray(np.copy(im))
self.im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
else:
self.im = np.ascontiguousarray(np.copy(im))
self.im = Image.fromarray(self.im)
self.im = self.im.convert('RGBA')
self.size = (int(self.im.size[0]), int(self.im.size[1]))

def split_text(self, width, sentence, font, text_scale):
# 按规定宽度分组
max_line_height, total_lines = 0, 0
allText = []
for sen in sentence.split('\n'):
paragraph, line_height, line_count = self.get_paragraph(sen, width, font, text_scale)
max_line_height = max(line_height, max_line_height)
total_lines += line_count
allText.append((paragraph, line_count))
line_height = max_line_height
total_height = total_lines * line_height
return allText, total_height, line_height

@staticmethod
def get_paragraph(text, width, font, text_scale):
text_size = 0
if text_scale <= 15:
text_size = 0.1
elif text_scale <= 10:
text_size = 0.15
elif text_scale <= 5:
text_size = 0.2

txt = Image.new('RGBA', (1033, 737), (255, 255, 255, 0))
draw = ImageDraw.Draw(txt)
# 所有文字的段落
paragraph = ""
# 宽度总和
sum_width = 0
# 行数
line_count = 1
# 行高
line_height = 0
for char in text:
_, _, w, h = draw.textbbox((0, 0), char, font=font)

sum_width += w
if sum_width > width: # 超过预设宽度就修改段落 以及当前行数
line_count += 1
line_count += text_size + 0.1
sum_width = 0
paragraph += '\n'
paragraph += char
line_height = max(h, line_height)
if not paragraph.endswith('\n'):
paragraph += '\n'
return paragraph, line_height, line_count

def visualize_ocr(self):
im_canvas = Image.new('RGBA', self.size, (255, 255, 255, 1000))

for i, res in enumerate(self.texts):
if self.boxes is not None:
box = self.boxes[i]
x, y = box[0][0], box[0][1]
width = int(box[1][0] - box[0][0])
height = int(box[2][1] - box[1][1])
width_p = int((box[1][0] - box[0][0]) * 0.97)
text = res
if text == "":
continue

blank_line_count = text.count('\n')
blank_scale = blank_line_count * 2 if blank_line_count > 1 else blank_line_count

"""
方程组,求字体最大像素
x_num: x轴个数
y_num: y轴个数
text_scale: 文本像素

(y_num - blank_scale) * x_num = len(text)
x_num * text_scale = width_p
y_num * text_scale = height
===>
x_num = len(text) / (y_num - blank_scale)
x_num = width_p / text_scale
width_p / text_scale = len(text) / ((height / text_scale) - blank_scale)
===>
len(text) * text_scale * text_scale + blank_scale * len(text) * text_scale - width_p * height = 0
"""

text_scale = min(
int(quadratic(len(text), blank_scale * len(text), -width_p * height)) - 2,
int(self.size[0] / 66),
int(height/2)
)

font = ImageFont.truetype("SourceHanSansCN-Medium.otf", text_scale)
draw = ImageDraw.Draw(im_canvas)
paragraph, note_height, line_height = self.split_text(width_p, text, font, text_scale)

for sen, line_count in paragraph:
draw.text((x, y), sen, fill=(255, 0, 0), font=font)
y += line_height * line_count
draw.rectangle(
((box[0][0], box[0][1]), (box[2][0], box[2][1])),
fill=None,
outline=(139, 0, 139),
width=1)

im = image_join(self.im, im_canvas, 'x')
im = im.convert('RGB')
# 还原连续存储数组
im = np.ascontiguousarray(np.copy(im))
return im


def quadratic(a, b, c):
n = b * b - 4 * a * c
import math
if n >= 0:
x1 = (-b + math.sqrt(n)) / (2 * a)
x2 = (-b - math.sqrt(n)) / (2 * a)
return x1 if x1 > 0 else x2
else:
return '该一元二次方程无解'


def image_join(img1, img2, flag='y'):
size1, size2 = img1.size, img2.size
if flag == 'x':
im = Image.new("RGB", (size1[0] + size2[0], size1[1]))
loc1, loc2 = (0, 0), (size1[0], 0)
else:
im = Image.new("RGB", (size1[0], size2[1] + size1[1]))
loc1, loc2 = (0, 0), (0, size1[1])
im.paste(img1, loc1)
im.paste(img2, loc2)
return im

调用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
if __name__ == '__main__':
import base64
import requests

test_image = './test/t2.png'
origin_image = cv2.imread(test_image)
encoded = cv2_base64(origin_image)

json_data = {
"img_b64": encoded,
"lang": "cn"
}
response = requests.post('localhost:port/ai/ppocr', json=json_data).json()
ocr_result = response.get('data')
ocr_boxes = [line[0] for line in ocr_result[0]]
ocr_txts = [line[1][0] for line in ocr_result[0]]
ocr_scores = [line[1][1] for line in ocr_result[0]]

json_data = []

for i in range(len(ocr_result[0])):
json_data.append({
"box": [[int(i[0]), int(i[1])] for i in ocr_boxes[i]],
"score": ocr_scores[i],
"text": ocr_txts[i]
})

so = StructureOCR(json_data)
blocks, paragraphs = so.structure_ocr()

vo = VisualizeOCR(origin_image, blocks, paragraphs)
image = vo.visualize_ocr()

"""
显示图像
"""
# 适配显示器
if image.shape[0] > 1080:
mag = int(image.shape[0] / 1080)
image = cv2.resize(image, (int(image.shape[1] / mag), int(image.shape[0] / mag)))
if image.shape[1] > 1920:
mag = image.shape[1] / 1920
image = cv2.resize(image, (int(image.shape[1] / mag), int(image.shape[0] / mag)))

cv2.imshow("image", image)
cv2.waitKey(-1)

效果演示:

Powered by Hexo and Hexo-theme-hiker

Copyright © 2017 - 2024 青域 All Rights Reserved.

UV : | PV :