Files
linear_math_note/transform/trans.py
2025-07-10 15:58:14 +08:00

43 lines
1.1 KiB
Python

def linear_trans2d(arr:list[list[float]],x:list[float]) -> list[float]:
'''
@param arr:
[
[a,b],
[c,d]
]
@param x:
[x1,x2]
@return
arr times x
'''
return [x[0]*arr[0][0]+x[1]*arr[0][1], x[0]*arr[1][0]+x[1]*arr[1][1]]
def linear_trans(arr:list[list[float]],x:list[float]) -> list[float]:
if len(arr) != len(x):
raise ValueError("Shape must be same.")
return [sum(map(lambda x: x[0]*x[1],zip(line,x))) for line in arr]
import numpy as np
def linear_trans_np(arr:np.ndarray,x:np.ndarray) -> np.ndarray:
if not arr.shape[0] == arr.shape[1] == x.shape[0]:
raise ValueError("Shape must be (k,k) and (k,)")
return np.array([np.sum(line * x) for line in arr])
if __name__ == "__main__":
# check code with np
import numpy as np
np.random.seed()
for i in range(100):
arr = np.random.random_sample((4,4))
x = np.random.random_sample((4,))
ans = np.dot(arr,x)
out = linear_trans_np(arr,x)
# print(out)
# out = np.array(out)
assert np.all(np.abs(ans-out)<0.0001)