Nebula-fastapi接口维护文档

普通接口

1.导入fastapi、连接nebula连接池

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from fastapi import FastAPI
from nebula2.gclient.net import ConnectionPool
from nebula2.Config import Config


# 关闭在线文档,防止攻击
app_router = FastAPI(docs_url=None, redoc_url=None)

config = Config()
config.max_connection_pool_size = 10
# 连接超时时间
config.timeout = 60000
# 关闭空闲连接时间
config.idle_time = 0
# 检查空闲连接时间间隔
config.interval_check = -1
# 初始化连接池
connection_pool = ConnectionPool()
# 如果给定的服务器正常,则返回true,否则返回false
ok = connection_pool.init([('host', 9669)], config)

if __name__ == "__main__":
import uvicorn
uvicorn.run(app="nebula_api:app_router", reload=True, debug=True, host=host, port=port)

2.CORS跨域访问设置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from fastapi.middleware.cors import CORSMiddleware


# CORS
origins = [
"*"
]

app_router.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)

3.一般查询接口

1.单点查询

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# 从连接池中获取会话
session = connection_pool.get_session('root', 'nebula')
session.execute('USE ai_project')

data_final = []
errorcode = 0
result = {
"data": data_final,
"errorcode": errorcode
}
# 举例,查询标签example
data_select = session.execute(
str(f'match (v:example) return id(v),v.name;')
)

# 查询出错,错误代码1
if not data_select.is_succeeded():
session.release()
result["errorcode"] = 1
return result

# 查询结果不为空,显示数据
if not data_select.is_empty():
size = data_select.row_size()

for index in range(size):
data1 = data_select.row_values(index)[0].as_string()
data2 = data_select.row_values(index)[1].as_string()
data_dict = {
"entity_id": data1,
"name": data2,
}
data_final.append(data_dict)

session.release()
return result

# 查询类型下实体为空,错误代码2
session.release()
result["errorcode"] = 2
return result

2.树型查询

主函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# 自定义类和函数
from api_class import *

session = connection_pool.get_session('root', 'nebula')
session.execute('USE ai_project')

data_final = []
errorcode = 0
result = {
"data": data_final,
"errorcode": errorcode
}

# 以example_id为根节点
data_select = session.execute(
str(f'match Ret=(v)-[e:edge_example*0..15]->(p) where id(v)=="{example_id}" return nodes(Ret);')
)

# 查询出错,错误代码1
if not data_select.is_succeeded():
session.release()
result["errorcode"] = 1
return result

# 查询不为空,显示数据
if not data_select.is_empty():

size = data_select.row_size()
data_list = []

for index in range(size):
data1 = data_select.row_values(index)[0].as_list()
item = re.finditer(r'\(\"(.*?)\" :(.*?)\{(.*?)\}\)', str(data1))

parent_id = '0'
for match in item:
kv_dict = {
'parent_id': parent_id,
'entity_id': match.group(1).strip(),
'entity_type': match.group(2).strip()
}
kvs = match.group(3).replace(" ", "").split(",")

for kv in kvs:
kv_dict[kv.strip().split(':')[0]] = kv.split(':')[1][1:-1]
if kv_dict not in data_list:
data_list.append(kv_dict)

parent_id = match.group(1).strip()

# json对象列表转换为树形json对象
data_final = ApiFuncs.list_to_tree(data_list)
# 树形json对象按参数排序
ApiFuncs.data_sort(data_final)

for item in data_final:
del item["parent_id"]

session.release()
result["data"] = data_final
return result

session.release()
result["errorcode"] = 2
return result

将json对象列表转换为树形json对象的函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class ApiFuncs:
@staticmethod
def list_to_tree(data):
root = []
node = []

for d in data:
if d.get("parent_id") == '0':
root.append(d)
else:
node.append(d)

for p in root:
ApiFuncs.add_node(p, node)

if len(root) == 0:
return node

return root

@staticmethod
def add_node(p, node):
p["children"] = []

for n in node:
if n.get("parent_id") == p.get("entity_id"):
p["children"].append(n)

for t in p["children"]:
if not t.get("children"):
t["children"] = []

t["children"].append(ApiFuncs.add_node(t, node))
if not t["children"]:
del t["children"]

if len(p["children"]) == 0:
return

将树形json对象按参数排序的函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class SerialNameError(Exception):
def __init__(self, message):
self.message = message
...
@staticmethod
def data_sort(data):
for elem in data:
if elem["serial"] == "_NULL_":
elem["serial"] = "0"
try:
data.sort(key=lambda x: int(x["serial"]))
except ValueError:
raise SerialNameError("序号似乎无效")
for item in data:
if "children" in item.keys():
ApiFuncs.data_sort(item["children"])
return

4.一般编辑接口

1
2
3
4
5
6
# 继承BaseModel类的参数
from typing import List
from pydantic import BaseModel

class SampleEdit(BaseModel):
example_kv: List[dict]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
session = connection_pool.get_session('root', 'nebula')
session.execute('USE ai_project')

