关键信息提取(KIE)

环境准备

服务器

· OS: CentOS7.9

· GPU: RTX3090 24G2

· CUDA: 11.7

· CUDNN: 8.9.2

### 飞桨

· paddlepaddle: paddlepaddle-gpu==2.4.2(cudatoolkit=11.7,建议conda安装)

1
conda install paddlepaddle-gpu==2.4.2 cudatoolkit=11.7 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/Paddle/ -c conda-forge


· paddleocr: 2.9.1

· paddlenlp: 2.5.2

### 项目依赖包

requirements.txt

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# paddlepaddle 2.4.2 注:必须
# conda install paddlepaddle-gpu==2.4.2 cudatoolkit=11.7 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/Paddle/ -c conda-forge

shapely
scikit-image
pyclipper
lmdb
tqdm
numpy
rapidfuzz
opencv-python
opencv-contrib-python
cython
Pillow
pyyaml
requests
albumentations==1.4.10
# to be compatible with albumentations
albucore==0.0.13

sentencepiece
yacs
seqeval
pypandoc
attrdict3
python_docx
paddlenlp==2.5.2


注: bugfix

1. `ModuleNotFoundError: No module named ‘ppocr.**
前往官方github主页(<https://github.com/PaddlePaddle/PaddleOCR>)补齐源码到python环境 e.g.conda 虚拟环境下,需要补齐的源码根目录为:/root/anaconda3/envs/py39_kie/lib/python3.9/site-packages/paddleocr/ppocr<font color='gold'>**2.**</font>ModuleNotFoundError: No module named ‘paddle.fluid’安装较低版本paddlepaddlepaddle2.4.2/2.5.0<font color='gold'>**3.**</font> 类型错误:
1
2
3
InvalidArgumentError: The type of data we are trying to retrieve does not match the type of data currently contained in the container.
[Hint: Expected dtype() == paddle::experimental::CppTypeToDataType<T>::Type(), but received dtype():10 != paddle::experimental::CppTypeToDataType<T>::Type():9.] (at /paddle/paddle/phi/core/dense_tensor.cc:143)
[operator < less_equal > error]
官方示例bug,需要添加参数项
–use_visual_backbone = False<!--more--> ## 模型 参考:<https://paddlepaddle.github.io/PaddleOCR/latest/ppstructure/model_train/train_kie.html> ### 准备(inference模型) SER任务模型下载链接:<https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar> RE任务模型下载链接:<https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar> 创建项目根目录demo并添加测试图片1.png,在demo目录下新建models文件夹用于存放OCR、SER、RE模型、字典文件、字体文件。其中OCR模型可使用ppocrv3、v4版本下检测和识别(det/rec)模型,字典文件复制源码目录下paddleocr/ppocr/utils/dict/kie_dict/xfund_class_list.txt,字体文件使用simfang.ttf### SER/RE模型串联 <font color='gold'>**1.**</font>创建官方代码示例demo/test.py用于测试和显示图像(源码:paddleocr/ppstructure/kie/predict_kie_token_ser_re.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))

os.environ["FLAGS_allocator_strategy"] = "auto_growth"

import cv2
import json
import numpy as np
import time

import paddleocr.tools.infer.utility as utility
from paddleocr.tools.infer_kie_token_ser_re import make_input
from paddleocr.ppocr.postprocess import build_post_process
from paddleocr.ppocr.utils.logging import get_logger
from paddleocr.ppocr.utils.visual import draw_ser_results, draw_re_results
from paddleocr.ppocr.utils.utility import get_image_file_list, check_and_read
from paddleocr.ppstructure.utility import parse_args
from paddleocr.ppstructure.kie.predict_kie_token_ser import SerPredictor

logger = get_logger()


class SerRePredictor(object):
def __init__(self, args):
self.use_visual_backbone = args.use_visual_backbone
self.ser_engine = SerPredictor(args)
if args.re_model_dir is not None:
postprocess_params = {"name": "VQAReTokenLayoutLMPostProcess"}
self.postprocess_op = build_post_process(postprocess_params)
(
self.predictor,
self.input_tensor,
self.output_tensors,
self.config,
) = utility.create_predictor(args, "re", logger)
else:
self.predictor = None

def __call__(self, img):
starttime = time.time()
ser_results, ser_inputs, ser_elapse = self.ser_engine(img)
if self.predictor is None:
return ser_results, ser_elapse

re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
if self.use_visual_backbone == False:
re_input.pop(4)
for idx in range(len(self.input_tensor)):
self.input_tensor[idx].copy_from_cpu(re_input[idx])

self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = dict(
loss=outputs[1],
pred_relations=outputs[2],
hidden_states=outputs[0],
)

post_result = self.postprocess_op(
preds, ser_results=ser_results, entity_idx_dict_batch=entity_idx_dict_batch
)

elapse = time.time() - starttime
return post_result, elapse


def main(args):
image_file_list = get_image_file_list(args.image_dir)
ser_re_predictor = SerRePredictor(args)
count = 0
total_time = 0

os.makedirs(args.output, exist_ok=True)
with open(
os.path.join(args.output, "infer.txt"), mode="w", encoding="utf-8"
) as f_w:
for image_file in image_file_list:
img, flag, _ = check_and_read(image_file)
if not flag:
img = cv2.imread(image_file)
img = img[:, :, ::-1]
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
re_res, elapse = ser_re_predictor(img)
re_res = re_res[0]

res_str = "{}\t{}\n".format(
image_file,
json.dumps(
{
"ocr_info": re_res,
},
ensure_ascii=False,
),
)
f_w.write(res_str)
if ser_re_predictor.predictor is not None:
img_res = draw_re_results(
image_file, re_res, font_path=args.vis_font_path
)
img_save_path = os.path.join(
args.output,
os.path.splitext(os.path.basename(image_file))[0] + "_ser_re.jpg",
)
else:
img_res = draw_ser_results(
image_file, re_res, font_path=args.vis_font_path
)
img_save_path = os.path.join(
args.output,
os.path.splitext(os.path.basename(image_file))[0] + "_ser.jpg",
)

cv2.imwrite(img_save_path, img_res)
logger.info("save vis result to {}".format(img_save_path))
if count > 0:
total_time += elapse
count += 1
logger.info("Predict time of {}: {}".format(image_file, elapse))


if __name__ == "__main__":
args = parse_args()
args.mode = 'kie'
args.use_visual_backbone = False
args.kie_algorithm = 'LayoutXLM'
args.re_model_dir = './models/re_vi_layoutxlm_xfund_infer'
args.ser_model_dir = './models/ser_vi_layoutxlm_xfund_infer'
args.ser_dict_path = './models/xfund_class_list.txt'
args.vis_font_path = './models/simfang.ttf'
args.ocr_order_method = "tb-yx"
args.det_model_dir = './models/ch_PP-OCRv3_det_infer'
args.rec_model_dir = './models/ch_PP-OCRv4_rec_infer'
args.image_dir = './1.png'
main(args)
![](关键信息提取(KIE)\1.png) <font color='gold'>**2.**</font>创建
demo/main.py用于测试
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import cv2
import time
from paddleocr.ppstructure.kie.predict_kie_token_ser_re import SerRePredictor
from paddleocr.ppstructure.utility import parse_args
from paddleocr.tools.infer_kie_token_ser_re import make_input


# 继承SerRePredictor类,重写__call__方法,实现同时输出Ser模型结果
class KIE(SerRePredictor):
def __init__(self, para):
super().__init__(para)

def __call__(self, im):
start_time = time.time()
ser_results, ser_inputs, ser_elapse = self.ser_engine(im)
if self.predictor is None:
return ser_results, ser_elapse

re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
if not self.use_visual_backbone:
re_input.pop(4)
for idx in range(len(self.input_tensor)):
self.input_tensor[idx].copy_from_cpu(re_input[idx])

self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = dict(
loss=outputs[1],
pred_relations=outputs[2],
hidden_states=outputs[0],
)

post_result = self.postprocess_op(
preds, ser_results=ser_results, entity_idx_dict_batch=entity_idx_dict_batch
)

elapse_time = time.time() - start_time
return ser_results, post_result, elapse_time


if __name__ == "__main__":
args = parse_args()
args.mode = 'kie'
args.show_log = True
args.use_visual_backbone = False
args.kie_algorithm = 'LayoutXLM'
args.re_model_dir = './models/re_vi_layoutxlm_xfund_infer'
args.ser_model_dir = './models/ser_vi_layoutxlm_xfund_infer'
args.ser_dict_path = './models/xfund_class_list.txt'
args.vis_font_path = './models/simfang.ttf'
args.ocr_order_method = "tb-yx"
args.det_model_dir = './models/ch_PP-OCRv3_det_infer'
args.rec_model_dir = './models/ch_PP-OCRv4_rec_infer'
args.det_limit_side_len = 1600,

kie = KIE(args)
img = cv2.imread('./1.png')
ser_res, re_res, elapse = kie(img)
print(ser_res, '\n', re_res, '\n', elapse)
输出:
1
2
3
[[{'transcription': '证号', 'bbox': [239, 113, 267, 128], 'points': [[239.0, 113.0], [267.0, 113.0], [267.0, 128.0], [239.0, 128.0]], 'pred_id': 0, 'pred': 'O'}, {'transcription': 'T4105051990090', 'bbox': [240, 135, 441, 157], 'points': [[240.0, 135.0], [441.0, 136.0], [440.0, 157.0], [240.0, 156.0]], 'pred_id': 3, 'pred': 'ANSWER'}, {'transcription': '姓名', 'bbox': [239, 162, 266, 178], 'points': [[239.0, 162.0], [266.0, 162.0], [266.0, 178.0], [239.0, 178.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '作业类别', 'bbox': [431, 162, 486, 179], 'points': [[431.0, 162.0], [486.0, 162.0], [486.0, 179.0], [431.0, 179.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '毕宏', 'bbox': [237, 182, 296, 203], 'points': [[237.0, 182.0], [296.0, 182.0], [296.0, 203.0], [237.0, 203.0]], 'pred_id': 3, 'pred': 'ANSWER'}, {'transcription': '焊接与热切割作业', 'bbox': [432, 185, 586, 204], 'points': [[432.0, 185.0], [586.0, 185.0], [586.0, 204.0], [432.0, 204.0]], 'pred_id': 3, 'pred': 'ANSWER'}, {'transcription': '男', 'bbox': [237, 246, 259, 269], 'points': [[237.0, 246.0], [259.0, 246.0], [259.0, 269.0], [237.0, 269.0]], 'pred_id': 3, 'pred': 'ANSWER'}, {'transcription': '性别', 'bbox': [238, 227, 269, 246], 'points': [[238.0, 227.0], [269.0, 227.0], [269.0, 246.0], [238.0, 246.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '操作项目', 'bbox': [430, 228, 484, 245], 'points': [[430.0, 228.0], [484.0, 228.0], [484.0, 245.0], [430.0, 245.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '熔化焊接与热切割作业', 'bbox': [432, 251, 622, 268], 'points': [[432.0, 251.0], [622.0, 251.0], [622.0, 268.0], [432.0, 268.0]], 'pred_id': 3, 'pred': 'ANSWER'}, {'transcription': '2023-03-29', 'bbox': [71, 342, 173, 362], 'points': [[72.0, 342.0], [173.0, 345.0], [173.0, 362.0], [71.0, 360.0]], 'pred_id': 3, 'pred': 'ANSWER'}, {'transcription': '初领日期', 'bbox': [73, 325, 128, 341], 'points': [[73.0, 325.0], [128.0, 325.0], [128.0, 341.0], [73.0, 341.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '有效期限', 'bbox': [226, 327, 278, 342], 'points': [[226.0, 327.0], [278.0, 327.0], [278.0, 342.0], [226.0, 342.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '2023-03-29至2029-03-28', 'bbox': [226, 345, 448, 362], 'points': [[226.0, 345.0], [448.0, 345.0], [448.0, 362.0], [226.0, 362.0]], 'pred_id': 3, 'pred': 'ANSWER'}, {'transcription': '2026-03-28前', 'bbox': [71, 389, 192, 413], 'points': [[72.0, 389.0], [192.0, 392.0], [192.0, 413.0], [71.0, 410.0]], 'pred_id': 3, 'pred': 'ANSWER'}, {'transcription': '应复审日期', 'bbox': [74, 373, 140, 390], 'points': [[74.0, 373.0], [140.0, 373.0], [140.0, 390.0], [74.0, 390.0]], 'pred_id': 0, 'pred': 'O'}, {'transcription': '签发机关', 'bbox': [223, 374, 278, 393], 'points': [[223.0, 375.0], [277.0, 374.0], [278.0, 392.0], [224.0, 393.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '河南省应急管理厅', 'bbox': [224, 394, 378, 414], 'points': [[224.0, 394.0], [378.0, 394.0], [378.0, 414.0], [224.0, 414.0]], 'pred_id': 3, 'pred': 'ANSWER'}]] 
[[({'transcription': '姓名', 'bbox': [239, 162, 266, 178], 'points': [[239.0, 162.0], [266.0, 162.0], [266.0, 178.0], [239.0, 178.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '毕宏', 'bbox': [237, 182, 296, 203], 'points': [[237.0, 182.0], [296.0, 182.0], [296.0, 203.0], [237.0, 203.0]], 'pred_id': 3, 'pred': 'ANSWER'}), ({'transcription': '作业类别', 'bbox': [431, 162, 486, 179], 'points': [[431.0, 162.0], [486.0, 162.0], [486.0, 179.0], [431.0, 179.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '焊接与热切割作业', 'bbox': [432, 185, 586, 204], 'points': [[432.0, 185.0], [586.0, 185.0], [586.0, 204.0], [432.0, 204.0]], 'pred_id': 3, 'pred': 'ANSWER'}), ({'transcription': '作业类别', 'bbox': [431, 162, 486, 179], 'points': [[431.0, 162.0], [486.0, 162.0], [486.0, 179.0], [431.0, 179.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '熔化焊接与热切割作业', 'bbox': [432, 251, 622, 268], 'points': [[432.0, 251.0], [622.0, 251.0], [622.0, 268.0], [432.0, 268.0]], 'pred_id': 3, 'pred': 'ANSWER'}), ({'transcription': '性别', 'bbox': [238, 227, 269, 246], 'points': [[238.0, 227.0], [269.0, 227.0], [269.0, 246.0], [238.0, 246.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '男', 'bbox': [237, 246, 259, 269], 'points': [[237.0, 246.0], [259.0, 246.0], [259.0, 269.0], [237.0, 269.0]], 'pred_id': 3, 'pred': 'ANSWER'}), ({'transcription': '初领日期', 'bbox': [73, 325, 128, 341], 'points': [[73.0, 325.0], [128.0, 325.0], [128.0, 341.0], [73.0, 341.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '2023-03-29', 'bbox': [71, 342, 173, 362], 'points': [[72.0, 342.0], [173.0, 345.0], [173.0, 362.0], [71.0, 360.0]], 'pred_id': 3, 'pred': 'ANSWER'}), ({'transcription': '有效期限', 'bbox': [226, 327, 278, 342], 'points': [[226.0, 327.0], [278.0, 327.0], [278.0, 342.0], [226.0, 342.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '2023-03-29至2029-03-28', 'bbox': [226, 345, 448, 362], 'points': [[226.0, 345.0], [448.0, 345.0], [448.0, 362.0], [226.0, 362.0]], 'pred_id': 3, 'pred': 'ANSWER'}), ({'transcription': '签发机关', 'bbox': [223, 374, 278, 393], 'points': [[223.0, 375.0], [277.0, 374.0], [278.0, 392.0], [224.0, 393.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '河南省应急管理厅', 'bbox': [224, 394, 378, 414], 'points': [[224.0, 394.0], [378.0, 394.0], [378.0, 414.0], [224.0, 414.0]], 'pred_id': 3, 'pred': 'ANSWER'})]]
0.959315299987793
## 训练(基于XFUND数据集) 参考:<https://paddlepaddle.github.io/PaddleOCR/latest/ppocr/model_train/kie.html> ### 数据下载 XFUND github主页:<https://github.com/doc-analysis/XFUND> 数据集下载地址:<https://github.com/doc-analysis/XFUND/releases/tag/v1.0> 选择中文数据集
zh.train.jsonzh.train.zipzh.val.jsonzh.val.zip并下载 ![](关键信息提取(KIE)\2.png) ### 数据格式转换 创建XFUND格式转Paddle训练格式代码demo/trans_xfun_data.py(源码:paddleocr/ppstructure/kie/tools/trans_xfun_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import json


def transfer_xfun_data(json_path=None, output_file=None):
with open(json_path, "r", encoding="utf-8") as fin:
lines = fin.readlines()

json_info = json.loads(lines[0])
documents = json_info["documents"]
with open(output_file, "w", encoding="utf-8") as fout:
for idx, document in enumerate(documents):
label_info = []
img_info = document["img"]
document = document["document"]
image_path = img_info["fname"]

for doc in document:
x1, y1, x2, y2 = doc["box"]
points = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
label_info.append(
{
"transcription": doc["text"],
"label": doc["label"],
"points": points,
"id": doc["id"],
"linking": doc["linking"],
}
)

fout.write(
image_path + "\t" + json.dumps(label_info, ensure_ascii=False) + "\n"
)

print("===ok====")

ori_gt_path = r'D:\pycharmproject_2\kie\demo\dataset\zh.val.json'
output_path = r'D:\pycharmproject_2\kie\demo\dataset\xfun_val.json'

transfer_xfun_data(ori_gt_path, output_path)
原始格式:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
{
"lang": "zh",
"version": "0.1",
"split": "train",
"documents": [
{
"id": "zh_train_0",
"uid": "640a0301a1cb24331748b579405502b44d6791883b25ea0eafc8a68126ccdadd",
"document": [
{
"box": [
104,
114,
530,
175
],
"text": "汇丰晋信",
"label": "other",
"words": [
{
"box": [
110,
117,
152,
175
],
"text": "汇"
},
{
"box": [
189,
117,
229,
177
],
"text": "丰"
},
{
"box": [
385,
117,
426,
177
],
"text": "晋"
},
{
"box": [
466,
116,
508,
177
],
"text": "信"
}
],
"linking": [],
"id": 1
}
...
]
}
...
]
}
转换:
1
2
zh_train_0.jpg	[{"transcription": "汇丰晋信", "label": "other", "points": [[104, 114], [530, 114], [530, 175], [104, 175]], "id": 1, "linking": []}, {"transcription": "受理时间:", "label": "question", "points": [[126, 267], [266, 267], [266, 305], [126, 305]], "id": 7, "linking": [[7, 13]]}, {"transcription": "2020.6.15", "label": "answer", "points": [[321, 239], [537, 239], [537, 285], [321, 285]], "id": 13, "linking": [[7, 13]]}...]
zh_train_1.jpg [...]
### 配置字典文件 用于存储各段文本的标签类别,这里使用
models/xfund_class_list.txt字典文件
1
2
3
4
OTHER
QUESTION
ANSWER
HEADER
<font color='red'>**注:**</font> 字典中的标签和json文件中的标签不区分大小写 ### 配置训练文件 #### SER模型训练文件 创建并修改SER模型训练yaml文件
demo/ser_vi_layoutxlm_xfund_zh.yml,源文件路径:configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_vi_layoutxlm_xfund_zh
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: dataset/xfund_train/zh_train/zh_val_42.jpg
d2s_train_image_shape: [3, 224, 224]
# if you want to predict using the groundtruth ocr info,
# you can use the following config
# infer_img: train_data/XFUND/zh_val/val.json
# infer_mode: False

save_res_path: ./output/ser/xfund_zh/res
kie_rec_model_dir:
kie_det_model_dir:
amp_custom_white_list: ['scale', 'concat', 'elementwise_add']

Architecture:
model_type: kie
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForSer
pretrained: True
checkpoints:
# one of base or vi
mode: vi
num_classes: &num_classes 7 # <--------------------------------------当标签数为n时,存在OTHER标签,取2n-1;不存在OTHER标签,取2n+1

Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
key: "backbone_out"

Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000

PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path models/xfund_class_list.txt # <--------------------------------------字典文件路径

Metric:
name: VQASerTokenMetric
main_indicator: hmean

Train:
dataset:
name: SimpleDataSet
data_dir: dataset/xfund_train/zh_train # <--------------------------------------训练集图像路径
label_file_list:
- dataset/xfund_train/xfun_train.json # <--------------------------------------训练集json文件路径
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: &use_textline_bbox_info True
# one of [None, "tb-yx"]
order_method: &order_method "tb-yx"
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4

Eval:
dataset:
name: SimpleDataSet
data_dir: dataset/xfund_train/zh_val # <--------------------------------------验证集图片路径
label_file_list:
- dataset/xfund_train/xfun_val.json # <--------------------------------------验证集json文件路径
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: *use_textline_bbox_info
order_method: *order_method
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
#### RE模型训练文件 创建并修改RE模型训练yaml文件
demo/re_vi_layoutxlm_xfund_zh.yml,源文件路径:configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
Global:
use_gpu: True
epoch_num: &epoch_num 130
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/re_vi_layoutxlm_xfund_zh
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: dataset/xfund_train/zh_train/zh_val_21.jpg
save_res_path: ./output/re/xfund_zh/with_gt
kie_rec_model_dir:
kie_det_model_dir:

Architecture:
model_type: kie
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForRe
pretrained: dataset/re_vi_layoutxlm_xfund_pretrained
mode: vi
checkpoints:

Loss:
name: LossFromOutput
key: loss
reduction: mean

Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
clip_norm: 10
lr:
learning_rate: 0.00005
warmup_epoch: 10
regularizer:
name: L2
factor: 0.00000

PostProcess:
name: VQAReTokenLayoutLMPostProcess

Metric:
name: VQAReTokenMetric
main_indicator: hmean

Train:
dataset:
name: SimpleDataSet
data_dir: dataset/xfund_train/zh_train # <--------------------------------------训练集图片路径
label_file_list:
- dataset/xfund_train/xfun_train.json # <--------------------------------------训练集json文件路径
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: &class_path models/xfund_class_list.txt # <--------------------------------------字典文件路径
use_textline_bbox_info: &use_textline_bbox_info True
order_method: &order_method "tb-yx"
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
batch_size_per_card: 2
num_workers: 4

Eval:
dataset:
name: SimpleDataSet
data_dir: dataset/xfund_train/zh_val # <--------------------------------------验证集图片路径
label_file_list:
- dataset/xfund_train/xfun_val.json # <--------------------------------------验证集json文件路径
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: *use_textline_bbox_info
order_method: *order_method
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 8
### 开始训练 创建训练代码
demo/train.py,训练源码位置:paddleocr/tools/train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))

import yaml
import paddle
import paddle.distributed as dist

from paddleocr.ppocr.data import build_dataloader, set_signal_handlers
from paddleocr.ppocr.modeling.architectures import build_model
from paddleocr.ppocr.losses import build_loss
from paddleocr.ppocr.optimizer import build_optimizer
from paddleocr.ppocr.postprocess import build_post_process
from paddleocr.ppocr.metrics import build_metric
from paddleocr.ppocr.utils.save_load import load_model
from paddleocr.ppocr.utils.utility import set_seed
from paddleocr.ppocr.modeling.architectures import apply_to_static
import paddleocr.tools.program as program
import paddleocr.tools.naive_sync_bn as naive_sync_bn

dist.get_world_size()


def main(config, device, logger, vdl_writer, seed):
# init dist environment
if config["Global"]["distributed"]:
dist.init_parallel_env()

global_config = config["Global"]

# build dataloader
set_signal_handlers()
train_dataloader = build_dataloader(config, "Train", device, logger, seed)
if len(train_dataloader) == 0:
logger.error(
"No Images in train dataset, please ensure\n"
+ "\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n"
+ "\t2. The annotation file and path in the configuration file are provided normally."
)
return

if config["Eval"]:
valid_dataloader = build_dataloader(config, "Eval", device, logger, seed)
else:
valid_dataloader = None
step_pre_epoch = len(train_dataloader)

# build post process
post_process_class = build_post_process(config["PostProcess"], global_config)

# build model
# for rec algorithm
if hasattr(post_process_class, "character"):
char_num = len(getattr(post_process_class, "character"))
if config["Architecture"]["algorithm"] in [
"Distillation",
]: # distillation model
for key in config["Architecture"]["Models"]:
if (
config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead"
): # for multi head
if config["PostProcess"]["name"] == "DistillationSARLabelDecode":
char_num = char_num - 2
if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode":
char_num = char_num - 3
out_channels_list = {}
out_channels_list["CTCLabelDecode"] = char_num
# update SARLoss params
if (
list(config["Loss"]["loss_config_list"][-1].keys())[0]
== "DistillationSARLoss"
):
config["Loss"]["loss_config_list"][-1]["DistillationSARLoss"][
"ignore_index"
] = (char_num + 1)
out_channels_list["SARLabelDecode"] = char_num + 2
elif any(
"DistillationNRTRLoss" in d
for d in config["Loss"]["loss_config_list"]
):
out_channels_list["NRTRLabelDecode"] = char_num + 3

config["Architecture"]["Models"][key]["Head"][
"out_channels_list"
] = out_channels_list
else:
config["Architecture"]["Models"][key]["Head"][
"out_channels"
] = char_num
elif config["Architecture"]["Head"]["name"] == "MultiHead": # for multi head
if config["PostProcess"]["name"] == "SARLabelDecode":
char_num = char_num - 2
if config["PostProcess"]["name"] == "NRTRLabelDecode":
char_num = char_num - 3
out_channels_list = {}
out_channels_list["CTCLabelDecode"] = char_num
# update SARLoss params
if list(config["Loss"]["loss_config_list"][1].keys())[0] == "SARLoss":
if config["Loss"]["loss_config_list"][1]["SARLoss"] is None:
config["Loss"]["loss_config_list"][1]["SARLoss"] = {
"ignore_index": char_num + 1
}
else:
config["Loss"]["loss_config_list"][1]["SARLoss"]["ignore_index"] = (
char_num + 1
)
out_channels_list["SARLabelDecode"] = char_num + 2
elif list(config["Loss"]["loss_config_list"][1].keys())[0] == "NRTRLoss":
out_channels_list["NRTRLabelDecode"] = char_num + 3
config["Architecture"]["Head"]["out_channels_list"] = out_channels_list
else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num

if config["PostProcess"]["name"] == "SARLabelDecode": # for SAR model
config["Loss"]["ignore_index"] = char_num - 1

model = build_model(config["Architecture"])

use_sync_bn = config["Global"].get("use_sync_bn", False)
if use_sync_bn:
if config["Global"].get("use_npu", False):
naive_sync_bn.convert_syncbn(model)
else:
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
logger.info("convert_sync_batchnorm")

model = apply_to_static(model, config, logger)

# build loss
loss_class = build_loss(config["Loss"])

# build optim
optimizer, lr_scheduler = build_optimizer(
config["Optimizer"],
epochs=config["Global"]["epoch_num"],
step_each_epoch=len(train_dataloader),
model=model,
)

# build metric
eval_class = build_metric(config["Metric"])

logger.info("train dataloader has {} iters".format(len(train_dataloader)))
if valid_dataloader is not None:
logger.info("valid dataloader has {} iters".format(len(valid_dataloader)))

use_amp = config["Global"].get("use_amp", False)
amp_level = config["Global"].get("amp_level", "O2")
amp_dtype = config["Global"].get("amp_dtype", "float16")
amp_custom_black_list = config["Global"].get("amp_custom_black_list", [])
amp_custom_white_list = config["Global"].get("amp_custom_white_list", [])
if os.path.exists(
os.path.join(config["Global"]["save_model_dir"], "train_result.json")
):
try:
os.remove(
os.path.join(config["Global"]["save_model_dir"], "train_result.json")
)
except:
pass
if use_amp:
AMP_RELATED_FLAGS_SETTING = {
"FLAGS_max_inplace_grad_add": 8,
}
if paddle.is_compiled_with_cuda():
AMP_RELATED_FLAGS_SETTING.update(
{
"FLAGS_cudnn_batchnorm_spatial_persistent": 1,
"FLAGS_gemm_use_half_precision_compute_type": 0,
}
)
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
scale_loss = config["Global"].get("scale_loss", 1.0)
use_dynamic_loss_scaling = config["Global"].get(
"use_dynamic_loss_scaling", False
)
scaler = paddle.amp.GradScaler(
init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling,
)
if amp_level == "O2":
model, optimizer = paddle.amp.decorate(
models=model,
optimizers=optimizer,
level=amp_level,
master_weight=True,
dtype=amp_dtype,
)
else:
scaler = None

# load pretrain model
pre_best_model_dict = load_model(
config, model, optimizer, config["Architecture"]["model_type"]
)

if config["Global"]["distributed"]:
model = paddle.DataParallel(model)
# start train
program.train(
config,
train_dataloader,
valid_dataloader,
device,
model,
loss_class,
optimizer,
lr_scheduler,
post_process_class,
eval_class,
pre_best_model_dict,
logger,
step_pre_epoch,
vdl_writer,
scaler,
amp_level,
amp_custom_black_list,
amp_custom_white_list,
amp_dtype,
)


def test_reader(config, device, logger):
loader = build_dataloader(config, "Train", device, logger)
import time

starttime = time.time()
count = 0
try:
for data in loader():
count += 1
if count % 1 == 0:
batch_time = time.time() - starttime
starttime = time.time()
logger.info(
"reader: {}, {}, {}".format(count, len(data[0]), batch_time)
)
except Exception as e:
logger.info(e)
logger.info("finish reader: {}, Success!".format(count))


if __name__ == "__main__":
config, device, logger, vdl_writer = program.preprocess(is_train=True)
seed = config["Global"]["seed"] if "seed" in config["Global"] else 1024
set_seed(seed)
main(config, device, logger, vdl_writer, seed)
# test_reader(config, device, logger)
根据训练目标的配置文件,分发显卡和训练任务: <font color='gold'>**SER**</font>:
CUDA_VISIBLE_DEVICES=0 python train.py -c ser_vi_layoutxlm_xfund_zh.yml<font color='gold'>**RE**</font>:CUDA_VISIBLE_DEVICES=1 python train.py -c re_vi_layoutxlm_xfund_zh.yml首次启动将自动下载预训练模型。日志输出样例如下:
1
2
3
4
5
6
7
[2024/11/26 17:35:39] ppocr INFO: epoch: [1/200], global_step: 10, lr: 0.000006, loss: 1.838789, avg_reader_cost: 0.39165 s, avg_batch_cost: 0.54732 s, avg_samples: 8.0, ips: 14.61671 samples/s, eta: 0:34:34, max_mem_reserved: 11694 MB, max_mem_allocated: 10523 MB
[2024/11/26 17:35:43] ppocr INFO: epoch: [1/200], global_step: 19, lr: 0.000018, loss: 1.446505, avg_reader_cost: 0.22235 s, avg_batch_cost: 0.32516 s, avg_samples: 6.9, ips: 21.22014 samples/s, eta: 0:28:56, max_mem_reserved: 11694 MB, max_mem_allocated: 10523 MB
eval model:: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:04<00:00, 1.61it/s]
[2024/11/26 17:35:47] ppocr INFO: cur metric, precision: 0.05993363749481543, recall: 0.10410662824207492, hmean: 0.07607265069755198, fps: 101.1334614167882
[2024/11/26 17:36:22] ppocr INFO: save best model is to ./output/ser_vi_layoutxlm_xfund_zh/best_accuracy
[2024/11/26 17:36:22] ppocr INFO: best metric, hmean: 0.07607265069755198, precision: 0.05993363749481543, recall: 0.10410662824207492, fps: 101.1334614167882, best_epoch: 1
[2024/11/26 17:36:39] ppocr INFO: save model in ./output/ser_vi_layoutxlm_xfund_zh/latest
### 模型导出 训练结束后,进入最佳模型目录
demo/output/re_vi_layoutxlm_xfund_zh/best_accuracy,此时模型不具备.pdmodel文件,不能直接用于预测,需要导出为inference模型(本文适当略过模型评估步骤) ![](关键信息提取(KIE)\3.png) 创建模型导出代码demo/export_model.py,导出源码位置:paddleocr/tools/export_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))

import argparse

from paddleocr.tools.program import load_config, merge_config, ArgsParser
from paddleocr.ppocr.utils.export_model import export


def main():
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
config = merge_config(config, FLAGS.opt)
# export model
export(config)


if __name__ == "__main__":
main()
导出命令: <font color='gold'>**SER**</font>:
1
2
python export_model.py -c ser_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./output/ser_vi_layoutxlm_xfund_z
h/best_accuracy Global.save_inference_dir=./inference/ser_vi_layoutxlm Global.save_inference_dir=./inference/ser_vi_layoutxlm
<font color='gold'>**RE**</font>:
1
2
python export_model.py -c re_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./output/re_vi_layoutxlm_xfund_z
h/best_accuracy Global.save_inference_dir=./inference/re_vi_layoutxlm Global.save_inference_dir=./inference/re_vi_layoutxlm
将在
demo/inference目录下生成导出的推理模型,可修改args参数中模型路径进行验证。 ## 客制化 ### 标注工具:PPOCRLabel 操作系统回到windows,并创建demo/dataset/temp文件夹存放训练图像 安装:pip install PPOCRLabel==2.1.3 -i https://pypi.tuna.tsinghua.edu.cn/simple`

