Commit 0ac67b89 by 杨艳磊

优化代码结构,增加解析gpu直通,audio解析等逻辑

1 parent 14bebfb0
Pipeline #24196 failed
in 0 seconds
Showing with 95 additions and 42 deletions
...@@ -5,15 +5,16 @@ import ( ...@@ -5,15 +5,16 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"libvirt.org/go/libvirt" "libvirt.org/go/libvirt"
libvirtxml "libvirt.org/go/libvirtxml" libvirtxml "libvirt.org/go/libvirtxml"
"strconv"
"strings" "strings"
) )
type QemuClient interface { type QemuClient interface {
CreateVM() error CreateVM(vmName string, memoryMB uint, vcpu uint, diskImagePath string, gpuPCIAddress string, audioPCIAddress string) error
StartVM() error StartVM() error
StopVM() error StopVM() error
DestroyVM() error DestroyVM() error
GetAllDomainStats() ([]libvirt.DomainStats,error) GetAllDomainStats() ([]libvirt.DomainStats, error)
ListAllDomains(vName string) ListAllDomains(vName string)
} }
...@@ -40,10 +41,10 @@ func NewClient() (QemuClient, error) { ...@@ -40,10 +41,10 @@ func NewClient() (QemuClient, error) {
} }
// 生成虚拟机 XML 配置(含 GPU 直通) // 生成虚拟机 XML 配置(含 GPU 直通)
func generateVMXML() string { func generateVMXML(vmName string, memoryMB uint, vcpu uint, diskImagePath string, gpuPCIAddress DomainAddressPCI, audioPCIAddress DomainAddressPCI) string {
domainCfg := &libvirtxml.Domain{ domainCfg := &libvirtxml.Domain{
Type: "kvm", Type: "kvm",
Name: vmName, Name: vmName,
Description: "My test VM", Description: "My test VM",
Memory: &libvirtxml.DomainMemory{ Memory: &libvirtxml.DomainMemory{
Unit: "MiB", Unit: "MiB",
...@@ -81,22 +82,29 @@ func generateVMXML() string { ...@@ -81,22 +82,29 @@ func generateVMXML() string {
}, },
Hostdevs: []libvirtxml.DomainHostdev{ Hostdevs: []libvirtxml.DomainHostdev{
{ {
Mode: "subsystem",
Type: "pci",
Managed: "yes", Managed: "yes",
Source: &libvirtxml.DomainHostdevSource{ SubsysPCI: &libvirtxml.DomainHostdevSubsysPCI{
Address: &libvirtxml.DomainAddress{ Source: &libvirtxml.DomainHostdevSubsysPCISource{
PCI: &libvirtxml.DomainAddressPCI{ Address: &libvirtxml.DomainAddressPCI{
Domain: parsePCIComponent(gpuPCIAddress, 0), Domain: gpuPCIAddress.domain,
Bus: parsePCIComponent(gpuPCIAddress, 1), Bus: gpuPCIAddress.bus,
Slot: parsePCIComponent(gpuPCIAddress, 2), Slot: gpuPCIAddress.slot,
Function: parsePCIComponent(gpuPCIAddress, 3), Function: gpuPCIAddress.function,
}, },
}, },
}, },
},
{
Managed: "yes",
SubsysPCI: &libvirtxml.DomainHostdevSubsysPCI{ SubsysPCI: &libvirtxml.DomainHostdevSubsysPCI{
Add Source: &libvirtxml.DomainHostdevSubsysPCISource{
Address: &libvirtxml.DomainAddressPCI{
Domain: audioPCIAddress.domain,
Bus: audioPCIAddress.bus,
Slot: audioPCIAddress.slot,
Function: audioPCIAddress.function,
},
},
}, },
}, },
}, },
...@@ -107,32 +115,79 @@ func generateVMXML() string { ...@@ -107,32 +115,79 @@ func generateVMXML() string {
return xml return xml
} }
// 解析PCI地址组件(例如 "0000:01:00.0" -> 各部分转换为十六进制) type DomainAddressPCI struct {
func parsePCIComponent(addr string, index int) string { domain, bus, slot, function *uint
}
// 解析完整PCI地址(支持0000:3b:00.0 和 3b:00.0两种格式)
func parsePCIAddress(addr string) (DomainAddressPCI, error) {
components := strings.Split(addr, ":") components := strings.Split(addr, ":")
if len(components) < 3 { var (
return "0x00" domainStr = "0000" // 默认domain
busStr string
slotFuncStr string
)
// 处理带domain的格式 (0000:3b:00.0)
if len(components) == 3 {
domainStr = components[0]
busStr = components[1]
slotFuncStr = components[2]
} else if len(components) == 2 { // 无domain格式 (3b:00.0)
busStr = components[0]
slotFuncStr = components[1]
} else {
return DomainAddressPCI{}, fmt.Errorf("invalid PCI address format: %s", addr)
}
// 解析slot和function
slotFuncParts := strings.Split(slotFuncStr, ".")
if len(slotFuncParts) != 2 {
return DomainAddressPCI{}, fmt.Errorf("invalid slot.function format: %s", slotFuncStr)
} }
slotFunc := strings.Split(components[2], ".") slotStr, functionStr := slotFuncParts[0], slotFuncParts[1]
if index == 3 && len(slotFunc) > 1 {
return "0x" + slotFunc[1] // 转换为uint (16进制)
parseHex := func(s string) (uint, error) {
val, err := strconv.ParseUint(s, 16, 16)
if err != nil {
return 0, fmt.Errorf("parse error: %s", s)
}
return uint(val), nil
} }
switch index { //var domain, bus, slot, function uint
case 0: domain, err := parseHex(domainStr)
return "0x" + components[0] if err != nil {
case 1: return DomainAddressPCI{}, err
return "0x" + components[1]
case 2:
return "0x" + slotFunc[0]
default:
return "0x00"
} }
bus, err := parseHex(busStr)
if err != nil {
return DomainAddressPCI{}, err
}
slot, err := parseHex(slotStr)
if err != nil {
return DomainAddressPCI{}, err
}
function, err := parseHex(functionStr)
if err != nil {
return DomainAddressPCI{}, err
}
return DomainAddressPCI{domain: &domain, bus: &bus, slot: &slot, function: &function}, nil
} }
// 创建虚拟机 // 创建虚拟机
func (c Client) CreateVM() error { func (c Client) CreateVM(vmName string, memoryMB uint, vcpu uint, diskImagePath string, gpuPCIAddress string, audioPCIAddress string) error {
xml := generateVMXML() pciAddress, err := parsePCIAddress(gpuPCIAddress)
dom, err := c.connect.DomainDefineXML(xml) if err != nil {
return err
}
audioAddress, err := parsePCIAddress(audioPCIAddress)
if err != nil {
return err
}
xml := generateVMXML(vmName, memoryMB, vcpu, diskImagePath, pciAddress, audioAddress)
dom, err := c.connect.DomainDefineXML(xml)
if err != nil { if err != nil {
return errors.Wrap(err, "Failed to define domain") return errors.Wrap(err, "Failed to define domain")
} }
...@@ -153,7 +208,7 @@ func (c Client) CreateVM() error { ...@@ -153,7 +208,7 @@ func (c Client) CreateVM() error {
// 启动虚拟机 // 启动虚拟机
func (c Client) StartVM() error { func (c Client) StartVM() error {
dom, err := c.connect.LookupDomainByName(vmName) dom, err := c.connect.LookupDomainByName(vmName)
if err != nil { if err != nil {
return errors.Wrap(err, "Failed to find domain") return errors.Wrap(err, "Failed to find domain")
} }
...@@ -168,7 +223,7 @@ func (c Client) StartVM() error { ...@@ -168,7 +223,7 @@ func (c Client) StartVM() error {
// 停止虚拟机 // 停止虚拟机
func (c Client) StopVM() error { func (c Client) StopVM() error {
dom, err := c.connect.LookupDomainByName(vmName) dom, err := c.connect.LookupDomainByName(vmName)
if err != nil { if err != nil {
return errors.Wrap(err, "Failed to find domain") return errors.Wrap(err, "Failed to find domain")
} }
...@@ -197,18 +252,16 @@ func (c Client) DestroyVM() error { ...@@ -197,18 +252,16 @@ func (c Client) DestroyVM() error {
return nil return nil
} }
// 通过虚拟机名称获取其全部状态 // 通过虚拟机名称获取其全部状态
func (c Client) GetAllDomainStats() ([]libvirt.DomainStats,error) { func (c Client) GetAllDomainStats() ([]libvirt.DomainStats, error) {
dom, err := c.connect.LookupDomainByName(vmName) dom, err := c.connect.LookupDomainByName(vmName)
if err != nil { if err != nil {
return nil,errors.Wrap(err, "Failed to find domain") return nil, errors.Wrap(err, "Failed to find domain")
} }
defer func(dom *libvirt.Domain) { defer func(dom *libvirt.Domain) {
err := dom.Free() err := dom.Free()
if err != nil { if err != nil {
fmt.Printf("err:%v \n",err) fmt.Printf("err:%v \n", err)
} }
}(dom) }(dom)
dobs := make([]*libvirt.Domain, 0) dobs := make([]*libvirt.Domain, 0)
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!