线性变换
This commit is contained in:
42
transform/trans.py
Normal file
42
transform/trans.py
Normal file
@ -0,0 +1,42 @@
|
||||
def linear_trans2d(arr:list[list[float]],x:list[float]) -> list[float]:
|
||||
'''
|
||||
@param arr:
|
||||
[
|
||||
[a,b],
|
||||
[c,d]
|
||||
]
|
||||
|
||||
@param x:
|
||||
[x1,x2]
|
||||
|
||||
@return
|
||||
arr times x
|
||||
'''
|
||||
return [x[0]*arr[0][0]+x[1]*arr[0][1], x[0]*arr[1][0]+x[1]*arr[1][1]]
|
||||
|
||||
def linear_trans(arr:list[list[float]],x:list[float]) -> list[float]:
|
||||
if len(arr) != len(x):
|
||||
raise ValueError("Shape must be same.")
|
||||
return [sum(map(lambda x: x[0]*x[1],zip(line,x))) for line in arr]
|
||||
|
||||
import numpy as np
|
||||
|
||||
def linear_trans_np(arr:np.ndarray,x:np.ndarray) -> np.ndarray:
|
||||
if not arr.shape[0] == arr.shape[1] == x.shape[0]:
|
||||
raise ValueError("Shape must be (k,k) and (k,)")
|
||||
return np.array([np.sum(line * x) for line in arr])
|
||||
|
||||
if __name__ == "__main__":
|
||||
# check code with np
|
||||
import numpy as np
|
||||
np.random.seed()
|
||||
|
||||
for i in range(100):
|
||||
arr = np.random.random_sample((4,4))
|
||||
x = np.random.random_sample((4,))
|
||||
|
||||
ans = np.dot(arr,x)
|
||||
out = linear_trans_np(arr,x)
|
||||
# print(out)
|
||||
# out = np.array(out)
|
||||
assert np.all(np.abs(ans-out)<0.0001)
|
Reference in New Issue
Block a user