启动:PPOCRLabel --lang ch --kie True

此时标注工具左下角显示关键词列表项



注: 最新版可能存在bug,建议使用2.1.3版本。且截至目前(24.11月 v2.1.12版本),工具仍不具备RE任务关系标注功能,需制作后处理脚本 *难点

后处理:

SER

创建SER任务处理代码demo/trans+ppocrlabel.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import json

with open('./dataset/temp/Label.txt', 'r', encoding='utf-8') as l:
content = l.readlines()
for line in content:
image = line.split('\t')[0]
label = json.loads(line.split('\t')[1])
label_set = []
for i, lab in enumerate(label):
new_lab = {
'transcription': lab.get('transcription'),
'points': lab.get('points'),
'label': lab.get('key_cls'),
'id': i,
'linking': []
}
label_set.append(new_lab)
label_set_json = json.dumps(label_set, ensure_ascii=False)
with open('./dataset/project/train.json', 'a+', encoding='utf-8') as f:
f.write(image)
f.write('\t')
f.write(label_set_json)
f.write('\n')
RE

测试中,未完待续

间隙树排序算法的可视化

做文档翻译的OCR程序时,会遇到这样一个场景,因为通常OCR模型的输出都是按文本块逐行返回,当结果进入翻译模型时会丢失行与行之间的信息。为了解决这个问题,需要对OCR结果进行进一步的版面分析,将文本块合并成段落,再输入到翻译模型中去。

