-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathpy_lib.rs
148 lines (134 loc) · 5.36 KB
/
py_lib.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
// Tauri Python Plugin
// © Copyright 2024, by Marco Mengelkoch
// Licensed under MIT License, see License file for more details
// git clone https://github.com/marcomq/tauri-plugin-python
use std::sync::atomic::AtomicBool;
use std::{collections::HashSet, sync::Mutex};
use rustpython_vm::py_serde;
use lazy_static::lazy_static;
use crate::{models::*, Error};
fn create_globals() -> rustpython_vm::scope::Scope {
rustpython_vm::Interpreter::without_stdlib(Default::default())
.enter(|vm| vm.new_scope_with_builtins())
}
lazy_static! {
static ref INIT_BLOCKED: AtomicBool = false.into();
static ref FUNCTION_MAP: Mutex<HashSet<String>> = Mutex::new(HashSet::new());
static ref GLOBALS: rustpython_vm::scope::Scope = create_globals();
}
pub fn init() {}
pub fn run_python(payload: StringRequest) -> crate::Result<()> {
run_python_internal(payload.value, "<embedded>".into())
}
pub fn run_python_internal(code: String, filename: String) -> crate::Result<()> {
rustpython_vm::Interpreter::without_stdlib(Default::default()).enter(|vm| {
let code_obj = vm
.compile(&code, rustpython_vm::compiler::Mode::Exec, filename)
.map_err(|err| vm.new_syntax_error(&err, Some(&code)))?;
vm.run_code_obj(code_obj, GLOBALS.clone())
})?;
Ok(())
}
pub fn register_function(payload: RegisterRequest) -> crate::Result<()> {
register_function_str(payload.python_function_call, payload.number_of_args)
}
pub fn register_function_str(
function_name: String,
number_of_args: Option<u8>,
) -> crate::Result<()> {
if INIT_BLOCKED.load(std::sync::atomic::Ordering::Relaxed) {
return Err("Cannot register after function called".into());
}
rustpython_vm::Interpreter::without_stdlib(Default::default()).enter(|vm| {
let var_dot_split: Vec<&str> = function_name.split(".").collect();
let func = GLOBALS
.globals
.get_item(var_dot_split[0], vm)
.unwrap_or_else(|_| {
panic!("Cannot find '{}' in globals", var_dot_split[0]);
});
if var_dot_split.len() > 2 {
func.get_attr(&vm.ctx.new_str(var_dot_split[1]), vm)
.unwrap()
.get_attr(&vm.ctx.new_str(var_dot_split[2]), vm)
.unwrap();
} else if var_dot_split.len() > 1 {
func.get_attr(&vm.ctx.new_str(var_dot_split[1]), vm)
.unwrap_or_else(|_| {
panic!(
"Cannot find sub function '{}' in '{}'",
var_dot_split[1], var_dot_split[0]
);
});
}
if let Some(num_args) = number_of_args {
let py_analyze_sig = format!(
r#"
from inspect import signature
if len(signature({}).parameters) != {}:
raise Exception("Function parameters don't match in 'registerFunction'")
"#,
function_name, num_args
);
let code_obj = vm
.compile(
&py_analyze_sig,
rustpython_vm::compiler::Mode::Exec,
"<embedded>".to_owned(),
)
.map_err(|err| vm.new_syntax_error(&err, Some(&py_analyze_sig)))?;
vm.run_code_obj(code_obj, GLOBALS.clone())
.unwrap_or_else(|_| {
panic!("Number of args doesn't match signature of {function_name}.")
});
}
// dbg!(format!("Added '{function_name}'"));
FUNCTION_MAP.lock().unwrap().insert(function_name);
Ok(())
})
}
pub fn call_function(payload: RunRequest) -> crate::Result<String> {
INIT_BLOCKED.store(true, std::sync::atomic::Ordering::Relaxed);
let function_name = payload.function_name;
if FUNCTION_MAP.lock().unwrap().get(&function_name).is_none() {
return Err(Error::String(format!(
"Function {function_name} has not been registered yet"
)));
}
rustpython_vm::Interpreter::without_stdlib(Default::default()).enter(|vm| {
let posargs: Vec<_> = payload
.args
.into_iter()
.map(|value| py_serde::deserialize(vm, value).unwrap())
.collect();
let var_dot_split: Vec<&str> = function_name.split(".").collect();
let func = GLOBALS.globals.get_item(var_dot_split[0], vm)?;
Ok(if var_dot_split.len() > 2 {
func.get_attr(&vm.ctx.new_str(var_dot_split[1]), vm)?
.get_attr(&vm.ctx.new_str(var_dot_split[2]), vm)?
} else if var_dot_split.len() > 1 {
func.get_attr(&vm.ctx.new_str(var_dot_split[1]), vm)?
} else {
func
}
.call(posargs, vm)?
.str(vm)?
.to_string())
})
}
pub fn read_variable(payload: StringRequest) -> crate::Result<String> {
rustpython_vm::Interpreter::without_stdlib(Default::default()).enter(|vm| {
let var_dot_split: Vec<&str> = payload.value.split(".").collect();
let var = GLOBALS.globals.get_item(var_dot_split[0], vm)?;
Ok(if var_dot_split.len() > 2 {
var.get_attr(&vm.ctx.new_str(var_dot_split[1]), vm)?
.get_attr(&vm.ctx.new_str(var_dot_split[2]), vm)?
} else if var_dot_split.len() > 1 {
var.get_attr(&vm.ctx.new_str(var_dot_split[1]), vm)?
} else {
var
}
.str(vm)?
.to_string())
})
}