result = {
"msg": True,
"errorcode": 0
}

for item in eg.example_kv:
# 新增
if item["type_operation"] == 'add':
"""
==========增==========
"""
temp_id = item["name"] + "_" + ApiFuncs.create_time()
temp_md5 = ApiFuncs.create_md5(temp_id)

# 检索当前编号,创建实体时对编号自增
serial_list = session.execute(
str(f'match (v:example) return v.seq;')
)

data = []
size = serial_list.row_size()
for index in range(size):
if serial_list.row_values(index)[0].is_string():
data1 = serial_list.row_values(index)[0].as_string()
data.append(int(data1))
if data:
count = max(data)
count += 1
else:
count = 1

session.execute(
str(f'INSERT VERTEX example ( name,seq ) VALUES "{temp_md5}":("{item["name"]}", "{count}");')
)

# 删除
elif item["type_operation"] == 'delete':
"""
==========删==========
"""
session.execute(str(f'DELETE VERTEX "{item["id"]}";'))

# 修改
elif item["type_operation"] == 'update':
"""
==========改==========
"""
temp_md5 = item["id"]

session.execute(
str(f'UPDATE VERTEX ON example "{temp_md5}" SET name = "{item["name"]}";')
)

# 关闭连接池
session.release()
return result

测试接口:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import requests
import json

data = {
'example_kv': [
{"type_operation": "add", "name": "test"},
{...},
...
],
}

body = json.dumps(data)
response = requests.post('http://127.0.0.1:8888/example', data=body)
print(response.text)

5.功能函数

查找树型json结构中关键字

1
2
3
4
5
6
7
8
9
10
import re

def find_name(json_data, e):
for element in json_data:
match = re.search(e, element["name"])
if match:
id_collect.append(element["entity_id"])
if "children" in element.keys():
find_name(element["children"], e)
return id_collect

分页

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from __future__ import annotations
import math
from fastapi import FastAPI, Depends
from fastapi_pagination import paginate, add_pagination
from typing import TypeVar, Generic, Sequence

from fastapi import Query
from fastapi_pagination.bases import AbstractPage, AbstractParams, RawParams
from pydantic import BaseModel


T = TypeVar("T")


class Params(BaseModel, AbstractParams):
page: int = Query(1, ge=1, description="Page number")
size: int = Query(17, gt=0, le=100, description="Page size")

def to_raw_params(self) -> RawParams:
return RawParams(
limit=self.size,
offset=self.size * (self.page - 1),
)


class Page(AbstractPage[T], Generic[T]):
results: Sequence[T]
total: int
page: int
size: int
next: str
previous: str
total_pages: int

__params_type__ = Params

@classmethod
def create(
cls,
results: results,
total: int,
params: Params,
) -> Page[T]:
page = params.page
size = params.size
total_pages = math.ceil(total / params.size)
next = f"?page={page + 1}&size={size}" if (page + 1) <= total_pages else "null"
previous = f"?page={page - 1}&size={size}" if (page - 1) >= 1 else "null"

return cls(results=results, total=total, page=params.page,
size=params.size,
next=next,
previous=previous,
total_pages=total_pages)


app = FastAPI()


class User(BaseModel):
sample_id: str
sample_name: str
time: str


sample_data = [
{
"sample_id": "000001",
"sample_name": "test1",
"time": "2022-08-11 21:15:33"
},
{
"sample_id": "000002",
"sample_name": "test2",
"time": "2022-08-12 13:45:56"
},
{
"sample_id": "000003",
"sample_name": "test3",
"time": "2022-08-12 13:45:59"
},
]


@app.get('/sample', response_model=Page[User])
async def get_users():
return paginate(sample_data)


add_pagination(app)


if __name__ == "__main__":
import uvicorn
uvicorn.run(app="1:app", reload=True, host='127.0.0.1', port=9999)

项目级

文件目录结构

│ api_class.py
│ main.py

├─api
│ │ nebuladb.py
│ │ proj_1.py
│ │ proj_2.py
│ │ __init__.py
│ │
│ └─__pycache__
│ nebuladb.cpython-39.pyc
│ proj_1cpython-39.pyc
│ proj_2.cpython-39.pyc
│ __init__.cpython-39.pyc

├─data
│ nebula_data.txt

├─logs
│ 2022-xx-01.txt
│ 2022-xx-02.txt
│ 2022-xx-03.txt
│ 2022-xx-04.txt

├─scripts
│ post_1.py
│ test_1.py

├─utils
│ │ client.py
│ │ creat_data.py
│ │ log.py
│ │ snapshot_day_by_day.py
│ │ __init__.py
│ │
│ └─__pycache__
│ log.cpython-39.pyc
│ __init__.cpython-39.pyc

└─__pycache__
api_class.cpython-39.pyc
main.cpython-39.pyc

主函数main.py

1
2
3
4
5
6
7
8
from api import main
from fastapi_pagination import add_pagination