算法

间隙·树·排序算法

参考链接:https://github.com/hiroi-sora/GapTree_Sort_Algorithm

算法主要对文本块按间隙进行划分,再经过树形排序,将底层OCR的输出结果从子文本块转化成段落文本块。

算法重构

为了使文本块的定位更加准确,使用原算法输出段落块(content_blocks)的上下左右四个边界点构成新段落块(new_blocks)

源码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def structure_ocr(self):
for line, tb in enumerate(self.blocks_data):
tb["bbox"] = self.bboxes[line]

gtree = GapTree(lambda tb: tb["bbox"])
sorted_text_blocks = gtree.sort(self.blocks_data) # 文本块排序
# print(sorted_text_blocks)
pp = ParagraphParse(self.get_info, self.set_end)
# 获取所有区块的文本块
nodes_text_blocks = gtree.get_nodes_text_blocks()
content_blocks = []
for tbs in nodes_text_blocks:
content = ""
tbs = pp.run(tbs) # 预测结尾分隔符
for tb in tbs: # 输出文本和结尾分隔符
content += tb["text"] + tb["end"]
content_blocks.append(content)

node_tbs = []
for node in gtree.current_nodes:
if not node["units"]:
continue # 跳过没有块的根节点
x0 = node["x_left"]
x1 = node["x_right"]
y0 = gtree.current_rows[node["r_top"]][0][0][1]
y1 = gtree.current_rows[node["r_bottom"]][0][0][3]
node_tbs.append([[x0, y0], [x1, y0], [x1, y1], [x0, y1]])

