关键信息提取(KIE)

环境准备

服务器

· OS: CentOS7.9

· GPU: RTX3090 24G2

· CUDA: 11.7

· CUDNN: 8.9.2

### 飞桨

· paddlepaddle: paddlepaddle-gpu==2.4.2(cudatoolkit=11.7,建议conda安装)

1
conda install paddlepaddle-gpu==2.4.2 cudatoolkit=11.7 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/Paddle/ -c conda-forge


· paddleocr: 2.9.1

· paddlenlp: 2.5.2

### 项目依赖包

requirements.txt

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# paddlepaddle 2.4.2 注:必须
# conda install paddlepaddle-gpu==2.4.2 cudatoolkit=11.7 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/Paddle/ -c conda-forge

shapely
scikit-image
pyclipper
lmdb
tqdm
numpy
rapidfuzz
opencv-python
opencv-contrib-python
cython
Pillow
pyyaml
requests
albumentations==1.4.10
# to be compatible with albumentations
albucore==0.0.13

sentencepiece
yacs
seqeval
pypandoc
attrdict3
python_docx
paddlenlp==2.5.2


注: bugfix

1. `ModuleNotFoundError: No module named ‘ppocr.**
前往官方github主页(<https://github.com/PaddlePaddle/PaddleOCR>)补齐源码到python环境 e.g.conda 虚拟环境下,需要补齐的源码根目录为:/root/anaconda3/envs/py39_kie/lib/python3.9/site-packages/paddleocr/ppocr<font color='gold'>**2.**</font>ModuleNotFoundError: No module named ‘paddle.fluid’安装较低版本paddlepaddlepaddle2.4.2/2.5.0<font color='gold'>**3.**</font> 类型错误:
1
2
3
InvalidArgumentError: The type of data we are trying to retrieve does not match the type of data currently contained in the container.
[Hint: Expected dtype() == paddle::experimental::CppTypeToDataType<T>::Type(), but received dtype():10 != paddle::experimental::CppTypeToDataType<T>::Type():9.] (at /paddle/paddle/phi/core/dense_tensor.cc:143)
[operator < less_equal > error]
官方示例bug,需要添加参数项
–use_visual_backbone = False<!--more--> ## 模型 参考:<https://paddlepaddle.github.io/PaddleOCR/latest/ppstructure/model_train/train_kie.html> ### 准备(inference模型) SER任务模型下载链接:<https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar> RE任务模型下载链接:<https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar> 创建项目根目录demo并添加测试图片1.png,在demo目录下新建models文件夹用于存放OCR、SER、RE模型、字典文件、字体文件。其中OCR模型可使用ppocrv3、v4版本下检测和识别(det/rec)模型,字典文件复制源码目录下paddleocr/ppocr/utils/dict/kie_dict/xfund_class_list.txt,字体文件使用simfang.ttf### SER/RE模型串联 <font color='gold'>**1.**</font>创建官方代码示例demo/test.py用于测试和显示图像(源码:paddleocr/ppstructure/kie/predict_kie_token_ser_re.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))

os.environ["FLAGS_allocator_strategy"] = "auto_growth"

import cv2
import json
import numpy as np
import time

import paddleocr.tools.infer.utility as utility
from paddleocr.tools.infer_kie_token_ser_re import make_input
from paddleocr.ppocr.postprocess import build_post_process
from paddleocr.ppocr.utils.logging import get_logger
from paddleocr.ppocr.utils.visual import draw_ser_results, draw_re_results
from paddleocr.ppocr.utils.utility import get_image_file_list, check_and_read
from paddleocr.ppstructure.utility import parse_args
from paddleocr.ppstructure.kie.predict_kie_token_ser import SerPredictor

logger = get_logger()


class SerRePredictor(object):
def __init__(self, args):
self.use_visual_backbone = args.use_visual_backbone
self.ser_engine = SerPredictor(args)
if args.re_model_dir is not None:
postprocess_params = {"name": "VQAReTokenLayoutLMPostProcess"}
self.postprocess_op = build_post_process(postprocess_params)
(
self.predictor,
self.input_tensor,
self.output_tensors,
self.config,
) = utility.create_predictor(args, "re", logger)
else:
self.predictor = None

def __call__(self, img):
starttime = time.time()
ser_results, ser_inputs, ser_elapse = self.ser_engine(img)
if self.predictor is None:
return ser_results, ser_elapse

re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
if self.use_visual_backbone == False:
re_input.pop(4)
for idx in range(len(self.input_tensor)):
self.input_tensor[idx].copy_from_cpu(re_input[idx])

self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = dict(
loss=outputs[1],
pred_relations=outputs[2],
hidden_states=outputs[0],
)

post_result = self.postprocess_op(
preds, ser_results=ser_results, entity_idx_dict_batch=entity_idx_dict_batch
)

elapse = time.time() - starttime
return post_result, elapse


def main(args):
image_file_list = get_image_file_list(args.image_dir)
ser_re_predictor = SerRePredictor(args)
count = 0
total_time = 0

os.makedirs(args.output, exist_ok=True)
with open(
os.path.join(args.output, "infer.txt"), mode="w", encoding="utf-8"
) as f_w:
for image_file in image_file_list:
img, flag, _ = check_and_read(image_file)
if not flag:
img = cv2.imread(image_file)
img = img[:, :, ::-1]
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
re_res, elapse = ser_re_predictor(img)
re_res = re_res[0]

res_str = "{}\t{}\n".format(
image_file,
json.dumps(
{
"ocr_info": re_res,
},
ensure_ascii=False,
),
)
f_w.write(res_str)
if ser_re_predictor.predictor is not None:
img_res = draw_re_results(
image_file, re_res, font_path=args.vis_font_path
)
img_save_path = os.path.join(
args.output,
os.path.splitext(os.path.basename(image_file))[0] + "_ser_re.jpg",
)
else:
img_res = draw_ser_results(
image_file, re_res, font_path=args.vis_font_path
)
img_save_path = os.path.join(
args.output,
os.path.splitext(os.path.basename(image_file))[0] + "_ser.jpg",
)

cv2.imwrite(img_save_path, img_res)
logger.info("save vis result to {}".format(img_save_path))
if count > 0:
total_time += elapse
count += 1
logger.info("Predict time of {}: {}".format(image_file, elapse))


if __name__ == "__main__":
args = parse_args()
args.mode = 'kie'
args.use_visual_backbone = False
args.kie_algorithm = 'LayoutXLM'
args.re_model_dir = './models/re_vi_layoutxlm_xfund_infer'
args.ser_model_dir = './models/ser_vi_layoutxlm_xfund_infer'
args.ser_dict_path = './models/xfund_class_list.txt'
args.vis_font_path = './models/simfang.ttf'
args.ocr_order_method = "tb-yx"
args.det_model_dir = './models/ch_PP-OCRv3_det_infer'
args.rec_model_dir = './models/ch_PP-OCRv4_rec_infer'
args.image_dir = './1.png'
main(args)
![](关键信息提取(KIE)\1.png) <font color='gold'>**2.**</font>创建
demo/main.py用于测试
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import cv2
import time
from paddleocr.ppstructure.kie.predict_kie_token_ser_re import SerRePredictor
from paddleocr.ppstructure.utility import parse_args
from paddleocr.tools.infer_kie_token_ser_re import make_input


# 继承SerRePredictor类,重写__call__方法,实现同时输出Ser模型结果
class KIE(SerRePredictor):
def __init__(self, para):
super().__init__(para)

def __call__(self, im):
start_time = time.time()
ser_results, ser_inputs, ser_elapse = self.ser_engine(im)
if self.predictor is None:
return ser_results, ser_elapse

re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
if not self.use_visual_backbone:
re_input.pop(4)
for idx in range(len(self.input_tensor)):
self.input_tensor[idx].copy_from_cpu(re_input[idx])

self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = dict(
loss=outputs[1],
pred_relations=outputs[2],
hidden_states=outputs[0],
)

post_result = self.postprocess_op(
preds, ser_results=ser_results, entity_idx_dict_batch=entity_idx_dict_batch
)

elapse_time = time.time() - start_time
return ser_results, post_result, elapse_time


if __name__ == "__main__":
args = parse_args()
args.mode = 'kie'
args.show_log = True
args.use_visual_backbone = False
args.kie_algorithm = 'LayoutXLM'
args.re_model_dir = './models/re_vi_layoutxlm_xfund_infer'
args.ser_model_dir = './models/ser_vi_layoutxlm_xfund_infer'
args.ser_dict_path = './models/xfund_class_list.txt'
args.vis_font_path = './models/simfang.ttf'
args.ocr_order_method = "tb-yx"
args.det_model_dir = './models/ch_PP-OCRv3_det_infer'
args.rec_model_dir = './models/ch_PP-OCRv4_rec_infer'
args.det_limit_side_len = 1600,

kie = KIE(args)
img = cv2.imread('./1.png')
ser_res, re_res, elapse = kie(img)
print(ser_res, '\n', re_res, '\n', elapse)
输出:
1
2
3
[[{'transcription': '证号', 'bbox': [239, 113, 267, 128], 'points': [[239.0, 113.0], [267.0, 113.0], [267.0, 128.0], [239.0, 128.0]], 'pred_id': 0, 'pred': 'O'}, {'transcription': 'T4105051990090', 'bbox': [240, 135, 441, 157], 'points': [[240.0, 135.0], [441.0, 136.0], [440.0, 157.0], [240.0, 156.0]], 'pred_id': 3, 'pred': 'ANSWER'}, {'transcription': '姓名', 'bbox': [239, 162, 266, 178], 'points': [[239.0, 162.0], [266.0, 162.0], [266.0, 178.0], [239.0, 178.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '作业类别', 'bbox': [431, 162, 486, 179], 'points': [[431.0, 162.0], [486.0, 162.0], [486.0, 179.0], [431.0, 179.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '毕宏', 'bbox': [237, 182, 296, 203], 'points': [[237.0, 182.0], [296.0, 182.0], [296.0, 203.0], [237.0, 203.0]], 'pred_id': 3, 'pred': 'ANSWER'}, {'transcription': '焊接与热切割作业', 'bbox': [432, 185, 586, 204], 'points': [[432.0, 185.0], [586.0, 185.0], [586.0, 204.0], [432.0, 204.0]], 'pred_id': 3, 'pred': 'ANSWER'}, {'transcription': '男', 'bbox': [237, 246, 259, 269], 'points': [[237.0, 246.0], [259.0, 246.0], [259.0, 269.0], [237.0, 269.0]], 'pred_id': 3, 'pred': 'ANSWER'}, {'transcription': '性别', 'bbox': [238, 227, 269, 246], 'points': [[238.0, 227.0], [269.0, 227.0], [269.0, 246.0], [238.0, 246.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '操作项目', 'bbox': [430, 228, 484, 245], 'points': [[430.0, 228.0], [484.0, 228.0], [484.0, 245.0], [430.0, 245.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '熔化焊接与热切割作业', 'bbox': [432, 251, 622, 268], 'points': [[432.0, 251.0], [622.0, 251.0], [622.0, 268.0], [432.0, 268.0]], 'pred_id': 3, 'pred': 'ANSWER'}, {'transcription': '2023-03-29', 'bbox': [71, 342, 173, 362], 'points': [[72.0, 342.0], [173.0, 345.0], [173.0, 362.0], [71.0, 360.0]], 'pred_id': 3, 'pred': 'ANSWER'}, {'transcription': '初领日期', 'bbox': [73, 325, 128, 341], 'points': [[73.0, 325.0], [128.0, 325.0], [128.0, 341.0], [73.0, 341.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '有效期限', 'bbox': [226, 327, 278, 342], 'points': [[226.0, 327.0], [278.0, 327.0], [278.0, 342.0], [226.0, 342.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '2023-03-29至2029-03-28', 'bbox': [226, 345, 448, 362], 'points': [[226.0, 345.0], [448.0, 345.0], [448.0, 362.0], [226.0, 362.0]], 'pred_id': 3, 'pred': 'ANSWER'}, {'transcription': '2026-03-28前', 'bbox': [71, 389, 192, 413], 'points': [[72.0, 389.0], [192.0, 392.0], [192.0, 413.0], [71.0, 410.0]], 'pred_id': 3, 'pred': 'ANSWER'}, {'transcription': '应复审日期', 'bbox': [74, 373, 140, 390], 'points': [[74.0, 373.0], [140.0, 373.0], [140.0, 390.0], [74.0, 390.0]], 'pred_id': 0, 'pred': 'O'}, {'transcription': '签发机关', 'bbox': [223, 374, 278, 393], 'points': [[223.0, 375.0], [277.0, 374.0], [278.0, 392.0], [224.0, 393.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '河南省应急管理厅', 'bbox': [224, 394, 378, 414], 'points': [[224.0, 394.0], [378.0, 394.0], [378.0, 414.0], [224.0, 414.0]], 'pred_id': 3, 'pred': 'ANSWER'}]] 
[[({'transcription': '姓名', 'bbox': [239, 162, 266, 178], 'points': [[239.0, 162.0], [266.0, 162.0], [266.0, 178.0], [239.0, 178.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '毕宏', 'bbox': [237, 182, 296, 203], 'points': [[237.0, 182.0], [296.0, 182.0], [296.0, 203.0], [237.0, 203.0]], 'pred_id': 3, 'pred': 'ANSWER'}), ({'transcription': '作业类别', 'bbox': [431, 162, 486, 179], 'points': [[431.0, 162.0], [486.0, 162.0], [486.0, 179.0], [431.0, 179.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '焊接与热切割作业', 'bbox': [432, 185, 586, 204], 'points': [[432.0, 185.0], [586.0, 185.0], [586.0, 204.0], [432.0, 204.0]], 'pred_id': 3, 'pred': 'ANSWER'}), ({'transcription': '作业类别', 'bbox': [431, 162, 486, 179], 'points': [[431.0, 162.0], [486.0, 162.0], [486.0, 179.0], [431.0, 179.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '熔化焊接与热切割作业', 'bbox': [432, 251, 622, 268], 'points': [[432.0, 251.0], [622.0, 251.0], [622.0, 268.0], [432.0, 268.0]], 'pred_id': 3, 'pred': 'ANSWER'}), ({'transcription': '性别', 'bbox': [238, 227, 269, 246], 'points': [[238.0, 227.0], [269.0, 227.0], [269.0, 246.0], [238.0, 246.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '男', 'bbox': [237, 246, 259, 269], 'points': [[237.0, 246.0], [259.0, 246.0], [259.0, 269.0], [237.0, 269.0]], 'pred_id': 3, 'pred': 'ANSWER'}), ({'transcription': '初领日期', 'bbox': [73, 325, 128, 341], 'points': [[73.0, 325.0], [128.0, 325.0], [128.0, 341.0], [73.0, 341.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '2023-03-29', 'bbox': [71, 342, 173, 362], 'points': [[72.0, 342.0], [173.0, 345.0], [173.0, 362.0], [71.0, 360.0]], 'pred_id': 3, 'pred': 'ANSWER'}), ({'transcription': '有效期限', 'bbox': [226, 327, 278, 342], 'points': [[226.0, 327.0], [278.0, 327.0], [278.0, 342.0], [226.0, 342.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '2023-03-29至2029-03-28', 'bbox': [226, 345, 448, 362], 'points': [[226.0, 345.0], [448.0, 345.0], [448.0, 362.0], [226.0, 362.0]], 'pred_id': 3, 'pred': 'ANSWER'}), ({'transcription': '签发机关', 'bbox': [223, 374, 278, 393], 'points': [[223.0, 375.0], [277.0, 374.0], [278.0, 392.0], [224.0, 393.0]], 'pred_id': 1, 'pred': 'QUESTION'}, {'transcription': '河南省应急管理厅', 'bbox': [224, 394, 378, 414], 'points': [[224.0, 394.0], [378.0, 394.0], [378.0, 414.0], [224.0, 414.0]], 'pred_id': 3, 'pred': 'ANSWER'})]]
0.959315299987793
## 训练(基于XFUND数据集) 参考:<https://paddlepaddle.github.io/PaddleOCR/latest/ppocr/model_train/kie.html> ### 数据下载 XFUND github主页:<https://github.com/doc-analysis/XFUND> 数据集下载地址:<https://github.com/doc-analysis/XFUND/releases/tag/v1.0> 选择中文数据集
zh.train.jsonzh.train.zipzh.val.jsonzh.val.zip并下载 ![](关键信息提取(KIE)\2.png) ### 数据格式转换 创建XFUND格式转Paddle训练格式代码demo/trans_xfun_data.py(源码:paddleocr/ppstructure/kie/tools/trans_xfun_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import json


def transfer_xfun_data(json_path=None, output_file=None):
with open(json_path, "r", encoding="utf-8") as fin:
lines = fin.readlines()

json_info = json.loads(lines[0])
documents = json_info["documents"]
with open(output_file, "w", encoding="utf-8") as fout:
for idx, document in enumerate(documents):
label_info = []
img_info = document["img"]
document = document["document"]
image_path = img_info["fname"]

for doc in document:
x1, y1, x2, y2 = doc["box"]
points = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
label_info.append(
{
"transcription": doc["text"],
"label": doc["label"],
"points": points,
"id": doc["id"],
"linking": doc["linking"],
}
)

fout.write(
image_path + "\t" + json.dumps(label_info, ensure_ascii=False) + "\n"
)

print("===ok====")

ori_gt_path = r'D:\pycharmproject_2\kie\demo\dataset\zh.val.json'
output_path = r'D:\pycharmproject_2\kie\demo\dataset\xfun_val.json'

transfer_xfun_data(ori_gt_path, output_path)
原始格式:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
{
"lang": "zh",
"version": "0.1",
"split": "train",
"documents": [
{
"id": "zh_train_0",
"uid": "640a0301a1cb24331748b579405502b44d6791883b25ea0eafc8a68126ccdadd",
"document": [
{
"box": [
104,
114,
530,
175
],
"text": "汇丰晋信",
"label": "other",
"words": [
{
"box": [
110,
117,
152,
175
],
"text": "汇"
},
{
"box": [
189,
117,
229,
177
],
"text": "丰"
},
{
"box": [
385,
117,
426,
177
],
"text": "晋"
},
{
"box": [
466,
116,
508,
177
],
"text": "信"
}
],
"linking": [],
"id": 1
}
...
]
}
...
]
}
转换:
1
2
zh_train_0.jpg	[{"transcription": "汇丰晋信", "label": "other", "points": [[104, 114], [530, 114], [530, 175], [104, 175]], "id": 1, "linking": []}, {"transcription": "受理时间:", "label": "question", "points": [[126, 267], [266, 267], [266, 305], [126, 305]], "id": 7, "linking": [[7, 13]]}, {"transcription": "2020.6.15", "label": "answer", "points": [[321, 239], [537, 239], [537, 285], [321, 285]], "id": 13, "linking": [[7, 13]]}...]
zh_train_1.jpg [...]
### 配置字典文件 用于存储各段文本的标签类别,这里使用
models/xfund_class_list.txt字典文件
1
2
3
4
OTHER
QUESTION
ANSWER
HEADER
<font color='red'>**注:**</font> 字典中的标签和json文件中的标签不区分大小写 ### 配置训练文件 #### SER模型训练文件 创建并修改SER模型训练yaml文件
demo/ser_vi_layoutxlm_xfund_zh.yml,源文件路径:configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_vi_layoutxlm_xfund_zh
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: dataset/xfund_train/zh_train/zh_val_42.jpg
d2s_train_image_shape: [3, 224, 224]
# if you want to predict using the groundtruth ocr info,
# you can use the following config
# infer_img: train_data/XFUND/zh_val/val.json
# infer_mode: False

save_res_path: ./output/ser/xfund_zh/res
kie_rec_model_dir:
kie_det_model_dir:
amp_custom_white_list: ['scale', 'concat', 'elementwise_add']

Architecture:
model_type: kie
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForSer
pretrained: True
checkpoints:
# one of base or vi
mode: vi
num_classes: &num_classes 7 # <--------------------------------------当标签数为n时,存在OTHER标签,取2n-1;不存在OTHER标签,取2n+1

Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
key: "backbone_out"

Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000

PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path models/xfund_class_list.txt # <--------------------------------------字典文件路径

Metric:
name: VQASerTokenMetric
main_indicator: hmean

Train:
dataset:
name: SimpleDataSet
data_dir: dataset/xfund_train/zh_train # <--------------------------------------训练集图像路径
label_file_list:
- dataset/xfund_train/xfun_train.json # <--------------------------------------训练集json文件路径
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: &use_textline_bbox_info True
# one of [None, "tb-yx"]
order_method: &order_method "tb-yx"
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4

Eval:
dataset:
name: SimpleDataSet
data_dir: dataset/xfund_train/zh_val # <--------------------------------------验证集图片路径
label_file_list:
- dataset/xfund_train/xfun_val.json # <--------------------------------------验证集json文件路径
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: *use_textline_bbox_info
order_method: *order_method
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
#### RE模型训练文件 创建并修改RE模型训练yaml文件
demo/re_vi_layoutxlm_xfund_zh.yml,源文件路径:configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
Global:
use_gpu: True
epoch_num: &epoch_num 130
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/re_vi_layoutxlm_xfund_zh
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: dataset/xfund_train/zh_train/zh_val_21.jpg
save_res_path: ./output/re/xfund_zh/with_gt
kie_rec_model_dir:
kie_det_model_dir:

Architecture:
model_type: kie
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForRe
pretrained: dataset/re_vi_layoutxlm_xfund_pretrained
mode: vi
checkpoints:

Loss:
name: LossFromOutput
key: loss
reduction: mean

Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
clip_norm: 10
lr:
learning_rate: 0.00005
warmup_epoch: 10
regularizer:
name: L2
factor: 0.00000

PostProcess:
name: VQAReTokenLayoutLMPostProcess

Metric:
name: VQAReTokenMetric
main_indicator: hmean

Train:
dataset:
name: SimpleDataSet
data_dir: dataset/xfund_train/zh_train # <--------------------------------------训练集图片路径
label_file_list:
- dataset/xfund_train/xfun_train.json # <--------------------------------------训练集json文件路径
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: &class_path models/xfund_class_list.txt # <--------------------------------------字典文件路径
use_textline_bbox_info: &use_textline_bbox_info True
order_method: &order_method "tb-yx"
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
batch_size_per_card: 2
num_workers: 4

Eval:
dataset:
name: SimpleDataSet
data_dir: dataset/xfund_train/zh_val # <--------------------------------------验证集图片路径
label_file_list:
- dataset/xfund_train/xfun_val.json # <--------------------------------------验证集json文件路径
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: *use_textline_bbox_info
order_method: *order_method
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 8
### 开始训练 创建训练代码
demo/train.py,训练源码位置:paddleocr/tools/train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))

import yaml
import paddle
import paddle.distributed as dist

from paddleocr.ppocr.data import build_dataloader, set_signal_handlers
from paddleocr.ppocr.modeling.architectures import build_model
from paddleocr.ppocr.losses import build_loss
from paddleocr.ppocr.optimizer import build_optimizer
from paddleocr.ppocr.postprocess import build_post_process
from paddleocr.ppocr.metrics import build_metric
from paddleocr.ppocr.utils.save_load import load_model
from paddleocr.ppocr.utils.utility import set_seed
from paddleocr.ppocr.modeling.architectures import apply_to_static
import paddleocr.tools.program as program
import paddleocr.tools.naive_sync_bn as naive_sync_bn

dist.get_world_size()


def main(config, device, logger, vdl_writer, seed):
# init dist environment
if config["Global"]["distributed"]:
dist.init_parallel_env()

global_config = config["Global"]

# build dataloader
set_signal_handlers()
train_dataloader = build_dataloader(config, "Train", device, logger, seed)
if len(train_dataloader) == 0:
logger.error(
"No Images in train dataset, please ensure\n"
+ "\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n"
+ "\t2. The annotation file and path in the configuration file are provided normally."
)
return

if config["Eval"]:
valid_dataloader = build_dataloader(config, "Eval", device, logger, seed)
else:
valid_dataloader = None
step_pre_epoch = len(train_dataloader)

# build post process
post_process_class = build_post_process(config["PostProcess"], global_config)

# build model
# for rec algorithm
if hasattr(post_process_class, "character"):
char_num = len(getattr(post_process_class, "character"))
if config["Architecture"]["algorithm"] in [
"Distillation",
]: # distillation model
for key in config["Architecture"]["Models"]:
if (
config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead"
): # for multi head
if config["PostProcess"]["name"] == "DistillationSARLabelDecode":
char_num = char_num - 2
if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode":
char_num = char_num - 3
out_channels_list = {}
out_channels_list["CTCLabelDecode"] = char_num
# update SARLoss params
if (
list(config["Loss"]["loss_config_list"][-1].keys())[0]
== "DistillationSARLoss"
):
config["Loss"]["loss_config_list"][-1]["DistillationSARLoss"][
"ignore_index"
] = (char_num + 1)
out_channels_list["SARLabelDecode"] = char_num + 2
elif any(
"DistillationNRTRLoss" in d
for d in config["Loss"]["loss_config_list"]
):
out_channels_list["NRTRLabelDecode"] = char_num + 3

config["Architecture"]["Models"][key]["Head"][
"out_channels_list"
] = out_channels_list
else:
config["Architecture"]["Models"][key]["Head"][
"out_channels"
] = char_num
elif config["Architecture"]["Head"]["name"] == "MultiHead": # for multi head
if config["PostProcess"]["name"] == "SARLabelDecode":
char_num = char_num - 2
if config["PostProcess"]["name"] == "NRTRLabelDecode":
char_num = char_num - 3
out_channels_list = {}
out_channels_list["CTCLabelDecode"] = char_num
# update SARLoss params
if list(config["Loss"]["loss_config_list"][1].keys())[0] == "SARLoss":
if config["Loss"]["loss_config_list"][1]["SARLoss"] is None:
config["Loss"]["loss_config_list"][1]["SARLoss"] = {
"ignore_index": char_num + 1
}
else:
config["Loss"]["loss_config_list"][1]["SARLoss"]["ignore_index"] = (
char_num + 1
)
out_channels_list["SARLabelDecode"] = char_num + 2
elif list(config["Loss"]["loss_config_list"][1].keys())[0] == "NRTRLoss":
out_channels_list["NRTRLabelDecode"] = char_num + 3
config["Architecture"]["Head"]["out_channels_list"] = out_channels_list
else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num

if config["PostProcess"]["name"] == "SARLabelDecode": # for SAR model
config["Loss"]["ignore_index"] = char_num - 1

model = build_model(config["Architecture"])

use_sync_bn = config["Global"].get("use_sync_bn", False)
if use_sync_bn:
if config["Global"].get("use_npu", False):
naive_sync_bn.convert_syncbn(model)
else:
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
logger.info("convert_sync_batchnorm")

model = apply_to_static(model, config, logger)

# build loss
loss_class = build_loss(config["Loss"])

# build optim
optimizer, lr_scheduler = build_optimizer(
config["Optimizer"],
epochs=config["Global"]["epoch_num"],
step_each_epoch=len(train_dataloader),
model=model,
)

# build metric
eval_class = build_metric(config["Metric"])

logger.info("train dataloader has {} iters".format(len(train_dataloader)))
if valid_dataloader is not None:
logger.info("valid dataloader has {} iters".format(len(valid_dataloader)))

use_amp = config["Global"].get("use_amp", False)
amp_level = config["Global"].get("amp_level", "O2")
amp_dtype = config["Global"].get("amp_dtype", "float16")
amp_custom_black_list = config["Global"].get("amp_custom_black_list", [])
amp_custom_white_list = config["Global"].get("amp_custom_white_list", [])
if os.path.exists(
os.path.join(config["Global"]["save_model_dir"], "train_result.json")
):
try:
os.remove(
os.path.join(config["Global"]["save_model_dir"], "train_result.json")
)
except:
pass
if use_amp:
AMP_RELATED_FLAGS_SETTING = {
"FLAGS_max_inplace_grad_add": 8,
}
if paddle.is_compiled_with_cuda():
AMP_RELATED_FLAGS_SETTING.update(
{
"FLAGS_cudnn_batchnorm_spatial_persistent": 1,
"FLAGS_gemm_use_half_precision_compute_type": 0,
}
)
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
scale_loss = config["Global"].get("scale_loss", 1.0)
use_dynamic_loss_scaling = config["Global"].get(
"use_dynamic_loss_scaling", False
)
scaler = paddle.amp.GradScaler(
init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling,
)
if amp_level == "O2":
model, optimizer = paddle.amp.decorate(
models=model,
optimizers=optimizer,
level=amp_level,
master_weight=True,
dtype=amp_dtype,
)
else:
scaler = None

# load pretrain model
pre_best_model_dict = load_model(
config, model, optimizer, config["Architecture"]["model_type"]
)

if config["Global"]["distributed"]:
model = paddle.DataParallel(model)
# start train
program.train(
config,
train_dataloader,
valid_dataloader,
device,
model,
loss_class,
optimizer,
lr_scheduler,
post_process_class,
eval_class,
pre_best_model_dict,
logger,
step_pre_epoch,
vdl_writer,
scaler,
amp_level,
amp_custom_black_list,
amp_custom_white_list,
amp_dtype,
)


def test_reader(config, device, logger):
loader = build_dataloader(config, "Train", device, logger)
import time

starttime = time.time()
count = 0
try:
for data in loader():
count += 1
if count % 1 == 0:
batch_time = time.time() - starttime
starttime = time.time()
logger.info(
"reader: {}, {}, {}".format(count, len(data[0]), batch_time)
)
except Exception as e:
logger.info(e)
logger.info("finish reader: {}, Success!".format(count))


if __name__ == "__main__":
config, device, logger, vdl_writer = program.preprocess(is_train=True)
seed = config["Global"]["seed"] if "seed" in config["Global"] else 1024
set_seed(seed)
main(config, device, logger, vdl_writer, seed)
# test_reader(config, device, logger)
根据训练目标的配置文件,分发显卡和训练任务: <font color='gold'>**SER**</font>:
CUDA_VISIBLE_DEVICES=0 python train.py -c ser_vi_layoutxlm_xfund_zh.yml<font color='gold'>**RE**</font>:CUDA_VISIBLE_DEVICES=1 python train.py -c re_vi_layoutxlm_xfund_zh.yml首次启动将自动下载预训练模型。日志输出样例如下:
1
2
3
4
5
6
7
[2024/11/26 17:35:39] ppocr INFO: epoch: [1/200], global_step: 10, lr: 0.000006, loss: 1.838789, avg_reader_cost: 0.39165 s, avg_batch_cost: 0.54732 s, avg_samples: 8.0, ips: 14.61671 samples/s, eta: 0:34:34, max_mem_reserved: 11694 MB, max_mem_allocated: 10523 MB
[2024/11/26 17:35:43] ppocr INFO: epoch: [1/200], global_step: 19, lr: 0.000018, loss: 1.446505, avg_reader_cost: 0.22235 s, avg_batch_cost: 0.32516 s, avg_samples: 6.9, ips: 21.22014 samples/s, eta: 0:28:56, max_mem_reserved: 11694 MB, max_mem_allocated: 10523 MB
eval model:: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:04<00:00, 1.61it/s]
[2024/11/26 17:35:47] ppocr INFO: cur metric, precision: 0.05993363749481543, recall: 0.10410662824207492, hmean: 0.07607265069755198, fps: 101.1334614167882
[2024/11/26 17:36:22] ppocr INFO: save best model is to ./output/ser_vi_layoutxlm_xfund_zh/best_accuracy
[2024/11/26 17:36:22] ppocr INFO: best metric, hmean: 0.07607265069755198, precision: 0.05993363749481543, recall: 0.10410662824207492, fps: 101.1334614167882, best_epoch: 1
[2024/11/26 17:36:39] ppocr INFO: save model in ./output/ser_vi_layoutxlm_xfund_zh/latest
### 模型导出 训练结束后,进入最佳模型目录
demo/output/re_vi_layoutxlm_xfund_zh/best_accuracy,此时模型不具备.pdmodel文件,不能直接用于预测,需要导出为inference模型(本文适当略过模型评估步骤) ![](关键信息提取(KIE)\3.png) 创建模型导出代码demo/export_model.py,导出源码位置:paddleocr/tools/export_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))

import argparse

from paddleocr.tools.program import load_config, merge_config, ArgsParser
from paddleocr.ppocr.utils.export_model import export


def main():
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
config = merge_config(config, FLAGS.opt)
# export model
export(config)


if __name__ == "__main__":
main()
导出命令: <font color='gold'>**SER**</font>:
1
2
python export_model.py -c ser_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./output/ser_vi_layoutxlm_xfund_z
h/best_accuracy Global.save_inference_dir=./inference/ser_vi_layoutxlm Global.save_inference_dir=./inference/ser_vi_layoutxlm
<font color='gold'>**RE**</font>:
1
2
python export_model.py -c re_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./output/re_vi_layoutxlm_xfund_z
h/best_accuracy Global.save_inference_dir=./inference/re_vi_layoutxlm Global.save_inference_dir=./inference/re_vi_layoutxlm
将在
demo/inference目录下生成导出的推理模型,可修改args参数中模型路径进行验证。 ## 客制化 ### 标注工具:PPOCRLabel 操作系统回到windows,并创建demo/dataset/temp文件夹存放训练图像 安装:pip install PPOCRLabel==2.1.3 -i https://pypi.tuna.tsinghua.edu.cn/simple`

启动:PPOCRLabel --lang ch --kie True

此时标注工具左下角显示关键词列表项



注: 最新版可能存在bug,建议使用2.1.3版本。且截至目前(24.11月 v2.1.12版本),工具仍不具备RE任务关系标注功能,需制作后处理脚本 *难点

后处理:

SER

创建SER任务处理代码demo/trans+ppocrlabel.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import json

with open('./dataset/temp/Label.txt', 'r', encoding='utf-8') as l:
content = l.readlines()
for line in content:
image = line.split('\t')[0]
label = json.loads(line.split('\t')[1])
label_set = []
for i, lab in enumerate(label):
new_lab = {
'transcription': lab.get('transcription'),
'points': lab.get('points'),
'label': lab.get('key_cls'),
'id': i,
'linking': []
}
label_set.append(new_lab)
label_set_json = json.dumps(label_set, ensure_ascii=False)
with open('./dataset/project/train.json', 'a+', encoding='utf-8') as f:
f.write(image)
f.write('\t')
f.write(label_set_json)
f.write('\n')
RE

测试中,未完待续

Powered by Hexo and Hexo-theme-hiker

Copyright © 2017 - 2024 青域 All Rights Reserved.

UV : | PV :