app_router = main()
add_pagination(app_router)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app="main:app_router", reload=True, debug=True, host='0.0.0.0', port=8888)

日志utils/log.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import os
import logging
import logging.handlers
import time


class LogInit(object):
__instance = None

def __init__(self):
_log_dir = os.path.join(os.path.dirname(__file__), '../logs')
_log_name = time.strftime('%Y-%m-%d', time.localtime(time.time())) + '.txt'

self.logger = logging.getLogger(_log_name)
self.logger.setLevel(logging.DEBUG)
formatter = logging.Formatter(
'%(asctime)s - %(levelname)s - %(filename)s - %(funcName)s - %(lineno)s - %(message)s')
if not self.logger.handlers:
file_log_handler = logging.handlers.RotatingFileHandler(os.path.join(_log_dir, _log_name),
maxBytes=10 * 1024 * 1024, backupCount=3,
encoding="utf-8")
file_log_handler.setLevel(logging.INFO)
file_log_handler.setFormatter(formatter)
self.logger.addHandler(file_log_handler)

@staticmethod
def set_logger():
if not LogInit.__instance:
LogInit.__instance = LogInit()
return LogInit.__instance


logger = LogInit.set_logger().logger

使用方式:

1
2
3
4
5
from utils.log import logger

logger.info("信息")
logger.warning("警告")
logger.error("错误")

数据库配置api/nebuladb.py

1
2
3
4
5
6
7
8
9
10
11
from nebula2.Config import Config
from nebula2.gclient.net import ConnectionPool

# NEBULA
config = Config()
config.max_connection_pool_size = 10
config.timeout = 60000
config.idle_time = 0
config.interval_check = -1
connection_pool = ConnectionPool()
ok = connection_pool.init([('192.168.80.128', 9669)], config)

图谱快照保存脚本utils/snapshot_day_by_day.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import schedule
import time

from nebula2.gclient.net import ConnectionPool
from nebula2.Config import Config


config = Config()
config.max_connection_pool_size = 10
config.timeout = 60000
config.idle_time = 0
config.interval_check = -1
connection_pool = ConnectionPool()
ok = connection_pool.init([('192.168.80.128', 9669)], config)


def add_snapshot():
session = connection_pool.get_session('root', 'nebula')
session.execute('USE ai_project')
session.execute(str('CREATE SNAPSHOT'))


def del_snapshot():
session = connection_pool.get_session('root', 'nebula')
session.execute('USE ai_project')
data = session.execute(str('SHOW SNAPSHOTS'))
datas = []
if not data.is_empty():
size = data.row_size()

for index in range(size):
data1 = data.row_values(index)[0].as_string()
datas.append(data1)
if len(datas) > 3:
data_final = datas[:-3]
for i in data_final:
session.execute(str(f'DROP SNAPSHOT {i}'))


schedule.every().day.at("22:00").do(add_snapshot)
schedule.every().day.at("22:00").do(del_snapshot)

while True:
schedule.run_pending()
time.sleep(1)

api包api/__init__.py配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from api.stdlib import stdlib_router
from api.quota import quota_router


def main():
app = FastAPI(docs_url=None, redoc_url=None)
app_cors(app)
app_stdlib(app)
app_quota(app)
app_ocr(app)
return app


def app_cors(app: FastAPI):
"""
CORS
:param app:
:return:
"""
origins = [
"*"
]

app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)


def app_proj_1(app: FastAPI):
"""
项目1
:param app:
:return:
"""
app.include_router(proj_1_router)


def app_proj_2(app: FastAPI):
"""
项目2
:param app:
:return:
"""
app.include_router(proj_2_router)

项目接口api/proj_1(2).py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from fastapi import APIRouter

from utils.log import logger
from api_class import *
from .nebuladb import *

proj_1_router = APIRouter(
prefix="/api",
)

@stdlib_router.api_route("/project/standard", methods=['GET'])
async def get_project_standard():
"""
获取工程规范
:return:
"""
# 从连接池中获取会话
session = connection_pool.get_session('root', 'nebula')
session.execute('USE ai_project')

data_final = []
errorcode = 0
result = {
"data": data_final,
"errorcode": errorcode
}
# 举例,查询标签example
data_select = session.execute(
str(f'match (v:example) return id(v),v.name;')
)

# 查询出错,错误代码1
if not data_select.is_succeeded():
session.release()
result["errorcode"] = 1
return result

# 查询结果不为空,显示数据
if not data_select.is_empty():
size = data_select.row_size()

for index in range(size):
data1 = data_select.row_values(index)[0].as_string()
data2 = data_select.row_values(index)[1].as_string()
data_dict = {
"entity_id": data1,
"name": data2,
}
data_final.append(data_dict)

session.release()
return result

# 查询类型下实体为空,错误代码2
session.release()
result["errorcode"] = 2
return result

Powered by Hexo and Hexo-theme-hiker

Copyright © 2017 - 2024 青域 All Rights Reserved.

UV : | PV :