return node_tbs, content_blocks

Read More

Surya OCR

环境准备

版本:

python3.9 + surya-ocr 0.4.15

模型准备:

检测模型:surya_det3

识别模型:surya_rec

版面模型:surya_layout3

源码修改

因首次使用下载模型被墙,提前将模型收录至模型文件夹并修改源码导入部分:

(源码位置:...Python39/Lib/site-packages/surya/settings.py)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from typing import Dict, Optional

from dotenv import find_dotenv
from pydantic import computed_field
from pydantic_settings import BaseSettings
import torch
import os


class Settings(BaseSettings):
# General
TORCH_DEVICE: Optional[str] = None
IMAGE_DPI: int = 96
IN_STREAMLIT: bool = False # Whether we're running in streamlit

# Paths
DATA_DIR: str = "data"
RESULT_DIR: str = "results"
BASE_DIR: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
FONT_DIR: str = os.path.join(BASE_DIR, "static", "fonts")

@computed_field
def TORCH_DEVICE_MODEL(self) -> str:
if self.TORCH_DEVICE is not None:
return self.TORCH_DEVICE

if torch.cuda.is_available():
return "cuda"

if torch.backends.mps.is_available():
return "mps"

return "cpu"

# Text detection
DETECTOR_BATCH_SIZE: Optional[int] = None # Defaults to 2 for CPU/MPS, 32 otherwise
DETECTOR_MODEL_CHECKPOINT: str = r"D:\pycharmproject_2\translate_plat\surya_ocr\models\surya_det3"
DETECTOR_BENCH_DATASET_NAME: str = "vikp/doclaynet_bench"
DETECTOR_IMAGE_CHUNK_HEIGHT: int = 1400 # Height at which to slice images vertically
DETECTOR_TEXT_THRESHOLD: float = 0.6 # Threshold for text detection (above this is considered text)
DETECTOR_BLANK_THRESHOLD: float = 0.35 # Threshold for blank space (below this is considered blank)
DETECTOR_POSTPROCESSING_CPU_WORKERS: int = min(8, os.cpu_count()) # Number of workers for postprocessing
DETECTOR_MIN_PARALLEL_THRESH: int = 3 # Minimum number of images before we parallelize

