提取图像中的高频信息
1. 傅里叶变换提取高频和低频【有损】
- 环境:集群210.30.98.11
- 效果:
2. 傅里叶变换提取振幅和相位【无损】
- 环境:集群210.30.98.11
- 效果:
3. 小波变换【不涉及恢复代码】
- 环境:集群210.30.98.11
- 效果:
代码1.
import torchvision.transforms as T
from PIL import Image
import torch
import matplotlib.pyplot as plt
import os
# 提取频率分量的函数(从前述代码中复用)
def extract_frequency_components(image, cutoff_ratio=0):
image = image.float()
C, H, W = image.shape
fft_image = torch.fft.fftshift(torch.fft.fft2(image, dim=(-2, -1)), dim=(-2, -1))
center_x, center_y = H // 2, W // 2
cutoff_x, cutoff_y = int(cutoff_ratio * H), int(cutoff_ratio * W)
mask = torch.zeros_like(fft_image)
mask[:, center_x - cutoff_x:center_x + cutoff_x, center_y - cutoff_y:center_y + cutoff_y] = 1
low_freq_fft = fft_image * mask
high_freq_fft = fft_image * (1 - mask)
low_freq = torch.abs(torch.fft.ifft2(torch.fft.ifftshift(low_freq_fft, dim=(-2, -1)), dim=(-2, -1)))
high_freq = torch.abs(torch.fft.ifft2(torch.fft.ifftshift(high_freq_fft, dim=(-2, -1)), dim=(-2, -1)))
return low_freq, high_freq
# 从频率分量恢复图像
def recover_image(low_freq, high_freq):
return low_freq + high_freq
# 保存张量图像为文件(原函数保留,可能后续还有单独保存单张图需求)
def save_image(tensor_image, output_path):
tensor_image = tensor_image.clamp(0, 1) # 确保像素值在 [0, 1] 范围内
to_pil = T.ToPILImage()
pil_image = to_pil(tensor_image)
pil_image.save(output_path)
# 主函数,修改这里实现四张图拼成一张图并保存
def process_and_save_images(image_path, output_dir, cutoff_ratio=0.1):
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 加载和预处理图像
image = Image.open(image_path).convert('RGB')
transform = T.Compose([
T.Resize((256, 256)),
T.ToTensor()
])
image_tensor = transform(image) # 形状 (C, H, W)
# 提取频率分量
low_freq, high_freq = extract_frequency_components(image_tensor, cutoff_ratio)
# 恢复图像
recovered_image = recover_image(low_freq, high_freq)
# 将张量转换为适合显示的格式(numpy数组)并调整维度顺序等
original_image_np = image_tensor.permute(1, 2, 0).numpy()
low_freq_np = low_freq.permute(1, 2, 0).numpy()
high_freq_np = high_freq.permute(1, 2, 0).numpy()
recovered_image_np = recovered_image.permute(1, 2, 0).numpy()
# 创建一个2x2的子图布局
fig, axes = plt.subplots(2, 2, figsize=(8, 8))
# 在子图中绘制原始图像
axes[0, 0].imshow(original_image_np)
axes[0, 0].set_title('Original Image')
axes[0, 0].axis('off')
# 在子图中绘制低频图像
axes[0, 1].imshow(low_freq_np)
axes[0, 1].set_title('Low Frequency Image')
axes[0, 1].axis('off')
# 在子图中绘制高频图像
axes[1, 0].imshow(high_freq_np)
axes[1, 0].set_title('High Frequency Image')
axes[1, 0].axis('off')
# 在子图中绘制恢复后的图像
axes[1, 1].imshow(recovered_image_np)
axes[1, 1].set_title('Recovered Image')
axes[1, 1].axis('off')
# 调整子图之间的间距等布局参数
plt.tight_layout()
# 保存拼接后的图像
combined_image_path = os.path.join(output_dir, "combined_images2.png")
plt.savefig(combined_image_path)
plt.close(fig)
print(f"Processed images saved to {output_dir}")
# 示例调用
image_path = "lena.png" # 替换为你的图像路径
output_dir = "output_images" # 替换为输出目录
process_and_save_images(image_path, output_dir, cutoff_ratio=0.1)
代码2
import torchvision.transforms as T
from PIL import Image
import torch
import matplotlib.pyplot as plt
import os
def extract_frequency_components(image):
fft_image = torch.fft.fftn(image, dim=(-2, -1))
fft_shifted = torch.fft.fftshift(fft_image)
magnitude = torch.abs(fft_shifted)
phase = torch.angle(fft_shifted)
return magnitude, phase
def recover_image(magnitude, phase):
real = magnitude * torch.cos(phase)
imag = magnitude * torch.sin(phase)
complex_freq = torch.complex(real, imag)
complex_freq_shifted = torch.fft.ifftshift(complex_freq)
recovered_image = torch.fft.ifftn(complex_freq_shifted, dim=(-2, -1))
return recovered_image.real
# 保存张量图像为文件
def save_image(tensor_image, output_path):
tensor_image = tensor_image.clamp(0, 1) # 确保像素值在 [0, 1] 范围内
to_pil = T.ToPILImage()
pil_image = to_pil(tensor_image)
pil_image.save(output_path)
# 主函数,修改此处实现四张图拼成一张图并保存
def process_and_save_images(image_path, output_dir, cutoff_ratio=0.1):
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 加载和预处理图像
image = Image.open(image_path).convert('RGB')
transform = T.Compose([
T.Resize((256, 256)),
T.ToTensor()
])
image_tensor = transform(image) # 形状 (C, H, W)
# 提取频率分量
low_freq, high_freq = extract_frequency_components(image_tensor)
# 恢复图像
recovered_image = recover_image(low_freq, high_freq)
# 将张量转换为适合显示的格式(numpy数组)并调整维度顺序等
original_image_np = image_tensor.permute(1, 2, 0).numpy()
low_freq_np = low_freq.permute(1, 2, 0).numpy()
high_freq_np = high_freq.permute(1, 2, 0).numpy()
recovered_image_np = recovered_image.permute(1, 2, 0).numpy()
# 创建一个2x2的子图布局
fig, axes = plt.subplots(2, 2, figsize=(8, 8))
# 在子图中绘制原始图像
axes[0, 0].imshow(original_image_np)
axes[0, 0].set_title('Original Image')
axes[0, 0].axis('off')
# 在子图中绘制低频图像
axes[0, 1].imshow(low_freq_np)
axes[0, 1].set_title('Low Frequency Image')
axes[0, 1].axis('off')
# 在子图中绘制高频图像
axes[1, 0].imshow(high_freq_np)
axes[1, 0].set_title('High Frequency Image')
axes[1, 0].axis('off')
# 在子图中绘制恢复后的图像
axes[1, 1].imshow(recovered_image_np)
axes[1, 1].set_title('Recovered Image')
axes[1, 1].axis('off')
# 调整子图之间的间距等布局参数
plt.tight_layout()
# 保存拼接后的图像
combined_image_path = os.path.join(output_dir, "combined_images3.png")
plt.savefig(combined_image_path)
plt.close(fig)
print(f"Processed images saved to {output_dir}")
# 示例调用
image_path = "lena.png" # 替换为你的图像路径
output_dir = "output_images" # 替换为输出目录
process_and_save_images(image_path, output_dir, cutoff_ratio=0.1)片
代码3
import torch
import torchvision.transforms as T
import pywt
import matplotlib.pyplot as plt
import os
from PIL import Image
# 确保输出文件夹存在,如果不存在则创建(用于保存最终合成的图片)
output_folder = "wavelet_images"
os.makedirs(output_folder, exist_ok=True)
# 加载彩色图像(替换为你实际的图像路径)
image = Image.open('lena.png').convert('RGB')
# 使用torchvision将图像转换为张量,并进行归一化(这里简单归一化到0-1范围,可根据实际调整)
transform = T.Compose([T.ToTensor()])
image_tensor = transform(image)
# 获取图像的通道数(RGB图像通道数为3)
num_channels = image_tensor.shape[0]
# 选择小波基(同样以 'db4' 为例)
wavelet = 'db4'
# 计算每行子图的数量(这里假设每行展示3个通道的同类型分量,共4种分量,所以每行12个子图)
subplots_per_row = 4
total_subplots = num_channels * 4
# 计算子图布局的行数和列数
num_rows = num_channels
num_cols = 4
print(f"num_rows: {num_rows}, num_cols: {num_cols}") # 打印行数和列数查看
# 创建一个合适的子图布局(根据通道数和分量数确定行数和列数)
fig, axes = plt.subplots(num_rows, num_cols, figsize=(16, num_rows * 4))
print(f"axes shape: {axes.shape}") # 打印axes的维度结构查看
# 索引用于遍历子图坐标轴对象数组
subplot_index = 0
for channel in range(num_channels):
# 获取当前通道的图像数据,将张量维度调整为 (height, width),符合pywt的输入要求
channel_image = image_tensor[channel, :, :].numpy()
coeffs = pywt.dwt2(channel_image, wavelet)
cA, (cH, cV, cD) = coeffs
# 对每个通道的各个分量进行格式调整以便正确显示,将数据转换为torch张量,并归一化到0-1范围
cA_tensor = torch.from_numpy(cA).float() / 255.0
cH_tensor = torch.from_numpy(cH).float() / 255.0
cV_tensor = torch.from_numpy(cV).float() / 255.0
cD_tensor = torch.from_numpy(cD).float() / 255.0
# 根据axes的实际维度结构调整索引方式
if num_rows == 1:
current_axis = axes[subplot_index % num_cols]
else:
current_axis = axes[subplot_index // num_cols, subplot_index % num_cols]
# 在对应的子图中绘制该通道的低频分量(LL)
current_axis.imshow(cA_tensor.numpy(), cmap='gray')
current_axis.set_title(
f"Channel {channel} - (LL)")#Low Frequency Component
current_axis.axis('off')
subplot_index += 1
# 在对应的子图中绘制该通道的水平高频分量(HL)
current_axis = None
if num_rows == 1:
current_axis = axes[subplot_index % num_cols]
else:
current_axis = axes[subplot_index // num_cols, subplot_index % num_cols]
current_axis.imshow(cH_tensor.numpy(), cmap='gray')
current_axis.set_title(
f"Channel {channel} - (HL)") #Horizontal High Frequency Component
current_axis.axis('off')
subplot_index += 1
# 在对应的子图中绘制该通道的垂直高频分量(LH)
current_axis = None
if num_rows == 1:
current_axis = axes[subplot_index % num_cols]
else:
current_axis = axes[subplot_index // num_cols, subplot_index % num_cols]
current_axis.imshow(cV_tensor.numpy(), cmap='gray')
current_axis.set_title(
f"Channel {channel} - (LH)") #Vertical High Frequency Component
current_axis.axis('off')
subplot_index += 1
# 在对应的子图中绘制该通道的对角高频分量(HH)
current_axis = None
if num_rows == 1:
current_axis = axes[subplot_index % num_cols]
else:
current_axis = axes[subplot_index // num_cols, subplot_index % num_cols]
current_axis.imshow(cD_tensor.numpy(), cmap='gray')
current_axis.set_title(
f"Channel {channel} - (HH)") #Diagonal High Frequency Component
current_axis.axis('off')
subplot_index += 1
# 调整子图之间的间距等布局参数
plt.tight_layout()
# 保存拼接后的图像到指定文件夹下,命名为wavelet_decomposition_color.png
output_path = os.path.join(output_folder, "wavelet_decomposition_color.png")
plt.savefig(output_path)
# 关闭图像显示(避免显示多余的图像窗口等情况)
plt.close(fig)
原文地址:https://blog.csdn.net/cp_oldy/article/details/143952594
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!