简介:本文介绍通过ModelScope来完成表格OCR这一应用,该应用使用三个模型:
- 表格识别(table_recognition)
- 文本检测(ocr_detection)
- 文本识别(ocr_recognition)
操作步骤
参考快速开始
环境准备
为了更快的体验产品,这里选择了使用ModelScope提供的远程环境,即Notebook进行开发,更加便捷。
- “快速开始”文档中,并未给出进入notebook的链接,需要从个人中心进入,https://modelscope.cn/#/my/mynotebook
模型调试:表格识别
参考:https://www.modelscope.cn/models/damo/cv_dla34_table-structure-recognition_cycle-centernet/summary
- 准备图像文件。
- 将图像文件上传至notebook(可拖拽),或使用url。
- 输入下列代码。
from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks import cv2 #read image from url table_recognition = pipeline(Tasks.table_recognition, model='damo/cv_dla34_table-structure-recognition_cycle-centernet') result = table_recognition('https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/table_recognition.jpg') print(result) #read image from local path img_path = 'table_ocr.jpg' #Here define upload image name as table_ocr.jpg img = cv2.imread(img_path) result = table_recognition(img) print(result)
模型调试:文本检测
参考:https://www.modelscope.cn/models/damo/cv_resnet18_ocr-detection-line-level_damo/summary
- 准备图像文件。
- 将图像文件上传至notebook(可拖拽),或使用url。
- 输入下列代码。
from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks import cv2 ocr_detection = pipeline(Tasks.ocr_detection, model='damo/cv_resnet18_ocr-detection-line-level_damo') # ocr_detection = pipeline(Tasks.ocr_detection, model='damo/cv_resnet18_ocr-detection-word-level_damo') # read file img_path = 'ocr_detection.jpg' img = cv2.imread(img_path) result = ocr_detection(img) print(result) # or read url img_url = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/ocr_detection.jpg' result_url = ocr_detection(img_url) print(result_url)
上面展示的是文本行检测模型的使用方法。
如需使用单词检测模型,请替换为第6行注释的模型,并参考https://www.modelscope.cn/models/damo/cv_resnet18_ocr-detection-word-level_damo/summary。
模型调试:文本识别
参考:https://www.modelscope.cn/models/damo/cv_convnextTiny_ocr-recognition-general_damo/summary
- 准备图像文件。
- 将图像文件上传至notebook(可拖拽),或使用url。
- 输入下列代码。
from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks import cv2 ocr_recognition = pipeline(Tasks.ocr_recognition, model='damo/cv_convnextTiny_ocr-recognition-general_damo') # ocr_recognition = pipeline(Tasks.ocr_recognition, model='damo/cv_convnextTiny_ocr-recognition-scene_damo') # ocr_recognition = pipeline(Tasks.ocr_recognition, model='damo/cv_convnextTiny_ocr-recognition-document_damo') # ocr_recognition = pipeline(Tasks.ocr_recognition, model='damo/cv_convnextTiny_ocr-recognition-handwritten_damo') # read file img_path = 'ocr_recognition.jpg' img = cv2.imread(img_path) result = ocr_recognition(img) print(result) # or read url img_url = 'http://duguang-labelling.oss-cn-shanghai.aliyuncs.com/mass_img_tmp_20220922/ocr_recognition.jpg' result_url = ocr_recognition(img_url) print(result_url)
模型调试:将表格识别-文字检测-文字识别串联有了上述的基础,我们串联各个模块,以实现完整的OCR功能。
首先,定义好三个模型:
ocr_table = pipeline(Tasks.table_recognition, model='damo/cv_dla34_table-structure-recognition_cycle-centernet') ocr_detection = pipeline(Tasks.ocr_detection, model='damo/cv_resnet18_ocr-detection-line-level_damo') ocr_recognition = pipeline(Tasks.ocr_recognition, model='damo/cv_convnextTiny_ocr-recognition-general_damo')
然后,通过如下2行代码进行表格识别和文字检测。
table_res = ocr_table(image)['polygons'] det_res = ocr_detection(image)['polygons']
接着,用下边一行代码读入图片、文字位置、表格信息,通过文字位置得到文本图片块进而识别出文字内容,再通过判断文字是否在单元格内从而得到文字与表格的关系。
img, result = text_recognition(det_res, table_res, image)
其中,判断文字与单元格关系的函数为point_in_box。
至此,我们已经在代码层面完成了三个模型的调试和串联,并实现了完整的OCR功能。
完整代码如下:
from PIL import Image from PIL import ImageDraw from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks import numpy as np import cv2 import math import pandas as pd import gradio as gr # scripts for crop images def crop_image(img, position): def distance(x1,y1,x2,y2): return math.sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2)) position = position.tolist() for i in range(4): for j in range(i+1, 4): if(position[i][0] > position[j][0]): tmp = position[j] position[j] = position[i] position[i] = tmp if position[0][1] > position[1][1]: tmp = position[0] position[0] = position[1] position[1] = tmp if position[2][1] > position[3][1]: tmp = position[2] position[2] = position[3] position[3] = tmp x1, y1 = position[0][0], position[0][1] x2, y2 = position[2][0], position[2][1] x3, y3 = position[3][0], position[3][1] x4, y4 = position[1][0], position[1][1] corners = np.zeros((4,2), np.float32) corners[0] = [x1, y1] corners[1] = [x2, y2] corners[2] = [x4, y4] corners[3] = [x3, y3] img_width = distance((x1+x4)/2, (y1+y4)/2, (x2+x3)/2, (y2+y3)/2) img_height = distance((x1+x2)/2, (y1+y2)/2, (x4+x3)/2, (y4+y3)/2) corners_trans = np.zeros((4,2), np.float32) corners_trans[0] = [0, 0] corners_trans[1] = [img_width - 1, 0] corners_trans[2] = [0, img_height - 1] corners_trans[3] = [img_width - 1, img_height - 1] transform = cv2.getPerspectiveTransform(corners, corners_trans) dst = cv2.warpPerspective(img, transform, (int(img_width), int(img_height))) return dst def point_in_box(box,point): x1,y1 = box[0][0],box[0][1] x2,y2 = box[1][0],box[1][1] x3,y3 = box[2][0],box[2][1] x4,y4 = box[3][0],box[3][1] ctx,cty = point[0],point[1] a = (x2 - x1)*(cty - y1) - (y2 - y1)*(ctx - x1) b = (x3 - x2)*(cty - y2) - (y3 - y2)*(ctx - x2) c = (x4 - x3)*(cty - y3) - (y4 - y3)*(ctx - x3) d = (x1 - x4)*(cty - y4) - (y1 - y4)*(ctx - x4) if ((a > 0 and b > 0 and c > 0 and d > 0) or (a < 0 and b < 0 and c < 0 and d < 0)): return True else : return False def order_point(coor): arr = np.array(coor).reshape([4, 2]) sum_ = np.sum(arr, 0) centroid = sum_ / arr.shape[0] theta = np.arctan2(arr[:, 1] - centroid[1], arr[:, 0] - centroid[0]) sort_points = arr[np.argsort(theta)] sort_points = sort_points.reshape([4, -1]) if sort_points[0][0] > centroid[0]: sort_points = np.concatenate([sort_points[3:], sort_points[:3]]) sort_points = sort_points.reshape([4, 2]).astype('float32') return sort_points ocr_table = pipeline(Tasks.table_recognition, model='damo/cv_dla34_table-structure-recognition_cycle-centernet') ocr_detection = pipeline(Tasks.ocr_detection, model='damo/cv_resnet18_ocr-detection-line-level_damo') ocr_recognition = pipeline(Tasks.ocr_recognition, model='damo/cv_convnextTiny_ocr-recognition-general_damo') def coord2str(box): out = [] for i in range(0,4): out.append(str(round(box[i][0],1))+','+str(round(box[i][1],1))) return ';'.join(out) def text_recognition(det_res, table_res, image): output = [] table_res = np.array(table_res).reshape([len(table_res), 4, 2]) for i in range(det_res.shape[0]): pts = order_point(det_res[i]) image_crop = crop_image(image, pts) result = ocr_recognition(image_crop) find_cell = 0 ctx = (p0[0]+p1[0]+p2[0]+p3[0]) / 4.0 cty = (p0[1]+p1[1]+p2[1]+p3[1]) / 4.0 for j in range(0,len(table_res)): if point_in_box(table_res[j],[ctx,cty]): output.append([str(i + 1), coord2str(pts), coord2str(table_res[j]), result['text'].replace(' ', '')]) find_cell = 1 break if find_cell == 0: output.append([str(i + 1), coord2str(pts), '', result['text'].replace(' ', '')]) result = pd.DataFrame(output, columns=['序号', '文本坐标', '单元格坐标', '行识别结果']) return image, result def table_test_parsing(image): table_res = ocr_table(image)['polygons'] det_res = ocr_detection(image)['polygons'] img, result = text_recognition(det_res, table_res, image) return img, result img_path = 'table_ocr.jpg' img = cv2.imread(img_path) img, result = table_test_parsing(img) print(result)