# Text recognition
RECOGNITION_MODEL_CHECKPOINT: str = r"D:\pycharmproject_2\translate_plat\surya_ocr\models\surya_rec"
RECOGNITION_MAX_TOKENS: int = 175
RECOGNITION_BATCH_SIZE: Optional[int] = None # Defaults to 8 for CPU/MPS, 256 otherwise
RECOGNITION_IMAGE_SIZE: Dict = {"height": 196, "width": 896}
RECOGNITION_RENDER_FONTS: Dict[str, str] = {
"all": os.path.join(FONT_DIR, "GoNotoCurrent-Regular.ttf"),
"zh": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"),
"ja": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"),
"ko": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"),
}
RECOGNITION_FONT_DL_BASE: str = "https://github.com/satbyy/go-noto-universal/releases/download/v7.0"
RECOGNITION_BENCH_DATASET_NAME: str = "vikp/rec_bench"
RECOGNITION_PAD_VALUE: int = 255 # Should be 0 or 255
RECOGNITION_STATIC_CACHE: bool = False # Static cache for torch compile
RECOGNITION_MAX_LANGS: int = 4

# Layout
LAYOUT_MODEL_CHECKPOINT: str = r"D:\pycharmproject_2\translate_plat\surya_ocr\models\surya_layout3"
LAYOUT_BENCH_DATASET_NAME: str = "vikp/publaynet_bench"

# Ordering
ORDER_MODEL_CHECKPOINT: str = "vikp/surya_order"
ORDER_IMAGE_SIZE: Dict = {"height": 1024, "width": 1024}
ORDER_MAX_BOXES: int = 256
ORDER_BATCH_SIZE: Optional[int] = None # Defaults to 4 for CPU/MPS, 32 otherwise
ORDER_BENCH_DATASET_NAME: str = "vikp/order_bench"

# Tesseract (for benchmarks only)
TESSDATA_PREFIX: Optional[str] = None

@computed_field
@property
def MODEL_DTYPE(self) -> torch.dtype:
return torch.float32 if self.TORCH_DEVICE_MODEL == "cpu" else torch.float16

class Config:
env_file = find_dotenv("local.env")
extra = "ignore"


settings = Settings()

Read More

doccano

Doccano是一种用于文本标注的开源工具,旨在简化和加速标注任务的进行。它提供了一个直观的用户界面,使标注人员能够轻松地对文本数据进行标注,并创建高质量的训练数据集用于机器学习和自然语言处理任务。

链接:https://github.com/doccano/doccano

一、安装部署

环境

操作系统:Centos7.9

python:3.10

doccano:1.6.2

pip安装

:百度源没有相应安装包

pip install doccano==1.6.2 -i https://pypi.tuna.tsinghua.edu.cn/simple

初始化

doccano init

设置超级管理员账号密码

doccano createuser --username admin --password 123456

启动服务

doccano webserver --port 8000

Read More

Bert+GRU地址归一算法

一、 算法简介

本地址归一算法(已经下简称算法)旨在对输入文本中出现的地址信息或一般地址信息做地址结构化抽取,并输出该地址映射到数据库中的标准化地址。

二、算法模块

算法由以下不同模块共同组成,各模块在算法的各个生命周期起到重要作用:

1、建立原始地址库

使用postgres数据库(以下简pg)创建存储全国各级原始地址的原始地址库。通过网络爬虫不断采集和更新地址数据(主要来源为高德地图),并存储到原始地址库,地址库包括地址的名称信息、poi信息、类别信息、经纬度信息等原始内容,为后续工作的开展做数据支撑。

2、文本地址抽取

使用UIE(Universal Information Extraction)框架,结合ERNIE3.0模型,使模型具备从无结构或半结构的文本中抽取地址信息的能力。

3、地址分级算法

i)分级标准

首先需要确定一套地址的分级标准细节,本算法采用的分级标准基于阿里《文地址要素解析标注规范》,并做一定程度范围的修改,将地址分为18个不同级别:

① Prov:省级行政区划,省、自治区、直辖市

② City:地级行政区划,地级市、地区、自治州等

③ District:县级行政区划,市辖区、县级市、县等

④ Devzone:广义的上的开发区,包含一般性产业 园区、度假区

⑤ Town:乡级行政区划,镇、街道、乡等

⑥ Community:包含社区、行政村(生产大队、村委会),自然村

⑦ Village Group:限定 xx 组、xx 队、xx 社

⑧ Road:有正式名称的道路,包括隧道、高架、街、弄、巷等。 步行街、商业街

⑨ Roadno:路牌号

⑩ Poi:目标兴趣点

⑪ Subpoi:目标兴趣点的子兴趣点

⑫ Houseno:楼栋号,农村地址的门牌号(包括类似南楼、北楼一类的描述)

⑬ Cellno:单元号,包括甲乙丙丁等

⑭ Floorno:楼层号

⑮ Roomno:房间号

⑯ Assist:定位词,包括方位、解释性名词

⑰ Intersection:路桥交叉口、交汇处、十字路口等

⑱ Distance:距离

ii)算法细节

从爬虫获取的原始地址库构建地址结构化数据集,并对数据集划分为训练集和验证集进行数据标注。标注方式采用B-I-E-O-S五位序列标注法,该标注法将尽可能的保留被标注地址的分级信息。

构建地址分级模型。使用基于多头注意力机制Transformers架构的BERT大模型作为预训练模型,并将模型结合CRF、GRU算法。CRF:全称为条件随机场(Conditional Random Fields),结合了最大熵模型和隐马尔可夫模型的特点,是一种无向图模型。它在序列标注任务如分词、词性标注和命名实体识别等方面取得了很好的效果;GRU:全称为门控循环单元(Gated Recurrent Unit),是一种常用于序列数据建模的神经网络模型,能够很好解决BERT循环神经网络中的长期依赖问题,捕获序列中的长期特征,避免训练过程中的梯度消失和梯度爆炸,使完成的模型结构具有更优秀的泛化能力、更好地拟合真实地址数据。

Read More

yolov8-pose:关键点姿态检测

环境&安装

同上文yolov8:火灾检测

模型使用yolov8n-pose

数据标注

标注工具:labelme

对图像中的目标(人物)及其关键点进行标记,包括1个目标类别和17个关键点类别

数据格式转换

将labelme数据格式转为yolo格式,通用转换代码:

1
2
3
# TODO:
# 参考yolov8-火灾检测,未完待续
...

创建训练yaml文件

参考yolov8n-pose.yaml

1
2
3
4
5
6
7
8
9
10
11
12
train: /exp/work/video/yolov8/datasets/human-pose/images/train #训练集文件夹
val: /exp/work/video/yolov8/datasets/human-pose/images/val # 验证集文件夹
test: /exp/work/video/yolov8/datasets/human-pose/images/val # 测试集文件夹
nc: 1 # 分类数

# 关键点,每个关键点有 X Y 是否可见 三个参数
# 可见性:2-可见不遮挡 1-遮挡 0-没有点
kpt_shape: [17, 3]

# 框的类别(对于关键点检测,只有一类)
names:
0: people

Read More

yolov8-火灾检测

环境

GPU

  • NVIDIA 3090*2
  • 显卡驱动 535.104.05
  • CUDA版本 12.2
  • CUDAtoolkit (cuda_12.2.2_535.104.05_linux)
  • cuDNN (v8.9.7)

yolo版本

  • v8.1.5 (ultralytics yolov8)

pytorch版本

  • v2.1.2

python环境

  • CentOS7.9
  • anaconda3
  • python3.9

安装

源码主页:https://github.com/ultralytics/ultralytics

官方文档:https://docs.ultralytics.com/zh

克隆源码

1
git clone https://hub.nuaa.cf/ultralytics/ultralytics.git

安装依赖

1
pip install pip install ultralytics -i https://mirror.baidu.com/pypi/simple

环境验证

python

1
2
import ultralytics
ultralytics.checks()

cli

1
yolo predict model=yolov8n.pt source=ultralytics/assets/zidane.jpg

执行完毕后得到输出的结果如下:

1
2
3
4
5
6
7
8
(py39_yolov8) [root@jdz yolov8]# yolo predict model=yolov8n.pt source=ultralytics/assets/zidane.jpg 
Ultralytics YOLOv8.1.5 🚀 Python-3.9.18 torch-2.1.2+cu121 CUDA:0 (NVIDIA GeForce RTX 3090, 24260MiB)
YOLOv8n summary (fused): 168 layers, 3151904 parameters, 0 gradients, 8.7 GFLOPs

image 1/1 /exp/work/video/yolov8/ultralytics/assets/zidane.jpg: 384x640 2 persons, 1 tie, 216.9ms
Speed: 7.3ms preprocess, 216.9ms inference, 762.4ms postprocess per image at shape (1, 3, 384, 640)
Results saved to runs/detect/predict
💡 Learn more at https://docs.ultralytics.com/modes/predict

将在Results saved to runs/detect/predict目录下找到输出结果

Read More

NetworkX: 图论算法应用

NetworkX

NetworkX是一款Python的软件包,用于创造、操作复杂网络,以及学习复杂网络的结构、动力学及其功能。有了NetworkX就可以用标准或者不标准的数据格式加载或者存储网络,它可以产生许多种类的随机网络或经典网络,也可以分析网络结构、建立网络模型、设计新的网络算法、绘制网络等

参考文献地址: https://www.osgeo.cn/networkx/reference/index.html

图计算应用方式比较

1.nebula + spark

依赖nebula-spark-connector包、nebula-algorithm包和spark集群的数据读取、图计算方式

2.clickhouse + NetworkX

由于nebula-algorithm依赖spark集群,且nebula-console原生的数据读取能力不佳,在环境受限且计算量有限的情况下优先考虑跳过spark集群和nebula图库,采用clickhouse + NetworkX的图计算方式,其中clickhouse是存储了nebula源数据的列式分布式表,作用类似于方法1中将nebula集群数据通过nebula-spark-connector包导入为spark-DataFrame,仅用做数据读取,再通过将数据转化为NetworkX的图结构进行图计算

Read More


Powered by Hexo and Hexo-theme-hiker

Copyright © 2017 - 2024 青域 All Rights Reserved.

UV : | PV :